R^2 for hyena

This commit is contained in:
2023-08-05 12:14:16 +02:00
parent ffd2d2f1a7
commit bdef020bff
2 changed files with 227 additions and 65 deletions

View File

@@ -12,7 +12,7 @@
"text": [
"torch.Size([1, 1024, 128]) torch.Size([1, 1024, 128])\n",
"Causality check: gradients should not flow \"from future to past\"\n",
"tensor(3.2471e-09) tensor(0.4080)\n"
"tensor(-4.1268e-09) tensor(0.0844)\n"
]
}
],
@@ -347,7 +347,7 @@
" last_state = u[:,-1,:]\n",
" # Decrease the channel dimension back to 1\n",
" y = self.fc_after(last_state)\n",
" return y\n",
" return y,last_state\n",
"\n",
"\n",
"if __name__ == \"__main__\":\n",
@@ -359,7 +359,7 @@
" )\n",
"\n",
" x = torch.randn(1, 1023, 1, requires_grad=True) # 1D time series input\n",
" y = layer(x)\n",
" y, last_state = layer(x)\n",
"\n",
" #import pdb;pdb.set_trace()\n",
" print(x.shape, y.shape) # should now be [1, 1024, 1]\n",
@@ -371,7 +371,7 @@
},
{
"cell_type": "code",
"execution_count": 32,
"execution_count": 3,
"id": "80cde67b-992f-4cb0-8824-4a6b7e4984ca",
"metadata": {},
"outputs": [
@@ -379,30 +379,16 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Train Epoch: 1 [0/640 (0%)]\tLoss: 0.433575\n",
"Train Epoch: 2 [0/640 (0%)]\tLoss: 0.054185\n",
"Train Epoch: 3 [0/640 (0%)]\tLoss: 0.007312\n",
"Train Epoch: 4 [0/640 (0%)]\tLoss: 0.004312\n",
"Train Epoch: 5 [0/640 (0%)]\tLoss: 0.003393\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[32], line 87\u001b[0m\n\u001b[1;32m 84\u001b[0m train_loader \u001b[38;5;241m=\u001b[39m DataLoader(dataset, batch_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m64\u001b[39m, shuffle\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 86\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m11\u001b[39m): \u001b[38;5;66;03m# Train for 10 epochs\u001b[39;00m\n\u001b[0;32m---> 87\u001b[0m \u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_loader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mepoch\u001b[49m\u001b[43m)\u001b[49m\n",
"Cell \u001b[0;32mIn[32], line 59\u001b[0m, in \u001b[0;36mtrain\u001b[0;34m(model, device, train_loader, optimizer, epoch)\u001b[0m\n\u001b[1;32m 57\u001b[0m data, target \u001b[38;5;241m=\u001b[39m data\u001b[38;5;241m.\u001b[39mto(device), target\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[1;32m 58\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[0;32m---> 59\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 60\u001b[0m \u001b[38;5;66;03m#import pdb;pdb.set_trace()\u001b[39;00m\n\u001b[1;32m 62\u001b[0m loss \u001b[38;5;241m=\u001b[39m F\u001b[38;5;241m.\u001b[39mmse_loss(output, target)\n",
"File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
"Cell \u001b[0;32mIn[2], line 40\u001b[0m, in \u001b[0;36mHyenaOperatorAutoregressive1D.forward\u001b[0;34m(self, u, *args, **kwargs)\u001b[0m\n\u001b[1;32m 38\u001b[0m u \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfc_before(u) \n\u001b[1;32m 39\u001b[0m \u001b[38;5;66;03m# Pass through the operator\u001b[39;00m\n\u001b[0;32m---> 40\u001b[0m u \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moperator\u001b[49m\u001b[43m(\u001b[49m\u001b[43mu\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 41\u001b[0m last_state \u001b[38;5;241m=\u001b[39m u[:,\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m,:]\n\u001b[1;32m 42\u001b[0m \u001b[38;5;66;03m# Decrease the channel dimension back to 1\u001b[39;00m\n",
"File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
"Cell \u001b[0;32mIn[1], line 237\u001b[0m, in \u001b[0;36mHyenaOperator.forward\u001b[0;34m(self, u, *args, **kwargs)\u001b[0m\n\u001b[1;32m 234\u001b[0m u \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39min_proj(u)\n\u001b[1;32m 235\u001b[0m u \u001b[38;5;241m=\u001b[39m rearrange(u, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mb l d -> b d l\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m--> 237\u001b[0m uc \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mshort_filter\u001b[49m\u001b[43m(\u001b[49m\u001b[43mu\u001b[49m\u001b[43m)\u001b[49m[\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m,:l_filter] \n\u001b[1;32m 238\u001b[0m \u001b[38;5;241m*\u001b[39mx, v \u001b[38;5;241m=\u001b[39m uc\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39md_model, dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 240\u001b[0m k \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfilter_fn\u001b[38;5;241m.\u001b[39mfilter(l_filter)[\u001b[38;5;241m0\u001b[39m]\n",
"File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
"File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/conv.py:313\u001b[0m, in \u001b[0;36mConv1d.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 312\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 313\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_conv_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/conv.py:309\u001b[0m, in \u001b[0;36mConv1d._conv_forward\u001b[0;34m(self, input, weight, bias)\u001b[0m\n\u001b[1;32m 305\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpadding_mode \u001b[38;5;241m!=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mzeros\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[1;32m 306\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m F\u001b[38;5;241m.\u001b[39mconv1d(F\u001b[38;5;241m.\u001b[39mpad(\u001b[38;5;28minput\u001b[39m, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reversed_padding_repeated_twice, mode\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpadding_mode),\n\u001b[1;32m 307\u001b[0m weight, bias, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstride,\n\u001b[1;32m 308\u001b[0m _single(\u001b[38;5;241m0\u001b[39m), \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdilation, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgroups)\n\u001b[0;32m--> 309\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconv1d\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbias\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstride\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 310\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpadding\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdilation\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgroups\u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
"Train Epoch: 1 [0/640 (0%)]\tLoss: 0.446847\n",
"Train Epoch: 2 [0/640 (0%)]\tLoss: 0.077979\n",
"Train Epoch: 3 [0/640 (0%)]\tLoss: 0.021656\n",
"Train Epoch: 4 [0/640 (0%)]\tLoss: 0.007355\n",
"Train Epoch: 5 [0/640 (0%)]\tLoss: 0.004926\n",
"Train Epoch: 6 [0/640 (0%)]\tLoss: 0.006014\n",
"Train Epoch: 7 [0/640 (0%)]\tLoss: 0.003400\n",
"Train Epoch: 8 [0/640 (0%)]\tLoss: 0.003720\n",
"Train Epoch: 9 [0/640 (0%)]\tLoss: 0.004267\n",
"Train Epoch: 10 [0/640 (0%)]\tLoss: 0.004081\n"
]
}
],
@@ -454,18 +440,21 @@
" # Generate sine wave with the random attributes\n",
" sine_wave = generate_sine_with_noise(self.n_points, frequency, phase, amplitude, noise_sd)\n",
"\n",
" return torch.Tensor(sine_wave[:-1, 1, None]), torch.Tensor(sine_wave[-1:, 0])\n",
" # Return the sine wave and the parameters\n",
" return torch.Tensor(sine_wave[:-1, 1, None]), torch.Tensor(sine_wave[-1:, 0]), torch.Tensor([frequency, phase, amplitude, noise_sd])\n",
"\n",
"\n",
"\n",
"# Usage:\n",
"dataset = SineDataset(640, 1025, (1, 3), (0, 2*np.pi), (0.5, 1.5), (0.05, 0.15))\n",
"\n",
"def train(model, device, train_loader, optimizer, epoch):\n",
" model.train()\n",
" for batch_idx, (data, target) in enumerate(train_loader):\n",
" for batch_idx, (data, target, params) in enumerate(train_loader):\n",
" #data = data[...,None]\n",
" data, target = data.to(device), target.to(device)\n",
" optimizer.zero_grad()\n",
" output = model(data)\n",
" output,last_state = model(data)\n",
" #import pdb;pdb.set_trace()\n",
"\n",
" loss = F.mse_loss(output, target)\n",
@@ -504,10 +493,81 @@
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 4,
"id": "90330622-8b44-4b45-8158-6840538f768c",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.linear_model import LinearRegression\n",
"from sklearn.metrics import r2_score\n",
"\n",
"def fit_and_evaluate_linear_regression(outputs_and_params):\n",
" # Split the data into inputs (last_states) and targets (params)\n",
" inputs = np.concatenate([x[0] for x in outputs_and_params])\n",
" targets = np.concatenate([x[1] for x in outputs_and_params])\n",
" \n",
" r2_scores = []\n",
" param_names = [\"frequency\", \"phase\", \"amplitude\", \"noise_sd\"]\n",
" \n",
" # Fit the linear regression model for each parameter and calculate the R^2 score\n",
" for i in range(targets.shape[1]):\n",
" model = LinearRegression().fit(inputs, targets[:, i])\n",
" pred = model.predict(inputs)\n",
" score = r2_score(targets[:, i], pred)\n",
" r2_scores.append(score)\n",
" print(f\"R^2 score for {param_names[i]}: {score:.2f}\")\n",
" \n",
" return r2_scores"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "5eb62a22-cad8-43c4-b757-f36b6a01e9be",
"metadata": {},
"outputs": [],
"source": [
"def generate_outputs(model, device, data_loader):\n",
" model.eval()\n",
" outputs_and_params = []\n",
" with torch.no_grad():\n",
" for data, target, params in data_loader:\n",
" data, target = data.to(device), target.to(device)\n",
" output, last_state = model(data)\n",
" outputs_and_params.append((last_state.cpu().numpy(), params.cpu().numpy()))\n",
" return outputs_and_params\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "a95ee542-1c39-4f04-9184-e26c6983a018",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"R^2 score for frequency: 0.77\n",
"R^2 score for phase: 0.66\n",
"R^2 score for amplitude: 0.99\n",
"R^2 score for noise_sd: 0.97\n"
]
}
],
"source": [
"outputs_and_params = generate_outputs(model, device, train_loader)\n",
"\n",
"# Fit the linear regression model and print the R^2 score for each parameter\n",
"r2_scores = fit_and_evaluate_linear_regression(outputs_and_params)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1b763e03-baab-4b02-bae0-5747461bca7f",
"id": "61c7e1db-ffee-4ef2-a8c6-5756e326904c",
"metadata": {},
"outputs": [],
"source": []