import logging from dataclasses import dataclass from typing import Callable, Sequence import torch from torch import nn from ...ml_utils.config import NameAndKwargs from ..base import BaseSequenceModelTrain from ..seq_enc_dec import create_decoder, create_encoder from .rwkv_model import RWKVBlock, RWKVConfig LOGGER = logging.getLogger(__name__) @dataclass class SequenceRWKVConfig(RWKVConfig): encoder: NameAndKwargs = None decoder: NameAndKwargs = None class SequenceRWKV(BaseSequenceModelTrain): config_class = SequenceRWKVConfig def __init__(self, config: SequenceRWKVConfig, **kwargs): super().__init__(**kwargs) self.config = config if config.wkv_config is not None: LOGGER.info("Using WKV cuda kernel.") else: LOGGER.info("Using WKV torch kernel.") self.encoder = create_encoder(config=config) self.decoder = create_decoder(config=config) self.blocks = nn.ModuleList([RWKVBlock(config, block_idx=i) for i in range(config.num_layers)]) self.blocks_ln = nn.LayerNorm(config.embedding_dim) self.reset_parameters() @property def context_length(self) -> int: return self.config.context_length @property def vocab_size(self) -> int: return self.config.vocab_size @property def input_dim(self) -> Sequence[int]: return self.config.input_dim @property def output_dim(self) -> Sequence[int]: return self.config.output_dim def reset_parameters(self): for block in self.blocks: block.reset_parameters() self.blocks_ln.reset_parameters() self.encoder.reset_parameters() self.decoder.reset_parameters() def _create_optim_groups(self, config: RWKVConfig): optim_groups = [{"params": [p for p in self.parameters()], "weight_decay": 0.0}] return optim_groups def get_loss_func(self) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]: import torch.nn.functional as F def loss_fn(logits, targets): assert not torch.any(torch.isnan(logits.view(-1))) assert not torch.any(torch.isnan(targets.view(-1))) loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) return loss return loss_fn def forward(self, x: torch.Tensor) -> torch.Tensor: assert ( x.size(1) <= self.config.context_length ), f"Forward input sequence length {x.size(1)} is longer than context length {self.config.context_length}" y = self.encoder(x) for block in self.blocks: y = block(y) y = self.blocks_ln(y) y = self.decoder(y) return y