Add RWKV, H3, Hyena

This commit is contained in:
2023-08-05 17:33:32 +02:00
parent a71030547c
commit 7b15a413d4
22 changed files with 1794 additions and 0 deletions

View File

@@ -0,0 +1,92 @@
import logging
from dataclasses import dataclass
from typing import Callable, Sequence
import torch
from torch import nn
from ...ml_utils.config import NameAndKwargs
from ..base import BaseSequenceModelTrain
from ..seq_enc_dec import create_decoder, create_encoder
from .rwkv_model import RWKVBlock, RWKVConfig
LOGGER = logging.getLogger(__name__)
@dataclass
class SequenceRWKVConfig(RWKVConfig):
encoder: NameAndKwargs = None
decoder: NameAndKwargs = None
class SequenceRWKV(BaseSequenceModelTrain):
config_class = SequenceRWKVConfig
def __init__(self, config: SequenceRWKVConfig, **kwargs):
super().__init__(**kwargs)
self.config = config
if config.wkv_config is not None:
LOGGER.info("Using WKV cuda kernel.")
else:
LOGGER.info("Using WKV torch kernel.")
self.encoder = create_encoder(config=config)
self.decoder = create_decoder(config=config)
self.blocks = nn.ModuleList([RWKVBlock(config, block_idx=i) for i in range(config.num_layers)])
self.blocks_ln = nn.LayerNorm(config.embedding_dim)
self.reset_parameters()
@property
def context_length(self) -> int:
return self.config.context_length
@property
def vocab_size(self) -> int:
return self.config.vocab_size
@property
def input_dim(self) -> Sequence[int]:
return self.config.input_dim
@property
def output_dim(self) -> Sequence[int]:
return self.config.output_dim
def reset_parameters(self):
for block in self.blocks:
block.reset_parameters()
self.blocks_ln.reset_parameters()
self.encoder.reset_parameters()
self.decoder.reset_parameters()
def _create_optim_groups(self, config: RWKVConfig):
optim_groups = [{"params": [p for p in self.parameters()], "weight_decay": 0.0}]
return optim_groups
def get_loss_func(self) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]:
import torch.nn.functional as F
def loss_fn(logits, targets):
assert not torch.any(torch.isnan(logits.view(-1)))
assert not torch.any(torch.isnan(targets.view(-1)))
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
return loss
return loss_fn
def forward(self, x: torch.Tensor) -> torch.Tensor:
assert (
x.size(1) <= self.config.context_length
), f"Forward input sequence length {x.size(1)} is longer than context length {self.config.context_length}"
y = self.encoder(x)
for block in self.blocks:
y = block(y)
y = self.blocks_ln(y)
y = self.decoder(y)
return y