{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "6c6e33cb-72f9-42fa-936a-33b5fe338d15", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([1, 1024, 128]) torch.Size([1, 1024, 128])\n", "Causality check: gradients should not flow \"from future to past\"\n", "tensor(-4.1268e-09) tensor(0.0844)\n" ] } ], "source": [ "# %load standalone_hyena.py\n", "\"\"\"\n", "Simplified standalone version of Hyena: https://arxiv.org/abs/2302.10866, designed for quick experimentation.\n", "A complete version is available under `src.models.sequence.hyena`.\n", "\"\"\"\n", "\n", "import math\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from einops import rearrange\n", "\n", "\n", "def fftconv(u, k, D):\n", " seqlen = u.shape[-1]\n", " fft_size = 2 * seqlen\n", " \n", " k_f = torch.fft.rfft(k, n=fft_size) / fft_size\n", " u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size)\n", " \n", " if len(u.shape) > 3: k_f = k_f.unsqueeze(1)\n", " y = torch.fft.irfft(u_f * k_f, n=fft_size, norm='forward')[..., :seqlen]\n", "\n", " out = y + u * D.unsqueeze(-1)\n", " return out.to(dtype=u.dtype)\n", "\n", "\n", "@torch.jit.script \n", "def mul_sum(q, y):\n", " return (q * y).sum(dim=1)\n", "\n", "class OptimModule(nn.Module):\n", " \"\"\" Interface for Module that allows registering buffers/parameters with configurable optimizer hyperparameters \"\"\"\n", "\n", " def register(self, name, tensor, lr=None, wd=0.0):\n", " \"\"\"Register a tensor with a configurable learning rate and 0 weight decay\"\"\"\n", "\n", " if lr == 0.0:\n", " self.register_buffer(name, tensor)\n", " else:\n", " self.register_parameter(name, nn.Parameter(tensor))\n", "\n", " optim = {}\n", " if lr is not None: optim[\"lr\"] = lr\n", " if wd is not None: optim[\"weight_decay\"] = wd\n", " setattr(getattr(self, name), \"_optim\", optim)\n", " \n", "\n", "class Sin(nn.Module):\n", " def __init__(self, dim, w=10, train_freq=True):\n", " super().__init__()\n", " self.freq = nn.Parameter(w * torch.ones(1, dim)) if train_freq else w * torch.ones(1, dim)\n", "\n", " def forward(self, x):\n", " return torch.sin(self.freq * x)\n", " \n", " \n", "class PositionalEmbedding(OptimModule):\n", " def __init__(self, emb_dim: int, seq_len: int, lr_pos_emb: float=1e-5, **kwargs): \n", " \"\"\"Complex exponential positional embeddings for Hyena filters.\"\"\" \n", " super().__init__()\n", " \n", " self.seq_len = seq_len\n", " # The time embedding fed to the filteres is normalized so that t_f = 1\n", " t = torch.linspace(0, 1, self.seq_len)[None, :, None] # 1, L, 1\n", " \n", " if emb_dim > 1:\n", " bands = (emb_dim - 1) // 2 \n", " # To compute the right embeddings we use the \"proper\" linspace \n", " t_rescaled = torch.linspace(0, seq_len - 1, seq_len)[None, :, None]\n", " w = 2 * math.pi * t_rescaled / seq_len # 1, L, 1 \n", " \n", " f = torch.linspace(1e-4, bands - 1, bands)[None, None] \n", " z = torch.exp(-1j * f * w)\n", " z = torch.cat([t, z.real, z.imag], dim=-1)\n", " self.register(\"z\", z, lr=lr_pos_emb) \n", " self.register(\"t\", t, lr=0.0)\n", " \n", " def forward(self, L):\n", " return self.z[:, :L], self.t[:, :L]\n", " \n", "\n", "class ExponentialModulation(OptimModule):\n", " def __init__(\n", " self,\n", " d_model,\n", " fast_decay_pct=0.3,\n", " slow_decay_pct=1.5,\n", " target=1e-2,\n", " modulation_lr=0.0,\n", " modulate: bool=True,\n", " shift: float = 0.0,\n", " **kwargs\n", " ):\n", " super().__init__()\n", " self.modulate = modulate\n", " self.shift = shift\n", " max_decay = math.log(target) / fast_decay_pct\n", " min_decay = math.log(target) / slow_decay_pct\n", " deltas = torch.linspace(min_decay, max_decay, d_model)[None, None]\n", " self.register(\"deltas\", deltas, lr=modulation_lr)\n", " \n", " def forward(self, t, x):\n", " if self.modulate:\n", " decay = torch.exp(-t * self.deltas.abs()) \n", " x = x * (decay + self.shift)\n", " return x \n", "\n", "\n", "class HyenaFilter(OptimModule):\n", " def __init__(\n", " self, \n", " d_model,\n", " emb_dim=3, # dim of input to MLP, augments with positional encoding\n", " order=16, # width of the implicit MLP \n", " fused_fft_conv=False,\n", " seq_len=1024, \n", " lr=1e-3, \n", " lr_pos_emb=1e-5,\n", " dropout=0.0, \n", " w=1, # frequency of periodic activations \n", " wd=0, # weight decay of kernel parameters \n", " bias=True,\n", " num_inner_mlps=2,\n", " normalized=False,\n", " **kwargs\n", " ):\n", " \"\"\"\n", " Implicit long filter with modulation.\n", " \n", " Args:\n", " d_model: number of channels in the input\n", " emb_dim: dimension of the positional encoding (`emb_dim` - 1) // 2 is the number of bands\n", " order: width of the FFN\n", " num_inner_mlps: number of inner linear layers inside filter MLP\n", " \"\"\"\n", " super().__init__()\n", " self.d_model = d_model\n", " self.use_bias = bias\n", " self.fused_fft_conv = fused_fft_conv\n", " self.bias = nn.Parameter(torch.randn(self.d_model))\n", " self.dropout = nn.Dropout(dropout)\n", " \n", " act = Sin(dim=order, w=w)\n", " self.emb_dim = emb_dim\n", " assert emb_dim % 2 != 0 and emb_dim >= 3, \"emb_dim must be odd and greater or equal to 3 (time, sine and cosine)\"\n", " self.seq_len = seq_len\n", " \n", " self.pos_emb = PositionalEmbedding(emb_dim, seq_len, lr_pos_emb)\n", " \n", " self.implicit_filter = nn.Sequential(\n", " nn.Linear(emb_dim, order),\n", " act,\n", " )\n", " for i in range(num_inner_mlps):\n", " self.implicit_filter.append(nn.Linear(order, order))\n", " self.implicit_filter.append(act)\n", "\n", " self.implicit_filter.append(nn.Linear(order, d_model, bias=False))\n", " \n", " self.modulation = ExponentialModulation(d_model, **kwargs)\n", " \n", " self.normalized = normalized\n", " for c in self.implicit_filter.children():\n", " for name, v in c.state_dict().items(): \n", " optim = {\"weight_decay\": wd, \"lr\": lr}\n", " setattr(getattr(c, name), \"_optim\", optim)\n", "\n", " def filter(self, L, *args, **kwargs):\n", " z, t = self.pos_emb(L)\n", " h = self.implicit_filter(z)\n", " h = self.modulation(t, h)\n", " return h\n", "\n", " def forward(self, x, L, k=None, bias=None, *args, **kwargs):\n", " if k is None: k = self.filter(L)\n", " \n", " # Ensure compatibility with filters that return a tuple \n", " k = k[0] if type(k) is tuple else k \n", "\n", " y = fftconv(x, k, bias)\n", " return y\n", " \n", " \n", "class HyenaOperator(nn.Module):\n", " def __init__(\n", " self,\n", " d_model,\n", " l_max,\n", " order=2, \n", " filter_order=64,\n", " dropout=0.0, \n", " filter_dropout=0.0, \n", " **filter_args,\n", " ):\n", " r\"\"\"\n", " Hyena operator described in the paper https://arxiv.org/pdf/2302.10866.pdf\n", " \n", " Args:\n", " d_model (int): Dimension of the input and output embeddings (width of the layer)\n", " l_max: (int): Maximum input sequence length. Defaults to None\n", " order: (int): Depth of the Hyena recurrence. Defaults to 2\n", " dropout: (float): Dropout probability. Defaults to 0.0\n", " filter_dropout: (float): Dropout probability for the filter. Defaults to 0.0\n", " \"\"\"\n", " super().__init__()\n", " self.d_model = d_model\n", " self.l_max = l_max\n", " self.order = order\n", " inner_width = d_model * (order + 1)\n", " self.dropout = nn.Dropout(dropout)\n", " self.in_proj = nn.Linear(d_model, inner_width)\n", " self.out_proj = nn.Linear(d_model, d_model)\n", " \n", " self.short_filter = nn.Conv1d(\n", " inner_width, \n", " inner_width, \n", " 3,\n", " padding=2,\n", " groups=inner_width\n", " )\n", " self.filter_fn = HyenaFilter(\n", " d_model * (order - 1), \n", " order=filter_order, \n", " seq_len=l_max,\n", " channels=1, \n", " dropout=filter_dropout, \n", " **filter_args\n", " ) \n", "\n", " def forward(self, u, *args, **kwargs):\n", " l = u.size(-2)\n", " l_filter = min(l, self.l_max)\n", " u = self.in_proj(u)\n", " u = rearrange(u, 'b l d -> b d l')\n", " \n", " uc = self.short_filter(u)[...,:l_filter] \n", " *x, v = uc.split(self.d_model, dim=1)\n", " \n", " k = self.filter_fn.filter(l_filter)[0]\n", " k = rearrange(k, 'l (o d) -> o d l', o=self.order - 1)\n", " bias = rearrange(self.filter_fn.bias, '(o d) -> o d', o=self.order - 1)\n", " \n", " for o, x_i in enumerate(reversed(x[1:])):\n", " v = self.dropout(v * x_i)\n", " v = self.filter_fn(v, l_filter, k=k[o], bias=bias[o])\n", "\n", " y = rearrange(v * x[0], 'b d l -> b l d')\n", "\n", " y = self.out_proj(y)\n", " return y\n", "\n", " \n", " \n", "if __name__ == \"__main__\":\n", " layer = HyenaOperator(\n", " \n", " d_model=128, \n", " l_max=1024, \n", " order=2, \n", " filter_order=64\n", " )\n", " x = torch.randn(1, 1024, 128, requires_grad=True)\n", " y = layer(x)\n", " \n", " print(x.shape, y.shape)\n", " \n", " grad = torch.autograd.grad(y[:, 10, :].sum(), x)[0]\n", " print('Causality check: gradients should not flow \"from future to past\"')\n", " print(grad[0, 11, :].sum(), grad[0, 9, :].sum())\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "032ef08a-8cc6-491a-9eb8-4a6b3f2d165e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([1, 1023, 1]) torch.Size([1, 1])\n" ] } ], "source": [ "class HyenaOperatorAutoregressive1D(nn.Module):\n", " def __init__(\n", " self,\n", " d_model,\n", " l_max,\n", " order=2, \n", " filter_order=64,\n", " dropout=0.0, \n", " filter_dropout=0.0, \n", " **filter_args,\n", " ):\n", " super().__init__()\n", "\n", " self.l_max = l_max\n", " self.d_model = d_model\n", " self.l_max = l_max\n", " self.order = order\n", " inner_width = d_model * (order + 1)\n", "\n", " self.dropout = nn.Dropout(dropout)\n", " self.in_proj = nn.Linear(d_model, inner_width)\n", " self.out_proj = nn.Linear(d_model, d_model)\n", " self.fc_before = nn.Linear(1, d_model) # Fully connected layer before the main layer\n", " self.fc_after = nn.Linear(d_model, 1) # Fully connected layer after the main layer\n", "\n", " self.operator = HyenaOperator(\n", " d_model=d_model,\n", " l_max=l_max,\n", " order=order, \n", " filter_order=filter_order,\n", " dropout=dropout, \n", " filter_dropout=filter_dropout, \n", " **filter_args,\n", " )\n", "\n", " def forward(self, u, *args, **kwargs):\n", " # Increase the channel dimension from 1 to d_model\n", " u = self.fc_before(u) \n", " # Pass through the operator\n", " u = self.operator(u)\n", " last_state = u[:,-1,:]\n", " # Decrease the channel dimension back to 1\n", " y = self.fc_after(last_state)\n", " return y,last_state\n", "\n", "\n", "if __name__ == \"__main__\":\n", " layer = HyenaOperatorAutoregressive1D(\n", " d_model=128, \n", " l_max=1024, \n", " order=2, \n", " filter_order=64\n", " )\n", "\n", " x = torch.randn(1, 1023, 1, requires_grad=True) # 1D time series input\n", " y, last_state = layer(x)\n", "\n", " #import pdb;pdb.set_trace()\n", " print(x.shape, y.shape) # should now be [1, 1024, 1]\n", "\n", " #grad = torch.autograd.grad(y[:, 10, 0].sum(), x)[0]\n", " #print('Causality check: gradients should not flow \"from future to past\"')\n", " #print(grad[0, 11, 0].sum(), grad[0, 9, 0].sum())" ] }, { "cell_type": "code", "execution_count": 3, "id": "80cde67b-992f-4cb0-8824-4a6b7e4984ca", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train Epoch: 1 [0/640 (0%)]\tLoss: 0.446847\n", "Train Epoch: 2 [0/640 (0%)]\tLoss: 0.077979\n", "Train Epoch: 3 [0/640 (0%)]\tLoss: 0.021656\n", "Train Epoch: 4 [0/640 (0%)]\tLoss: 0.007355\n", "Train Epoch: 5 [0/640 (0%)]\tLoss: 0.004926\n", "Train Epoch: 6 [0/640 (0%)]\tLoss: 0.006014\n", "Train Epoch: 7 [0/640 (0%)]\tLoss: 0.003400\n", "Train Epoch: 8 [0/640 (0%)]\tLoss: 0.003720\n", "Train Epoch: 9 [0/640 (0%)]\tLoss: 0.004267\n", "Train Epoch: 10 [0/640 (0%)]\tLoss: 0.004081\n" ] } ], "source": [ "import torch\n", "import torch.optim as optim\n", "import torch.nn.functional as F\n", "from torch.utils.data import DataLoader, Dataset\n", "import numpy as np\n", "\n", "def generate_sine_with_noise(n_points, frequency, phase, amplitude, noise_sd):\n", " # Generate an array of points from 0 to 2*pi\n", " x = np.linspace(0, 2*np.pi, n_points)\n", " \n", " # Generate the sine wave\n", " sine_wave = amplitude * np.sin(frequency * x + phase)\n", " \n", " # Generate Gaussian noise\n", " noise = np.random.normal(scale=noise_sd, size=n_points)\n", " \n", " # Add the noise to the sine wave\n", " sine_wave_noise = sine_wave + noise\n", " \n", " # Stack the sine wave and the noisy sine wave into a 2D array\n", " output = np.column_stack((sine_wave, sine_wave_noise))\n", " \n", " return output\n", " \n", " \n", "class SineDataset(Dataset):\n", " def __init__(self, n_samples, n_points, frequency_range, phase_range, amplitude_range, noise_sd_range):\n", " self.n_samples = n_samples\n", " self.n_points = n_points\n", " self.frequency_range = frequency_range\n", " self.phase_range = phase_range\n", " self.amplitude_range = amplitude_range\n", " self.noise_sd_range = noise_sd_range\n", "\n", " def __len__(self):\n", " return self.n_samples\n", "\n", " def __getitem__(self, idx):\n", " # Generate random attributes\n", " frequency = np.random.uniform(*self.frequency_range)\n", " phase = np.random.uniform(*self.phase_range)\n", " amplitude = np.random.uniform(*self.amplitude_range)\n", " noise_sd = np.random.uniform(*self.noise_sd_range)\n", "\n", " # Generate sine wave with the random attributes\n", " sine_wave = generate_sine_with_noise(self.n_points, frequency, phase, amplitude, noise_sd)\n", "\n", " # Return the sine wave and the parameters\n", " return torch.Tensor(sine_wave[:-1, 1, None]), torch.Tensor(sine_wave[-1:, 0]), torch.Tensor([frequency, phase, amplitude, noise_sd])\n", "\n", "\n", "\n", "# Usage:\n", "dataset = SineDataset(640, 1025, (1, 3), (0, 2*np.pi), (0.5, 1.5), (0.05, 0.15))\n", "\n", "def train(model, device, train_loader, optimizer, epoch):\n", " model.train()\n", " for batch_idx, (data, target, params) in enumerate(train_loader):\n", " #data = data[...,None]\n", " data, target = data.to(device), target.to(device)\n", " optimizer.zero_grad()\n", " output,last_state = model(data)\n", " #import pdb;pdb.set_trace()\n", "\n", " loss = F.mse_loss(output, target)\n", " loss.backward()\n", " optimizer.step()\n", " if batch_idx % 10 == 0:\n", " print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n", " epoch, batch_idx * len(data), len(train_loader.dataset),\n", " 100. * batch_idx / len(train_loader), loss.item()))\n", "\n", "if __name__ == \"__main__\":\n", " device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "\n", " model = HyenaOperatorAutoregressive1D(\n", " d_model=128, \n", " l_max=1024, \n", " order=2, \n", " filter_order=64\n", " ).to(device)\n", "\n", " optimizer = optim.Adam(model.parameters())\n", "\n", " # Assume 10000 samples in the dataset\n", " #dataset = SineDataset(10000, 1025, 2, 0, 1, 0.1)\n", " train_loader = DataLoader(dataset, batch_size=64, shuffle=True)\n", "\n", " for epoch in range(1, 11): # Train for 10 epochs\n", " train(model, device, train_loader, optimizer, epoch)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "cc9f9031-5ee1-49f8-a70f-ad85ca015596", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 4, "id": "90330622-8b44-4b45-8158-6840538f768c", "metadata": {}, "outputs": [], "source": [ "from sklearn.linear_model import LinearRegression\n", "from sklearn.metrics import r2_score\n", "\n", "def fit_and_evaluate_linear_regression(outputs_and_params):\n", " # Split the data into inputs (last_states) and targets (params)\n", " inputs = np.concatenate([x[0] for x in outputs_and_params])\n", " targets = np.concatenate([x[1] for x in outputs_and_params])\n", " \n", " r2_scores = []\n", " param_names = [\"frequency\", \"phase\", \"amplitude\", \"noise_sd\"]\n", " \n", " # Fit the linear regression model for each parameter and calculate the R^2 score\n", " for i in range(targets.shape[1]):\n", " model = LinearRegression().fit(inputs, targets[:, i])\n", " pred = model.predict(inputs)\n", " score = r2_score(targets[:, i], pred)\n", " r2_scores.append(score)\n", " print(f\"R^2 score for {param_names[i]}: {score:.2f}\")\n", " \n", " return r2_scores" ] }, { "cell_type": "code", "execution_count": 5, "id": "5eb62a22-cad8-43c4-b757-f36b6a01e9be", "metadata": {}, "outputs": [], "source": [ "def generate_outputs(model, device, data_loader):\n", " model.eval()\n", " outputs_and_params = []\n", " with torch.no_grad():\n", " for data, target, params in data_loader:\n", " data, target = data.to(device), target.to(device)\n", " output, last_state = model(data)\n", " outputs_and_params.append((last_state.cpu().numpy(), params.cpu().numpy()))\n", " return outputs_and_params\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "a95ee542-1c39-4f04-9184-e26c6983a018", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "R^2 score for frequency: 0.77\n", "R^2 score for phase: 0.66\n", "R^2 score for amplitude: 0.99\n", "R^2 score for noise_sd: 0.97\n" ] } ], "source": [ "outputs_and_params = generate_outputs(model, device, train_loader)\n", "\n", "# Fit the linear regression model and print the R^2 score for each parameter\n", "r2_scores = fit_and_evaluate_linear_regression(outputs_and_params)" ] }, { "cell_type": "code", "execution_count": null, "id": "61c7e1db-ffee-4ef2-a8c6-5756e326904c", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.10" } }, "nbformat": 4, "nbformat_minor": 5 }