Add RWKV, H3, Hyena
This commit is contained in:
92
models/rwkv/sequence_rwkv.py
Normal file
92
models/rwkv/sequence_rwkv.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Sequence
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...ml_utils.config import NameAndKwargs
|
||||
from ..base import BaseSequenceModelTrain
|
||||
from ..seq_enc_dec import create_decoder, create_encoder
|
||||
from .rwkv_model import RWKVBlock, RWKVConfig
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SequenceRWKVConfig(RWKVConfig):
|
||||
encoder: NameAndKwargs = None
|
||||
decoder: NameAndKwargs = None
|
||||
|
||||
|
||||
class SequenceRWKV(BaseSequenceModelTrain):
|
||||
config_class = SequenceRWKVConfig
|
||||
|
||||
def __init__(self, config: SequenceRWKVConfig, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.config = config
|
||||
if config.wkv_config is not None:
|
||||
LOGGER.info("Using WKV cuda kernel.")
|
||||
else:
|
||||
LOGGER.info("Using WKV torch kernel.")
|
||||
|
||||
self.encoder = create_encoder(config=config)
|
||||
self.decoder = create_decoder(config=config)
|
||||
|
||||
self.blocks = nn.ModuleList([RWKVBlock(config, block_idx=i) for i in range(config.num_layers)])
|
||||
self.blocks_ln = nn.LayerNorm(config.embedding_dim)
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
@property
|
||||
def context_length(self) -> int:
|
||||
return self.config.context_length
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return self.config.vocab_size
|
||||
|
||||
@property
|
||||
def input_dim(self) -> Sequence[int]:
|
||||
return self.config.input_dim
|
||||
|
||||
@property
|
||||
def output_dim(self) -> Sequence[int]:
|
||||
return self.config.output_dim
|
||||
|
||||
def reset_parameters(self):
|
||||
for block in self.blocks:
|
||||
block.reset_parameters()
|
||||
self.blocks_ln.reset_parameters()
|
||||
self.encoder.reset_parameters()
|
||||
self.decoder.reset_parameters()
|
||||
|
||||
def _create_optim_groups(self, config: RWKVConfig):
|
||||
optim_groups = [{"params": [p for p in self.parameters()], "weight_decay": 0.0}]
|
||||
return optim_groups
|
||||
|
||||
def get_loss_func(self) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]:
|
||||
import torch.nn.functional as F
|
||||
|
||||
def loss_fn(logits, targets):
|
||||
assert not torch.any(torch.isnan(logits.view(-1)))
|
||||
assert not torch.any(torch.isnan(targets.view(-1)))
|
||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
|
||||
return loss
|
||||
|
||||
return loss_fn
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
assert (
|
||||
x.size(1) <= self.config.context_length
|
||||
), f"Forward input sequence length {x.size(1)} is longer than context length {self.config.context_length}"
|
||||
|
||||
y = self.encoder(x)
|
||||
|
||||
for block in self.blocks:
|
||||
y = block(y)
|
||||
y = self.blocks_ln(y)
|
||||
|
||||
y = self.decoder(y)
|
||||
|
||||
return y
|
||||
Reference in New Issue
Block a user