Add RWKV, H3, Hyena
This commit is contained in:
469
models/rwkv/rwkv_model.py
Normal file
469
models/rwkv/rwkv_model.py
Normal file
@@ -0,0 +1,469 @@
|
||||
import logging
|
||||
import math
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..base import (
|
||||
BaseSequenceModelConfig,
|
||||
BaseSequenceModelTrain,
|
||||
ResettableParametersModule,
|
||||
)
|
||||
from .wkv_kernel import WKV, WKVConfig, WKVTorch
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class L2Wrap(torch.autograd.Function):
|
||||
"""L2 regularization for the logits."""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, loss, y):
|
||||
ctx.save_for_backward(y)
|
||||
return loss
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
y = ctx.saved_tensors[0]
|
||||
# to encourage the logits to be close to 0
|
||||
factor = 1e-4 / (y.shape[0] * y.shape[1])
|
||||
maxx, ids = torch.max(y, -1, keepdim=True)
|
||||
gy = torch.zeros_like(y)
|
||||
gy.scatter_(-1, ids, maxx * factor)
|
||||
return (grad_output, gy)
|
||||
|
||||
|
||||
def _get_activation_fn(activation):
|
||||
if activation == "relu":
|
||||
return F.relu
|
||||
elif activation == "gelu":
|
||||
return F.gelu
|
||||
elif activation == "silu":
|
||||
return F.silu
|
||||
elif activation == "relu_squared":
|
||||
return lambda x: torch.square(torch.relu(x))
|
||||
elif activation == "selu":
|
||||
return F.selu
|
||||
elif activation == "elu":
|
||||
return F.elu
|
||||
else:
|
||||
raise ValueError(f"Unknown activation function {activation}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class RWKVConfig(BaseSequenceModelConfig):
|
||||
embedding_dim: int = 640
|
||||
ffn_dim: int = 2048
|
||||
num_layers: int = 12
|
||||
attention_dim: int = -1
|
||||
wkv_config: Optional[Union[WKVConfig, Dict]] = field(
|
||||
default_factory=lambda: WKVConfig()
|
||||
)
|
||||
l2_logit_reg: bool = False
|
||||
use_timemix_timemix: bool = True
|
||||
use_timemix_channelmix: bool = True
|
||||
channelmix_act_fn: str = "relu_squared"
|
||||
init: str = "reproduce_init" # options "paper_init", "reproduce_init" (reproduce init is different from paper init)
|
||||
# the reproduce_init was figured out when only the code was available, not the paper
|
||||
# the init as it is described in the paper slightly differs from the reproduce_init
|
||||
# the reproduce_init actually performs much better than the paper_init (ca. 0.5 better in train loss after 300 steps on WikiText-103)
|
||||
# so we use this as default. reproduce_init is also (likely to be) used in the original code.
|
||||
|
||||
def __post_init__(self):
|
||||
# TODO: Check if this was actually needed
|
||||
if self.wkv_config is not None:
|
||||
self.wkv_config.T_max = max(
|
||||
self.wkv_config.T_max, self.context_length
|
||||
)
|
||||
if self.attention_dim <= 0:
|
||||
self.attention_dim = self.embedding_dim
|
||||
|
||||
|
||||
class RWKV(BaseSequenceModelTrain):
|
||||
config_class = RWKVConfig
|
||||
|
||||
def __init__(self, config: RWKVConfig, **kwargs):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.cfg = config
|
||||
|
||||
# self.embedding = nn.Embedding(
|
||||
# num_embeddings=self.cfg.vocab_size,
|
||||
# embedding_dim=self.cfg.embedding_dim,
|
||||
# )
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
RWKVBlock(config=self.cfg, block_idx=i)
|
||||
for i in range(self.cfg.num_layers)
|
||||
]
|
||||
)
|
||||
if self.cfg.wkv_config is not None:
|
||||
LOGGER.info("Using WKV cuda kernel.")
|
||||
else:
|
||||
LOGGER.info("Using WKV torch kernel.")
|
||||
|
||||
# self.ln_out = nn.LayerNorm(self.cfg.embedding_dim)
|
||||
# self.head = nn.Linear(
|
||||
# self.cfg.embedding_dim, self.cfg.vocab_size, bias=False
|
||||
# )
|
||||
self.reset_parameters()
|
||||
|
||||
@property
|
||||
def context_length(self) -> int:
|
||||
return self.config.context_length
|
||||
|
||||
# @property
|
||||
# def vocab_size(self) -> int:
|
||||
# return self.config.vocab_size
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
# init embedding
|
||||
# default init is zero # TODO try this
|
||||
# we use a narrow uniform init, in the original code they use the initial learning rate
|
||||
# we just set it to a small value
|
||||
# emb_init_range = 0.0008 # 1e-3
|
||||
# nn.init.uniform_(
|
||||
# self.embedding.weight, a=-emb_init_range, b=emb_init_range
|
||||
# )
|
||||
# init blocks
|
||||
for b in self.blocks:
|
||||
b.reset_parameters()
|
||||
# init head and layer norm
|
||||
# self.head.reset_parameters()
|
||||
# self.ln_out.reset_parameters()
|
||||
|
||||
def _create_optim_groups(self, config: RWKVConfig):
|
||||
optim_groups = [
|
||||
{"params": [p for p in self.parameters()], "weight_decay": 0.0}
|
||||
]
|
||||
return optim_groups
|
||||
|
||||
def forward(self, x):
|
||||
# no embedding
|
||||
# # input shape: (B, T), T <= context_len, T are token ids
|
||||
# B, T = x.size()
|
||||
# assert T <= self.cfg.context_length, (
|
||||
# f"input sequence length {T} exceeds model "
|
||||
# f"context length {self.cfg.context_length}"
|
||||
# )
|
||||
|
||||
# x = self.embedding(x) # (B, T, C), C = embedding_dim
|
||||
|
||||
for i, block in enumerate(self.blocks):
|
||||
x = block(x)
|
||||
|
||||
# x = self.ln_out(x)
|
||||
# no head
|
||||
# x = self.head(x)
|
||||
return x
|
||||
|
||||
def get_loss_func(self):
|
||||
def loss_fn(y_hat, y):
|
||||
loss = F.cross_entropy(
|
||||
y_hat.view(-1, y_hat.size(-1)), y.view(-1), ignore_index=-1
|
||||
)
|
||||
if self.cfg.l2_logit_reg:
|
||||
loss = L2Wrap.apply(loss, y_hat)
|
||||
return loss
|
||||
|
||||
return loss_fn
|
||||
|
||||
|
||||
def _calc_gain(weight: torch.Tensor) -> float:
|
||||
"""Calculate the gain value of the given weight tensor."""
|
||||
gain = 1.0
|
||||
fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(weight)
|
||||
if fan_out > fan_in:
|
||||
gain = math.sqrt(fan_out / fan_in)
|
||||
return gain
|
||||
|
||||
|
||||
class RWKVBlock(ResettableParametersModule):
|
||||
def __init__(self, config: RWKVConfig, block_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.block_idx = block_idx
|
||||
|
||||
self.ln0 = None
|
||||
if self.block_idx == 0:
|
||||
self.ln0 = nn.LayerNorm(self.config.embedding_dim)
|
||||
# TODO 1) maybe additional positional embedding here (only in block 0)
|
||||
|
||||
self.ln1 = nn.LayerNorm(self.config.embedding_dim)
|
||||
self.ln2 = nn.LayerNorm(self.config.embedding_dim)
|
||||
|
||||
# TODO 2) maybe pre feedforward here (channel mix) see line 325f in RWKV-v4neo/model.py
|
||||
self.attention_timemix = RWKVTimeMix(
|
||||
config=self.config, block_id=self.block_idx
|
||||
)
|
||||
self.ffn_channelmix = RWKVChannelMix(
|
||||
config=self.config, block_id=self.block_idx
|
||||
)
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
if self.ln0 is not None:
|
||||
self.ln0.reset_parameters()
|
||||
self.ln1.reset_parameters()
|
||||
self.ln2.reset_parameters()
|
||||
self.attention_timemix.reset_parameters()
|
||||
self.ffn_channelmix.reset_parameters()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.block_idx == 0 and self.ln0 is not None:
|
||||
x = self.ln0(x)
|
||||
# TODO 1) maybe positional embedding here (only in block 0)
|
||||
# x = x+pos_emb
|
||||
|
||||
# TODO 2) maybe pre feedforward here (channel mix) see line 325f in RWKV-v4neo/model.py
|
||||
# residual connection 1
|
||||
x = x + self.attention_timemix(self.ln1(x))
|
||||
# residual connection 2
|
||||
x = x + self.ffn_channelmix(self.ln2(x))
|
||||
return x
|
||||
|
||||
|
||||
class RWKVTimeMix(ResettableParametersModule):
|
||||
def __init__(self, config: RWKVConfig, block_id: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.block_id = block_id
|
||||
|
||||
embedding_dim = self.config.embedding_dim
|
||||
attention_dim = self.config.attention_dim
|
||||
# init time mix constants
|
||||
req_grad = True # TODO make this configurable
|
||||
self.time_mix_k = nn.Parameter(
|
||||
torch.empty((1, 1, embedding_dim)), requires_grad=req_grad
|
||||
)
|
||||
self.time_mix_v = nn.Parameter(
|
||||
torch.empty((1, 1, embedding_dim)), requires_grad=req_grad
|
||||
)
|
||||
self.time_mix_r = nn.Parameter(
|
||||
torch.empty((1, 1, embedding_dim)), requires_grad=req_grad
|
||||
)
|
||||
|
||||
# init time decay
|
||||
self.time_decay = nn.Parameter(
|
||||
torch.empty((attention_dim,)), requires_grad=req_grad
|
||||
)
|
||||
self.time_first = nn.Parameter(
|
||||
torch.empty((attention_dim,)), requires_grad=req_grad
|
||||
)
|
||||
|
||||
# init layers / parameters
|
||||
# this shifts the time dimension by 1 forward, pad 0 at first time step, remove last time step:
|
||||
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
||||
self.key = nn.Linear(embedding_dim, attention_dim, bias=False)
|
||||
self.value = nn.Linear(embedding_dim, attention_dim, bias=False)
|
||||
self.receptance = nn.Linear(embedding_dim, attention_dim, bias=False)
|
||||
self.output = nn.Linear(attention_dim, embedding_dim, bias=False)
|
||||
|
||||
if self.config.wkv_config is not None:
|
||||
# use CUDA implementation
|
||||
self.wkv = WKV(config=self.config.wkv_config)
|
||||
else:
|
||||
# use pure PyTorch implementation
|
||||
self.wkv = WKVTorch()
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
# init time mix constants
|
||||
time_mix_k, time_mix_v, time_mix_r = self._init_time_mix_constants()
|
||||
req_grad = True
|
||||
self.time_mix_k = nn.Parameter(time_mix_k, requires_grad=req_grad)
|
||||
self.time_mix_v = nn.Parameter(time_mix_v, requires_grad=req_grad)
|
||||
self.time_mix_r = nn.Parameter(time_mix_r, requires_grad=req_grad)
|
||||
# init time decay
|
||||
time_decay, time_first = self._init_time_decay_constants()
|
||||
self.time_decay = nn.Parameter(time_decay, requires_grad=req_grad)
|
||||
self.time_first = nn.Parameter(time_first, requires_grad=req_grad)
|
||||
# init layers / parameters
|
||||
if self.config.init == "paper_init":
|
||||
# ZERO INIT
|
||||
nn.init.zeros_(self.receptance.weight)
|
||||
nn.init.zeros_(self.key.weight)
|
||||
nn.init.zeros_(self.value.weight)
|
||||
# NORMAL INIT
|
||||
nn.init.normal_(
|
||||
self.output.weight,
|
||||
std=math.sqrt(self.config.ffn_dim / self.config.embedding_dim),
|
||||
)
|
||||
elif self.config.init == "reproduce_init":
|
||||
# ZERO INIT
|
||||
nn.init.zeros_(self.key.weight)
|
||||
nn.init.zeros_(self.receptance.weight)
|
||||
nn.init.zeros_(self.output.weight)
|
||||
# ORTHOGONAL INIT
|
||||
nn.init.orthogonal_(
|
||||
self.value.weight, gain=_calc_gain(self.value.weight)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown init method {self.config.init}")
|
||||
|
||||
def _compute_rkv(self, x):
|
||||
if self.config.use_timemix_timemix:
|
||||
xx = self.time_shift(
|
||||
x
|
||||
) # Mix x with the previous timestep to produce xk, xv, xr
|
||||
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
||||
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
|
||||
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
||||
else:
|
||||
xk = x
|
||||
xv = x
|
||||
xr = x
|
||||
k = self.key(xk)
|
||||
v = self.value(xv)
|
||||
r = self.receptance(xr)
|
||||
sr = torch.sigmoid(r)
|
||||
return sr, k, v
|
||||
|
||||
def forward(self, x):
|
||||
B, T, C = x.size() # x = (batch_size, seq_len, embedding_dim)
|
||||
attention_dim = self.config.attention_dim
|
||||
sr, k, v = self._compute_rkv(
|
||||
x
|
||||
) # sr, k, v = (batch_size, seq_len, attention_dim)
|
||||
# wkv cuda/torch kernel
|
||||
rwkv = sr * self.wkv(
|
||||
B, T, attention_dim, self.time_decay, self.time_first, k, v
|
||||
)
|
||||
return self.output(rwkv)
|
||||
|
||||
def _init_time_mix_constants(
|
||||
self,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
num_blocks = self.config.num_layers
|
||||
embedding_dim = self.config.embedding_dim
|
||||
|
||||
ratio_0_to_1 = self.block_id / max(1, num_blocks - 1) # 0 to 1
|
||||
ratio_1_to_almost0 = 1.0 - (self.block_id / num_blocks) # 1 to ~0
|
||||
|
||||
# TODO does this make sense?
|
||||
# different time mix constants for each block and each embedding dim
|
||||
embed_dim_val = torch.ones(1, 1, embedding_dim)
|
||||
for i in range(embedding_dim):
|
||||
embed_dim_val[0, 0, i] = i / embedding_dim
|
||||
|
||||
# TODO check constants 0.3 and 0.5
|
||||
time_mix_k = torch.pow(embed_dim_val, ratio_1_to_almost0)
|
||||
time_mix_v = (
|
||||
torch.pow(embed_dim_val, ratio_1_to_almost0) + 0.3 * ratio_0_to_1
|
||||
)
|
||||
time_mix_r = torch.pow(embed_dim_val, 0.5 * ratio_1_to_almost0)
|
||||
|
||||
return time_mix_k, time_mix_v, time_mix_r
|
||||
|
||||
def _init_time_decay_constants(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
num_blocks = self.config.num_layers
|
||||
attention_dim = self.config.attention_dim
|
||||
ratio_0_to_1 = self.block_id / max(1, num_blocks - 1) # 0 to 1
|
||||
|
||||
# time decay
|
||||
# this encourages the model to decay the information in different memory cells (channel dimensions)
|
||||
# at different speeds
|
||||
decay_speed = torch.ones(attention_dim)
|
||||
for h in range(attention_dim):
|
||||
decay_speed[h] = -5 + 8 * (h / (attention_dim - 1)) ** (
|
||||
0.7 + 1.3 * ratio_0_to_1
|
||||
)
|
||||
time_decay = decay_speed
|
||||
|
||||
# time first
|
||||
# The alternating zigzag pattern initially creates subtle variations in the tensor elements,
|
||||
# which are intended to help the model treat different dimensions of the embedding differently
|
||||
zigzag = (
|
||||
torch.tensor([(i + 1) % 3 - 1 for i in range(attention_dim)]) * 0.5
|
||||
)
|
||||
time_first = (
|
||||
torch.ones(attention_dim) * torch.log(torch.tensor(0.3)) + zigzag
|
||||
)
|
||||
|
||||
return time_decay, time_first
|
||||
|
||||
|
||||
class RWKVChannelMix(ResettableParametersModule):
|
||||
def __init__(self, config: RWKVConfig, block_id: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.block_id = block_id
|
||||
|
||||
self._act_fn = _get_activation_fn(self.config.channelmix_act_fn)
|
||||
|
||||
embedding_dim = self.config.embedding_dim
|
||||
ffn_dim = self.config.ffn_dim
|
||||
# init time mix constants
|
||||
req_grad = True
|
||||
self.time_mix_k = nn.Parameter(
|
||||
torch.empty((1, 1, embedding_dim)), requires_grad=req_grad
|
||||
)
|
||||
self.time_mix_r = nn.Parameter(
|
||||
torch.empty((1, 1, embedding_dim)), requires_grad=req_grad
|
||||
)
|
||||
|
||||
# init layers / parameters
|
||||
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
||||
self.key = nn.Linear(embedding_dim, ffn_dim, bias=False)
|
||||
self.receptance = nn.Linear(embedding_dim, embedding_dim, bias=False)
|
||||
self.value = nn.Linear(ffn_dim, embedding_dim, bias=False)
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
# init time mix constants
|
||||
time_mix_k, time_mix_r = self._init_time_mix_constants()
|
||||
req_grad = True
|
||||
self.time_mix_k = nn.Parameter(time_mix_k, requires_grad=req_grad)
|
||||
self.time_mix_r = nn.Parameter(time_mix_r, requires_grad=req_grad)
|
||||
# init layers / parameters
|
||||
if self.config.init == "paper_init":
|
||||
# ZERO INIT
|
||||
nn.init.zeros_(self.receptance.weight)
|
||||
nn.init.zeros_(self.key.weight)
|
||||
# NORMAL INIT
|
||||
nn.init.normal_(
|
||||
self.value.weight,
|
||||
std=math.sqrt(self.config.ffn_dim / self.config.embedding_dim),
|
||||
)
|
||||
elif self.config.init == "reproduce_init":
|
||||
# ZERO INIT
|
||||
nn.init.zeros_(self.receptance.weight)
|
||||
nn.init.zeros_(self.value.weight)
|
||||
# ORTHOGONAL INIT
|
||||
nn.init.orthogonal_(
|
||||
self.key.weight, gain=_calc_gain(self.key.weight)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown init method {self.config.init}")
|
||||
|
||||
def _init_time_mix_constants(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
num_blocks = self.config.num_layers
|
||||
embedding_dim = self.config.embedding_dim
|
||||
|
||||
ratio_1_to_almost0 = 1.0 - (self.block_id / num_blocks) # 1 to ~0
|
||||
embed_dim_val = torch.ones(1, 1, embedding_dim)
|
||||
for i in range(embedding_dim):
|
||||
embed_dim_val[0, 0, i] = i / embedding_dim
|
||||
|
||||
time_mix_k = torch.pow(embed_dim_val, ratio_1_to_almost0)
|
||||
time_mix_r = torch.pow(embed_dim_val, ratio_1_to_almost0)
|
||||
|
||||
return time_mix_k, time_mix_r
|
||||
|
||||
def forward(self, x):
|
||||
if self.config.use_timemix_channelmix:
|
||||
xx = self.time_shift(x)
|
||||
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
||||
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
||||
else:
|
||||
xk = x
|
||||
xr = x
|
||||
k = self.key(xk)
|
||||
k = self._act_fn(k)
|
||||
kv = self.value(k)
|
||||
y = torch.sigmoid(self.receptance(xr)) * kv
|
||||
return y
|
||||
Reference in New Issue
Block a user