import logging import math from dataclasses import dataclass, field from typing import Dict, Optional, Tuple, Union import torch import torch.nn.functional as F from torch import nn from ..base import ( BaseSequenceModelConfig, BaseSequenceModelTrain, ResettableParametersModule, ) from .wkv_kernel import WKV, WKVConfig, WKVTorch LOGGER = logging.getLogger(__name__) class L2Wrap(torch.autograd.Function): """L2 regularization for the logits.""" @staticmethod def forward(ctx, loss, y): ctx.save_for_backward(y) return loss @staticmethod def backward(ctx, grad_output): y = ctx.saved_tensors[0] # to encourage the logits to be close to 0 factor = 1e-4 / (y.shape[0] * y.shape[1]) maxx, ids = torch.max(y, -1, keepdim=True) gy = torch.zeros_like(y) gy.scatter_(-1, ids, maxx * factor) return (grad_output, gy) def _get_activation_fn(activation): if activation == "relu": return F.relu elif activation == "gelu": return F.gelu elif activation == "silu": return F.silu elif activation == "relu_squared": return lambda x: torch.square(torch.relu(x)) elif activation == "selu": return F.selu elif activation == "elu": return F.elu else: raise ValueError(f"Unknown activation function {activation}") @dataclass class RWKVConfig(BaseSequenceModelConfig): embedding_dim: int = 640 ffn_dim: int = 2048 num_layers: int = 12 attention_dim: int = -1 wkv_config: Optional[Union[WKVConfig, Dict]] = field( default_factory=lambda: WKVConfig() ) l2_logit_reg: bool = False use_timemix_timemix: bool = True use_timemix_channelmix: bool = True channelmix_act_fn: str = "relu_squared" init: str = "reproduce_init" # options "paper_init", "reproduce_init" (reproduce init is different from paper init) # the reproduce_init was figured out when only the code was available, not the paper # the init as it is described in the paper slightly differs from the reproduce_init # the reproduce_init actually performs much better than the paper_init (ca. 0.5 better in train loss after 300 steps on WikiText-103) # so we use this as default. reproduce_init is also (likely to be) used in the original code. def __post_init__(self): # TODO: Check if this was actually needed if self.wkv_config is not None: self.wkv_config.T_max = max( self.wkv_config.T_max, self.context_length ) if self.attention_dim <= 0: self.attention_dim = self.embedding_dim class RWKV(BaseSequenceModelTrain): config_class = RWKVConfig def __init__(self, config: RWKVConfig, **kwargs): super().__init__() self.config = config self.cfg = config # self.embedding = nn.Embedding( # num_embeddings=self.cfg.vocab_size, # embedding_dim=self.cfg.embedding_dim, # ) self.blocks = nn.ModuleList( [ RWKVBlock(config=self.cfg, block_idx=i) for i in range(self.cfg.num_layers) ] ) if self.cfg.wkv_config is not None: LOGGER.info("Using WKV cuda kernel.") else: LOGGER.info("Using WKV torch kernel.") # self.ln_out = nn.LayerNorm(self.cfg.embedding_dim) # self.head = nn.Linear( # self.cfg.embedding_dim, self.cfg.vocab_size, bias=False # ) self.reset_parameters() @property def context_length(self) -> int: return self.config.context_length # @property # def vocab_size(self) -> int: # return self.config.vocab_size def reset_parameters(self) -> None: # init embedding # default init is zero # TODO try this # we use a narrow uniform init, in the original code they use the initial learning rate # we just set it to a small value # emb_init_range = 0.0008 # 1e-3 # nn.init.uniform_( # self.embedding.weight, a=-emb_init_range, b=emb_init_range # ) # init blocks for b in self.blocks: b.reset_parameters() # init head and layer norm # self.head.reset_parameters() # self.ln_out.reset_parameters() def _create_optim_groups(self, config: RWKVConfig): optim_groups = [ {"params": [p for p in self.parameters()], "weight_decay": 0.0} ] return optim_groups def forward(self, x): # no embedding # # input shape: (B, T), T <= context_len, T are token ids # B, T = x.size() # assert T <= self.cfg.context_length, ( # f"input sequence length {T} exceeds model " # f"context length {self.cfg.context_length}" # ) # x = self.embedding(x) # (B, T, C), C = embedding_dim for i, block in enumerate(self.blocks): x = block(x) # x = self.ln_out(x) # no head # x = self.head(x) return x def get_loss_func(self): def loss_fn(y_hat, y): loss = F.cross_entropy( y_hat.view(-1, y_hat.size(-1)), y.view(-1), ignore_index=-1 ) if self.cfg.l2_logit_reg: loss = L2Wrap.apply(loss, y_hat) return loss return loss_fn def _calc_gain(weight: torch.Tensor) -> float: """Calculate the gain value of the given weight tensor.""" gain = 1.0 fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(weight) if fan_out > fan_in: gain = math.sqrt(fan_out / fan_in) return gain class RWKVBlock(ResettableParametersModule): def __init__(self, config: RWKVConfig, block_idx: int): super().__init__() self.config = config self.block_idx = block_idx self.ln0 = None if self.block_idx == 0: self.ln0 = nn.LayerNorm(self.config.embedding_dim) # TODO 1) maybe additional positional embedding here (only in block 0) self.ln1 = nn.LayerNorm(self.config.embedding_dim) self.ln2 = nn.LayerNorm(self.config.embedding_dim) # TODO 2) maybe pre feedforward here (channel mix) see line 325f in RWKV-v4neo/model.py self.attention_timemix = RWKVTimeMix( config=self.config, block_id=self.block_idx ) self.ffn_channelmix = RWKVChannelMix( config=self.config, block_id=self.block_idx ) def reset_parameters(self) -> None: if self.ln0 is not None: self.ln0.reset_parameters() self.ln1.reset_parameters() self.ln2.reset_parameters() self.attention_timemix.reset_parameters() self.ffn_channelmix.reset_parameters() def forward(self, x: torch.Tensor) -> torch.Tensor: if self.block_idx == 0 and self.ln0 is not None: x = self.ln0(x) # TODO 1) maybe positional embedding here (only in block 0) # x = x+pos_emb # TODO 2) maybe pre feedforward here (channel mix) see line 325f in RWKV-v4neo/model.py # residual connection 1 x = x + self.attention_timemix(self.ln1(x)) # residual connection 2 x = x + self.ffn_channelmix(self.ln2(x)) return x class RWKVTimeMix(ResettableParametersModule): def __init__(self, config: RWKVConfig, block_id: int): super().__init__() self.config = config self.block_id = block_id embedding_dim = self.config.embedding_dim attention_dim = self.config.attention_dim # init time mix constants req_grad = True # TODO make this configurable self.time_mix_k = nn.Parameter( torch.empty((1, 1, embedding_dim)), requires_grad=req_grad ) self.time_mix_v = nn.Parameter( torch.empty((1, 1, embedding_dim)), requires_grad=req_grad ) self.time_mix_r = nn.Parameter( torch.empty((1, 1, embedding_dim)), requires_grad=req_grad ) # init time decay self.time_decay = nn.Parameter( torch.empty((attention_dim,)), requires_grad=req_grad ) self.time_first = nn.Parameter( torch.empty((attention_dim,)), requires_grad=req_grad ) # init layers / parameters # this shifts the time dimension by 1 forward, pad 0 at first time step, remove last time step: self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) self.key = nn.Linear(embedding_dim, attention_dim, bias=False) self.value = nn.Linear(embedding_dim, attention_dim, bias=False) self.receptance = nn.Linear(embedding_dim, attention_dim, bias=False) self.output = nn.Linear(attention_dim, embedding_dim, bias=False) if self.config.wkv_config is not None: # use CUDA implementation self.wkv = WKV(config=self.config.wkv_config) else: # use pure PyTorch implementation self.wkv = WKVTorch() self.reset_parameters() def reset_parameters(self) -> None: # init time mix constants time_mix_k, time_mix_v, time_mix_r = self._init_time_mix_constants() req_grad = True self.time_mix_k = nn.Parameter(time_mix_k, requires_grad=req_grad) self.time_mix_v = nn.Parameter(time_mix_v, requires_grad=req_grad) self.time_mix_r = nn.Parameter(time_mix_r, requires_grad=req_grad) # init time decay time_decay, time_first = self._init_time_decay_constants() self.time_decay = nn.Parameter(time_decay, requires_grad=req_grad) self.time_first = nn.Parameter(time_first, requires_grad=req_grad) # init layers / parameters if self.config.init == "paper_init": # ZERO INIT nn.init.zeros_(self.receptance.weight) nn.init.zeros_(self.key.weight) nn.init.zeros_(self.value.weight) # NORMAL INIT nn.init.normal_( self.output.weight, std=math.sqrt(self.config.ffn_dim / self.config.embedding_dim), ) elif self.config.init == "reproduce_init": # ZERO INIT nn.init.zeros_(self.key.weight) nn.init.zeros_(self.receptance.weight) nn.init.zeros_(self.output.weight) # ORTHOGONAL INIT nn.init.orthogonal_( self.value.weight, gain=_calc_gain(self.value.weight) ) else: raise ValueError(f"Unknown init method {self.config.init}") def _compute_rkv(self, x): if self.config.use_timemix_timemix: xx = self.time_shift( x ) # Mix x with the previous timestep to produce xk, xv, xr xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) xv = x * self.time_mix_v + xx * (1 - self.time_mix_v) xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) else: xk = x xv = x xr = x k = self.key(xk) v = self.value(xv) r = self.receptance(xr) sr = torch.sigmoid(r) return sr, k, v def forward(self, x): B, T, C = x.size() # x = (batch_size, seq_len, embedding_dim) attention_dim = self.config.attention_dim sr, k, v = self._compute_rkv( x ) # sr, k, v = (batch_size, seq_len, attention_dim) # wkv cuda/torch kernel rwkv = sr * self.wkv( B, T, attention_dim, self.time_decay, self.time_first, k, v ) return self.output(rwkv) def _init_time_mix_constants( self, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: num_blocks = self.config.num_layers embedding_dim = self.config.embedding_dim ratio_0_to_1 = self.block_id / max(1, num_blocks - 1) # 0 to 1 ratio_1_to_almost0 = 1.0 - (self.block_id / num_blocks) # 1 to ~0 # TODO does this make sense? # different time mix constants for each block and each embedding dim embed_dim_val = torch.ones(1, 1, embedding_dim) for i in range(embedding_dim): embed_dim_val[0, 0, i] = i / embedding_dim # TODO check constants 0.3 and 0.5 time_mix_k = torch.pow(embed_dim_val, ratio_1_to_almost0) time_mix_v = ( torch.pow(embed_dim_val, ratio_1_to_almost0) + 0.3 * ratio_0_to_1 ) time_mix_r = torch.pow(embed_dim_val, 0.5 * ratio_1_to_almost0) return time_mix_k, time_mix_v, time_mix_r def _init_time_decay_constants(self) -> Tuple[torch.Tensor, torch.Tensor]: num_blocks = self.config.num_layers attention_dim = self.config.attention_dim ratio_0_to_1 = self.block_id / max(1, num_blocks - 1) # 0 to 1 # time decay # this encourages the model to decay the information in different memory cells (channel dimensions) # at different speeds decay_speed = torch.ones(attention_dim) for h in range(attention_dim): decay_speed[h] = -5 + 8 * (h / (attention_dim - 1)) ** ( 0.7 + 1.3 * ratio_0_to_1 ) time_decay = decay_speed # time first # The alternating zigzag pattern initially creates subtle variations in the tensor elements, # which are intended to help the model treat different dimensions of the embedding differently zigzag = ( torch.tensor([(i + 1) % 3 - 1 for i in range(attention_dim)]) * 0.5 ) time_first = ( torch.ones(attention_dim) * torch.log(torch.tensor(0.3)) + zigzag ) return time_decay, time_first class RWKVChannelMix(ResettableParametersModule): def __init__(self, config: RWKVConfig, block_id: int): super().__init__() self.config = config self.block_id = block_id self._act_fn = _get_activation_fn(self.config.channelmix_act_fn) embedding_dim = self.config.embedding_dim ffn_dim = self.config.ffn_dim # init time mix constants req_grad = True self.time_mix_k = nn.Parameter( torch.empty((1, 1, embedding_dim)), requires_grad=req_grad ) self.time_mix_r = nn.Parameter( torch.empty((1, 1, embedding_dim)), requires_grad=req_grad ) # init layers / parameters self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) self.key = nn.Linear(embedding_dim, ffn_dim, bias=False) self.receptance = nn.Linear(embedding_dim, embedding_dim, bias=False) self.value = nn.Linear(ffn_dim, embedding_dim, bias=False) self.reset_parameters() def reset_parameters(self): # init time mix constants time_mix_k, time_mix_r = self._init_time_mix_constants() req_grad = True self.time_mix_k = nn.Parameter(time_mix_k, requires_grad=req_grad) self.time_mix_r = nn.Parameter(time_mix_r, requires_grad=req_grad) # init layers / parameters if self.config.init == "paper_init": # ZERO INIT nn.init.zeros_(self.receptance.weight) nn.init.zeros_(self.key.weight) # NORMAL INIT nn.init.normal_( self.value.weight, std=math.sqrt(self.config.ffn_dim / self.config.embedding_dim), ) elif self.config.init == "reproduce_init": # ZERO INIT nn.init.zeros_(self.receptance.weight) nn.init.zeros_(self.value.weight) # ORTHOGONAL INIT nn.init.orthogonal_( self.key.weight, gain=_calc_gain(self.key.weight) ) else: raise ValueError(f"Unknown init method {self.config.init}") def _init_time_mix_constants(self) -> Tuple[torch.Tensor, torch.Tensor]: num_blocks = self.config.num_layers embedding_dim = self.config.embedding_dim ratio_1_to_almost0 = 1.0 - (self.block_id / num_blocks) # 1 to ~0 embed_dim_val = torch.ones(1, 1, embedding_dim) for i in range(embedding_dim): embed_dim_val[0, 0, i] = i / embedding_dim time_mix_k = torch.pow(embed_dim_val, ratio_1_to_almost0) time_mix_r = torch.pow(embed_dim_val, ratio_1_to_almost0) return time_mix_k, time_mix_r def forward(self, x): if self.config.use_timemix_channelmix: xx = self.time_shift(x) xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) else: xk = x xr = x k = self.key(xk) k = self._act_fn(k) kv = self.value(k) y = torch.sigmoid(self.receptance(xr)) * kv return y