hyena predict longer sequences

This commit is contained in:
2023-08-05 15:05:06 +02:00
parent bdef020bff
commit a71030547c
4 changed files with 1981 additions and 73 deletions

View File

@@ -12,7 +12,7 @@
"text": [
"torch.Size([1, 1024, 128]) torch.Size([1, 1024, 128])\n",
"Causality check: gradients should not flow \"from future to past\"\n",
"tensor(2.1330e-09) tensor(0.2463)\n"
"tensor(-4.1268e-09) tensor(0.0844)\n"
]
}
],
@@ -291,7 +291,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 2,
"id": "032ef08a-8cc6-491a-9eb8-4a6b3f2d165e",
"metadata": {},
"outputs": [
@@ -343,7 +343,9 @@
" # Increase the channel dimension from 1 to d_model\n",
" u = self.fc_before(u) \n",
" # Pass through the operator\n",
" #[B,1024,128] --> [B,1024,128]\n",
" u = self.operator(u)\n",
" \n",
" last_state = u[:,-1,:]\n",
" # Decrease the channel dimension back to 1\n",
" y = self.fc_after(last_state)\n",
@@ -371,7 +373,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 3,
"id": "80cde67b-992f-4cb0-8824-4a6b7e4984ca",
"metadata": {},
"outputs": [
@@ -379,16 +381,16 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Train Epoch: 1 [0/640 (0%)]\tLoss: 0.736030\n",
"Train Epoch: 2 [0/640 (0%)]\tLoss: 0.013385\n",
"Train Epoch: 3 [0/640 (0%)]\tLoss: 0.019001\n",
"Train Epoch: 4 [0/640 (0%)]\tLoss: 0.010262\n",
"Train Epoch: 5 [0/640 (0%)]\tLoss: 0.005347\n",
"Train Epoch: 6 [0/640 (0%)]\tLoss: 0.006345\n",
"Train Epoch: 7 [0/640 (0%)]\tLoss: 0.004454\n",
"Train Epoch: 8 [0/640 (0%)]\tLoss: 0.003857\n",
"Train Epoch: 9 [0/640 (0%)]\tLoss: 0.003062\n",
"Train Epoch: 10 [0/640 (0%)]\tLoss: 0.002607\n"
"Train Epoch: 1 [0/640 (0%)]\tLoss: 0.446847\n",
"Train Epoch: 2 [0/640 (0%)]\tLoss: 0.077979\n",
"Train Epoch: 3 [0/640 (0%)]\tLoss: 0.021656\n",
"Train Epoch: 4 [0/640 (0%)]\tLoss: 0.007355\n",
"Train Epoch: 5 [0/640 (0%)]\tLoss: 0.004926\n",
"Train Epoch: 6 [0/640 (0%)]\tLoss: 0.006014\n",
"Train Epoch: 7 [0/640 (0%)]\tLoss: 0.003400\n",
"Train Epoch: 8 [0/640 (0%)]\tLoss: 0.003720\n",
"Train Epoch: 9 [0/640 (0%)]\tLoss: 0.004267\n",
"Train Epoch: 10 [0/640 (0%)]\tLoss: 0.004081\n"
]
}
],
@@ -495,49 +497,155 @@
},
{
"cell_type": "code",
"execution_count": 17,
"id": "1b763e03-baab-4b02-bae0-5747461bca7f",
"execution_count": 4,
"id": "90330622-8b44-4b45-8158-6840538f768c",
"metadata": {},
"outputs": [],
"source": [
"def correlate(model, device, data_loader):\n",
" model.eval()\n",
" correlations = {key: [] for key in [\"frequency\", \"phase\", \"amplitude\", \"noise_sd\"]}\n",
" with torch.no_grad():\n",
" for data, target, params in data_loader:\n",
"from sklearn.linear_model import LinearRegression\n",
"from sklearn.metrics import r2_score\n",
"\n",
" data, target = data.to(device), target.to(device)\n",
" output, last_state = model(data)\n",
" last_state_np = last_state.cpu().numpy()\n",
" params_np = params.cpu().numpy()\n",
" #import pdb;pdb.set_trace()\n",
"\n",
" # Compute correlations between last_state and parameters\n",
" for i, key in enumerate(correlations.keys()):\n",
" correlation = np.corrcoef(last_state_np.squeeze(), params_np[:,i])[0,1]\n",
" correlations[key].append(correlation)\n",
"def fit_and_evaluate_linear_regression(outputs_and_params):\n",
" # Split the data into inputs (last_states) and targets (params)\n",
" inputs = np.concatenate([x[0] for x in outputs_and_params])\n",
" targets = np.concatenate([x[1] for x in outputs_and_params])\n",
" \n",
" # Average correlations over all batches\n",
" avg_correlations = {key: np.mean(value) for key, value in correlations.items()}\n",
" return avg_correlations"
" r2_scores = []\n",
" param_names = [\"frequency\", \"phase\", \"amplitude\", \"noise_sd\"]\n",
" \n",
" # Fit the linear regression model for each parameter and calculate the R^2 score\n",
" for i in range(targets.shape[1]):\n",
" model = LinearRegression().fit(inputs, targets[:, i])\n",
" pred = model.predict(inputs)\n",
" score = r2_score(targets[:, i], pred)\n",
" r2_scores.append(score)\n",
" print(f\"R^2 score for {param_names[i]}: {score:.2f}\")\n",
" \n",
" return r2_scores"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "f4c78c51-a538-4d24-ab7b-fb6035c78df8",
"execution_count": 5,
"id": "5eb62a22-cad8-43c4-b757-f36b6a01e9be",
"metadata": {},
"outputs": [],
"source": [
"def generate_outputs(model, device, data_loader):\n",
" model.eval()\n",
" outputs_and_params = []\n",
" with torch.no_grad():\n",
" for data, target, params in data_loader:\n",
" data, target = data.to(device), target.to(device)\n",
" output, last_state = model(data)\n",
" outputs_and_params.append((last_state.cpu().numpy(), params.cpu().numpy()))\n",
" return outputs_and_params\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "a95ee542-1c39-4f04-9184-e26c6983a018",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"> \u001b[0;32m/tmp/ipykernel_9454/3342195123.py\u001b[0m(14)\u001b[0;36mcorrelate\u001b[0;34m()\u001b[0m\n",
"\u001b[0;32m 12 \u001b[0;31m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0m\u001b[0;32m 13 \u001b[0;31m \u001b[0;31m# Compute correlations between last_state and parameters\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0m\u001b[0;32m---> 14 \u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkey\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcorrelations\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0m\u001b[0;32m 15 \u001b[0;31m \u001b[0mcorrelation\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcorrcoef\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlast_state_np\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams_np\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0m\u001b[0;32m 16 \u001b[0;31m \u001b[0mcorrelations\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcorrelation\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"R^2 score for frequency: 0.77\n",
"R^2 score for phase: 0.66\n",
"R^2 score for amplitude: 0.99\n",
"R^2 score for noise_sd: 0.97\n"
]
}
],
"source": [
"outputs_and_params = generate_outputs(model, device, train_loader)\n",
"\n",
"# Fit the linear regression model and print the R^2 score for each parameter\n",
"r2_scores = fit_and_evaluate_linear_regression(outputs_and_params)"
]
},
{
"cell_type": "markdown",
"id": "de65d0d2-b0c6-4ac5-a87f-70fa7f90480b",
"metadata": {},
"source": [
"# Autoregressive"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "8ee139a6-aee5-4309-8685-bf0c28893279",
"metadata": {},
"outputs": [],
"source": [
"def predict_autoregressive(model, initial_data, n_steps):\n",
" model.eval()\n",
" predictions = []\n",
" current_input = initial_data\n",
" with torch.no_grad():\n",
" for _ in range(n_steps):\n",
" # Get the prediction for the next step and save it\n",
" import pdb;pdb.set_trace()\n",
" next_output, _ = model(current_input)\n",
" predictions.append(next_output)\n",
"\n",
" # Prepare the input for the next step\n",
" next_input = torch.cat((current_input[:, 1:, :], next_output[:, -1:, :]), dim=1)\n",
"\n",
" current_input = next_input\n",
"\n",
" return torch.cat(predictions, dim=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "272040d6-4dfb-438b-a432-744e950effd1",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 23,
"id": "9f49cbf8-08e8-428a-af80-bcb2515a6327",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([1, 1024, 1])"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"initial_data = dataset[0]\n",
"initial_data[0][None,...].shape"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "c75464e3-e44d-4325-8804-29d8081e3a45",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"> \u001b[0;32m/tmp/ipykernel_9626/3477884695.py\u001b[0m(9)\u001b[0;36mpredict_autoregressive\u001b[0;34m()\u001b[0m\n",
"\u001b[0;32m 7 \u001b[0;31m \u001b[0;31m# Get the prediction for the next step and save it\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0m\u001b[0;32m 8 \u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mpdb\u001b[0m\u001b[0;34m;\u001b[0m\u001b[0mpdb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0m\u001b[0;32m----> 9 \u001b[0;31m \u001b[0mnext_output\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcurrent_input\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0m\u001b[0;32m 10 \u001b[0;31m \u001b[0mpredictions\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnext_output\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0m\u001b[0;32m 11 \u001b[0;31m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0m\n"
]
},
@@ -545,53 +653,34 @@
"name": "stdin",
"output_type": "stream",
"text": [
"ipdb> last_state_np.shape\n"
"ipdb> current_input\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(64, 128)\n"
"tensor([[-1.2113],\n",
" [-1.2368],\n",
" [-1.1851],\n",
" ...,\n",
" [-1.1084],\n",
" [-0.9648],\n",
" [-1.0586]])\n"
]
},
{
"name": "stdin",
"output_type": "stream",
"text": [
"ipdb> last_state_np[0]\n"
"ipdb> current_input.shape\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"array([ 0.13440676, -0.03606963, -0.4140934 , 0.5431792 , 0.36556095,\n",
" 0.2256596 , -0.4616309 , -0.05567896, 0.17625177, -0.23529659,\n",
" -0.5208519 , 0.29691923, 0.15615058, 0.31342992, -0.5054718 ,\n",
" -0.33130994, -0.03956199, 0.31403548, -0.15925817, -0.22006416,\n",
" 0.00838468, -0.30691615, -0.1828884 , -0.52498204, 0.07198659,\n",
" 0.38572663, -0.27560705, 0.12110637, -0.17199083, 0.3913066 ,\n",
" -0.03978934, -0.21167544, -0.43025637, 0.20562531, 0.3000516 ,\n",
" -0.6784174 , -0.04233613, 0.4706083 , 0.20292807, 0.49932548,\n",
" 0.00203749, 0.2665777 , -0.16989222, 0.40648764, 0.22203793,\n",
" -0.44289762, 0.20751204, -0.38801843, -0.001487 , -0.49365598,\n",
" 0.05991718, -0.10120638, 0.36523518, -0.15450253, 0.11142011,\n",
" -0.20295474, 0.12229299, 0.09449576, -0.3422598 , 0.18969077,\n",
" 0.517254 , 0.08046471, 0.02134303, -0.35802346, -0.26192123,\n",
" 0.26145002, 0.11439252, 0.03314593, -0.15331428, 0.42282102,\n",
" 0.6026961 , -0.04233361, -0.5652172 , 0.33544067, 0.05744396,\n",
" 0.43544483, 0.2176097 , 0.22265801, -0.03894311, -0.01405966,\n",
" 0.23479447, 0.32931918, 0.21597862, 0.40402904, 0.20630498,\n",
" 0.09036086, -0.16922598, -0.1774486 , -0.14753146, -0.22214624,\n",
" -0.19101782, -0.09274255, 0.10928088, -0.01354241, -0.3864469 ,\n",
" 0.46331462, -0.38134843, -0.07766411, 0.750954 , -0.06306303,\n",
" -0.33691666, -0.1798551 , 0.19826202, -0.13544285, 0.01956506,\n",
" 0.6431204 , -0.11272874, 0.1345196 , 0.23029736, 0.28865197,\n",
" 0.70087713, 0.3592593 , 0.30329305, -0.26943353, -0.11942452,\n",
" -0.21187985, 0.19452253, 0.05659255, -0.00958484, -0.33417243,\n",
" -0.14836329, 0.28580692, 0.20885246, 0.18010336, 0.56253076,\n",
" -0.25303417, 0.0189368 , 0.2504725 ], dtype=float32)\n"
"torch.Size([1024, 1])\n"
]
},
{
@@ -603,13 +692,13 @@
}
],
"source": [
"correlate(model,\"cpu\",train_loader)"
"predict_autoregressive(model,initial_data[0][None,...],100)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e1558ffb-4699-4a8c-b05d-c3ac31a3829f",
"id": "90a37c56-59f3-49bc-bfbe-d9e126c42ed1",
"metadata": {},
"outputs": [],
"source": []

File diff suppressed because one or more lines are too long