From bdef020bff32eb3f7040d43ce5f839a0fb852ac4 Mon Sep 17 00:00:00 2001 From: martin Date: Sat, 5 Aug 2023 12:14:16 +0200 Subject: [PATCH] R^2 for hyena --- .../simple_hyena_model-checkpoint.ipynb | 168 ++++++++++++++---- hyena_test/simple_hyena_model.ipynb | 124 +++++++++---- 2 files changed, 227 insertions(+), 65 deletions(-) diff --git a/hyena_test/.ipynb_checkpoints/simple_hyena_model-checkpoint.ipynb b/hyena_test/.ipynb_checkpoints/simple_hyena_model-checkpoint.ipynb index 0b8e949..2ef8861 100644 --- a/hyena_test/.ipynb_checkpoints/simple_hyena_model-checkpoint.ipynb +++ b/hyena_test/.ipynb_checkpoints/simple_hyena_model-checkpoint.ipynb @@ -12,7 +12,7 @@ "text": [ "torch.Size([1, 1024, 128]) torch.Size([1, 1024, 128])\n", "Causality check: gradients should not flow \"from future to past\"\n", - "tensor(3.2471e-09) tensor(0.4080)\n" + "tensor(2.1330e-09) tensor(0.2463)\n" ] } ], @@ -291,7 +291,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 4, "id": "032ef08a-8cc6-491a-9eb8-4a6b3f2d165e", "metadata": {}, "outputs": [ @@ -347,7 +347,7 @@ " last_state = u[:,-1,:]\n", " # Decrease the channel dimension back to 1\n", " y = self.fc_after(last_state)\n", - " return y\n", + " return y,last_state\n", "\n", "\n", "if __name__ == \"__main__\":\n", @@ -359,7 +359,7 @@ " )\n", "\n", " x = torch.randn(1, 1023, 1, requires_grad=True) # 1D time series input\n", - " y = layer(x)\n", + " y, last_state = layer(x)\n", "\n", " #import pdb;pdb.set_trace()\n", " print(x.shape, y.shape) # should now be [1, 1024, 1]\n", @@ -371,7 +371,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 10, "id": "80cde67b-992f-4cb0-8824-4a6b7e4984ca", "metadata": {}, "outputs": [ @@ -379,30 +379,16 @@ "name": "stdout", "output_type": "stream", "text": [ - "Train Epoch: 1 [0/640 (0%)]\tLoss: 0.433575\n", - "Train Epoch: 2 [0/640 (0%)]\tLoss: 0.054185\n", - "Train Epoch: 3 [0/640 (0%)]\tLoss: 0.007312\n", - "Train Epoch: 4 [0/640 (0%)]\tLoss: 0.004312\n", - "Train Epoch: 5 [0/640 (0%)]\tLoss: 0.003393\n" - ] - }, - { - "ename": "KeyboardInterrupt", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[32], line 87\u001b[0m\n\u001b[1;32m 84\u001b[0m train_loader \u001b[38;5;241m=\u001b[39m DataLoader(dataset, batch_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m64\u001b[39m, shuffle\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 86\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m11\u001b[39m): \u001b[38;5;66;03m# Train for 10 epochs\u001b[39;00m\n\u001b[0;32m---> 87\u001b[0m \u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_loader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mepoch\u001b[49m\u001b[43m)\u001b[49m\n", - "Cell \u001b[0;32mIn[32], line 59\u001b[0m, in \u001b[0;36mtrain\u001b[0;34m(model, device, train_loader, optimizer, epoch)\u001b[0m\n\u001b[1;32m 57\u001b[0m data, target \u001b[38;5;241m=\u001b[39m data\u001b[38;5;241m.\u001b[39mto(device), target\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[1;32m 58\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[0;32m---> 59\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 60\u001b[0m \u001b[38;5;66;03m#import pdb;pdb.set_trace()\u001b[39;00m\n\u001b[1;32m 62\u001b[0m loss \u001b[38;5;241m=\u001b[39m F\u001b[38;5;241m.\u001b[39mmse_loss(output, target)\n", - "File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", - "Cell \u001b[0;32mIn[2], line 40\u001b[0m, in \u001b[0;36mHyenaOperatorAutoregressive1D.forward\u001b[0;34m(self, u, *args, **kwargs)\u001b[0m\n\u001b[1;32m 38\u001b[0m u \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfc_before(u) \n\u001b[1;32m 39\u001b[0m \u001b[38;5;66;03m# Pass through the operator\u001b[39;00m\n\u001b[0;32m---> 40\u001b[0m u \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moperator\u001b[49m\u001b[43m(\u001b[49m\u001b[43mu\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 41\u001b[0m last_state \u001b[38;5;241m=\u001b[39m u[:,\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m,:]\n\u001b[1;32m 42\u001b[0m \u001b[38;5;66;03m# Decrease the channel dimension back to 1\u001b[39;00m\n", - "File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", - "Cell \u001b[0;32mIn[1], line 237\u001b[0m, in \u001b[0;36mHyenaOperator.forward\u001b[0;34m(self, u, *args, **kwargs)\u001b[0m\n\u001b[1;32m 234\u001b[0m u \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39min_proj(u)\n\u001b[1;32m 235\u001b[0m u \u001b[38;5;241m=\u001b[39m rearrange(u, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mb l d -> b d l\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m--> 237\u001b[0m uc \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mshort_filter\u001b[49m\u001b[43m(\u001b[49m\u001b[43mu\u001b[49m\u001b[43m)\u001b[49m[\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m,:l_filter] \n\u001b[1;32m 238\u001b[0m \u001b[38;5;241m*\u001b[39mx, v \u001b[38;5;241m=\u001b[39m uc\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39md_model, dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 240\u001b[0m k \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfilter_fn\u001b[38;5;241m.\u001b[39mfilter(l_filter)[\u001b[38;5;241m0\u001b[39m]\n", - "File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", - "File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/conv.py:313\u001b[0m, in \u001b[0;36mConv1d.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 312\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 313\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_conv_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/conv.py:309\u001b[0m, in \u001b[0;36mConv1d._conv_forward\u001b[0;34m(self, input, weight, bias)\u001b[0m\n\u001b[1;32m 305\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpadding_mode \u001b[38;5;241m!=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mzeros\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[1;32m 306\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m F\u001b[38;5;241m.\u001b[39mconv1d(F\u001b[38;5;241m.\u001b[39mpad(\u001b[38;5;28minput\u001b[39m, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reversed_padding_repeated_twice, mode\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpadding_mode),\n\u001b[1;32m 307\u001b[0m weight, bias, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstride,\n\u001b[1;32m 308\u001b[0m _single(\u001b[38;5;241m0\u001b[39m), \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdilation, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgroups)\n\u001b[0;32m--> 309\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconv1d\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbias\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstride\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 310\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpadding\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdilation\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgroups\u001b[49m\u001b[43m)\u001b[49m\n", - "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + "Train Epoch: 1 [0/640 (0%)]\tLoss: 0.736030\n", + "Train Epoch: 2 [0/640 (0%)]\tLoss: 0.013385\n", + "Train Epoch: 3 [0/640 (0%)]\tLoss: 0.019001\n", + "Train Epoch: 4 [0/640 (0%)]\tLoss: 0.010262\n", + "Train Epoch: 5 [0/640 (0%)]\tLoss: 0.005347\n", + "Train Epoch: 6 [0/640 (0%)]\tLoss: 0.006345\n", + "Train Epoch: 7 [0/640 (0%)]\tLoss: 0.004454\n", + "Train Epoch: 8 [0/640 (0%)]\tLoss: 0.003857\n", + "Train Epoch: 9 [0/640 (0%)]\tLoss: 0.003062\n", + "Train Epoch: 10 [0/640 (0%)]\tLoss: 0.002607\n" ] } ], @@ -454,18 +440,21 @@ " # Generate sine wave with the random attributes\n", " sine_wave = generate_sine_with_noise(self.n_points, frequency, phase, amplitude, noise_sd)\n", "\n", - " return torch.Tensor(sine_wave[:-1, 1, None]), torch.Tensor(sine_wave[-1:, 0])\n", + " # Return the sine wave and the parameters\n", + " return torch.Tensor(sine_wave[:-1, 1, None]), torch.Tensor(sine_wave[-1:, 0]), torch.Tensor([frequency, phase, amplitude, noise_sd])\n", + "\n", + "\n", "\n", "# Usage:\n", "dataset = SineDataset(640, 1025, (1, 3), (0, 2*np.pi), (0.5, 1.5), (0.05, 0.15))\n", "\n", "def train(model, device, train_loader, optimizer, epoch):\n", " model.train()\n", - " for batch_idx, (data, target) in enumerate(train_loader):\n", + " for batch_idx, (data, target, params) in enumerate(train_loader):\n", " #data = data[...,None]\n", " data, target = data.to(device), target.to(device)\n", " optimizer.zero_grad()\n", - " output = model(data)\n", + " output,last_state = model(data)\n", " #import pdb;pdb.set_trace()\n", "\n", " loss = F.mse_loss(output, target)\n", @@ -506,10 +495,123 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 17, "id": "1b763e03-baab-4b02-bae0-5747461bca7f", "metadata": {}, "outputs": [], + "source": [ + "def correlate(model, device, data_loader):\n", + " model.eval()\n", + " correlations = {key: [] for key in [\"frequency\", \"phase\", \"amplitude\", \"noise_sd\"]}\n", + " with torch.no_grad():\n", + " for data, target, params in data_loader:\n", + "\n", + " data, target = data.to(device), target.to(device)\n", + " output, last_state = model(data)\n", + " last_state_np = last_state.cpu().numpy()\n", + " params_np = params.cpu().numpy()\n", + " #import pdb;pdb.set_trace()\n", + "\n", + " # Compute correlations between last_state and parameters\n", + " for i, key in enumerate(correlations.keys()):\n", + " correlation = np.corrcoef(last_state_np.squeeze(), params_np[:,i])[0,1]\n", + " correlations[key].append(correlation)\n", + " \n", + " # Average correlations over all batches\n", + " avg_correlations = {key: np.mean(value) for key, value in correlations.items()}\n", + " return avg_correlations" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "f4c78c51-a538-4d24-ab7b-fb6035c78df8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "> \u001b[0;32m/tmp/ipykernel_9454/3342195123.py\u001b[0m(14)\u001b[0;36mcorrelate\u001b[0;34m()\u001b[0m\n", + "\u001b[0;32m 12 \u001b[0;31m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0m\u001b[0;32m 13 \u001b[0;31m \u001b[0;31m# Compute correlations between last_state and parameters\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0m\u001b[0;32m---> 14 \u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkey\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcorrelations\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0m\u001b[0;32m 15 \u001b[0;31m \u001b[0mcorrelation\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcorrcoef\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlast_state_np\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams_np\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0m\u001b[0;32m 16 \u001b[0;31m \u001b[0mcorrelations\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcorrelation\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0m\n" + ] + }, + { + "name": "stdin", + "output_type": "stream", + "text": [ + "ipdb> last_state_np.shape\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(64, 128)\n" + ] + }, + { + "name": "stdin", + "output_type": "stream", + "text": [ + "ipdb> last_state_np[0]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "array([ 0.13440676, -0.03606963, -0.4140934 , 0.5431792 , 0.36556095,\n", + " 0.2256596 , -0.4616309 , -0.05567896, 0.17625177, -0.23529659,\n", + " -0.5208519 , 0.29691923, 0.15615058, 0.31342992, -0.5054718 ,\n", + " -0.33130994, -0.03956199, 0.31403548, -0.15925817, -0.22006416,\n", + " 0.00838468, -0.30691615, -0.1828884 , -0.52498204, 0.07198659,\n", + " 0.38572663, -0.27560705, 0.12110637, -0.17199083, 0.3913066 ,\n", + " -0.03978934, -0.21167544, -0.43025637, 0.20562531, 0.3000516 ,\n", + " -0.6784174 , -0.04233613, 0.4706083 , 0.20292807, 0.49932548,\n", + " 0.00203749, 0.2665777 , -0.16989222, 0.40648764, 0.22203793,\n", + " -0.44289762, 0.20751204, -0.38801843, -0.001487 , -0.49365598,\n", + " 0.05991718, -0.10120638, 0.36523518, -0.15450253, 0.11142011,\n", + " -0.20295474, 0.12229299, 0.09449576, -0.3422598 , 0.18969077,\n", + " 0.517254 , 0.08046471, 0.02134303, -0.35802346, -0.26192123,\n", + " 0.26145002, 0.11439252, 0.03314593, -0.15331428, 0.42282102,\n", + " 0.6026961 , -0.04233361, -0.5652172 , 0.33544067, 0.05744396,\n", + " 0.43544483, 0.2176097 , 0.22265801, -0.03894311, -0.01405966,\n", + " 0.23479447, 0.32931918, 0.21597862, 0.40402904, 0.20630498,\n", + " 0.09036086, -0.16922598, -0.1774486 , -0.14753146, -0.22214624,\n", + " -0.19101782, -0.09274255, 0.10928088, -0.01354241, -0.3864469 ,\n", + " 0.46331462, -0.38134843, -0.07766411, 0.750954 , -0.06306303,\n", + " -0.33691666, -0.1798551 , 0.19826202, -0.13544285, 0.01956506,\n", + " 0.6431204 , -0.11272874, 0.1345196 , 0.23029736, 0.28865197,\n", + " 0.70087713, 0.3592593 , 0.30329305, -0.26943353, -0.11942452,\n", + " -0.21187985, 0.19452253, 0.05659255, -0.00958484, -0.33417243,\n", + " -0.14836329, 0.28580692, 0.20885246, 0.18010336, 0.56253076,\n", + " -0.25303417, 0.0189368 , 0.2504725 ], dtype=float32)\n" + ] + }, + { + "name": "stdin", + "output_type": "stream", + "text": [ + "ipdb> q\n" + ] + } + ], + "source": [ + "correlate(model,\"cpu\",train_loader)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e1558ffb-4699-4a8c-b05d-c3ac31a3829f", + "metadata": {}, + "outputs": [], "source": [] } ], diff --git a/hyena_test/simple_hyena_model.ipynb b/hyena_test/simple_hyena_model.ipynb index 0b8e949..2088487 100644 --- a/hyena_test/simple_hyena_model.ipynb +++ b/hyena_test/simple_hyena_model.ipynb @@ -12,7 +12,7 @@ "text": [ "torch.Size([1, 1024, 128]) torch.Size([1, 1024, 128])\n", "Causality check: gradients should not flow \"from future to past\"\n", - "tensor(3.2471e-09) tensor(0.4080)\n" + "tensor(-4.1268e-09) tensor(0.0844)\n" ] } ], @@ -347,7 +347,7 @@ " last_state = u[:,-1,:]\n", " # Decrease the channel dimension back to 1\n", " y = self.fc_after(last_state)\n", - " return y\n", + " return y,last_state\n", "\n", "\n", "if __name__ == \"__main__\":\n", @@ -359,7 +359,7 @@ " )\n", "\n", " x = torch.randn(1, 1023, 1, requires_grad=True) # 1D time series input\n", - " y = layer(x)\n", + " y, last_state = layer(x)\n", "\n", " #import pdb;pdb.set_trace()\n", " print(x.shape, y.shape) # should now be [1, 1024, 1]\n", @@ -371,7 +371,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 3, "id": "80cde67b-992f-4cb0-8824-4a6b7e4984ca", "metadata": {}, "outputs": [ @@ -379,30 +379,16 @@ "name": "stdout", "output_type": "stream", "text": [ - "Train Epoch: 1 [0/640 (0%)]\tLoss: 0.433575\n", - "Train Epoch: 2 [0/640 (0%)]\tLoss: 0.054185\n", - "Train Epoch: 3 [0/640 (0%)]\tLoss: 0.007312\n", - "Train Epoch: 4 [0/640 (0%)]\tLoss: 0.004312\n", - "Train Epoch: 5 [0/640 (0%)]\tLoss: 0.003393\n" - ] - }, - { - "ename": "KeyboardInterrupt", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[32], line 87\u001b[0m\n\u001b[1;32m 84\u001b[0m train_loader \u001b[38;5;241m=\u001b[39m DataLoader(dataset, batch_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m64\u001b[39m, shuffle\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 86\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m11\u001b[39m): \u001b[38;5;66;03m# Train for 10 epochs\u001b[39;00m\n\u001b[0;32m---> 87\u001b[0m \u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_loader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mepoch\u001b[49m\u001b[43m)\u001b[49m\n", - "Cell \u001b[0;32mIn[32], line 59\u001b[0m, in \u001b[0;36mtrain\u001b[0;34m(model, device, train_loader, optimizer, epoch)\u001b[0m\n\u001b[1;32m 57\u001b[0m data, target \u001b[38;5;241m=\u001b[39m data\u001b[38;5;241m.\u001b[39mto(device), target\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[1;32m 58\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[0;32m---> 59\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 60\u001b[0m \u001b[38;5;66;03m#import pdb;pdb.set_trace()\u001b[39;00m\n\u001b[1;32m 62\u001b[0m loss \u001b[38;5;241m=\u001b[39m F\u001b[38;5;241m.\u001b[39mmse_loss(output, target)\n", - "File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", - "Cell \u001b[0;32mIn[2], line 40\u001b[0m, in \u001b[0;36mHyenaOperatorAutoregressive1D.forward\u001b[0;34m(self, u, *args, **kwargs)\u001b[0m\n\u001b[1;32m 38\u001b[0m u \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfc_before(u) \n\u001b[1;32m 39\u001b[0m \u001b[38;5;66;03m# Pass through the operator\u001b[39;00m\n\u001b[0;32m---> 40\u001b[0m u \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moperator\u001b[49m\u001b[43m(\u001b[49m\u001b[43mu\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 41\u001b[0m last_state \u001b[38;5;241m=\u001b[39m u[:,\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m,:]\n\u001b[1;32m 42\u001b[0m \u001b[38;5;66;03m# Decrease the channel dimension back to 1\u001b[39;00m\n", - "File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", - "Cell \u001b[0;32mIn[1], line 237\u001b[0m, in \u001b[0;36mHyenaOperator.forward\u001b[0;34m(self, u, *args, **kwargs)\u001b[0m\n\u001b[1;32m 234\u001b[0m u \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39min_proj(u)\n\u001b[1;32m 235\u001b[0m u \u001b[38;5;241m=\u001b[39m rearrange(u, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mb l d -> b d l\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m--> 237\u001b[0m uc \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mshort_filter\u001b[49m\u001b[43m(\u001b[49m\u001b[43mu\u001b[49m\u001b[43m)\u001b[49m[\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m,:l_filter] \n\u001b[1;32m 238\u001b[0m \u001b[38;5;241m*\u001b[39mx, v \u001b[38;5;241m=\u001b[39m uc\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39md_model, dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 240\u001b[0m k \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfilter_fn\u001b[38;5;241m.\u001b[39mfilter(l_filter)[\u001b[38;5;241m0\u001b[39m]\n", - "File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", - "File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/conv.py:313\u001b[0m, in \u001b[0;36mConv1d.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 312\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 313\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_conv_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/conv.py:309\u001b[0m, in \u001b[0;36mConv1d._conv_forward\u001b[0;34m(self, input, weight, bias)\u001b[0m\n\u001b[1;32m 305\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpadding_mode \u001b[38;5;241m!=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mzeros\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[1;32m 306\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m F\u001b[38;5;241m.\u001b[39mconv1d(F\u001b[38;5;241m.\u001b[39mpad(\u001b[38;5;28minput\u001b[39m, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reversed_padding_repeated_twice, mode\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpadding_mode),\n\u001b[1;32m 307\u001b[0m weight, bias, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstride,\n\u001b[1;32m 308\u001b[0m _single(\u001b[38;5;241m0\u001b[39m), \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdilation, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgroups)\n\u001b[0;32m--> 309\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconv1d\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbias\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstride\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 310\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpadding\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdilation\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgroups\u001b[49m\u001b[43m)\u001b[49m\n", - "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + "Train Epoch: 1 [0/640 (0%)]\tLoss: 0.446847\n", + "Train Epoch: 2 [0/640 (0%)]\tLoss: 0.077979\n", + "Train Epoch: 3 [0/640 (0%)]\tLoss: 0.021656\n", + "Train Epoch: 4 [0/640 (0%)]\tLoss: 0.007355\n", + "Train Epoch: 5 [0/640 (0%)]\tLoss: 0.004926\n", + "Train Epoch: 6 [0/640 (0%)]\tLoss: 0.006014\n", + "Train Epoch: 7 [0/640 (0%)]\tLoss: 0.003400\n", + "Train Epoch: 8 [0/640 (0%)]\tLoss: 0.003720\n", + "Train Epoch: 9 [0/640 (0%)]\tLoss: 0.004267\n", + "Train Epoch: 10 [0/640 (0%)]\tLoss: 0.004081\n" ] } ], @@ -454,18 +440,21 @@ " # Generate sine wave with the random attributes\n", " sine_wave = generate_sine_with_noise(self.n_points, frequency, phase, amplitude, noise_sd)\n", "\n", - " return torch.Tensor(sine_wave[:-1, 1, None]), torch.Tensor(sine_wave[-1:, 0])\n", + " # Return the sine wave and the parameters\n", + " return torch.Tensor(sine_wave[:-1, 1, None]), torch.Tensor(sine_wave[-1:, 0]), torch.Tensor([frequency, phase, amplitude, noise_sd])\n", + "\n", + "\n", "\n", "# Usage:\n", "dataset = SineDataset(640, 1025, (1, 3), (0, 2*np.pi), (0.5, 1.5), (0.05, 0.15))\n", "\n", "def train(model, device, train_loader, optimizer, epoch):\n", " model.train()\n", - " for batch_idx, (data, target) in enumerate(train_loader):\n", + " for batch_idx, (data, target, params) in enumerate(train_loader):\n", " #data = data[...,None]\n", " data, target = data.to(device), target.to(device)\n", " optimizer.zero_grad()\n", - " output = model(data)\n", + " output,last_state = model(data)\n", " #import pdb;pdb.set_trace()\n", "\n", " loss = F.mse_loss(output, target)\n", @@ -504,10 +493,81 @@ "outputs": [], "source": [] }, + { + "cell_type": "code", + "execution_count": 4, + "id": "90330622-8b44-4b45-8158-6840538f768c", + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.linear_model import LinearRegression\n", + "from sklearn.metrics import r2_score\n", + "\n", + "def fit_and_evaluate_linear_regression(outputs_and_params):\n", + " # Split the data into inputs (last_states) and targets (params)\n", + " inputs = np.concatenate([x[0] for x in outputs_and_params])\n", + " targets = np.concatenate([x[1] for x in outputs_and_params])\n", + " \n", + " r2_scores = []\n", + " param_names = [\"frequency\", \"phase\", \"amplitude\", \"noise_sd\"]\n", + " \n", + " # Fit the linear regression model for each parameter and calculate the R^2 score\n", + " for i in range(targets.shape[1]):\n", + " model = LinearRegression().fit(inputs, targets[:, i])\n", + " pred = model.predict(inputs)\n", + " score = r2_score(targets[:, i], pred)\n", + " r2_scores.append(score)\n", + " print(f\"R^2 score for {param_names[i]}: {score:.2f}\")\n", + " \n", + " return r2_scores" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "5eb62a22-cad8-43c4-b757-f36b6a01e9be", + "metadata": {}, + "outputs": [], + "source": [ + "def generate_outputs(model, device, data_loader):\n", + " model.eval()\n", + " outputs_and_params = []\n", + " with torch.no_grad():\n", + " for data, target, params in data_loader:\n", + " data, target = data.to(device), target.to(device)\n", + " output, last_state = model(data)\n", + " outputs_and_params.append((last_state.cpu().numpy(), params.cpu().numpy()))\n", + " return outputs_and_params\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "a95ee542-1c39-4f04-9184-e26c6983a018", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "R^2 score for frequency: 0.77\n", + "R^2 score for phase: 0.66\n", + "R^2 score for amplitude: 0.99\n", + "R^2 score for noise_sd: 0.97\n" + ] + } + ], + "source": [ + "outputs_and_params = generate_outputs(model, device, train_loader)\n", + "\n", + "# Fit the linear regression model and print the R^2 score for each parameter\n", + "r2_scores = fit_and_evaluate_linear_regression(outputs_and_params)" + ] + }, { "cell_type": "code", "execution_count": null, - "id": "1b763e03-baab-4b02-bae0-5747461bca7f", + "id": "61c7e1db-ffee-4ef2-a8c6-5756e326904c", "metadata": {}, "outputs": [], "source": []