hyena first try, trains
This commit is contained in:
@@ -0,0 +1,537 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "6c6e33cb-72f9-42fa-936a-33b5fe338d15",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"torch.Size([1, 1024, 128]) torch.Size([1, 1024, 128])\n",
|
||||
"Causality check: gradients should not flow \"from future to past\"\n",
|
||||
"tensor(3.2471e-09) tensor(0.4080)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# %load standalone_hyena.py\n",
|
||||
"\"\"\"\n",
|
||||
"Simplified standalone version of Hyena: https://arxiv.org/abs/2302.10866, designed for quick experimentation.\n",
|
||||
"A complete version is available under `src.models.sequence.hyena`.\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"import math\n",
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import torch.nn.functional as F\n",
|
||||
"from einops import rearrange\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def fftconv(u, k, D):\n",
|
||||
" seqlen = u.shape[-1]\n",
|
||||
" fft_size = 2 * seqlen\n",
|
||||
" \n",
|
||||
" k_f = torch.fft.rfft(k, n=fft_size) / fft_size\n",
|
||||
" u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size)\n",
|
||||
" \n",
|
||||
" if len(u.shape) > 3: k_f = k_f.unsqueeze(1)\n",
|
||||
" y = torch.fft.irfft(u_f * k_f, n=fft_size, norm='forward')[..., :seqlen]\n",
|
||||
"\n",
|
||||
" out = y + u * D.unsqueeze(-1)\n",
|
||||
" return out.to(dtype=u.dtype)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@torch.jit.script \n",
|
||||
"def mul_sum(q, y):\n",
|
||||
" return (q * y).sum(dim=1)\n",
|
||||
"\n",
|
||||
"class OptimModule(nn.Module):\n",
|
||||
" \"\"\" Interface for Module that allows registering buffers/parameters with configurable optimizer hyperparameters \"\"\"\n",
|
||||
"\n",
|
||||
" def register(self, name, tensor, lr=None, wd=0.0):\n",
|
||||
" \"\"\"Register a tensor with a configurable learning rate and 0 weight decay\"\"\"\n",
|
||||
"\n",
|
||||
" if lr == 0.0:\n",
|
||||
" self.register_buffer(name, tensor)\n",
|
||||
" else:\n",
|
||||
" self.register_parameter(name, nn.Parameter(tensor))\n",
|
||||
"\n",
|
||||
" optim = {}\n",
|
||||
" if lr is not None: optim[\"lr\"] = lr\n",
|
||||
" if wd is not None: optim[\"weight_decay\"] = wd\n",
|
||||
" setattr(getattr(self, name), \"_optim\", optim)\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"class Sin(nn.Module):\n",
|
||||
" def __init__(self, dim, w=10, train_freq=True):\n",
|
||||
" super().__init__()\n",
|
||||
" self.freq = nn.Parameter(w * torch.ones(1, dim)) if train_freq else w * torch.ones(1, dim)\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" return torch.sin(self.freq * x)\n",
|
||||
" \n",
|
||||
" \n",
|
||||
"class PositionalEmbedding(OptimModule):\n",
|
||||
" def __init__(self, emb_dim: int, seq_len: int, lr_pos_emb: float=1e-5, **kwargs): \n",
|
||||
" \"\"\"Complex exponential positional embeddings for Hyena filters.\"\"\" \n",
|
||||
" super().__init__()\n",
|
||||
" \n",
|
||||
" self.seq_len = seq_len\n",
|
||||
" # The time embedding fed to the filteres is normalized so that t_f = 1\n",
|
||||
" t = torch.linspace(0, 1, self.seq_len)[None, :, None] # 1, L, 1\n",
|
||||
" \n",
|
||||
" if emb_dim > 1:\n",
|
||||
" bands = (emb_dim - 1) // 2 \n",
|
||||
" # To compute the right embeddings we use the \"proper\" linspace \n",
|
||||
" t_rescaled = torch.linspace(0, seq_len - 1, seq_len)[None, :, None]\n",
|
||||
" w = 2 * math.pi * t_rescaled / seq_len # 1, L, 1 \n",
|
||||
" \n",
|
||||
" f = torch.linspace(1e-4, bands - 1, bands)[None, None] \n",
|
||||
" z = torch.exp(-1j * f * w)\n",
|
||||
" z = torch.cat([t, z.real, z.imag], dim=-1)\n",
|
||||
" self.register(\"z\", z, lr=lr_pos_emb) \n",
|
||||
" self.register(\"t\", t, lr=0.0)\n",
|
||||
" \n",
|
||||
" def forward(self, L):\n",
|
||||
" return self.z[:, :L], self.t[:, :L]\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"class ExponentialModulation(OptimModule):\n",
|
||||
" def __init__(\n",
|
||||
" self,\n",
|
||||
" d_model,\n",
|
||||
" fast_decay_pct=0.3,\n",
|
||||
" slow_decay_pct=1.5,\n",
|
||||
" target=1e-2,\n",
|
||||
" modulation_lr=0.0,\n",
|
||||
" modulate: bool=True,\n",
|
||||
" shift: float = 0.0,\n",
|
||||
" **kwargs\n",
|
||||
" ):\n",
|
||||
" super().__init__()\n",
|
||||
" self.modulate = modulate\n",
|
||||
" self.shift = shift\n",
|
||||
" max_decay = math.log(target) / fast_decay_pct\n",
|
||||
" min_decay = math.log(target) / slow_decay_pct\n",
|
||||
" deltas = torch.linspace(min_decay, max_decay, d_model)[None, None]\n",
|
||||
" self.register(\"deltas\", deltas, lr=modulation_lr)\n",
|
||||
" \n",
|
||||
" def forward(self, t, x):\n",
|
||||
" if self.modulate:\n",
|
||||
" decay = torch.exp(-t * self.deltas.abs()) \n",
|
||||
" x = x * (decay + self.shift)\n",
|
||||
" return x \n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class HyenaFilter(OptimModule):\n",
|
||||
" def __init__(\n",
|
||||
" self, \n",
|
||||
" d_model,\n",
|
||||
" emb_dim=3, # dim of input to MLP, augments with positional encoding\n",
|
||||
" order=16, # width of the implicit MLP \n",
|
||||
" fused_fft_conv=False,\n",
|
||||
" seq_len=1024, \n",
|
||||
" lr=1e-3, \n",
|
||||
" lr_pos_emb=1e-5,\n",
|
||||
" dropout=0.0, \n",
|
||||
" w=1, # frequency of periodic activations \n",
|
||||
" wd=0, # weight decay of kernel parameters \n",
|
||||
" bias=True,\n",
|
||||
" num_inner_mlps=2,\n",
|
||||
" normalized=False,\n",
|
||||
" **kwargs\n",
|
||||
" ):\n",
|
||||
" \"\"\"\n",
|
||||
" Implicit long filter with modulation.\n",
|
||||
" \n",
|
||||
" Args:\n",
|
||||
" d_model: number of channels in the input\n",
|
||||
" emb_dim: dimension of the positional encoding (`emb_dim` - 1) // 2 is the number of bands\n",
|
||||
" order: width of the FFN\n",
|
||||
" num_inner_mlps: number of inner linear layers inside filter MLP\n",
|
||||
" \"\"\"\n",
|
||||
" super().__init__()\n",
|
||||
" self.d_model = d_model\n",
|
||||
" self.use_bias = bias\n",
|
||||
" self.fused_fft_conv = fused_fft_conv\n",
|
||||
" self.bias = nn.Parameter(torch.randn(self.d_model))\n",
|
||||
" self.dropout = nn.Dropout(dropout)\n",
|
||||
" \n",
|
||||
" act = Sin(dim=order, w=w)\n",
|
||||
" self.emb_dim = emb_dim\n",
|
||||
" assert emb_dim % 2 != 0 and emb_dim >= 3, \"emb_dim must be odd and greater or equal to 3 (time, sine and cosine)\"\n",
|
||||
" self.seq_len = seq_len\n",
|
||||
" \n",
|
||||
" self.pos_emb = PositionalEmbedding(emb_dim, seq_len, lr_pos_emb)\n",
|
||||
" \n",
|
||||
" self.implicit_filter = nn.Sequential(\n",
|
||||
" nn.Linear(emb_dim, order),\n",
|
||||
" act,\n",
|
||||
" )\n",
|
||||
" for i in range(num_inner_mlps):\n",
|
||||
" self.implicit_filter.append(nn.Linear(order, order))\n",
|
||||
" self.implicit_filter.append(act)\n",
|
||||
"\n",
|
||||
" self.implicit_filter.append(nn.Linear(order, d_model, bias=False))\n",
|
||||
" \n",
|
||||
" self.modulation = ExponentialModulation(d_model, **kwargs)\n",
|
||||
" \n",
|
||||
" self.normalized = normalized\n",
|
||||
" for c in self.implicit_filter.children():\n",
|
||||
" for name, v in c.state_dict().items(): \n",
|
||||
" optim = {\"weight_decay\": wd, \"lr\": lr}\n",
|
||||
" setattr(getattr(c, name), \"_optim\", optim)\n",
|
||||
"\n",
|
||||
" def filter(self, L, *args, **kwargs):\n",
|
||||
" z, t = self.pos_emb(L)\n",
|
||||
" h = self.implicit_filter(z)\n",
|
||||
" h = self.modulation(t, h)\n",
|
||||
" return h\n",
|
||||
"\n",
|
||||
" def forward(self, x, L, k=None, bias=None, *args, **kwargs):\n",
|
||||
" if k is None: k = self.filter(L)\n",
|
||||
" \n",
|
||||
" # Ensure compatibility with filters that return a tuple \n",
|
||||
" k = k[0] if type(k) is tuple else k \n",
|
||||
"\n",
|
||||
" y = fftconv(x, k, bias)\n",
|
||||
" return y\n",
|
||||
" \n",
|
||||
" \n",
|
||||
"class HyenaOperator(nn.Module):\n",
|
||||
" def __init__(\n",
|
||||
" self,\n",
|
||||
" d_model,\n",
|
||||
" l_max,\n",
|
||||
" order=2, \n",
|
||||
" filter_order=64,\n",
|
||||
" dropout=0.0, \n",
|
||||
" filter_dropout=0.0, \n",
|
||||
" **filter_args,\n",
|
||||
" ):\n",
|
||||
" r\"\"\"\n",
|
||||
" Hyena operator described in the paper https://arxiv.org/pdf/2302.10866.pdf\n",
|
||||
" \n",
|
||||
" Args:\n",
|
||||
" d_model (int): Dimension of the input and output embeddings (width of the layer)\n",
|
||||
" l_max: (int): Maximum input sequence length. Defaults to None\n",
|
||||
" order: (int): Depth of the Hyena recurrence. Defaults to 2\n",
|
||||
" dropout: (float): Dropout probability. Defaults to 0.0\n",
|
||||
" filter_dropout: (float): Dropout probability for the filter. Defaults to 0.0\n",
|
||||
" \"\"\"\n",
|
||||
" super().__init__()\n",
|
||||
" self.d_model = d_model\n",
|
||||
" self.l_max = l_max\n",
|
||||
" self.order = order\n",
|
||||
" inner_width = d_model * (order + 1)\n",
|
||||
" self.dropout = nn.Dropout(dropout)\n",
|
||||
" self.in_proj = nn.Linear(d_model, inner_width)\n",
|
||||
" self.out_proj = nn.Linear(d_model, d_model)\n",
|
||||
" \n",
|
||||
" self.short_filter = nn.Conv1d(\n",
|
||||
" inner_width, \n",
|
||||
" inner_width, \n",
|
||||
" 3,\n",
|
||||
" padding=2,\n",
|
||||
" groups=inner_width\n",
|
||||
" )\n",
|
||||
" self.filter_fn = HyenaFilter(\n",
|
||||
" d_model * (order - 1), \n",
|
||||
" order=filter_order, \n",
|
||||
" seq_len=l_max,\n",
|
||||
" channels=1, \n",
|
||||
" dropout=filter_dropout, \n",
|
||||
" **filter_args\n",
|
||||
" ) \n",
|
||||
"\n",
|
||||
" def forward(self, u, *args, **kwargs):\n",
|
||||
" l = u.size(-2)\n",
|
||||
" l_filter = min(l, self.l_max)\n",
|
||||
" u = self.in_proj(u)\n",
|
||||
" u = rearrange(u, 'b l d -> b d l')\n",
|
||||
" \n",
|
||||
" uc = self.short_filter(u)[...,:l_filter] \n",
|
||||
" *x, v = uc.split(self.d_model, dim=1)\n",
|
||||
" \n",
|
||||
" k = self.filter_fn.filter(l_filter)[0]\n",
|
||||
" k = rearrange(k, 'l (o d) -> o d l', o=self.order - 1)\n",
|
||||
" bias = rearrange(self.filter_fn.bias, '(o d) -> o d', o=self.order - 1)\n",
|
||||
" \n",
|
||||
" for o, x_i in enumerate(reversed(x[1:])):\n",
|
||||
" v = self.dropout(v * x_i)\n",
|
||||
" v = self.filter_fn(v, l_filter, k=k[o], bias=bias[o])\n",
|
||||
"\n",
|
||||
" y = rearrange(v * x[0], 'b d l -> b l d')\n",
|
||||
"\n",
|
||||
" y = self.out_proj(y)\n",
|
||||
" return y\n",
|
||||
"\n",
|
||||
" \n",
|
||||
" \n",
|
||||
"if __name__ == \"__main__\":\n",
|
||||
" layer = HyenaOperator(\n",
|
||||
" \n",
|
||||
" d_model=128, \n",
|
||||
" l_max=1024, \n",
|
||||
" order=2, \n",
|
||||
" filter_order=64\n",
|
||||
" )\n",
|
||||
" x = torch.randn(1, 1024, 128, requires_grad=True)\n",
|
||||
" y = layer(x)\n",
|
||||
" \n",
|
||||
" print(x.shape, y.shape)\n",
|
||||
" \n",
|
||||
" grad = torch.autograd.grad(y[:, 10, :].sum(), x)[0]\n",
|
||||
" print('Causality check: gradients should not flow \"from future to past\"')\n",
|
||||
" print(grad[0, 11, :].sum(), grad[0, 9, :].sum())\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "032ef08a-8cc6-491a-9eb8-4a6b3f2d165e",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"torch.Size([1, 1023, 1]) torch.Size([1, 1])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"class HyenaOperatorAutoregressive1D(nn.Module):\n",
|
||||
" def __init__(\n",
|
||||
" self,\n",
|
||||
" d_model,\n",
|
||||
" l_max,\n",
|
||||
" order=2, \n",
|
||||
" filter_order=64,\n",
|
||||
" dropout=0.0, \n",
|
||||
" filter_dropout=0.0, \n",
|
||||
" **filter_args,\n",
|
||||
" ):\n",
|
||||
" super().__init__()\n",
|
||||
"\n",
|
||||
" self.l_max = l_max\n",
|
||||
" self.d_model = d_model\n",
|
||||
" self.l_max = l_max\n",
|
||||
" self.order = order\n",
|
||||
" inner_width = d_model * (order + 1)\n",
|
||||
"\n",
|
||||
" self.dropout = nn.Dropout(dropout)\n",
|
||||
" self.in_proj = nn.Linear(d_model, inner_width)\n",
|
||||
" self.out_proj = nn.Linear(d_model, d_model)\n",
|
||||
" self.fc_before = nn.Linear(1, d_model) # Fully connected layer before the main layer\n",
|
||||
" self.fc_after = nn.Linear(d_model, 1) # Fully connected layer after the main layer\n",
|
||||
"\n",
|
||||
" self.operator = HyenaOperator(\n",
|
||||
" d_model=d_model,\n",
|
||||
" l_max=l_max,\n",
|
||||
" order=order, \n",
|
||||
" filter_order=filter_order,\n",
|
||||
" dropout=dropout, \n",
|
||||
" filter_dropout=filter_dropout, \n",
|
||||
" **filter_args,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def forward(self, u, *args, **kwargs):\n",
|
||||
" # Increase the channel dimension from 1 to d_model\n",
|
||||
" u = self.fc_before(u) \n",
|
||||
" # Pass through the operator\n",
|
||||
" u = self.operator(u)\n",
|
||||
" last_state = u[:,-1,:]\n",
|
||||
" # Decrease the channel dimension back to 1\n",
|
||||
" y = self.fc_after(last_state)\n",
|
||||
" return y\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"if __name__ == \"__main__\":\n",
|
||||
" layer = HyenaOperatorAutoregressive1D(\n",
|
||||
" d_model=128, \n",
|
||||
" l_max=1024, \n",
|
||||
" order=2, \n",
|
||||
" filter_order=64\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" x = torch.randn(1, 1023, 1, requires_grad=True) # 1D time series input\n",
|
||||
" y = layer(x)\n",
|
||||
"\n",
|
||||
" #import pdb;pdb.set_trace()\n",
|
||||
" print(x.shape, y.shape) # should now be [1, 1024, 1]\n",
|
||||
"\n",
|
||||
" #grad = torch.autograd.grad(y[:, 10, 0].sum(), x)[0]\n",
|
||||
" #print('Causality check: gradients should not flow \"from future to past\"')\n",
|
||||
" #print(grad[0, 11, 0].sum(), grad[0, 9, 0].sum())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 32,
|
||||
"id": "80cde67b-992f-4cb0-8824-4a6b7e4984ca",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Train Epoch: 1 [0/640 (0%)]\tLoss: 0.433575\n",
|
||||
"Train Epoch: 2 [0/640 (0%)]\tLoss: 0.054185\n",
|
||||
"Train Epoch: 3 [0/640 (0%)]\tLoss: 0.007312\n",
|
||||
"Train Epoch: 4 [0/640 (0%)]\tLoss: 0.004312\n",
|
||||
"Train Epoch: 5 [0/640 (0%)]\tLoss: 0.003393\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"ename": "KeyboardInterrupt",
|
||||
"evalue": "",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
||||
"Cell \u001b[0;32mIn[32], line 87\u001b[0m\n\u001b[1;32m 84\u001b[0m train_loader \u001b[38;5;241m=\u001b[39m DataLoader(dataset, batch_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m64\u001b[39m, shuffle\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 86\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m11\u001b[39m): \u001b[38;5;66;03m# Train for 10 epochs\u001b[39;00m\n\u001b[0;32m---> 87\u001b[0m \u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_loader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mepoch\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"Cell \u001b[0;32mIn[32], line 59\u001b[0m, in \u001b[0;36mtrain\u001b[0;34m(model, device, train_loader, optimizer, epoch)\u001b[0m\n\u001b[1;32m 57\u001b[0m data, target \u001b[38;5;241m=\u001b[39m data\u001b[38;5;241m.\u001b[39mto(device), target\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[1;32m 58\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[0;32m---> 59\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 60\u001b[0m \u001b[38;5;66;03m#import pdb;pdb.set_trace()\u001b[39;00m\n\u001b[1;32m 62\u001b[0m loss \u001b[38;5;241m=\u001b[39m F\u001b[38;5;241m.\u001b[39mmse_loss(output, target)\n",
|
||||
"File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
|
||||
"Cell \u001b[0;32mIn[2], line 40\u001b[0m, in \u001b[0;36mHyenaOperatorAutoregressive1D.forward\u001b[0;34m(self, u, *args, **kwargs)\u001b[0m\n\u001b[1;32m 38\u001b[0m u \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfc_before(u) \n\u001b[1;32m 39\u001b[0m \u001b[38;5;66;03m# Pass through the operator\u001b[39;00m\n\u001b[0;32m---> 40\u001b[0m u \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moperator\u001b[49m\u001b[43m(\u001b[49m\u001b[43mu\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 41\u001b[0m last_state \u001b[38;5;241m=\u001b[39m u[:,\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m,:]\n\u001b[1;32m 42\u001b[0m \u001b[38;5;66;03m# Decrease the channel dimension back to 1\u001b[39;00m\n",
|
||||
"File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
|
||||
"Cell \u001b[0;32mIn[1], line 237\u001b[0m, in \u001b[0;36mHyenaOperator.forward\u001b[0;34m(self, u, *args, **kwargs)\u001b[0m\n\u001b[1;32m 234\u001b[0m u \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39min_proj(u)\n\u001b[1;32m 235\u001b[0m u \u001b[38;5;241m=\u001b[39m rearrange(u, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mb l d -> b d l\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m--> 237\u001b[0m uc \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mshort_filter\u001b[49m\u001b[43m(\u001b[49m\u001b[43mu\u001b[49m\u001b[43m)\u001b[49m[\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m,:l_filter] \n\u001b[1;32m 238\u001b[0m \u001b[38;5;241m*\u001b[39mx, v \u001b[38;5;241m=\u001b[39m uc\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39md_model, dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 240\u001b[0m k \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfilter_fn\u001b[38;5;241m.\u001b[39mfilter(l_filter)[\u001b[38;5;241m0\u001b[39m]\n",
|
||||
"File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
|
||||
"File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/conv.py:313\u001b[0m, in \u001b[0;36mConv1d.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 312\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 313\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_conv_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/conv.py:309\u001b[0m, in \u001b[0;36mConv1d._conv_forward\u001b[0;34m(self, input, weight, bias)\u001b[0m\n\u001b[1;32m 305\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpadding_mode \u001b[38;5;241m!=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mzeros\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[1;32m 306\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m F\u001b[38;5;241m.\u001b[39mconv1d(F\u001b[38;5;241m.\u001b[39mpad(\u001b[38;5;28minput\u001b[39m, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reversed_padding_repeated_twice, mode\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpadding_mode),\n\u001b[1;32m 307\u001b[0m weight, bias, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstride,\n\u001b[1;32m 308\u001b[0m _single(\u001b[38;5;241m0\u001b[39m), \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdilation, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgroups)\n\u001b[0;32m--> 309\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconv1d\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbias\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstride\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 310\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpadding\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdilation\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgroups\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"import torch.optim as optim\n",
|
||||
"import torch.nn.functional as F\n",
|
||||
"from torch.utils.data import DataLoader, Dataset\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"def generate_sine_with_noise(n_points, frequency, phase, amplitude, noise_sd):\n",
|
||||
" # Generate an array of points from 0 to 2*pi\n",
|
||||
" x = np.linspace(0, 2*np.pi, n_points)\n",
|
||||
" \n",
|
||||
" # Generate the sine wave\n",
|
||||
" sine_wave = amplitude * np.sin(frequency * x + phase)\n",
|
||||
" \n",
|
||||
" # Generate Gaussian noise\n",
|
||||
" noise = np.random.normal(scale=noise_sd, size=n_points)\n",
|
||||
" \n",
|
||||
" # Add the noise to the sine wave\n",
|
||||
" sine_wave_noise = sine_wave + noise\n",
|
||||
" \n",
|
||||
" # Stack the sine wave and the noisy sine wave into a 2D array\n",
|
||||
" output = np.column_stack((sine_wave, sine_wave_noise))\n",
|
||||
" \n",
|
||||
" return output\n",
|
||||
" \n",
|
||||
" \n",
|
||||
"class SineDataset(Dataset):\n",
|
||||
" def __init__(self, n_samples, n_points, frequency_range, phase_range, amplitude_range, noise_sd_range):\n",
|
||||
" self.n_samples = n_samples\n",
|
||||
" self.n_points = n_points\n",
|
||||
" self.frequency_range = frequency_range\n",
|
||||
" self.phase_range = phase_range\n",
|
||||
" self.amplitude_range = amplitude_range\n",
|
||||
" self.noise_sd_range = noise_sd_range\n",
|
||||
"\n",
|
||||
" def __len__(self):\n",
|
||||
" return self.n_samples\n",
|
||||
"\n",
|
||||
" def __getitem__(self, idx):\n",
|
||||
" # Generate random attributes\n",
|
||||
" frequency = np.random.uniform(*self.frequency_range)\n",
|
||||
" phase = np.random.uniform(*self.phase_range)\n",
|
||||
" amplitude = np.random.uniform(*self.amplitude_range)\n",
|
||||
" noise_sd = np.random.uniform(*self.noise_sd_range)\n",
|
||||
"\n",
|
||||
" # Generate sine wave with the random attributes\n",
|
||||
" sine_wave = generate_sine_with_noise(self.n_points, frequency, phase, amplitude, noise_sd)\n",
|
||||
"\n",
|
||||
" return torch.Tensor(sine_wave[:-1, 1, None]), torch.Tensor(sine_wave[-1:, 0])\n",
|
||||
"\n",
|
||||
"# Usage:\n",
|
||||
"dataset = SineDataset(640, 1025, (1, 3), (0, 2*np.pi), (0.5, 1.5), (0.05, 0.15))\n",
|
||||
"\n",
|
||||
"def train(model, device, train_loader, optimizer, epoch):\n",
|
||||
" model.train()\n",
|
||||
" for batch_idx, (data, target) in enumerate(train_loader):\n",
|
||||
" #data = data[...,None]\n",
|
||||
" data, target = data.to(device), target.to(device)\n",
|
||||
" optimizer.zero_grad()\n",
|
||||
" output = model(data)\n",
|
||||
" #import pdb;pdb.set_trace()\n",
|
||||
"\n",
|
||||
" loss = F.mse_loss(output, target)\n",
|
||||
" loss.backward()\n",
|
||||
" optimizer.step()\n",
|
||||
" if batch_idx % 10 == 0:\n",
|
||||
" print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n",
|
||||
" epoch, batch_idx * len(data), len(train_loader.dataset),\n",
|
||||
" 100. * batch_idx / len(train_loader), loss.item()))\n",
|
||||
"\n",
|
||||
"if __name__ == \"__main__\":\n",
|
||||
" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
||||
"\n",
|
||||
" model = HyenaOperatorAutoregressive1D(\n",
|
||||
" d_model=128, \n",
|
||||
" l_max=1024, \n",
|
||||
" order=2, \n",
|
||||
" filter_order=64\n",
|
||||
" ).to(device)\n",
|
||||
"\n",
|
||||
" optimizer = optim.Adam(model.parameters())\n",
|
||||
"\n",
|
||||
" # Assume 10000 samples in the dataset\n",
|
||||
" #dataset = SineDataset(10000, 1025, 2, 0, 1, 0.1)\n",
|
||||
" train_loader = DataLoader(dataset, batch_size=64, shuffle=True)\n",
|
||||
"\n",
|
||||
" for epoch in range(1, 11): # Train for 10 epochs\n",
|
||||
" train(model, device, train_loader, optimizer, epoch)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "cc9f9031-5ee1-49f8-a70f-ad85ca015596",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1b763e03-baab-4b02-bae0-5747461bca7f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
268
hyena_test/.ipynb_checkpoints/standalone_hyena-checkpoint.py
Normal file
268
hyena_test/.ipynb_checkpoints/standalone_hyena-checkpoint.py
Normal file
@@ -0,0 +1,268 @@
|
||||
"""
|
||||
Simplified standalone version of Hyena: https://arxiv.org/abs/2302.10866, designed for quick experimentation.
|
||||
A complete version is available under `src.models.sequence.hyena`.
|
||||
"""
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
def fftconv(u, k, D):
|
||||
seqlen = u.shape[-1]
|
||||
fft_size = 2 * seqlen
|
||||
|
||||
k_f = torch.fft.rfft(k, n=fft_size) / fft_size
|
||||
u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size)
|
||||
|
||||
if len(u.shape) > 3: k_f = k_f.unsqueeze(1)
|
||||
y = torch.fft.irfft(u_f * k_f, n=fft_size, norm='forward')[..., :seqlen]
|
||||
|
||||
out = y + u * D.unsqueeze(-1)
|
||||
return out.to(dtype=u.dtype)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def mul_sum(q, y):
|
||||
return (q * y).sum(dim=1)
|
||||
|
||||
class OptimModule(nn.Module):
|
||||
""" Interface for Module that allows registering buffers/parameters with configurable optimizer hyperparameters """
|
||||
|
||||
def register(self, name, tensor, lr=None, wd=0.0):
|
||||
"""Register a tensor with a configurable learning rate and 0 weight decay"""
|
||||
|
||||
if lr == 0.0:
|
||||
self.register_buffer(name, tensor)
|
||||
else:
|
||||
self.register_parameter(name, nn.Parameter(tensor))
|
||||
|
||||
optim = {}
|
||||
if lr is not None: optim["lr"] = lr
|
||||
if wd is not None: optim["weight_decay"] = wd
|
||||
setattr(getattr(self, name), "_optim", optim)
|
||||
|
||||
|
||||
class Sin(nn.Module):
|
||||
def __init__(self, dim, w=10, train_freq=True):
|
||||
super().__init__()
|
||||
self.freq = nn.Parameter(w * torch.ones(1, dim)) if train_freq else w * torch.ones(1, dim)
|
||||
|
||||
def forward(self, x):
|
||||
return torch.sin(self.freq * x)
|
||||
|
||||
|
||||
class PositionalEmbedding(OptimModule):
|
||||
def __init__(self, emb_dim: int, seq_len: int, lr_pos_emb: float=1e-5, **kwargs):
|
||||
"""Complex exponential positional embeddings for Hyena filters."""
|
||||
super().__init__()
|
||||
|
||||
self.seq_len = seq_len
|
||||
# The time embedding fed to the filteres is normalized so that t_f = 1
|
||||
t = torch.linspace(0, 1, self.seq_len)[None, :, None] # 1, L, 1
|
||||
|
||||
if emb_dim > 1:
|
||||
bands = (emb_dim - 1) // 2
|
||||
# To compute the right embeddings we use the "proper" linspace
|
||||
t_rescaled = torch.linspace(0, seq_len - 1, seq_len)[None, :, None]
|
||||
w = 2 * math.pi * t_rescaled / seq_len # 1, L, 1
|
||||
|
||||
f = torch.linspace(1e-4, bands - 1, bands)[None, None]
|
||||
z = torch.exp(-1j * f * w)
|
||||
z = torch.cat([t, z.real, z.imag], dim=-1)
|
||||
self.register("z", z, lr=lr_pos_emb)
|
||||
self.register("t", t, lr=0.0)
|
||||
|
||||
def forward(self, L):
|
||||
return self.z[:, :L], self.t[:, :L]
|
||||
|
||||
|
||||
class ExponentialModulation(OptimModule):
|
||||
def __init__(
|
||||
self,
|
||||
d_model,
|
||||
fast_decay_pct=0.3,
|
||||
slow_decay_pct=1.5,
|
||||
target=1e-2,
|
||||
modulation_lr=0.0,
|
||||
modulate: bool=True,
|
||||
shift: float = 0.0,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
self.modulate = modulate
|
||||
self.shift = shift
|
||||
max_decay = math.log(target) / fast_decay_pct
|
||||
min_decay = math.log(target) / slow_decay_pct
|
||||
deltas = torch.linspace(min_decay, max_decay, d_model)[None, None]
|
||||
self.register("deltas", deltas, lr=modulation_lr)
|
||||
|
||||
def forward(self, t, x):
|
||||
if self.modulate:
|
||||
decay = torch.exp(-t * self.deltas.abs())
|
||||
x = x * (decay + self.shift)
|
||||
return x
|
||||
|
||||
|
||||
class HyenaFilter(OptimModule):
|
||||
def __init__(
|
||||
self,
|
||||
d_model,
|
||||
emb_dim=3, # dim of input to MLP, augments with positional encoding
|
||||
order=16, # width of the implicit MLP
|
||||
fused_fft_conv=False,
|
||||
seq_len=1024,
|
||||
lr=1e-3,
|
||||
lr_pos_emb=1e-5,
|
||||
dropout=0.0,
|
||||
w=1, # frequency of periodic activations
|
||||
wd=0, # weight decay of kernel parameters
|
||||
bias=True,
|
||||
num_inner_mlps=2,
|
||||
normalized=False,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Implicit long filter with modulation.
|
||||
|
||||
Args:
|
||||
d_model: number of channels in the input
|
||||
emb_dim: dimension of the positional encoding (`emb_dim` - 1) // 2 is the number of bands
|
||||
order: width of the FFN
|
||||
num_inner_mlps: number of inner linear layers inside filter MLP
|
||||
"""
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.use_bias = bias
|
||||
self.fused_fft_conv = fused_fft_conv
|
||||
self.bias = nn.Parameter(torch.randn(self.d_model))
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
act = Sin(dim=order, w=w)
|
||||
self.emb_dim = emb_dim
|
||||
assert emb_dim % 2 != 0 and emb_dim >= 3, "emb_dim must be odd and greater or equal to 3 (time, sine and cosine)"
|
||||
self.seq_len = seq_len
|
||||
|
||||
self.pos_emb = PositionalEmbedding(emb_dim, seq_len, lr_pos_emb)
|
||||
|
||||
self.implicit_filter = nn.Sequential(
|
||||
nn.Linear(emb_dim, order),
|
||||
act,
|
||||
)
|
||||
for i in range(num_inner_mlps):
|
||||
self.implicit_filter.append(nn.Linear(order, order))
|
||||
self.implicit_filter.append(act)
|
||||
|
||||
self.implicit_filter.append(nn.Linear(order, d_model, bias=False))
|
||||
|
||||
self.modulation = ExponentialModulation(d_model, **kwargs)
|
||||
|
||||
self.normalized = normalized
|
||||
for c in self.implicit_filter.children():
|
||||
for name, v in c.state_dict().items():
|
||||
optim = {"weight_decay": wd, "lr": lr}
|
||||
setattr(getattr(c, name), "_optim", optim)
|
||||
|
||||
def filter(self, L, *args, **kwargs):
|
||||
z, t = self.pos_emb(L)
|
||||
h = self.implicit_filter(z)
|
||||
h = self.modulation(t, h)
|
||||
return h
|
||||
|
||||
def forward(self, x, L, k=None, bias=None, *args, **kwargs):
|
||||
if k is None: k = self.filter(L)
|
||||
|
||||
# Ensure compatibility with filters that return a tuple
|
||||
k = k[0] if type(k) is tuple else k
|
||||
|
||||
y = fftconv(x, k, bias)
|
||||
return y
|
||||
|
||||
|
||||
class HyenaOperator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model,
|
||||
l_max,
|
||||
order=2,
|
||||
filter_order=64,
|
||||
dropout=0.0,
|
||||
filter_dropout=0.0,
|
||||
**filter_args,
|
||||
):
|
||||
r"""
|
||||
Hyena operator described in the paper https://arxiv.org/pdf/2302.10866.pdf
|
||||
|
||||
Args:
|
||||
d_model (int): Dimension of the input and output embeddings (width of the layer)
|
||||
l_max: (int): Maximum input sequence length. Defaults to None
|
||||
order: (int): Depth of the Hyena recurrence. Defaults to 2
|
||||
dropout: (float): Dropout probability. Defaults to 0.0
|
||||
filter_dropout: (float): Dropout probability for the filter. Defaults to 0.0
|
||||
"""
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.l_max = l_max
|
||||
self.order = order
|
||||
inner_width = d_model * (order + 1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.in_proj = nn.Linear(d_model, inner_width)
|
||||
self.out_proj = nn.Linear(d_model, d_model)
|
||||
|
||||
self.short_filter = nn.Conv1d(
|
||||
inner_width,
|
||||
inner_width,
|
||||
3,
|
||||
padding=2,
|
||||
groups=inner_width
|
||||
)
|
||||
self.filter_fn = HyenaFilter(
|
||||
d_model * (order - 1),
|
||||
order=filter_order,
|
||||
seq_len=l_max,
|
||||
channels=1,
|
||||
dropout=filter_dropout,
|
||||
**filter_args
|
||||
)
|
||||
|
||||
def forward(self, u, *args, **kwargs):
|
||||
l = u.size(-2)
|
||||
l_filter = min(l, self.l_max)
|
||||
u = self.in_proj(u)
|
||||
u = rearrange(u, 'b l d -> b d l')
|
||||
|
||||
uc = self.short_filter(u)[...,:l_filter]
|
||||
*x, v = uc.split(self.d_model, dim=1)
|
||||
|
||||
k = self.filter_fn.filter(l_filter)[0]
|
||||
k = rearrange(k, 'l (o d) -> o d l', o=self.order - 1)
|
||||
bias = rearrange(self.filter_fn.bias, '(o d) -> o d', o=self.order - 1)
|
||||
|
||||
for o, x_i in enumerate(reversed(x[1:])):
|
||||
v = self.dropout(v * x_i)
|
||||
v = self.filter_fn(v, l_filter, k=k[o], bias=bias[o])
|
||||
|
||||
y = rearrange(v * x[0], 'b d l -> b l d')
|
||||
|
||||
y = self.out_proj(y)
|
||||
return y
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
layer = HyenaOperator(
|
||||
d_model=512,
|
||||
l_max=1024,
|
||||
order=2,
|
||||
filter_order=64
|
||||
)
|
||||
x = torch.randn(1, 1024, 512, requires_grad=True)
|
||||
y = layer(x)
|
||||
|
||||
print(x.shape, y.shape)
|
||||
|
||||
grad = torch.autograd.grad(y[:, 10, :].sum(), x)[0]
|
||||
print('Causality check: gradients should not flow "from future to past"')
|
||||
print(grad[0, 11, :].sum(), grad[0, 9, :].sum())
|
||||
Reference in New Issue
Block a user