Add RWKV, H3, Hyena
This commit is contained in:
160
.gitignore
vendored
Normal file
160
.gitignore
vendored
Normal file
@@ -0,0 +1,160 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
6
.gitmodules
vendored
Normal file
6
.gitmodules
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
[submodule "RWKV-LM"]
|
||||
path = RWKV-LM
|
||||
url = git@github.com:BlinkDL/RWKV-LM.git
|
||||
[submodule "safari"]
|
||||
path = safari
|
||||
url = git@github.com:HazyResearch/safari.git
|
||||
1
RWKV-LM
Submodule
1
RWKV-LM
Submodule
Submodule RWKV-LM added at 69e6c50001
137
TestModels.ipynb
Normal file
137
TestModels.ipynb
Normal file
@@ -0,0 +1,137 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%reload_ext autoreload\n",
|
||||
"import torch"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import importlib\n",
|
||||
"import os\n",
|
||||
"import sys\n",
|
||||
"\n",
|
||||
"# def import_from_file(module_name, file_path):\n",
|
||||
"# spec = importlib.util.spec_from_file_location(module_name, file_path)\n",
|
||||
"# module = importlib.util.module_from_spec(spec)\n",
|
||||
"# sys.modules[module_name] = module\n",
|
||||
"# spec.loader.exec_module(module)\n",
|
||||
"# return module\n",
|
||||
"\n",
|
||||
"# os.environ['RWKV_JIT_ON'] = '1'\n",
|
||||
"# os.environ['RWKV_T_MAX'] = '16384'\n",
|
||||
"# os.environ['RWKV_FLOAT_MODE'] = 'fp32'\n",
|
||||
"# os.environ['CUDA_HOME']\n",
|
||||
"# path = os.path.abspath('RWKV-LM/RWKV-v4neo/src/model.py')\n",
|
||||
"# rkwv_mod = import_from_file('rwkv_mod', path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Falling back on slow Vandermonde kernel. Install pykeops for improved memory efficiency.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from models.rwkv import RWKV, RWKVConfig\n",
|
||||
"sys.path.append('./safari')\n",
|
||||
"from safari.src.models.sequence.model import SequenceModel\n",
|
||||
"from safari.src.models.sequence.h3 import H3\n",
|
||||
"from safari.src.models.sequence.hyena import HyenaOperator\n",
|
||||
"from omegaconf import OmegaConf\n",
|
||||
"os.environ['CUDA_HOME'] = \".\"\n",
|
||||
"\n",
|
||||
"NUM_LAYERS = 3\n",
|
||||
"EMBEDDING_DIM = 128\n",
|
||||
"SEQUENCE_LENGTH = 1024\n",
|
||||
"\n",
|
||||
"def load_config(filename):\n",
|
||||
" with open(filename) as fp:\n",
|
||||
" cfg_yaml = fp.read()\n",
|
||||
" return OmegaConf.create(cfg_yaml)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"rwkv_full = RWKV(RWKVConfig(context_length=SEQUENCE_LENGTH, embedding_dim=EMBEDDING_DIM, num_layers=NUM_LAYERS, wkv_config=None))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"hyena_full = SequenceModel(layer=load_config('config/models/hyena.yaml'), d_model=EMBEDDING_DIM, n_layers=NUM_LAYERS)\n",
|
||||
"h3_full = SequenceModel(layer=load_config('config/models/h3.yaml'), d_model=EMBEDDING_DIM, n_layers=NUM_LAYERS)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"h3 = H3(d_model=EMBEDDING_DIM)\n",
|
||||
"hyena = HyenaOperator(d_model=EMBEDDING_DIM, l_max=16384)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "base",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.11"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
0
__init__.py
Normal file
0
__init__.py
Normal file
7
config/models/h3.yaml
Normal file
7
config/models/h3.yaml
Normal file
@@ -0,0 +1,7 @@
|
||||
_name_: h3
|
||||
d_state: 64
|
||||
head_dim: 1
|
||||
mode: diag
|
||||
measure: diag-lin
|
||||
# lr: ${eval:"min(0.001, ${optimizer.lr})"}
|
||||
lr: 0.001
|
||||
16
config/models/hyena.yaml
Normal file
16
config/models/hyena.yaml
Normal file
@@ -0,0 +1,16 @@
|
||||
_name_: hyena
|
||||
l_max: 1024
|
||||
order: 2
|
||||
filter_order: 64
|
||||
num_heads: 1
|
||||
inner_factor: 1
|
||||
num_blocks: 1
|
||||
fused_bias_fc: false
|
||||
outer_mixing: false
|
||||
dropout: 0.0
|
||||
filter_dropout: 0.0
|
||||
filter_cls: 'hyena-filter'
|
||||
post_order_ffn: false
|
||||
jit_filter: false
|
||||
short_filter_order: 3
|
||||
activation: "id"
|
||||
1
config/models/lstm.yaml
Normal file
1
config/models/lstm.yaml
Normal file
@@ -0,0 +1 @@
|
||||
_name_: LSTM_Transformer
|
||||
0
config/models/rwkv.yaml
Normal file
0
config/models/rwkv.yaml
Normal file
0
config/models/selfattn.yaml
Normal file
0
config/models/selfattn.yaml
Normal file
73
models/base.py
Normal file
73
models/base.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, List
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .interfaces import (
|
||||
ModelTrainInterface,
|
||||
SequenceConfigInterface,
|
||||
SequenceInterface,
|
||||
)
|
||||
from .base_model import BaseModel
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LayerConfigInterface(ABC):
|
||||
@abstractmethod
|
||||
def assign_model_config_params(self, model_config):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class OptimizerConfig:
|
||||
lr: float = 1e-3
|
||||
betas: List = field(default_factory=lambda: [0.9, 0.95])
|
||||
weight_decay: float = 0.0
|
||||
device_type: str = "cuda" # TODO is this necessary?
|
||||
fused: bool = True
|
||||
eps: float = 1e-6
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseSequenceModelConfig(SequenceConfigInterface):
|
||||
optimizer: OptimizerConfig = field(default_factory=OptimizerConfig)
|
||||
|
||||
shortname: str = "" # needed to give a model a more distinctive name used in configurations etc., temporary filled by hydra or OmegaConf
|
||||
|
||||
|
||||
class ResettableParametersModule(nn.Module, ABC):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def reset_parameters(self, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class BaseSequenceModelTrain(
|
||||
BaseModel, ModelTrainInterface, SequenceInterface
|
||||
):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def _create_optim_groups(self, **kwargs) -> list[dict[str, Any]]:
|
||||
# TODO think of a nice way to separate functionality from child classes
|
||||
# TODO move this into BaseModel and make it a separate interface too
|
||||
pass
|
||||
|
||||
def configure_optimizer(self) -> torch.optim.Optimizer:
|
||||
optim_cfg = self.config.optimizer
|
||||
optim_groups = self._create_optim_groups(self.config)
|
||||
use_fused = optim_cfg.device_type == "cuda" and optim_cfg.fused
|
||||
LOGGER.info(f"Using fused optimizer: {use_fused}")
|
||||
extra_args = dict(fused=True) if use_fused else dict()
|
||||
extra_args["eps"] = optim_cfg.eps
|
||||
optimizer = torch.optim.AdamW(
|
||||
optim_groups, lr=optim_cfg.lr, betas=optim_cfg.betas, **extra_args
|
||||
)
|
||||
return optimizer
|
||||
165
models/base_model.py
Normal file
165
models/base_model.py
Normal file
@@ -0,0 +1,165 @@
|
||||
import copy
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import asdict
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
FN_MODEL_PREFIX = "model_"
|
||||
FN_MODEL_FILE_EXT = ".p"
|
||||
|
||||
|
||||
def get_device(device: Union[torch.device, str, int]) -> torch.device:
|
||||
if device == "auto":
|
||||
device = "cuda"
|
||||
if isinstance(device, int):
|
||||
if device < 0:
|
||||
device = torch.device("cpu")
|
||||
else:
|
||||
device = torch.device(f"cuda:{device}")
|
||||
else:
|
||||
device = torch.device(device)
|
||||
|
||||
if (
|
||||
device.type == torch.device("cuda").type
|
||||
and not torch.cuda.is_available()
|
||||
):
|
||||
LOGGER.warn(f"Device '{str(device)}' is not available! Using cpu now.")
|
||||
return torch.device("cpu")
|
||||
return device
|
||||
|
||||
|
||||
class BaseModel(nn.Module, ABC):
|
||||
"""BaseModel class
|
||||
Takes care of easy saving and loading.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.config = None
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_loss_func(self, **kwargs):
|
||||
pass
|
||||
|
||||
def _get_constructor_parameters(self) -> dict:
|
||||
if isinstance(self.config, dict):
|
||||
return self.config
|
||||
return asdict(self.config)
|
||||
|
||||
def reset_parameters(self):
|
||||
self.apply(self.get_init_fn())
|
||||
|
||||
def get_init_fn(self) -> Callable[[torch.Tensor], None]:
|
||||
return None
|
||||
|
||||
@property
|
||||
def num_parameters(self) -> int:
|
||||
return (
|
||||
torch.tensor([p.numel() for p in self.parameters()]).sum().item()
|
||||
)
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return next(iter(self.parameters())).device
|
||||
|
||||
def copy_to_cpu(self) -> "BaseModel":
|
||||
"""Copy the model to CPU."""
|
||||
return copy.deepcopy(self).to(torch.device("cpu"))
|
||||
|
||||
def get_checkpoint_data(
|
||||
self, dict_key_prefix: str = "model_"
|
||||
) -> Dict[str, Any]:
|
||||
checkpoint_dict = {
|
||||
f"{dict_key_prefix}state_dict": self.state_dict(),
|
||||
f"{dict_key_prefix}data": self._get_constructor_parameters(),
|
||||
f"{dict_key_prefix}name": self.__class__.__name__,
|
||||
f"{dict_key_prefix}class": self.__class__,
|
||||
}
|
||||
return checkpoint_dict
|
||||
|
||||
def save(
|
||||
self,
|
||||
path: Union[str, Path],
|
||||
model_name: str,
|
||||
file_extension: Optional[str] = FN_MODEL_FILE_EXT,
|
||||
dict_key_prefix: str = "model_",
|
||||
) -> None:
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
save_path = path / (model_name + file_extension)
|
||||
torch.save(self.get_checkpoint_data(dict_key_prefix), save_path)
|
||||
|
||||
@staticmethod
|
||||
def model_save_name(
|
||||
idx: int, specifier: str = "epoch", num_digits: int = -1
|
||||
) -> str:
|
||||
"""Get a consistnet the model save name.
|
||||
|
||||
Args:
|
||||
epoch (int): Epoch / iteration number.
|
||||
specifier (str, optional): A specifier for the idx. Defaults to epoch.
|
||||
num_digits (int, optional): The number of digits in the save name. Unused by default,
|
||||
since this causes overrides when we have an overflow. Defaults to -1.
|
||||
|
||||
Returns:
|
||||
str: Model save name.
|
||||
"""
|
||||
if num_digits == -1:
|
||||
return f"{FN_MODEL_PREFIX}{specifier}_{idx}"
|
||||
else:
|
||||
return f"{FN_MODEL_PREFIX}{specifier}_{idx:0{num_digits}}"
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
cls,
|
||||
path: Union[str, Path],
|
||||
model_name: str = None,
|
||||
file_extension: Optional[str] = ".p",
|
||||
device: Union[torch.device, str, int] = "auto",
|
||||
dict_key_prefix: str = "model_",
|
||||
) -> "BaseModel":
|
||||
device = get_device(device)
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
if model_name is None:
|
||||
save_path = path
|
||||
else:
|
||||
save_path = path / (model_name + file_extension)
|
||||
checkpoint = torch.load(save_path, map_location=device)
|
||||
|
||||
return cls.params_from_checkpoint(
|
||||
checkpoint=checkpoint, dict_key_prefix=dict_key_prefix
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def params_from_checkpoint(
|
||||
cls, checkpoint: Dict[str, Any], dict_key_prefix: str = "model_"
|
||||
) -> "BaseModel":
|
||||
if hasattr(cls, "config_class"):
|
||||
from dacite import from_dict
|
||||
|
||||
config_cls = cls.config_class
|
||||
|
||||
model_cfg = from_dict(
|
||||
data_class=config_cls,
|
||||
data=checkpoint[f"{dict_key_prefix}data"],
|
||||
)
|
||||
model = cls(config=model_cfg)
|
||||
else:
|
||||
model = cls(**checkpoint[f"{dict_key_prefix}data"])
|
||||
model.load_state_dict(checkpoint[f"{dict_key_prefix}state_dict"])
|
||||
return model
|
||||
|
||||
# @staticmethod
|
||||
# def class_and_params_from_checkpoint(checkpoint: Dict[str, Any], dict_key_prefix: str = 'model_') -> 'BaseModel':
|
||||
# from . import get_model_class
|
||||
# model_class = get_model_class(checkpoint[f"{dict_key_prefix}name"])
|
||||
# model = model_class.params_from_checkpoint(checkpoint=checkpoint, dict_key_prefix=dict_key_prefix)
|
||||
# return model
|
||||
77
models/interfaces.py
Normal file
77
models/interfaces.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional, Protocol, Sequence
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class ModelTrainInterface(ABC):
|
||||
def configure_optimizer(self) -> Optional[torch.optim.Optimizer]:
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def get_loss_func(self, **kwargs) -> Callable[[Any], torch.Tensor]:
|
||||
pass
|
||||
|
||||
|
||||
class SequenceInterface(ABC):
|
||||
"""This is a generic interface for a sequence.
|
||||
In our case a sequence also includes its label. Therefore, the label (aka output_dim)
|
||||
is also part of the sequence interface.
|
||||
A sequence always has a length (=context_length).
|
||||
|
||||
A sequence can have one of the following flavors:
|
||||
Input sequence:
|
||||
- sequence of tokens (e.g. words): vocab_size must be specified
|
||||
- sequence of vectors: input_dim must be specified
|
||||
|
||||
Output sequence:
|
||||
- next token (e.g. word): `vocab_size` must be specified (e.g. Causal Language Modeling).
|
||||
- (sequence of) vectors: output_dim must be specified (e.g. Forecasting)
|
||||
- label: output_dim must be specified (e.g. Sequence Classification)
|
||||
|
||||
Examples:
|
||||
- Causal Language Modeling: input_dim = None, output_dim = None, vocab_size = int
|
||||
- Forecasting: input_dim = int, output_dim = int, vocab_size = None
|
||||
- Sequence Classification (General Sequence): input_dim = int, output_dim = int, vocab_size = None
|
||||
- Sequence Classification (Text): input_dim = None, output_dim = int, vocab_size = int
|
||||
"""
|
||||
|
||||
@property
|
||||
def input_dim(self) -> Optional[Sequence[int]]:
|
||||
return None
|
||||
|
||||
@property
|
||||
def output_dim(self) -> Optional[Sequence[int]]:
|
||||
return None
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> Optional[int]:
|
||||
return None
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def context_length(self) -> int:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class SequenceConfigInterface:
|
||||
context_length: int
|
||||
# vocab_size: Optional[int] = None
|
||||
input_dim: Optional[Sequence[int]] = None
|
||||
output_dim: Optional[Sequence[int]] = None
|
||||
|
||||
|
||||
class Tokenizer(Protocol):
|
||||
def __call__(self, **kwargs) -> Any:
|
||||
...
|
||||
|
||||
def __len__(self) -> int:
|
||||
...
|
||||
|
||||
|
||||
class TokenizerInterface(ABC):
|
||||
@property
|
||||
def tokenizer(self) -> Optional[Tokenizer]:
|
||||
return None
|
||||
3
models/rwkv/__init__.py
Normal file
3
models/rwkv/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .rwkv_model import RWKV, RWKVConfig
|
||||
|
||||
__all__ = ["RWKV", "RWKVConfig"]
|
||||
145
models/rwkv/cuda/wkv_cuda.cu
Normal file
145
models/rwkv/cuda/wkv_cuda.cu
Normal file
@@ -0,0 +1,145 @@
|
||||
#include <stdio.h>
|
||||
#include <assert.h>
|
||||
|
||||
#define MIN_VALUE (-1e38)
|
||||
|
||||
|
||||
template <typename F>
|
||||
__global__ void kernel_forward(const int B, // batch size
|
||||
const int T, // sequence length
|
||||
const int C, // dim_att (number of channels)
|
||||
const F *__restrict__ const _w_timedecay, // -exp(time decay) [attention_dim]
|
||||
const F *__restrict__ const _u_timefirst, // time first [attention_dim]
|
||||
const F *__restrict__ const _k, // keys [batch_size, sequence_length, embedding_dim]
|
||||
const F *__restrict__ const _v, // values [batch_size, sequence_length, embedding_dim]
|
||||
F *__restrict__ const _y // output [batch_size, sequence_length, embedding_dim]
|
||||
) {
|
||||
// - this block of code defines the area in the batch tensors that this thread will work on
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x; // this idx is for the whole batch
|
||||
const int _b = idx / C; // this idx is for the batch
|
||||
const int _c = idx % C; // this idx is for the channel
|
||||
const int _offset = _b * T * C + _c; // this idx is for the whole batch
|
||||
// -
|
||||
|
||||
// these are vectors of size C (channel dimension) element in embedding_dim
|
||||
F u_timefirst = _u_timefirst[_c];
|
||||
F w_timedecay = _w_timedecay[_c];
|
||||
// these are tensors of size B x T x C (stored as array, access through pointers)
|
||||
const F *__restrict__ const k = _k + _offset;
|
||||
const F *__restrict__ const v = _v + _offset;
|
||||
F *__restrict__ const y = _y + _offset;
|
||||
|
||||
// aa and bb are running sums divided by exp(pp) (to avoid overflow)
|
||||
F aa = 0, bb = 0, pp = MIN_VALUE;
|
||||
// loop goes over time dimension
|
||||
for (int i = 0; i < T; i++) {
|
||||
const int ii = i * C; // index ii picks the correct channel
|
||||
const F kk = k[ii];
|
||||
const F vv = v[ii];
|
||||
|
||||
F ww = u_timefirst + kk; //
|
||||
F p = max(pp, ww);
|
||||
F e1 = exp(pp - p); // exp(pp - p) is the same as exp(pp) / exp(p)
|
||||
F e2 = exp(ww - p);
|
||||
y[ii] = (e1 * aa + e2 * vv) / (e1 * bb + e2);
|
||||
|
||||
ww = w_timedecay + pp; //
|
||||
p = max(ww, kk);
|
||||
e1 = exp(ww - p);
|
||||
e2 = exp(kk - p);
|
||||
aa = e1 * aa + e2 * vv;
|
||||
bb = e1 * bb + e2;
|
||||
pp = p;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
__global__ void kernel_backward(const int B, const int T, const int C,
|
||||
const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
|
||||
const F *__restrict__ const _y, const F *__restrict__ const _gy,
|
||||
F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int _b = idx / C;
|
||||
const int _c = idx % C;
|
||||
const int _offset = _b * T * C + _c;
|
||||
|
||||
F u = _u[_c];
|
||||
F w = _w[_c];
|
||||
const F *__restrict__ const k = _k + _offset;
|
||||
const F *__restrict__ const v = _v + _offset;
|
||||
const F *__restrict__ const y = _y + _offset;
|
||||
const F *__restrict__ const gy = _gy + _offset;
|
||||
F *__restrict__ const gk = _gk + _offset;
|
||||
F *__restrict__ const gv = _gv + _offset;
|
||||
|
||||
F q[Tmax], r[Tmax];
|
||||
|
||||
F gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE;
|
||||
for (int i = 0; i < T; i++) {
|
||||
const int ii = i * C;
|
||||
const F kk = k[ii];
|
||||
const F vv = v[ii];
|
||||
const F yy = y[ii];
|
||||
|
||||
F ww = u + kk;
|
||||
F p = max(pp, ww);
|
||||
F e1 = exp(pp - p);
|
||||
F e2 = exp(ww - p);
|
||||
const F qq = gy[ii] / (e1 * bb + e2);
|
||||
gw += (ga - gb * yy) * e1 * qq;
|
||||
gu += (vv - yy) * e2 * qq;
|
||||
q[i] = qq;
|
||||
r[i] = ww - p;
|
||||
|
||||
ww = w + pp;
|
||||
p = max(ww, kk);
|
||||
e1 = exp(ww - p);
|
||||
e2 = exp(kk - p);
|
||||
ga = e1 * (aa + ga);
|
||||
gb = e1 * (bb + gb);
|
||||
aa = e1 * aa + e2 * vv;
|
||||
bb = e1 * bb + e2;
|
||||
pp = p;
|
||||
}
|
||||
const int _offsetBC = _b * C + _c;
|
||||
_gw[_offsetBC] = gw * _w[_c]; // multiply by w because of w -> -exp(w) in python forward()
|
||||
_gu[_offsetBC] = gu;
|
||||
|
||||
aa = 0, bb = 0, pp = MIN_VALUE;
|
||||
for (int i = T - 1; i >= 0; i--) {
|
||||
const int ii = i * C;
|
||||
const F kk = k[ii];
|
||||
const F vv = v[ii];
|
||||
const F yy = y[ii];
|
||||
const F qq = q[i];
|
||||
const F rr = r[i];
|
||||
|
||||
F e1 = qq * exp(rr);
|
||||
F e2 = exp(kk + pp);
|
||||
gk[ii] = e1 * (vv - yy) + e2 * (aa * vv + bb);
|
||||
gv[ii] = e1 + e2 * aa;
|
||||
|
||||
const F ww = w + pp;
|
||||
const F www = rr - u - kk;
|
||||
const F p = max(ww, www);
|
||||
e1 = exp(ww - p);
|
||||
e2 = qq * exp(www - p);
|
||||
aa = e1 * aa + e2;
|
||||
bb = e1 * bb - e2 * yy;
|
||||
pp = p;
|
||||
}
|
||||
}
|
||||
|
||||
void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) {
|
||||
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
|
||||
assert(B * C % threadsPerBlock.x == 0);
|
||||
dim3 numBlocks(B * C / threadsPerBlock.x);
|
||||
kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y);
|
||||
}
|
||||
|
||||
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv) {
|
||||
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
|
||||
assert(B * C % threadsPerBlock.x == 0);
|
||||
dim3 numBlocks(B * C / threadsPerBlock.x);
|
||||
kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv);
|
||||
}
|
||||
132
models/rwkv/cuda/wkv_cuda_bf16.cu
Normal file
132
models/rwkv/cuda/wkv_cuda_bf16.cu
Normal file
@@ -0,0 +1,132 @@
|
||||
#include <stdio.h>
|
||||
#include <assert.h>
|
||||
#include "ATen/ATen.h"
|
||||
#define MIN_VALUE (-1e38)
|
||||
typedef at::BFloat16 bf16;
|
||||
|
||||
__global__ void kernel_forward(const int B, const int T, const int C,
|
||||
const float *__restrict__ const _w, const bf16 *__restrict__ const _u, const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v,
|
||||
bf16 *__restrict__ const _y) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int _b = idx / C;
|
||||
const int _c = idx % C;
|
||||
const int _offset = _b * T * C + _c;
|
||||
|
||||
float u = float(_u[_c]);
|
||||
float w = _w[_c];
|
||||
const bf16 *__restrict__ const k = _k + _offset;
|
||||
const bf16 *__restrict__ const v = _v + _offset;
|
||||
bf16 *__restrict__ const y = _y + _offset;
|
||||
|
||||
// aa and bb are running sums divided by exp(pp) (to avoid overflow)
|
||||
float aa = 0, bb = 0, pp = MIN_VALUE;
|
||||
for (int i = 0; i < T; i++) {
|
||||
const int ii = i * C;
|
||||
const float kk = float(k[ii]);
|
||||
const float vv = float(v[ii]);
|
||||
|
||||
float ww = u + kk;
|
||||
float p = max(pp, ww);
|
||||
float e1 = exp(pp - p);
|
||||
float e2 = exp(ww - p);
|
||||
y[ii] = bf16((e1 * aa + e2 * vv) / (e1 * bb + e2));
|
||||
|
||||
ww = w + pp;
|
||||
p = max(ww, kk);
|
||||
e1 = exp(ww - p);
|
||||
e2 = exp(kk - p);
|
||||
aa = e1 * aa + e2 * vv;
|
||||
bb = e1 * bb + e2;
|
||||
pp = p;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void kernel_backward(const int B, const int T, const int C,
|
||||
const float *__restrict__ const _w, const bf16 *__restrict__ const _u, const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v,
|
||||
const bf16 *__restrict__ const _y, const bf16 *__restrict__ const _gy,
|
||||
bf16 *__restrict__ const _gw, bf16 *__restrict__ const _gu, bf16 *__restrict__ const _gk, bf16 *__restrict__ const _gv) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int _b = idx / C;
|
||||
const int _c = idx % C;
|
||||
const int _offset = _b * T * C + _c;
|
||||
|
||||
float u = float(_u[_c]);
|
||||
float w = _w[_c];
|
||||
const bf16 *__restrict__ const k = _k + _offset;
|
||||
const bf16 *__restrict__ const v = _v + _offset;
|
||||
const bf16 *__restrict__ const y = _y + _offset;
|
||||
const bf16 *__restrict__ const gy = _gy + _offset;
|
||||
bf16 *__restrict__ const gk = _gk + _offset;
|
||||
bf16 *__restrict__ const gv = _gv + _offset;
|
||||
|
||||
float q[Tmax], r[Tmax];
|
||||
|
||||
float gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE;
|
||||
for (int i = 0; i < T; i++) {
|
||||
const int ii = i * C;
|
||||
const float kk = float(k[ii]);
|
||||
const float vv = float(v[ii]);
|
||||
const float yy = float(y[ii]);
|
||||
|
||||
float ww = u + kk;
|
||||
float p = max(pp, ww);
|
||||
float e1 = exp(pp - p);
|
||||
float e2 = exp(ww - p);
|
||||
const float qq = float(gy[ii]) / (e1 * bb + e2);
|
||||
gw += (ga - gb * yy) * e1 * qq;
|
||||
gu += (vv - yy) * e2 * qq;
|
||||
q[i] = qq;
|
||||
r[i] = ww - p;
|
||||
|
||||
ww = w + pp;
|
||||
p = max(ww, kk);
|
||||
e1 = exp(ww - p);
|
||||
e2 = exp(kk - p);
|
||||
ga = e1 * (aa + ga);
|
||||
gb = e1 * (bb + gb);
|
||||
aa = e1 * aa + e2 * vv;
|
||||
bb = e1 * bb + e2;
|
||||
pp = p;
|
||||
}
|
||||
const int _offsetBC = _b * C + _c;
|
||||
_gw[_offsetBC] = bf16(gw * _w[_c]); // multiply by w because of w -> -exp(w) in python forward()
|
||||
_gu[_offsetBC] = bf16(gu);
|
||||
|
||||
aa = 0, bb = 0, pp = MIN_VALUE;
|
||||
for (int i = T - 1; i >= 0; i--) {
|
||||
const int ii = i * C;
|
||||
const float kk = float(k[ii]);
|
||||
const float vv = float(v[ii]);
|
||||
const float yy = float(y[ii]);
|
||||
const float qq = q[i];
|
||||
const float rr = r[i];
|
||||
|
||||
float e1 = qq * exp(rr);
|
||||
float e2 = exp(kk + pp);
|
||||
gk[ii] = bf16(e1 * (vv - yy) + e2 * (aa * vv + bb));
|
||||
gv[ii] = bf16(e1 + e2 * aa);
|
||||
|
||||
const float ww = w + pp;
|
||||
const float www = rr - u - kk;
|
||||
const float p = max(ww, www);
|
||||
e1 = exp(ww - p);
|
||||
e2 = qq * exp(www - p);
|
||||
aa = e1 * aa + e2;
|
||||
bb = e1 * bb - e2 * yy;
|
||||
pp = p;
|
||||
}
|
||||
}
|
||||
|
||||
void cuda_forward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y) {
|
||||
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
|
||||
assert(B * C % threadsPerBlock.x == 0);
|
||||
dim3 numBlocks(B * C / threadsPerBlock.x);
|
||||
kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y);
|
||||
}
|
||||
|
||||
void cuda_backward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv) {
|
||||
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
|
||||
assert(B * C % threadsPerBlock.x == 0);
|
||||
dim3 numBlocks(B * C / threadsPerBlock.x);
|
||||
kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv);
|
||||
}
|
||||
25
models/rwkv/cuda/wkv_op.cpp
Normal file
25
models/rwkv/cuda/wkv_op.cpp
Normal file
@@ -0,0 +1,25 @@
|
||||
#include <torch/extension.h>
|
||||
// #include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y);
|
||||
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv);
|
||||
|
||||
void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
|
||||
// TODO add this line const at::cuda::OptionalCUDAGuard device_guard(device_of(w));
|
||||
// (see https://discord.com/channels/992359628979568762/1084547464452907088/1091680712178016326)
|
||||
// see https://github.com/BlinkDL/ChatRWKV/blob/main/rwkv_pip_package/src/rwkv/cuda/wrapper.cpp
|
||||
cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>());
|
||||
}
|
||||
void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
|
||||
cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>(), gy.data_ptr<float>(), gw.data_ptr<float>(), gu.data_ptr<float>(), gk.data_ptr<float>(), gv.data_ptr<float>());
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("forward", &forward, "wkv forward");
|
||||
m.def("backward", &backward, "wkv backward");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY(wkv, m) {
|
||||
m.def("forward", forward);
|
||||
m.def("backward", backward);
|
||||
}
|
||||
25
models/rwkv/cuda/wkv_op_bf16.cpp
Normal file
25
models/rwkv/cuda/wkv_op_bf16.cpp
Normal file
@@ -0,0 +1,25 @@
|
||||
#include <torch/extension.h>
|
||||
#include "ATen/ATen.h"
|
||||
typedef at::BFloat16 bf16;
|
||||
|
||||
void cuda_forward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y);
|
||||
void cuda_backward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv);
|
||||
|
||||
void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
|
||||
cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>());
|
||||
}
|
||||
void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y,
|
||||
torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
|
||||
cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>(),
|
||||
gy.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>());
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("forward", &forward, "wkv forward");
|
||||
m.def("backward", &backward, "wkv backward");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY(wkv, m) {
|
||||
m.def("forward", forward);
|
||||
m.def("backward", backward);
|
||||
}
|
||||
469
models/rwkv/rwkv_model.py
Normal file
469
models/rwkv/rwkv_model.py
Normal file
@@ -0,0 +1,469 @@
|
||||
import logging
|
||||
import math
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..base import (
|
||||
BaseSequenceModelConfig,
|
||||
BaseSequenceModelTrain,
|
||||
ResettableParametersModule,
|
||||
)
|
||||
from .wkv_kernel import WKV, WKVConfig, WKVTorch
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class L2Wrap(torch.autograd.Function):
|
||||
"""L2 regularization for the logits."""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, loss, y):
|
||||
ctx.save_for_backward(y)
|
||||
return loss
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
y = ctx.saved_tensors[0]
|
||||
# to encourage the logits to be close to 0
|
||||
factor = 1e-4 / (y.shape[0] * y.shape[1])
|
||||
maxx, ids = torch.max(y, -1, keepdim=True)
|
||||
gy = torch.zeros_like(y)
|
||||
gy.scatter_(-1, ids, maxx * factor)
|
||||
return (grad_output, gy)
|
||||
|
||||
|
||||
def _get_activation_fn(activation):
|
||||
if activation == "relu":
|
||||
return F.relu
|
||||
elif activation == "gelu":
|
||||
return F.gelu
|
||||
elif activation == "silu":
|
||||
return F.silu
|
||||
elif activation == "relu_squared":
|
||||
return lambda x: torch.square(torch.relu(x))
|
||||
elif activation == "selu":
|
||||
return F.selu
|
||||
elif activation == "elu":
|
||||
return F.elu
|
||||
else:
|
||||
raise ValueError(f"Unknown activation function {activation}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class RWKVConfig(BaseSequenceModelConfig):
|
||||
embedding_dim: int = 640
|
||||
ffn_dim: int = 2048
|
||||
num_layers: int = 12
|
||||
attention_dim: int = -1
|
||||
wkv_config: Optional[Union[WKVConfig, Dict]] = field(
|
||||
default_factory=lambda: WKVConfig()
|
||||
)
|
||||
l2_logit_reg: bool = False
|
||||
use_timemix_timemix: bool = True
|
||||
use_timemix_channelmix: bool = True
|
||||
channelmix_act_fn: str = "relu_squared"
|
||||
init: str = "reproduce_init" # options "paper_init", "reproduce_init" (reproduce init is different from paper init)
|
||||
# the reproduce_init was figured out when only the code was available, not the paper
|
||||
# the init as it is described in the paper slightly differs from the reproduce_init
|
||||
# the reproduce_init actually performs much better than the paper_init (ca. 0.5 better in train loss after 300 steps on WikiText-103)
|
||||
# so we use this as default. reproduce_init is also (likely to be) used in the original code.
|
||||
|
||||
def __post_init__(self):
|
||||
# TODO: Check if this was actually needed
|
||||
if self.wkv_config is not None:
|
||||
self.wkv_config.T_max = max(
|
||||
self.wkv_config.T_max, self.context_length
|
||||
)
|
||||
if self.attention_dim <= 0:
|
||||
self.attention_dim = self.embedding_dim
|
||||
|
||||
|
||||
class RWKV(BaseSequenceModelTrain):
|
||||
config_class = RWKVConfig
|
||||
|
||||
def __init__(self, config: RWKVConfig, **kwargs):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.cfg = config
|
||||
|
||||
# self.embedding = nn.Embedding(
|
||||
# num_embeddings=self.cfg.vocab_size,
|
||||
# embedding_dim=self.cfg.embedding_dim,
|
||||
# )
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
RWKVBlock(config=self.cfg, block_idx=i)
|
||||
for i in range(self.cfg.num_layers)
|
||||
]
|
||||
)
|
||||
if self.cfg.wkv_config is not None:
|
||||
LOGGER.info("Using WKV cuda kernel.")
|
||||
else:
|
||||
LOGGER.info("Using WKV torch kernel.")
|
||||
|
||||
# self.ln_out = nn.LayerNorm(self.cfg.embedding_dim)
|
||||
# self.head = nn.Linear(
|
||||
# self.cfg.embedding_dim, self.cfg.vocab_size, bias=False
|
||||
# )
|
||||
self.reset_parameters()
|
||||
|
||||
@property
|
||||
def context_length(self) -> int:
|
||||
return self.config.context_length
|
||||
|
||||
# @property
|
||||
# def vocab_size(self) -> int:
|
||||
# return self.config.vocab_size
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
# init embedding
|
||||
# default init is zero # TODO try this
|
||||
# we use a narrow uniform init, in the original code they use the initial learning rate
|
||||
# we just set it to a small value
|
||||
# emb_init_range = 0.0008 # 1e-3
|
||||
# nn.init.uniform_(
|
||||
# self.embedding.weight, a=-emb_init_range, b=emb_init_range
|
||||
# )
|
||||
# init blocks
|
||||
for b in self.blocks:
|
||||
b.reset_parameters()
|
||||
# init head and layer norm
|
||||
# self.head.reset_parameters()
|
||||
# self.ln_out.reset_parameters()
|
||||
|
||||
def _create_optim_groups(self, config: RWKVConfig):
|
||||
optim_groups = [
|
||||
{"params": [p for p in self.parameters()], "weight_decay": 0.0}
|
||||
]
|
||||
return optim_groups
|
||||
|
||||
def forward(self, x):
|
||||
# no embedding
|
||||
# # input shape: (B, T), T <= context_len, T are token ids
|
||||
# B, T = x.size()
|
||||
# assert T <= self.cfg.context_length, (
|
||||
# f"input sequence length {T} exceeds model "
|
||||
# f"context length {self.cfg.context_length}"
|
||||
# )
|
||||
|
||||
# x = self.embedding(x) # (B, T, C), C = embedding_dim
|
||||
|
||||
for i, block in enumerate(self.blocks):
|
||||
x = block(x)
|
||||
|
||||
# x = self.ln_out(x)
|
||||
# no head
|
||||
# x = self.head(x)
|
||||
return x
|
||||
|
||||
def get_loss_func(self):
|
||||
def loss_fn(y_hat, y):
|
||||
loss = F.cross_entropy(
|
||||
y_hat.view(-1, y_hat.size(-1)), y.view(-1), ignore_index=-1
|
||||
)
|
||||
if self.cfg.l2_logit_reg:
|
||||
loss = L2Wrap.apply(loss, y_hat)
|
||||
return loss
|
||||
|
||||
return loss_fn
|
||||
|
||||
|
||||
def _calc_gain(weight: torch.Tensor) -> float:
|
||||
"""Calculate the gain value of the given weight tensor."""
|
||||
gain = 1.0
|
||||
fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(weight)
|
||||
if fan_out > fan_in:
|
||||
gain = math.sqrt(fan_out / fan_in)
|
||||
return gain
|
||||
|
||||
|
||||
class RWKVBlock(ResettableParametersModule):
|
||||
def __init__(self, config: RWKVConfig, block_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.block_idx = block_idx
|
||||
|
||||
self.ln0 = None
|
||||
if self.block_idx == 0:
|
||||
self.ln0 = nn.LayerNorm(self.config.embedding_dim)
|
||||
# TODO 1) maybe additional positional embedding here (only in block 0)
|
||||
|
||||
self.ln1 = nn.LayerNorm(self.config.embedding_dim)
|
||||
self.ln2 = nn.LayerNorm(self.config.embedding_dim)
|
||||
|
||||
# TODO 2) maybe pre feedforward here (channel mix) see line 325f in RWKV-v4neo/model.py
|
||||
self.attention_timemix = RWKVTimeMix(
|
||||
config=self.config, block_id=self.block_idx
|
||||
)
|
||||
self.ffn_channelmix = RWKVChannelMix(
|
||||
config=self.config, block_id=self.block_idx
|
||||
)
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
if self.ln0 is not None:
|
||||
self.ln0.reset_parameters()
|
||||
self.ln1.reset_parameters()
|
||||
self.ln2.reset_parameters()
|
||||
self.attention_timemix.reset_parameters()
|
||||
self.ffn_channelmix.reset_parameters()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.block_idx == 0 and self.ln0 is not None:
|
||||
x = self.ln0(x)
|
||||
# TODO 1) maybe positional embedding here (only in block 0)
|
||||
# x = x+pos_emb
|
||||
|
||||
# TODO 2) maybe pre feedforward here (channel mix) see line 325f in RWKV-v4neo/model.py
|
||||
# residual connection 1
|
||||
x = x + self.attention_timemix(self.ln1(x))
|
||||
# residual connection 2
|
||||
x = x + self.ffn_channelmix(self.ln2(x))
|
||||
return x
|
||||
|
||||
|
||||
class RWKVTimeMix(ResettableParametersModule):
|
||||
def __init__(self, config: RWKVConfig, block_id: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.block_id = block_id
|
||||
|
||||
embedding_dim = self.config.embedding_dim
|
||||
attention_dim = self.config.attention_dim
|
||||
# init time mix constants
|
||||
req_grad = True # TODO make this configurable
|
||||
self.time_mix_k = nn.Parameter(
|
||||
torch.empty((1, 1, embedding_dim)), requires_grad=req_grad
|
||||
)
|
||||
self.time_mix_v = nn.Parameter(
|
||||
torch.empty((1, 1, embedding_dim)), requires_grad=req_grad
|
||||
)
|
||||
self.time_mix_r = nn.Parameter(
|
||||
torch.empty((1, 1, embedding_dim)), requires_grad=req_grad
|
||||
)
|
||||
|
||||
# init time decay
|
||||
self.time_decay = nn.Parameter(
|
||||
torch.empty((attention_dim,)), requires_grad=req_grad
|
||||
)
|
||||
self.time_first = nn.Parameter(
|
||||
torch.empty((attention_dim,)), requires_grad=req_grad
|
||||
)
|
||||
|
||||
# init layers / parameters
|
||||
# this shifts the time dimension by 1 forward, pad 0 at first time step, remove last time step:
|
||||
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
||||
self.key = nn.Linear(embedding_dim, attention_dim, bias=False)
|
||||
self.value = nn.Linear(embedding_dim, attention_dim, bias=False)
|
||||
self.receptance = nn.Linear(embedding_dim, attention_dim, bias=False)
|
||||
self.output = nn.Linear(attention_dim, embedding_dim, bias=False)
|
||||
|
||||
if self.config.wkv_config is not None:
|
||||
# use CUDA implementation
|
||||
self.wkv = WKV(config=self.config.wkv_config)
|
||||
else:
|
||||
# use pure PyTorch implementation
|
||||
self.wkv = WKVTorch()
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
# init time mix constants
|
||||
time_mix_k, time_mix_v, time_mix_r = self._init_time_mix_constants()
|
||||
req_grad = True
|
||||
self.time_mix_k = nn.Parameter(time_mix_k, requires_grad=req_grad)
|
||||
self.time_mix_v = nn.Parameter(time_mix_v, requires_grad=req_grad)
|
||||
self.time_mix_r = nn.Parameter(time_mix_r, requires_grad=req_grad)
|
||||
# init time decay
|
||||
time_decay, time_first = self._init_time_decay_constants()
|
||||
self.time_decay = nn.Parameter(time_decay, requires_grad=req_grad)
|
||||
self.time_first = nn.Parameter(time_first, requires_grad=req_grad)
|
||||
# init layers / parameters
|
||||
if self.config.init == "paper_init":
|
||||
# ZERO INIT
|
||||
nn.init.zeros_(self.receptance.weight)
|
||||
nn.init.zeros_(self.key.weight)
|
||||
nn.init.zeros_(self.value.weight)
|
||||
# NORMAL INIT
|
||||
nn.init.normal_(
|
||||
self.output.weight,
|
||||
std=math.sqrt(self.config.ffn_dim / self.config.embedding_dim),
|
||||
)
|
||||
elif self.config.init == "reproduce_init":
|
||||
# ZERO INIT
|
||||
nn.init.zeros_(self.key.weight)
|
||||
nn.init.zeros_(self.receptance.weight)
|
||||
nn.init.zeros_(self.output.weight)
|
||||
# ORTHOGONAL INIT
|
||||
nn.init.orthogonal_(
|
||||
self.value.weight, gain=_calc_gain(self.value.weight)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown init method {self.config.init}")
|
||||
|
||||
def _compute_rkv(self, x):
|
||||
if self.config.use_timemix_timemix:
|
||||
xx = self.time_shift(
|
||||
x
|
||||
) # Mix x with the previous timestep to produce xk, xv, xr
|
||||
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
||||
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
|
||||
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
||||
else:
|
||||
xk = x
|
||||
xv = x
|
||||
xr = x
|
||||
k = self.key(xk)
|
||||
v = self.value(xv)
|
||||
r = self.receptance(xr)
|
||||
sr = torch.sigmoid(r)
|
||||
return sr, k, v
|
||||
|
||||
def forward(self, x):
|
||||
B, T, C = x.size() # x = (batch_size, seq_len, embedding_dim)
|
||||
attention_dim = self.config.attention_dim
|
||||
sr, k, v = self._compute_rkv(
|
||||
x
|
||||
) # sr, k, v = (batch_size, seq_len, attention_dim)
|
||||
# wkv cuda/torch kernel
|
||||
rwkv = sr * self.wkv(
|
||||
B, T, attention_dim, self.time_decay, self.time_first, k, v
|
||||
)
|
||||
return self.output(rwkv)
|
||||
|
||||
def _init_time_mix_constants(
|
||||
self,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
num_blocks = self.config.num_layers
|
||||
embedding_dim = self.config.embedding_dim
|
||||
|
||||
ratio_0_to_1 = self.block_id / max(1, num_blocks - 1) # 0 to 1
|
||||
ratio_1_to_almost0 = 1.0 - (self.block_id / num_blocks) # 1 to ~0
|
||||
|
||||
# TODO does this make sense?
|
||||
# different time mix constants for each block and each embedding dim
|
||||
embed_dim_val = torch.ones(1, 1, embedding_dim)
|
||||
for i in range(embedding_dim):
|
||||
embed_dim_val[0, 0, i] = i / embedding_dim
|
||||
|
||||
# TODO check constants 0.3 and 0.5
|
||||
time_mix_k = torch.pow(embed_dim_val, ratio_1_to_almost0)
|
||||
time_mix_v = (
|
||||
torch.pow(embed_dim_val, ratio_1_to_almost0) + 0.3 * ratio_0_to_1
|
||||
)
|
||||
time_mix_r = torch.pow(embed_dim_val, 0.5 * ratio_1_to_almost0)
|
||||
|
||||
return time_mix_k, time_mix_v, time_mix_r
|
||||
|
||||
def _init_time_decay_constants(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
num_blocks = self.config.num_layers
|
||||
attention_dim = self.config.attention_dim
|
||||
ratio_0_to_1 = self.block_id / max(1, num_blocks - 1) # 0 to 1
|
||||
|
||||
# time decay
|
||||
# this encourages the model to decay the information in different memory cells (channel dimensions)
|
||||
# at different speeds
|
||||
decay_speed = torch.ones(attention_dim)
|
||||
for h in range(attention_dim):
|
||||
decay_speed[h] = -5 + 8 * (h / (attention_dim - 1)) ** (
|
||||
0.7 + 1.3 * ratio_0_to_1
|
||||
)
|
||||
time_decay = decay_speed
|
||||
|
||||
# time first
|
||||
# The alternating zigzag pattern initially creates subtle variations in the tensor elements,
|
||||
# which are intended to help the model treat different dimensions of the embedding differently
|
||||
zigzag = (
|
||||
torch.tensor([(i + 1) % 3 - 1 for i in range(attention_dim)]) * 0.5
|
||||
)
|
||||
time_first = (
|
||||
torch.ones(attention_dim) * torch.log(torch.tensor(0.3)) + zigzag
|
||||
)
|
||||
|
||||
return time_decay, time_first
|
||||
|
||||
|
||||
class RWKVChannelMix(ResettableParametersModule):
|
||||
def __init__(self, config: RWKVConfig, block_id: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.block_id = block_id
|
||||
|
||||
self._act_fn = _get_activation_fn(self.config.channelmix_act_fn)
|
||||
|
||||
embedding_dim = self.config.embedding_dim
|
||||
ffn_dim = self.config.ffn_dim
|
||||
# init time mix constants
|
||||
req_grad = True
|
||||
self.time_mix_k = nn.Parameter(
|
||||
torch.empty((1, 1, embedding_dim)), requires_grad=req_grad
|
||||
)
|
||||
self.time_mix_r = nn.Parameter(
|
||||
torch.empty((1, 1, embedding_dim)), requires_grad=req_grad
|
||||
)
|
||||
|
||||
# init layers / parameters
|
||||
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
||||
self.key = nn.Linear(embedding_dim, ffn_dim, bias=False)
|
||||
self.receptance = nn.Linear(embedding_dim, embedding_dim, bias=False)
|
||||
self.value = nn.Linear(ffn_dim, embedding_dim, bias=False)
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
# init time mix constants
|
||||
time_mix_k, time_mix_r = self._init_time_mix_constants()
|
||||
req_grad = True
|
||||
self.time_mix_k = nn.Parameter(time_mix_k, requires_grad=req_grad)
|
||||
self.time_mix_r = nn.Parameter(time_mix_r, requires_grad=req_grad)
|
||||
# init layers / parameters
|
||||
if self.config.init == "paper_init":
|
||||
# ZERO INIT
|
||||
nn.init.zeros_(self.receptance.weight)
|
||||
nn.init.zeros_(self.key.weight)
|
||||
# NORMAL INIT
|
||||
nn.init.normal_(
|
||||
self.value.weight,
|
||||
std=math.sqrt(self.config.ffn_dim / self.config.embedding_dim),
|
||||
)
|
||||
elif self.config.init == "reproduce_init":
|
||||
# ZERO INIT
|
||||
nn.init.zeros_(self.receptance.weight)
|
||||
nn.init.zeros_(self.value.weight)
|
||||
# ORTHOGONAL INIT
|
||||
nn.init.orthogonal_(
|
||||
self.key.weight, gain=_calc_gain(self.key.weight)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown init method {self.config.init}")
|
||||
|
||||
def _init_time_mix_constants(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
num_blocks = self.config.num_layers
|
||||
embedding_dim = self.config.embedding_dim
|
||||
|
||||
ratio_1_to_almost0 = 1.0 - (self.block_id / num_blocks) # 1 to ~0
|
||||
embed_dim_val = torch.ones(1, 1, embedding_dim)
|
||||
for i in range(embedding_dim):
|
||||
embed_dim_val[0, 0, i] = i / embedding_dim
|
||||
|
||||
time_mix_k = torch.pow(embed_dim_val, ratio_1_to_almost0)
|
||||
time_mix_r = torch.pow(embed_dim_val, ratio_1_to_almost0)
|
||||
|
||||
return time_mix_k, time_mix_r
|
||||
|
||||
def forward(self, x):
|
||||
if self.config.use_timemix_channelmix:
|
||||
xx = self.time_shift(x)
|
||||
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
||||
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
||||
else:
|
||||
xk = x
|
||||
xr = x
|
||||
k = self.key(xk)
|
||||
k = self._act_fn(k)
|
||||
kv = self.value(k)
|
||||
y = torch.sigmoid(self.receptance(xr)) * kv
|
||||
return y
|
||||
92
models/rwkv/sequence_rwkv.py
Normal file
92
models/rwkv/sequence_rwkv.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Sequence
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...ml_utils.config import NameAndKwargs
|
||||
from ..base import BaseSequenceModelTrain
|
||||
from ..seq_enc_dec import create_decoder, create_encoder
|
||||
from .rwkv_model import RWKVBlock, RWKVConfig
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SequenceRWKVConfig(RWKVConfig):
|
||||
encoder: NameAndKwargs = None
|
||||
decoder: NameAndKwargs = None
|
||||
|
||||
|
||||
class SequenceRWKV(BaseSequenceModelTrain):
|
||||
config_class = SequenceRWKVConfig
|
||||
|
||||
def __init__(self, config: SequenceRWKVConfig, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.config = config
|
||||
if config.wkv_config is not None:
|
||||
LOGGER.info("Using WKV cuda kernel.")
|
||||
else:
|
||||
LOGGER.info("Using WKV torch kernel.")
|
||||
|
||||
self.encoder = create_encoder(config=config)
|
||||
self.decoder = create_decoder(config=config)
|
||||
|
||||
self.blocks = nn.ModuleList([RWKVBlock(config, block_idx=i) for i in range(config.num_layers)])
|
||||
self.blocks_ln = nn.LayerNorm(config.embedding_dim)
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
@property
|
||||
def context_length(self) -> int:
|
||||
return self.config.context_length
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return self.config.vocab_size
|
||||
|
||||
@property
|
||||
def input_dim(self) -> Sequence[int]:
|
||||
return self.config.input_dim
|
||||
|
||||
@property
|
||||
def output_dim(self) -> Sequence[int]:
|
||||
return self.config.output_dim
|
||||
|
||||
def reset_parameters(self):
|
||||
for block in self.blocks:
|
||||
block.reset_parameters()
|
||||
self.blocks_ln.reset_parameters()
|
||||
self.encoder.reset_parameters()
|
||||
self.decoder.reset_parameters()
|
||||
|
||||
def _create_optim_groups(self, config: RWKVConfig):
|
||||
optim_groups = [{"params": [p for p in self.parameters()], "weight_decay": 0.0}]
|
||||
return optim_groups
|
||||
|
||||
def get_loss_func(self) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]:
|
||||
import torch.nn.functional as F
|
||||
|
||||
def loss_fn(logits, targets):
|
||||
assert not torch.any(torch.isnan(logits.view(-1)))
|
||||
assert not torch.any(torch.isnan(targets.view(-1)))
|
||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
|
||||
return loss
|
||||
|
||||
return loss_fn
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
assert (
|
||||
x.size(1) <= self.config.context_length
|
||||
), f"Forward input sequence length {x.size(1)} is longer than context length {self.config.context_length}"
|
||||
|
||||
y = self.encoder(x)
|
||||
|
||||
for block in self.blocks:
|
||||
y = block(y)
|
||||
y = self.blocks_ln(y)
|
||||
|
||||
y = self.decoder(y)
|
||||
|
||||
return y
|
||||
259
models/rwkv/wkv_kernel.py
Normal file
259
models/rwkv/wkv_kernel.py
Normal file
@@ -0,0 +1,259 @@
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
from torch import nn
|
||||
from torch.utils.cpp_extension import load
|
||||
|
||||
"""This module is a wrapper for the C++/CUDA implementation of the WKV kernel.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class WKVConfig:
|
||||
T_max: int = 1024 # max sequence length within cuda operations
|
||||
cpp_ext_name: str = "wkv"
|
||||
device: Union[str, torch.device] = "cuda"
|
||||
float_mode: str = "fp32" # options: fp32, fp16, bfloat16
|
||||
|
||||
def __post_init__(self):
|
||||
self.device = torch.device(self.device)
|
||||
|
||||
def float_mode_to_dtype(self):
|
||||
if self.float_mode == "fp32" or "32" in str(self.float_mode):
|
||||
return torch.float32
|
||||
elif self.float_mode == "fp16" or "16" in str(self.float_mode):
|
||||
return torch.float16
|
||||
elif self.float_mode == "bfloat16":
|
||||
return torch.bfloat16
|
||||
else:
|
||||
raise ValueError(f"Unknown float_mode: {self.float_mode}")
|
||||
|
||||
|
||||
class WKV(nn.Module):
|
||||
_instance = None # for singleton
|
||||
|
||||
class _WKV(torch.autograd.Function):
|
||||
wkv_cuda = None
|
||||
wkv_config: WKVConfig = None
|
||||
|
||||
@classmethod
|
||||
def forward(
|
||||
cls,
|
||||
ctx,
|
||||
batch_size,
|
||||
seq_len,
|
||||
embedding_dim,
|
||||
time_decay, # TODO replace embedding_dim with attention_dim
|
||||
time_first,
|
||||
k,
|
||||
v,
|
||||
):
|
||||
wkv_cuda = cls.wkv_cuda
|
||||
wkv_config = cls.wkv_config
|
||||
# setup context # TODO for PyTorch 2.0 use extra setup_context() function
|
||||
ctx.batch_size = batch_size
|
||||
ctx.seq_len = seq_len
|
||||
ctx.embedding_dim = embedding_dim
|
||||
ctx.wkv_cuda = wkv_cuda
|
||||
ctx.wkv_config = wkv_config
|
||||
assert (
|
||||
seq_len <= wkv_config.T_max
|
||||
), f"Sequence length {seq_len} exceeds the maximum allowed T_max={wkv_config.T_max}"
|
||||
# TODO what does this assert do? Why necessary?
|
||||
assert (
|
||||
batch_size * embedding_dim % min(embedding_dim, wkv_config.T_max) == 0
|
||||
), "batch_size * embedding_dim must be divisible by min(embedding_dim, T_max)"
|
||||
#
|
||||
dtype = torch.float32 # convert all tensors to float32 (for cuda kernel)
|
||||
device = wkv_config.device
|
||||
# convert input tensors
|
||||
time_decay = time_decay.to(dtype=dtype, device=device, memory_format=torch.contiguous_format)
|
||||
time_first = time_first.to(dtype=dtype, device=device, memory_format=torch.contiguous_format)
|
||||
k = k.to(dtype=dtype, device=device, memory_format=torch.contiguous_format)
|
||||
v = v.to(dtype=dtype, device=device, memory_format=torch.contiguous_format)
|
||||
|
||||
# allocate output tensor
|
||||
y = torch.empty(
|
||||
batch_size, seq_len, embedding_dim, dtype=dtype, device=device, memory_format=torch.contiguous_format
|
||||
)
|
||||
|
||||
# call cuda kernel
|
||||
time_decay = -torch.exp(time_decay) # TODO why is this necessary?
|
||||
wkv_cuda.forward(batch_size, seq_len, embedding_dim, time_decay, time_first, k, v, y)
|
||||
ctx.save_for_backward(time_decay, time_first, k, v, y)
|
||||
|
||||
# convert output tensor to correct dtype
|
||||
y = y.to(dtype=wkv_config.float_mode_to_dtype())
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, gy):
|
||||
batch_size = ctx.batch_size
|
||||
seq_len = ctx.seq_len
|
||||
embedding_dim = ctx.embedding_dim
|
||||
assert (
|
||||
seq_len <= ctx.wkv_config.T_max
|
||||
), f"Sequence length {seq_len} exceeds the maximum allowed T_max={ctx.wkv_config.T_max}"
|
||||
assert (
|
||||
batch_size * embedding_dim % min(embedding_dim, ctx.wkv_config.T_max) == 0
|
||||
), "batch_size * embedding_dim must be divisible by min(embedding_dim, T_max)"
|
||||
|
||||
time_decay, time_first, k, v, y = ctx.saved_tensors
|
||||
|
||||
device = ctx.wkv_config.device
|
||||
# allocate gradient tensors
|
||||
gtime_decay = torch.zeros((batch_size, embedding_dim), device=device, dtype=torch.float32).contiguous()
|
||||
gtime_first = torch.zeros((batch_size, embedding_dim), device=device, dtype=torch.float32).contiguous()
|
||||
gk = torch.zeros((batch_size, seq_len, embedding_dim), device=device, dtype=torch.float32).contiguous()
|
||||
gv = torch.zeros((batch_size, seq_len, embedding_dim), device=device, dtype=torch.float32).contiguous()
|
||||
|
||||
# call cuda kernel
|
||||
gy = gy.to(dtype=torch.float32, memory_format=torch.contiguous_format)
|
||||
# arg0: int, arg1: int, arg2: int, arg3: at::Tensor, arg4: at::Tensor, arg5: at::Tensor, arg6: at::Tensor,
|
||||
# arg7: at::Tensor, arg8: at::Tensor, arg9: at::Tensor, arg10: at::Tensor, arg11: at::Tensor, arg12: at::Tensor
|
||||
ctx.wkv_cuda.backward(
|
||||
batch_size,
|
||||
seq_len,
|
||||
embedding_dim,
|
||||
time_decay,
|
||||
time_first,
|
||||
k,
|
||||
v,
|
||||
y,
|
||||
gy,
|
||||
gtime_decay,
|
||||
gtime_first,
|
||||
gk,
|
||||
gv,
|
||||
)
|
||||
|
||||
gtime_decay = gtime_decay.sum(dim=0)
|
||||
gtime_first = gtime_first.sum(dim=0)
|
||||
|
||||
# convert gradient tensors to correct dtype
|
||||
out_dtype = ctx.wkv_config.float_mode_to_dtype()
|
||||
|
||||
return (
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
gtime_decay.to(dtype=out_dtype),
|
||||
gtime_first.to(dtype=out_dtype),
|
||||
gk.to(dtype=out_dtype),
|
||||
gv.to(dtype=out_dtype),
|
||||
)
|
||||
|
||||
def __new__(cls, config: WKVConfig = WKVConfig()):
|
||||
if cls._instance is None:
|
||||
cls._instance = super(WKV, cls).__new__(cls)
|
||||
cls._instance._setup(config)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Dummy to avoid multiple calls to self._load_cuda()
|
||||
pass
|
||||
|
||||
def _setup(self, config: WKVConfig = WKVConfig()):
|
||||
"""Setup the WKV module. This is called by __new__ as constructor."""
|
||||
super().__init__()
|
||||
self.cfg = config
|
||||
self.wkv_cuda = self._load_cuda()
|
||||
self.device = self.cfg.device
|
||||
|
||||
def _load_cuda(self):
|
||||
cfg = self.cfg
|
||||
os.environ["CUDA_LIB"] = os.path.join(
|
||||
os.path.split(torch.utils.cpp_extension.include_paths(cuda=True)[-1])[0], "lib"
|
||||
)
|
||||
print(os.environ.get("LD_LIBRARY_PATH", ""))
|
||||
print(os.environ["CUDA_LIB"])
|
||||
|
||||
cpp_ext_sources_float32: List[str] = [
|
||||
str(Path(__file__).parent / "cuda/wkv_op.cpp"),
|
||||
str(Path(__file__).parent / "cuda/wkv_cuda.cu"),
|
||||
]
|
||||
cpp_ext_sources_bfloat16: List[str] = [
|
||||
str(Path(__file__).parent / "cuda/wkv_op_bf16.cpp"),
|
||||
str(Path(__file__).parent / "cuda/wkv_cuda_bf16.cu"),
|
||||
]
|
||||
if cfg.float_mode_to_dtype() == torch.float32:
|
||||
cpp_ext_sources = cpp_ext_sources_float32
|
||||
elif cfg.float_mode_to_dtype() == torch.bfloat16:
|
||||
cpp_ext_sources = cpp_ext_sources_bfloat16
|
||||
else:
|
||||
raise ValueError(f"Unsupported float mode {cfg.float_mode}")
|
||||
|
||||
myargs = {
|
||||
"verbose": True,
|
||||
"with_cuda": True,
|
||||
"extra_ldflags": [f"-L{os.environ['CUDA_LIB']}", "-lcublas"],
|
||||
"extra_cuda_cflags": [
|
||||
# "-gencode",
|
||||
# "arch=compute_70,code=compute_70",
|
||||
"-gencode",
|
||||
"arch=compute_80,code=compute_80",
|
||||
"-res-usage",
|
||||
"--use_fast_math",
|
||||
"-O3",
|
||||
"-Xptxas -O3",
|
||||
"--extra-device-vectorization",
|
||||
f"-DTmax={cfg.T_max}",
|
||||
],
|
||||
}
|
||||
|
||||
cuda_module = load(name=cfg.cpp_ext_name, sources=cpp_ext_sources, **myargs)
|
||||
|
||||
return cuda_module
|
||||
|
||||
def to(self, **kwargs):
|
||||
device = kwargs.get("device", None)
|
||||
if device is not None:
|
||||
self.device = self.cfg.device = torch.device(device)
|
||||
return super().to(**kwargs)
|
||||
|
||||
def forward(self, batch_size, seq_len, embeding_dim, time_decay, time_first, k, v):
|
||||
self.device = self.cfg.device = v.device
|
||||
assert self.device != torch.device("cpu"), "WKV is not implemented for CPU"
|
||||
self._WKV.wkv_cuda = self.wkv_cuda
|
||||
self._WKV.wkv_config = self.cfg
|
||||
return self._WKV.apply(batch_size, seq_len, embeding_dim, time_decay, time_first, k, v)
|
||||
|
||||
|
||||
class WKVTorch(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# did not add time_decay = -torch.exp(time_decay) here, it is not necessary
|
||||
# this is done in the forward function of WKV cuda kernel
|
||||
# if it is added here, the outputs will be different from the cuda kernel
|
||||
|
||||
def forward(self, batch_size, seq_len, embedding_dim, time_decay, time_first, k, v):
|
||||
dtype = k.dtype
|
||||
device = k.device
|
||||
y = torch.zeros(batch_size, seq_len, embedding_dim, dtype=dtype, device=device)
|
||||
MIN_VAL = 0.0 # -1e38
|
||||
# reshape inputs
|
||||
k_ = torch.permute(k, (1, 0, 2)) # rearrange(k, 'b s e -> s b e')
|
||||
v_ = torch.permute(v, (1, 0, 2)) # rearrange(v, 'b s e -> s b e')
|
||||
y_ = torch.permute(y, (1, 0, 2)) # rearrange(y, 'b s e -> s b e')
|
||||
tf = time_first.repeat(batch_size, 1) # repeat(time_first, 'e -> b e', b=batch_size)
|
||||
td = time_decay.repeat(batch_size, 1) # repeat(time_decay, 'e -> b e', b=batch_size)
|
||||
# running sums
|
||||
aa = torch.zeros(batch_size, embedding_dim, dtype=dtype, device=device)
|
||||
bb = torch.zeros(batch_size, embedding_dim, dtype=dtype, device=device)
|
||||
eps = torch.full((batch_size, embedding_dim), MIN_VAL, dtype=dtype, device=device)
|
||||
for t in range(seq_len):
|
||||
e_tf_k = torch.exp(tf + k_[t] - eps)
|
||||
y_[t] = (aa + v_[t] * e_tf_k) / (bb + e_tf_k)
|
||||
eps_next = torch.max(td + eps, k_[t])
|
||||
|
||||
e_td_k = torch.exp(td + eps - eps_next)
|
||||
e_k = torch.exp(k_[t] - eps_next)
|
||||
aa = aa * e_td_k + v_[t] * e_k
|
||||
bb = bb * e_td_k + e_k
|
||||
eps = eps_next
|
||||
y = rearrange(y_, "s b e -> b s e")
|
||||
return y
|
||||
1
safari
Submodule
1
safari
Submodule
Submodule safari added at 23a1d81b2f
Reference in New Issue
Block a user