93 lines
2.7 KiB
Python
93 lines
2.7 KiB
Python
import logging
|
|
from dataclasses import dataclass
|
|
from typing import Callable, Sequence
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from ...ml_utils.config import NameAndKwargs
|
|
from ..base import BaseSequenceModelTrain
|
|
from ..seq_enc_dec import create_decoder, create_encoder
|
|
from .rwkv_model import RWKVBlock, RWKVConfig
|
|
|
|
LOGGER = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class SequenceRWKVConfig(RWKVConfig):
|
|
encoder: NameAndKwargs = None
|
|
decoder: NameAndKwargs = None
|
|
|
|
|
|
class SequenceRWKV(BaseSequenceModelTrain):
|
|
config_class = SequenceRWKVConfig
|
|
|
|
def __init__(self, config: SequenceRWKVConfig, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.config = config
|
|
if config.wkv_config is not None:
|
|
LOGGER.info("Using WKV cuda kernel.")
|
|
else:
|
|
LOGGER.info("Using WKV torch kernel.")
|
|
|
|
self.encoder = create_encoder(config=config)
|
|
self.decoder = create_decoder(config=config)
|
|
|
|
self.blocks = nn.ModuleList([RWKVBlock(config, block_idx=i) for i in range(config.num_layers)])
|
|
self.blocks_ln = nn.LayerNorm(config.embedding_dim)
|
|
|
|
self.reset_parameters()
|
|
|
|
@property
|
|
def context_length(self) -> int:
|
|
return self.config.context_length
|
|
|
|
@property
|
|
def vocab_size(self) -> int:
|
|
return self.config.vocab_size
|
|
|
|
@property
|
|
def input_dim(self) -> Sequence[int]:
|
|
return self.config.input_dim
|
|
|
|
@property
|
|
def output_dim(self) -> Sequence[int]:
|
|
return self.config.output_dim
|
|
|
|
def reset_parameters(self):
|
|
for block in self.blocks:
|
|
block.reset_parameters()
|
|
self.blocks_ln.reset_parameters()
|
|
self.encoder.reset_parameters()
|
|
self.decoder.reset_parameters()
|
|
|
|
def _create_optim_groups(self, config: RWKVConfig):
|
|
optim_groups = [{"params": [p for p in self.parameters()], "weight_decay": 0.0}]
|
|
return optim_groups
|
|
|
|
def get_loss_func(self) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]:
|
|
import torch.nn.functional as F
|
|
|
|
def loss_fn(logits, targets):
|
|
assert not torch.any(torch.isnan(logits.view(-1)))
|
|
assert not torch.any(torch.isnan(targets.view(-1)))
|
|
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
|
|
return loss
|
|
|
|
return loss_fn
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
assert (
|
|
x.size(1) <= self.config.context_length
|
|
), f"Forward input sequence length {x.size(1)} is longer than context length {self.config.context_length}"
|
|
|
|
y = self.encoder(x)
|
|
|
|
for block in self.blocks:
|
|
y = block(y)
|
|
y = self.blocks_ln(y)
|
|
|
|
y = self.decoder(y)
|
|
|
|
return y
|