hyena predict longer sequences

This commit is contained in:
2023-08-05 15:05:06 +02:00
parent bdef020bff
commit a71030547c
4 changed files with 1981 additions and 73 deletions

View File

@@ -12,7 +12,7 @@
"text": [
"torch.Size([1, 1024, 128]) torch.Size([1, 1024, 128])\n",
"Causality check: gradients should not flow \"from future to past\"\n",
"tensor(2.1330e-09) tensor(0.2463)\n"
"tensor(-4.1268e-09) tensor(0.0844)\n"
]
}
],
@@ -291,7 +291,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 2,
"id": "032ef08a-8cc6-491a-9eb8-4a6b3f2d165e",
"metadata": {},
"outputs": [
@@ -343,7 +343,9 @@
" # Increase the channel dimension from 1 to d_model\n",
" u = self.fc_before(u) \n",
" # Pass through the operator\n",
" #[B,1024,128] --> [B,1024,128]\n",
" u = self.operator(u)\n",
" \n",
" last_state = u[:,-1,:]\n",
" # Decrease the channel dimension back to 1\n",
" y = self.fc_after(last_state)\n",
@@ -371,7 +373,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 3,
"id": "80cde67b-992f-4cb0-8824-4a6b7e4984ca",
"metadata": {},
"outputs": [
@@ -379,16 +381,16 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Train Epoch: 1 [0/640 (0%)]\tLoss: 0.736030\n",
"Train Epoch: 2 [0/640 (0%)]\tLoss: 0.013385\n",
"Train Epoch: 3 [0/640 (0%)]\tLoss: 0.019001\n",
"Train Epoch: 4 [0/640 (0%)]\tLoss: 0.010262\n",
"Train Epoch: 5 [0/640 (0%)]\tLoss: 0.005347\n",
"Train Epoch: 6 [0/640 (0%)]\tLoss: 0.006345\n",
"Train Epoch: 7 [0/640 (0%)]\tLoss: 0.004454\n",
"Train Epoch: 8 [0/640 (0%)]\tLoss: 0.003857\n",
"Train Epoch: 9 [0/640 (0%)]\tLoss: 0.003062\n",
"Train Epoch: 10 [0/640 (0%)]\tLoss: 0.002607\n"
"Train Epoch: 1 [0/640 (0%)]\tLoss: 0.446847\n",
"Train Epoch: 2 [0/640 (0%)]\tLoss: 0.077979\n",
"Train Epoch: 3 [0/640 (0%)]\tLoss: 0.021656\n",
"Train Epoch: 4 [0/640 (0%)]\tLoss: 0.007355\n",
"Train Epoch: 5 [0/640 (0%)]\tLoss: 0.004926\n",
"Train Epoch: 6 [0/640 (0%)]\tLoss: 0.006014\n",
"Train Epoch: 7 [0/640 (0%)]\tLoss: 0.003400\n",
"Train Epoch: 8 [0/640 (0%)]\tLoss: 0.003720\n",
"Train Epoch: 9 [0/640 (0%)]\tLoss: 0.004267\n",
"Train Epoch: 10 [0/640 (0%)]\tLoss: 0.004081\n"
]
}
],
@@ -495,49 +497,155 @@
},
{
"cell_type": "code",
"execution_count": 17,
"id": "1b763e03-baab-4b02-bae0-5747461bca7f",
"execution_count": 4,
"id": "90330622-8b44-4b45-8158-6840538f768c",
"metadata": {},
"outputs": [],
"source": [
"def correlate(model, device, data_loader):\n",
" model.eval()\n",
" correlations = {key: [] for key in [\"frequency\", \"phase\", \"amplitude\", \"noise_sd\"]}\n",
" with torch.no_grad():\n",
" for data, target, params in data_loader:\n",
"from sklearn.linear_model import LinearRegression\n",
"from sklearn.metrics import r2_score\n",
"\n",
" data, target = data.to(device), target.to(device)\n",
" output, last_state = model(data)\n",
" last_state_np = last_state.cpu().numpy()\n",
" params_np = params.cpu().numpy()\n",
" #import pdb;pdb.set_trace()\n",
"\n",
" # Compute correlations between last_state and parameters\n",
" for i, key in enumerate(correlations.keys()):\n",
" correlation = np.corrcoef(last_state_np.squeeze(), params_np[:,i])[0,1]\n",
" correlations[key].append(correlation)\n",
"def fit_and_evaluate_linear_regression(outputs_and_params):\n",
" # Split the data into inputs (last_states) and targets (params)\n",
" inputs = np.concatenate([x[0] for x in outputs_and_params])\n",
" targets = np.concatenate([x[1] for x in outputs_and_params])\n",
" \n",
" # Average correlations over all batches\n",
" avg_correlations = {key: np.mean(value) for key, value in correlations.items()}\n",
" return avg_correlations"
" r2_scores = []\n",
" param_names = [\"frequency\", \"phase\", \"amplitude\", \"noise_sd\"]\n",
" \n",
" # Fit the linear regression model for each parameter and calculate the R^2 score\n",
" for i in range(targets.shape[1]):\n",
" model = LinearRegression().fit(inputs, targets[:, i])\n",
" pred = model.predict(inputs)\n",
" score = r2_score(targets[:, i], pred)\n",
" r2_scores.append(score)\n",
" print(f\"R^2 score for {param_names[i]}: {score:.2f}\")\n",
" \n",
" return r2_scores"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "f4c78c51-a538-4d24-ab7b-fb6035c78df8",
"execution_count": 5,
"id": "5eb62a22-cad8-43c4-b757-f36b6a01e9be",
"metadata": {},
"outputs": [],
"source": [
"def generate_outputs(model, device, data_loader):\n",
" model.eval()\n",
" outputs_and_params = []\n",
" with torch.no_grad():\n",
" for data, target, params in data_loader:\n",
" data, target = data.to(device), target.to(device)\n",
" output, last_state = model(data)\n",
" outputs_and_params.append((last_state.cpu().numpy(), params.cpu().numpy()))\n",
" return outputs_and_params\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "a95ee542-1c39-4f04-9184-e26c6983a018",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"> \u001b[0;32m/tmp/ipykernel_9454/3342195123.py\u001b[0m(14)\u001b[0;36mcorrelate\u001b[0;34m()\u001b[0m\n",
"\u001b[0;32m 12 \u001b[0;31m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0m\u001b[0;32m 13 \u001b[0;31m \u001b[0;31m# Compute correlations between last_state and parameters\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0m\u001b[0;32m---> 14 \u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkey\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcorrelations\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0m\u001b[0;32m 15 \u001b[0;31m \u001b[0mcorrelation\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcorrcoef\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlast_state_np\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams_np\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0m\u001b[0;32m 16 \u001b[0;31m \u001b[0mcorrelations\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcorrelation\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"R^2 score for frequency: 0.77\n",
"R^2 score for phase: 0.66\n",
"R^2 score for amplitude: 0.99\n",
"R^2 score for noise_sd: 0.97\n"
]
}
],
"source": [
"outputs_and_params = generate_outputs(model, device, train_loader)\n",
"\n",
"# Fit the linear regression model and print the R^2 score for each parameter\n",
"r2_scores = fit_and_evaluate_linear_regression(outputs_and_params)"
]
},
{
"cell_type": "markdown",
"id": "de65d0d2-b0c6-4ac5-a87f-70fa7f90480b",
"metadata": {},
"source": [
"# Autoregressive"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "8ee139a6-aee5-4309-8685-bf0c28893279",
"metadata": {},
"outputs": [],
"source": [
"def predict_autoregressive(model, initial_data, n_steps):\n",
" model.eval()\n",
" predictions = []\n",
" current_input = initial_data\n",
" with torch.no_grad():\n",
" for _ in range(n_steps):\n",
" # Get the prediction for the next step and save it\n",
" import pdb;pdb.set_trace()\n",
" next_output, _ = model(current_input)\n",
" predictions.append(next_output)\n",
"\n",
" # Prepare the input for the next step\n",
" next_input = torch.cat((current_input[:, 1:, :], next_output[:, -1:, :]), dim=1)\n",
"\n",
" current_input = next_input\n",
"\n",
" return torch.cat(predictions, dim=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "272040d6-4dfb-438b-a432-744e950effd1",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 23,
"id": "9f49cbf8-08e8-428a-af80-bcb2515a6327",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([1, 1024, 1])"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"initial_data = dataset[0]\n",
"initial_data[0][None,...].shape"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "c75464e3-e44d-4325-8804-29d8081e3a45",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"> \u001b[0;32m/tmp/ipykernel_9626/3477884695.py\u001b[0m(9)\u001b[0;36mpredict_autoregressive\u001b[0;34m()\u001b[0m\n",
"\u001b[0;32m 7 \u001b[0;31m \u001b[0;31m# Get the prediction for the next step and save it\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0m\u001b[0;32m 8 \u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mpdb\u001b[0m\u001b[0;34m;\u001b[0m\u001b[0mpdb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0m\u001b[0;32m----> 9 \u001b[0;31m \u001b[0mnext_output\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcurrent_input\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0m\u001b[0;32m 10 \u001b[0;31m \u001b[0mpredictions\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnext_output\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0m\u001b[0;32m 11 \u001b[0;31m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0m\n"
]
},
@@ -545,53 +653,34 @@
"name": "stdin",
"output_type": "stream",
"text": [
"ipdb> last_state_np.shape\n"
"ipdb> current_input\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(64, 128)\n"
"tensor([[-1.2113],\n",
" [-1.2368],\n",
" [-1.1851],\n",
" ...,\n",
" [-1.1084],\n",
" [-0.9648],\n",
" [-1.0586]])\n"
]
},
{
"name": "stdin",
"output_type": "stream",
"text": [
"ipdb> last_state_np[0]\n"
"ipdb> current_input.shape\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"array([ 0.13440676, -0.03606963, -0.4140934 , 0.5431792 , 0.36556095,\n",
" 0.2256596 , -0.4616309 , -0.05567896, 0.17625177, -0.23529659,\n",
" -0.5208519 , 0.29691923, 0.15615058, 0.31342992, -0.5054718 ,\n",
" -0.33130994, -0.03956199, 0.31403548, -0.15925817, -0.22006416,\n",
" 0.00838468, -0.30691615, -0.1828884 , -0.52498204, 0.07198659,\n",
" 0.38572663, -0.27560705, 0.12110637, -0.17199083, 0.3913066 ,\n",
" -0.03978934, -0.21167544, -0.43025637, 0.20562531, 0.3000516 ,\n",
" -0.6784174 , -0.04233613, 0.4706083 , 0.20292807, 0.49932548,\n",
" 0.00203749, 0.2665777 , -0.16989222, 0.40648764, 0.22203793,\n",
" -0.44289762, 0.20751204, -0.38801843, -0.001487 , -0.49365598,\n",
" 0.05991718, -0.10120638, 0.36523518, -0.15450253, 0.11142011,\n",
" -0.20295474, 0.12229299, 0.09449576, -0.3422598 , 0.18969077,\n",
" 0.517254 , 0.08046471, 0.02134303, -0.35802346, -0.26192123,\n",
" 0.26145002, 0.11439252, 0.03314593, -0.15331428, 0.42282102,\n",
" 0.6026961 , -0.04233361, -0.5652172 , 0.33544067, 0.05744396,\n",
" 0.43544483, 0.2176097 , 0.22265801, -0.03894311, -0.01405966,\n",
" 0.23479447, 0.32931918, 0.21597862, 0.40402904, 0.20630498,\n",
" 0.09036086, -0.16922598, -0.1774486 , -0.14753146, -0.22214624,\n",
" -0.19101782, -0.09274255, 0.10928088, -0.01354241, -0.3864469 ,\n",
" 0.46331462, -0.38134843, -0.07766411, 0.750954 , -0.06306303,\n",
" -0.33691666, -0.1798551 , 0.19826202, -0.13544285, 0.01956506,\n",
" 0.6431204 , -0.11272874, 0.1345196 , 0.23029736, 0.28865197,\n",
" 0.70087713, 0.3592593 , 0.30329305, -0.26943353, -0.11942452,\n",
" -0.21187985, 0.19452253, 0.05659255, -0.00958484, -0.33417243,\n",
" -0.14836329, 0.28580692, 0.20885246, 0.18010336, 0.56253076,\n",
" -0.25303417, 0.0189368 , 0.2504725 ], dtype=float32)\n"
"torch.Size([1024, 1])\n"
]
},
{
@@ -603,13 +692,13 @@
}
],
"source": [
"correlate(model,\"cpu\",train_loader)"
"predict_autoregressive(model,initial_data[0][None,...],100)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e1558ffb-4699-4a8c-b05d-c3ac31a3829f",
"id": "90a37c56-59f3-49bc-bfbe-d9e126c42ed1",
"metadata": {},
"outputs": [],
"source": []

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long