diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..68bc17f --- /dev/null +++ b/.gitignore @@ -0,0 +1,160 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..2553eea --- /dev/null +++ b/.gitmodules @@ -0,0 +1,6 @@ +[submodule "RWKV-LM"] + path = RWKV-LM + url = git@github.com:BlinkDL/RWKV-LM.git +[submodule "safari"] + path = safari + url = git@github.com:HazyResearch/safari.git diff --git a/RWKV-LM b/RWKV-LM new file mode 160000 index 0000000..69e6c50 --- /dev/null +++ b/RWKV-LM @@ -0,0 +1 @@ +Subproject commit 69e6c50001e8da742dcfdd7e53064f155a6c9ad1 diff --git a/TestModels.ipynb b/TestModels.ipynb new file mode 100644 index 0000000..151f66d --- /dev/null +++ b/TestModels.ipynb @@ -0,0 +1,137 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%reload_ext autoreload\n", + "import torch" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import importlib\n", + "import os\n", + "import sys\n", + "\n", + "# def import_from_file(module_name, file_path):\n", + "# spec = importlib.util.spec_from_file_location(module_name, file_path)\n", + "# module = importlib.util.module_from_spec(spec)\n", + "# sys.modules[module_name] = module\n", + "# spec.loader.exec_module(module)\n", + "# return module\n", + "\n", + "# os.environ['RWKV_JIT_ON'] = '1'\n", + "# os.environ['RWKV_T_MAX'] = '16384'\n", + "# os.environ['RWKV_FLOAT_MODE'] = 'fp32'\n", + "# os.environ['CUDA_HOME']\n", + "# path = os.path.abspath('RWKV-LM/RWKV-v4neo/src/model.py')\n", + "# rkwv_mod = import_from_file('rwkv_mod', path)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Falling back on slow Vandermonde kernel. Install pykeops for improved memory efficiency.\n" + ] + } + ], + "source": [ + "from models.rwkv import RWKV, RWKVConfig\n", + "sys.path.append('./safari')\n", + "from safari.src.models.sequence.model import SequenceModel\n", + "from safari.src.models.sequence.h3 import H3\n", + "from safari.src.models.sequence.hyena import HyenaOperator\n", + "from omegaconf import OmegaConf\n", + "os.environ['CUDA_HOME'] = \".\"\n", + "\n", + "NUM_LAYERS = 3\n", + "EMBEDDING_DIM = 128\n", + "SEQUENCE_LENGTH = 1024\n", + "\n", + "def load_config(filename):\n", + " with open(filename) as fp:\n", + " cfg_yaml = fp.read()\n", + " return OmegaConf.create(cfg_yaml)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "rwkv_full = RWKV(RWKVConfig(context_length=SEQUENCE_LENGTH, embedding_dim=EMBEDDING_DIM, num_layers=NUM_LAYERS, wkv_config=None))" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "hyena_full = SequenceModel(layer=load_config('config/models/hyena.yaml'), d_model=EMBEDDING_DIM, n_layers=NUM_LAYERS)\n", + "h3_full = SequenceModel(layer=load_config('config/models/h3.yaml'), d_model=EMBEDDING_DIM, n_layers=NUM_LAYERS)\n", + "\n", + "\n", + "h3 = H3(d_model=EMBEDDING_DIM)\n", + "hyena = HyenaOperator(d_model=EMBEDDING_DIM, l_max=16384)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/config/models/h3.yaml b/config/models/h3.yaml new file mode 100644 index 0000000..4aaea9e --- /dev/null +++ b/config/models/h3.yaml @@ -0,0 +1,7 @@ +_name_: h3 +d_state: 64 +head_dim: 1 +mode: diag +measure: diag-lin +# lr: ${eval:"min(0.001, ${optimizer.lr})"} +lr: 0.001 \ No newline at end of file diff --git a/config/models/hyena.yaml b/config/models/hyena.yaml new file mode 100644 index 0000000..6065a87 --- /dev/null +++ b/config/models/hyena.yaml @@ -0,0 +1,16 @@ +_name_: hyena +l_max: 1024 +order: 2 +filter_order: 64 +num_heads: 1 +inner_factor: 1 +num_blocks: 1 +fused_bias_fc: false +outer_mixing: false +dropout: 0.0 +filter_dropout: 0.0 +filter_cls: 'hyena-filter' +post_order_ffn: false +jit_filter: false +short_filter_order: 3 +activation: "id" \ No newline at end of file diff --git a/config/models/lstm.yaml b/config/models/lstm.yaml new file mode 100644 index 0000000..feb4b53 --- /dev/null +++ b/config/models/lstm.yaml @@ -0,0 +1 @@ +_name_: LSTM_Transformer \ No newline at end of file diff --git a/config/models/rwkv.yaml b/config/models/rwkv.yaml new file mode 100644 index 0000000..e69de29 diff --git a/config/models/selfattn.yaml b/config/models/selfattn.yaml new file mode 100644 index 0000000..e69de29 diff --git a/models/base.py b/models/base.py new file mode 100644 index 0000000..5caf756 --- /dev/null +++ b/models/base.py @@ -0,0 +1,73 @@ +import logging +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any, List + +import torch +from torch import nn + +from .interfaces import ( + ModelTrainInterface, + SequenceConfigInterface, + SequenceInterface, +) +from .base_model import BaseModel + +LOGGER = logging.getLogger(__name__) + + +class LayerConfigInterface(ABC): + @abstractmethod + def assign_model_config_params(self, model_config): + pass + + +@dataclass +class OptimizerConfig: + lr: float = 1e-3 + betas: List = field(default_factory=lambda: [0.9, 0.95]) + weight_decay: float = 0.0 + device_type: str = "cuda" # TODO is this necessary? + fused: bool = True + eps: float = 1e-6 + + +@dataclass +class BaseSequenceModelConfig(SequenceConfigInterface): + optimizer: OptimizerConfig = field(default_factory=OptimizerConfig) + + shortname: str = "" # needed to give a model a more distinctive name used in configurations etc., temporary filled by hydra or OmegaConf + + +class ResettableParametersModule(nn.Module, ABC): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + @abstractmethod + def reset_parameters(self, **kwargs): + pass + + +class BaseSequenceModelTrain( + BaseModel, ModelTrainInterface, SequenceInterface +): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + @abstractmethod + def _create_optim_groups(self, **kwargs) -> list[dict[str, Any]]: + # TODO think of a nice way to separate functionality from child classes + # TODO move this into BaseModel and make it a separate interface too + pass + + def configure_optimizer(self) -> torch.optim.Optimizer: + optim_cfg = self.config.optimizer + optim_groups = self._create_optim_groups(self.config) + use_fused = optim_cfg.device_type == "cuda" and optim_cfg.fused + LOGGER.info(f"Using fused optimizer: {use_fused}") + extra_args = dict(fused=True) if use_fused else dict() + extra_args["eps"] = optim_cfg.eps + optimizer = torch.optim.AdamW( + optim_groups, lr=optim_cfg.lr, betas=optim_cfg.betas, **extra_args + ) + return optimizer diff --git a/models/base_model.py b/models/base_model.py new file mode 100644 index 0000000..30b8e53 --- /dev/null +++ b/models/base_model.py @@ -0,0 +1,165 @@ +import copy +from abc import ABC, abstractmethod +from dataclasses import asdict +from pathlib import Path +from typing import Any, Callable, Dict, Optional, Union + +import torch +from torch import nn + +FN_MODEL_PREFIX = "model_" +FN_MODEL_FILE_EXT = ".p" + + +def get_device(device: Union[torch.device, str, int]) -> torch.device: + if device == "auto": + device = "cuda" + if isinstance(device, int): + if device < 0: + device = torch.device("cpu") + else: + device = torch.device(f"cuda:{device}") + else: + device = torch.device(device) + + if ( + device.type == torch.device("cuda").type + and not torch.cuda.is_available() + ): + LOGGER.warn(f"Device '{str(device)}' is not available! Using cpu now.") + return torch.device("cpu") + return device + + +class BaseModel(nn.Module, ABC): + """BaseModel class + Takes care of easy saving and loading. + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.config = None + + @abstractmethod + def forward(self, *args, **kwargs): + pass + + @abstractmethod + def get_loss_func(self, **kwargs): + pass + + def _get_constructor_parameters(self) -> dict: + if isinstance(self.config, dict): + return self.config + return asdict(self.config) + + def reset_parameters(self): + self.apply(self.get_init_fn()) + + def get_init_fn(self) -> Callable[[torch.Tensor], None]: + return None + + @property + def num_parameters(self) -> int: + return ( + torch.tensor([p.numel() for p in self.parameters()]).sum().item() + ) + + @property + def device(self) -> torch.device: + return next(iter(self.parameters())).device + + def copy_to_cpu(self) -> "BaseModel": + """Copy the model to CPU.""" + return copy.deepcopy(self).to(torch.device("cpu")) + + def get_checkpoint_data( + self, dict_key_prefix: str = "model_" + ) -> Dict[str, Any]: + checkpoint_dict = { + f"{dict_key_prefix}state_dict": self.state_dict(), + f"{dict_key_prefix}data": self._get_constructor_parameters(), + f"{dict_key_prefix}name": self.__class__.__name__, + f"{dict_key_prefix}class": self.__class__, + } + return checkpoint_dict + + def save( + self, + path: Union[str, Path], + model_name: str, + file_extension: Optional[str] = FN_MODEL_FILE_EXT, + dict_key_prefix: str = "model_", + ) -> None: + if isinstance(path, str): + path = Path(path) + save_path = path / (model_name + file_extension) + torch.save(self.get_checkpoint_data(dict_key_prefix), save_path) + + @staticmethod + def model_save_name( + idx: int, specifier: str = "epoch", num_digits: int = -1 + ) -> str: + """Get a consistnet the model save name. + + Args: + epoch (int): Epoch / iteration number. + specifier (str, optional): A specifier for the idx. Defaults to epoch. + num_digits (int, optional): The number of digits in the save name. Unused by default, + since this causes overrides when we have an overflow. Defaults to -1. + + Returns: + str: Model save name. + """ + if num_digits == -1: + return f"{FN_MODEL_PREFIX}{specifier}_{idx}" + else: + return f"{FN_MODEL_PREFIX}{specifier}_{idx:0{num_digits}}" + + @classmethod + def load( + cls, + path: Union[str, Path], + model_name: str = None, + file_extension: Optional[str] = ".p", + device: Union[torch.device, str, int] = "auto", + dict_key_prefix: str = "model_", + ) -> "BaseModel": + device = get_device(device) + if isinstance(path, str): + path = Path(path) + if model_name is None: + save_path = path + else: + save_path = path / (model_name + file_extension) + checkpoint = torch.load(save_path, map_location=device) + + return cls.params_from_checkpoint( + checkpoint=checkpoint, dict_key_prefix=dict_key_prefix + ) + + @classmethod + def params_from_checkpoint( + cls, checkpoint: Dict[str, Any], dict_key_prefix: str = "model_" + ) -> "BaseModel": + if hasattr(cls, "config_class"): + from dacite import from_dict + + config_cls = cls.config_class + + model_cfg = from_dict( + data_class=config_cls, + data=checkpoint[f"{dict_key_prefix}data"], + ) + model = cls(config=model_cfg) + else: + model = cls(**checkpoint[f"{dict_key_prefix}data"]) + model.load_state_dict(checkpoint[f"{dict_key_prefix}state_dict"]) + return model + + # @staticmethod + # def class_and_params_from_checkpoint(checkpoint: Dict[str, Any], dict_key_prefix: str = 'model_') -> 'BaseModel': + # from . import get_model_class + # model_class = get_model_class(checkpoint[f"{dict_key_prefix}name"]) + # model = model_class.params_from_checkpoint(checkpoint=checkpoint, dict_key_prefix=dict_key_prefix) + # return model diff --git a/models/interfaces.py b/models/interfaces.py new file mode 100644 index 0000000..1971fed --- /dev/null +++ b/models/interfaces.py @@ -0,0 +1,77 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Callable, Optional, Protocol, Sequence + +import torch + + +class ModelTrainInterface(ABC): + def configure_optimizer(self) -> Optional[torch.optim.Optimizer]: + return None + + @abstractmethod + def get_loss_func(self, **kwargs) -> Callable[[Any], torch.Tensor]: + pass + + +class SequenceInterface(ABC): + """This is a generic interface for a sequence. + In our case a sequence also includes its label. Therefore, the label (aka output_dim) + is also part of the sequence interface. + A sequence always has a length (=context_length). + + A sequence can have one of the following flavors: + Input sequence: + - sequence of tokens (e.g. words): vocab_size must be specified + - sequence of vectors: input_dim must be specified + + Output sequence: + - next token (e.g. word): `vocab_size` must be specified (e.g. Causal Language Modeling). + - (sequence of) vectors: output_dim must be specified (e.g. Forecasting) + - label: output_dim must be specified (e.g. Sequence Classification) + + Examples: + - Causal Language Modeling: input_dim = None, output_dim = None, vocab_size = int + - Forecasting: input_dim = int, output_dim = int, vocab_size = None + - Sequence Classification (General Sequence): input_dim = int, output_dim = int, vocab_size = None + - Sequence Classification (Text): input_dim = None, output_dim = int, vocab_size = int + """ + + @property + def input_dim(self) -> Optional[Sequence[int]]: + return None + + @property + def output_dim(self) -> Optional[Sequence[int]]: + return None + + @property + def vocab_size(self) -> Optional[int]: + return None + + @property + @abstractmethod + def context_length(self) -> int: + pass + + +@dataclass +class SequenceConfigInterface: + context_length: int + # vocab_size: Optional[int] = None + input_dim: Optional[Sequence[int]] = None + output_dim: Optional[Sequence[int]] = None + + +class Tokenizer(Protocol): + def __call__(self, **kwargs) -> Any: + ... + + def __len__(self) -> int: + ... + + +class TokenizerInterface(ABC): + @property + def tokenizer(self) -> Optional[Tokenizer]: + return None diff --git a/models/rwkv/__init__.py b/models/rwkv/__init__.py new file mode 100644 index 0000000..00b02f1 --- /dev/null +++ b/models/rwkv/__init__.py @@ -0,0 +1,3 @@ +from .rwkv_model import RWKV, RWKVConfig + +__all__ = ["RWKV", "RWKVConfig"] diff --git a/models/rwkv/cuda/wkv_cuda.cu b/models/rwkv/cuda/wkv_cuda.cu new file mode 100644 index 0000000..8783791 --- /dev/null +++ b/models/rwkv/cuda/wkv_cuda.cu @@ -0,0 +1,145 @@ +#include +#include + +#define MIN_VALUE (-1e38) + + +template +__global__ void kernel_forward(const int B, // batch size + const int T, // sequence length + const int C, // dim_att (number of channels) + const F *__restrict__ const _w_timedecay, // -exp(time decay) [attention_dim] + const F *__restrict__ const _u_timefirst, // time first [attention_dim] + const F *__restrict__ const _k, // keys [batch_size, sequence_length, embedding_dim] + const F *__restrict__ const _v, // values [batch_size, sequence_length, embedding_dim] + F *__restrict__ const _y // output [batch_size, sequence_length, embedding_dim] + ) { + // - this block of code defines the area in the batch tensors that this thread will work on + const int idx = blockIdx.x * blockDim.x + threadIdx.x; // this idx is for the whole batch + const int _b = idx / C; // this idx is for the batch + const int _c = idx % C; // this idx is for the channel + const int _offset = _b * T * C + _c; // this idx is for the whole batch + // - + + // these are vectors of size C (channel dimension) element in embedding_dim + F u_timefirst = _u_timefirst[_c]; + F w_timedecay = _w_timedecay[_c]; + // these are tensors of size B x T x C (stored as array, access through pointers) + const F *__restrict__ const k = _k + _offset; + const F *__restrict__ const v = _v + _offset; + F *__restrict__ const y = _y + _offset; + + // aa and bb are running sums divided by exp(pp) (to avoid overflow) + F aa = 0, bb = 0, pp = MIN_VALUE; + // loop goes over time dimension + for (int i = 0; i < T; i++) { + const int ii = i * C; // index ii picks the correct channel + const F kk = k[ii]; + const F vv = v[ii]; + + F ww = u_timefirst + kk; // + F p = max(pp, ww); + F e1 = exp(pp - p); // exp(pp - p) is the same as exp(pp) / exp(p) + F e2 = exp(ww - p); + y[ii] = (e1 * aa + e2 * vv) / (e1 * bb + e2); + + ww = w_timedecay + pp; // + p = max(ww, kk); + e1 = exp(ww - p); + e2 = exp(kk - p); + aa = e1 * aa + e2 * vv; + bb = e1 * bb + e2; + pp = p; + } +} + +template +__global__ void kernel_backward(const int B, const int T, const int C, + const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, + const F *__restrict__ const _y, const F *__restrict__ const _gy, + F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int _b = idx / C; + const int _c = idx % C; + const int _offset = _b * T * C + _c; + + F u = _u[_c]; + F w = _w[_c]; + const F *__restrict__ const k = _k + _offset; + const F *__restrict__ const v = _v + _offset; + const F *__restrict__ const y = _y + _offset; + const F *__restrict__ const gy = _gy + _offset; + F *__restrict__ const gk = _gk + _offset; + F *__restrict__ const gv = _gv + _offset; + + F q[Tmax], r[Tmax]; + + F gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE; + for (int i = 0; i < T; i++) { + const int ii = i * C; + const F kk = k[ii]; + const F vv = v[ii]; + const F yy = y[ii]; + + F ww = u + kk; + F p = max(pp, ww); + F e1 = exp(pp - p); + F e2 = exp(ww - p); + const F qq = gy[ii] / (e1 * bb + e2); + gw += (ga - gb * yy) * e1 * qq; + gu += (vv - yy) * e2 * qq; + q[i] = qq; + r[i] = ww - p; + + ww = w + pp; + p = max(ww, kk); + e1 = exp(ww - p); + e2 = exp(kk - p); + ga = e1 * (aa + ga); + gb = e1 * (bb + gb); + aa = e1 * aa + e2 * vv; + bb = e1 * bb + e2; + pp = p; + } + const int _offsetBC = _b * C + _c; + _gw[_offsetBC] = gw * _w[_c]; // multiply by w because of w -> -exp(w) in python forward() + _gu[_offsetBC] = gu; + + aa = 0, bb = 0, pp = MIN_VALUE; + for (int i = T - 1; i >= 0; i--) { + const int ii = i * C; + const F kk = k[ii]; + const F vv = v[ii]; + const F yy = y[ii]; + const F qq = q[i]; + const F rr = r[i]; + + F e1 = qq * exp(rr); + F e2 = exp(kk + pp); + gk[ii] = e1 * (vv - yy) + e2 * (aa * vv + bb); + gv[ii] = e1 + e2 * aa; + + const F ww = w + pp; + const F www = rr - u - kk; + const F p = max(ww, www); + e1 = exp(ww - p); + e2 = qq * exp(www - p); + aa = e1 * aa + e2; + bb = e1 * bb - e2 * yy; + pp = p; + } +} + +void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) { + dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance + assert(B * C % threadsPerBlock.x == 0); + dim3 numBlocks(B * C / threadsPerBlock.x); + kernel_forward<<>>(B, T, C, w, u, k, v, y); +} + +void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv) { + dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance + assert(B * C % threadsPerBlock.x == 0); + dim3 numBlocks(B * C / threadsPerBlock.x); + kernel_backward<<>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv); +} diff --git a/models/rwkv/cuda/wkv_cuda_bf16.cu b/models/rwkv/cuda/wkv_cuda_bf16.cu new file mode 100644 index 0000000..881b6bb --- /dev/null +++ b/models/rwkv/cuda/wkv_cuda_bf16.cu @@ -0,0 +1,132 @@ +#include +#include +#include "ATen/ATen.h" +#define MIN_VALUE (-1e38) +typedef at::BFloat16 bf16; + +__global__ void kernel_forward(const int B, const int T, const int C, + const float *__restrict__ const _w, const bf16 *__restrict__ const _u, const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v, + bf16 *__restrict__ const _y) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int _b = idx / C; + const int _c = idx % C; + const int _offset = _b * T * C + _c; + + float u = float(_u[_c]); + float w = _w[_c]; + const bf16 *__restrict__ const k = _k + _offset; + const bf16 *__restrict__ const v = _v + _offset; + bf16 *__restrict__ const y = _y + _offset; + + // aa and bb are running sums divided by exp(pp) (to avoid overflow) + float aa = 0, bb = 0, pp = MIN_VALUE; + for (int i = 0; i < T; i++) { + const int ii = i * C; + const float kk = float(k[ii]); + const float vv = float(v[ii]); + + float ww = u + kk; + float p = max(pp, ww); + float e1 = exp(pp - p); + float e2 = exp(ww - p); + y[ii] = bf16((e1 * aa + e2 * vv) / (e1 * bb + e2)); + + ww = w + pp; + p = max(ww, kk); + e1 = exp(ww - p); + e2 = exp(kk - p); + aa = e1 * aa + e2 * vv; + bb = e1 * bb + e2; + pp = p; + } +} + +__global__ void kernel_backward(const int B, const int T, const int C, + const float *__restrict__ const _w, const bf16 *__restrict__ const _u, const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v, + const bf16 *__restrict__ const _y, const bf16 *__restrict__ const _gy, + bf16 *__restrict__ const _gw, bf16 *__restrict__ const _gu, bf16 *__restrict__ const _gk, bf16 *__restrict__ const _gv) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int _b = idx / C; + const int _c = idx % C; + const int _offset = _b * T * C + _c; + + float u = float(_u[_c]); + float w = _w[_c]; + const bf16 *__restrict__ const k = _k + _offset; + const bf16 *__restrict__ const v = _v + _offset; + const bf16 *__restrict__ const y = _y + _offset; + const bf16 *__restrict__ const gy = _gy + _offset; + bf16 *__restrict__ const gk = _gk + _offset; + bf16 *__restrict__ const gv = _gv + _offset; + + float q[Tmax], r[Tmax]; + + float gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE; + for (int i = 0; i < T; i++) { + const int ii = i * C; + const float kk = float(k[ii]); + const float vv = float(v[ii]); + const float yy = float(y[ii]); + + float ww = u + kk; + float p = max(pp, ww); + float e1 = exp(pp - p); + float e2 = exp(ww - p); + const float qq = float(gy[ii]) / (e1 * bb + e2); + gw += (ga - gb * yy) * e1 * qq; + gu += (vv - yy) * e2 * qq; + q[i] = qq; + r[i] = ww - p; + + ww = w + pp; + p = max(ww, kk); + e1 = exp(ww - p); + e2 = exp(kk - p); + ga = e1 * (aa + ga); + gb = e1 * (bb + gb); + aa = e1 * aa + e2 * vv; + bb = e1 * bb + e2; + pp = p; + } + const int _offsetBC = _b * C + _c; + _gw[_offsetBC] = bf16(gw * _w[_c]); // multiply by w because of w -> -exp(w) in python forward() + _gu[_offsetBC] = bf16(gu); + + aa = 0, bb = 0, pp = MIN_VALUE; + for (int i = T - 1; i >= 0; i--) { + const int ii = i * C; + const float kk = float(k[ii]); + const float vv = float(v[ii]); + const float yy = float(y[ii]); + const float qq = q[i]; + const float rr = r[i]; + + float e1 = qq * exp(rr); + float e2 = exp(kk + pp); + gk[ii] = bf16(e1 * (vv - yy) + e2 * (aa * vv + bb)); + gv[ii] = bf16(e1 + e2 * aa); + + const float ww = w + pp; + const float www = rr - u - kk; + const float p = max(ww, www); + e1 = exp(ww - p); + e2 = qq * exp(www - p); + aa = e1 * aa + e2; + bb = e1 * bb - e2 * yy; + pp = p; + } +} + +void cuda_forward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y) { + dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance + assert(B * C % threadsPerBlock.x == 0); + dim3 numBlocks(B * C / threadsPerBlock.x); + kernel_forward<<>>(B, T, C, w, u, k, v, y); +} + +void cuda_backward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv) { + dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance + assert(B * C % threadsPerBlock.x == 0); + dim3 numBlocks(B * C / threadsPerBlock.x); + kernel_backward<<>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv); +} diff --git a/models/rwkv/cuda/wkv_op.cpp b/models/rwkv/cuda/wkv_op.cpp new file mode 100644 index 0000000..605a7be --- /dev/null +++ b/models/rwkv/cuda/wkv_op.cpp @@ -0,0 +1,25 @@ +#include +// #include + +void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y); +void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv); + +void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) { + // TODO add this line const at::cuda::OptionalCUDAGuard device_guard(device_of(w)); + // (see https://discord.com/channels/992359628979568762/1084547464452907088/1091680712178016326) + // see https://github.com/BlinkDL/ChatRWKV/blob/main/rwkv_pip_package/src/rwkv/cuda/wrapper.cpp + cuda_forward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr()); +} +void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) { + cuda_backward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr(), gy.data_ptr(), gw.data_ptr(), gu.data_ptr(), gk.data_ptr(), gv.data_ptr()); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &forward, "wkv forward"); + m.def("backward", &backward, "wkv backward"); +} + +TORCH_LIBRARY(wkv, m) { + m.def("forward", forward); + m.def("backward", backward); +} diff --git a/models/rwkv/cuda/wkv_op_bf16.cpp b/models/rwkv/cuda/wkv_op_bf16.cpp new file mode 100644 index 0000000..5783416 --- /dev/null +++ b/models/rwkv/cuda/wkv_op_bf16.cpp @@ -0,0 +1,25 @@ +#include +#include "ATen/ATen.h" +typedef at::BFloat16 bf16; + +void cuda_forward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y); +void cuda_backward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv); + +void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) { + cuda_forward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr()); +} +void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, + torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) { + cuda_backward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr(), + gy.data_ptr(), gw.data_ptr(), gu.data_ptr(), gk.data_ptr(), gv.data_ptr()); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &forward, "wkv forward"); + m.def("backward", &backward, "wkv backward"); +} + +TORCH_LIBRARY(wkv, m) { + m.def("forward", forward); + m.def("backward", backward); +} diff --git a/models/rwkv/rwkv_model.py b/models/rwkv/rwkv_model.py new file mode 100644 index 0000000..e21b158 --- /dev/null +++ b/models/rwkv/rwkv_model.py @@ -0,0 +1,469 @@ +import logging +import math +from dataclasses import dataclass, field +from typing import Dict, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from ..base import ( + BaseSequenceModelConfig, + BaseSequenceModelTrain, + ResettableParametersModule, +) +from .wkv_kernel import WKV, WKVConfig, WKVTorch + +LOGGER = logging.getLogger(__name__) + + +class L2Wrap(torch.autograd.Function): + """L2 regularization for the logits.""" + + @staticmethod + def forward(ctx, loss, y): + ctx.save_for_backward(y) + return loss + + @staticmethod + def backward(ctx, grad_output): + y = ctx.saved_tensors[0] + # to encourage the logits to be close to 0 + factor = 1e-4 / (y.shape[0] * y.shape[1]) + maxx, ids = torch.max(y, -1, keepdim=True) + gy = torch.zeros_like(y) + gy.scatter_(-1, ids, maxx * factor) + return (grad_output, gy) + + +def _get_activation_fn(activation): + if activation == "relu": + return F.relu + elif activation == "gelu": + return F.gelu + elif activation == "silu": + return F.silu + elif activation == "relu_squared": + return lambda x: torch.square(torch.relu(x)) + elif activation == "selu": + return F.selu + elif activation == "elu": + return F.elu + else: + raise ValueError(f"Unknown activation function {activation}") + + +@dataclass +class RWKVConfig(BaseSequenceModelConfig): + embedding_dim: int = 640 + ffn_dim: int = 2048 + num_layers: int = 12 + attention_dim: int = -1 + wkv_config: Optional[Union[WKVConfig, Dict]] = field( + default_factory=lambda: WKVConfig() + ) + l2_logit_reg: bool = False + use_timemix_timemix: bool = True + use_timemix_channelmix: bool = True + channelmix_act_fn: str = "relu_squared" + init: str = "reproduce_init" # options "paper_init", "reproduce_init" (reproduce init is different from paper init) + # the reproduce_init was figured out when only the code was available, not the paper + # the init as it is described in the paper slightly differs from the reproduce_init + # the reproduce_init actually performs much better than the paper_init (ca. 0.5 better in train loss after 300 steps on WikiText-103) + # so we use this as default. reproduce_init is also (likely to be) used in the original code. + + def __post_init__(self): + # TODO: Check if this was actually needed + if self.wkv_config is not None: + self.wkv_config.T_max = max( + self.wkv_config.T_max, self.context_length + ) + if self.attention_dim <= 0: + self.attention_dim = self.embedding_dim + + +class RWKV(BaseSequenceModelTrain): + config_class = RWKVConfig + + def __init__(self, config: RWKVConfig, **kwargs): + super().__init__() + self.config = config + self.cfg = config + + # self.embedding = nn.Embedding( + # num_embeddings=self.cfg.vocab_size, + # embedding_dim=self.cfg.embedding_dim, + # ) + + self.blocks = nn.ModuleList( + [ + RWKVBlock(config=self.cfg, block_idx=i) + for i in range(self.cfg.num_layers) + ] + ) + if self.cfg.wkv_config is not None: + LOGGER.info("Using WKV cuda kernel.") + else: + LOGGER.info("Using WKV torch kernel.") + + # self.ln_out = nn.LayerNorm(self.cfg.embedding_dim) + # self.head = nn.Linear( + # self.cfg.embedding_dim, self.cfg.vocab_size, bias=False + # ) + self.reset_parameters() + + @property + def context_length(self) -> int: + return self.config.context_length + + # @property + # def vocab_size(self) -> int: + # return self.config.vocab_size + + def reset_parameters(self) -> None: + # init embedding + # default init is zero # TODO try this + # we use a narrow uniform init, in the original code they use the initial learning rate + # we just set it to a small value + # emb_init_range = 0.0008 # 1e-3 + # nn.init.uniform_( + # self.embedding.weight, a=-emb_init_range, b=emb_init_range + # ) + # init blocks + for b in self.blocks: + b.reset_parameters() + # init head and layer norm + # self.head.reset_parameters() + # self.ln_out.reset_parameters() + + def _create_optim_groups(self, config: RWKVConfig): + optim_groups = [ + {"params": [p for p in self.parameters()], "weight_decay": 0.0} + ] + return optim_groups + + def forward(self, x): + # no embedding + # # input shape: (B, T), T <= context_len, T are token ids + # B, T = x.size() + # assert T <= self.cfg.context_length, ( + # f"input sequence length {T} exceeds model " + # f"context length {self.cfg.context_length}" + # ) + + # x = self.embedding(x) # (B, T, C), C = embedding_dim + + for i, block in enumerate(self.blocks): + x = block(x) + + # x = self.ln_out(x) + # no head + # x = self.head(x) + return x + + def get_loss_func(self): + def loss_fn(y_hat, y): + loss = F.cross_entropy( + y_hat.view(-1, y_hat.size(-1)), y.view(-1), ignore_index=-1 + ) + if self.cfg.l2_logit_reg: + loss = L2Wrap.apply(loss, y_hat) + return loss + + return loss_fn + + +def _calc_gain(weight: torch.Tensor) -> float: + """Calculate the gain value of the given weight tensor.""" + gain = 1.0 + fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(weight) + if fan_out > fan_in: + gain = math.sqrt(fan_out / fan_in) + return gain + + +class RWKVBlock(ResettableParametersModule): + def __init__(self, config: RWKVConfig, block_idx: int): + super().__init__() + self.config = config + self.block_idx = block_idx + + self.ln0 = None + if self.block_idx == 0: + self.ln0 = nn.LayerNorm(self.config.embedding_dim) + # TODO 1) maybe additional positional embedding here (only in block 0) + + self.ln1 = nn.LayerNorm(self.config.embedding_dim) + self.ln2 = nn.LayerNorm(self.config.embedding_dim) + + # TODO 2) maybe pre feedforward here (channel mix) see line 325f in RWKV-v4neo/model.py + self.attention_timemix = RWKVTimeMix( + config=self.config, block_id=self.block_idx + ) + self.ffn_channelmix = RWKVChannelMix( + config=self.config, block_id=self.block_idx + ) + + def reset_parameters(self) -> None: + if self.ln0 is not None: + self.ln0.reset_parameters() + self.ln1.reset_parameters() + self.ln2.reset_parameters() + self.attention_timemix.reset_parameters() + self.ffn_channelmix.reset_parameters() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.block_idx == 0 and self.ln0 is not None: + x = self.ln0(x) + # TODO 1) maybe positional embedding here (only in block 0) + # x = x+pos_emb + + # TODO 2) maybe pre feedforward here (channel mix) see line 325f in RWKV-v4neo/model.py + # residual connection 1 + x = x + self.attention_timemix(self.ln1(x)) + # residual connection 2 + x = x + self.ffn_channelmix(self.ln2(x)) + return x + + +class RWKVTimeMix(ResettableParametersModule): + def __init__(self, config: RWKVConfig, block_id: int): + super().__init__() + self.config = config + self.block_id = block_id + + embedding_dim = self.config.embedding_dim + attention_dim = self.config.attention_dim + # init time mix constants + req_grad = True # TODO make this configurable + self.time_mix_k = nn.Parameter( + torch.empty((1, 1, embedding_dim)), requires_grad=req_grad + ) + self.time_mix_v = nn.Parameter( + torch.empty((1, 1, embedding_dim)), requires_grad=req_grad + ) + self.time_mix_r = nn.Parameter( + torch.empty((1, 1, embedding_dim)), requires_grad=req_grad + ) + + # init time decay + self.time_decay = nn.Parameter( + torch.empty((attention_dim,)), requires_grad=req_grad + ) + self.time_first = nn.Parameter( + torch.empty((attention_dim,)), requires_grad=req_grad + ) + + # init layers / parameters + # this shifts the time dimension by 1 forward, pad 0 at first time step, remove last time step: + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + self.key = nn.Linear(embedding_dim, attention_dim, bias=False) + self.value = nn.Linear(embedding_dim, attention_dim, bias=False) + self.receptance = nn.Linear(embedding_dim, attention_dim, bias=False) + self.output = nn.Linear(attention_dim, embedding_dim, bias=False) + + if self.config.wkv_config is not None: + # use CUDA implementation + self.wkv = WKV(config=self.config.wkv_config) + else: + # use pure PyTorch implementation + self.wkv = WKVTorch() + + self.reset_parameters() + + def reset_parameters(self) -> None: + # init time mix constants + time_mix_k, time_mix_v, time_mix_r = self._init_time_mix_constants() + req_grad = True + self.time_mix_k = nn.Parameter(time_mix_k, requires_grad=req_grad) + self.time_mix_v = nn.Parameter(time_mix_v, requires_grad=req_grad) + self.time_mix_r = nn.Parameter(time_mix_r, requires_grad=req_grad) + # init time decay + time_decay, time_first = self._init_time_decay_constants() + self.time_decay = nn.Parameter(time_decay, requires_grad=req_grad) + self.time_first = nn.Parameter(time_first, requires_grad=req_grad) + # init layers / parameters + if self.config.init == "paper_init": + # ZERO INIT + nn.init.zeros_(self.receptance.weight) + nn.init.zeros_(self.key.weight) + nn.init.zeros_(self.value.weight) + # NORMAL INIT + nn.init.normal_( + self.output.weight, + std=math.sqrt(self.config.ffn_dim / self.config.embedding_dim), + ) + elif self.config.init == "reproduce_init": + # ZERO INIT + nn.init.zeros_(self.key.weight) + nn.init.zeros_(self.receptance.weight) + nn.init.zeros_(self.output.weight) + # ORTHOGONAL INIT + nn.init.orthogonal_( + self.value.weight, gain=_calc_gain(self.value.weight) + ) + else: + raise ValueError(f"Unknown init method {self.config.init}") + + def _compute_rkv(self, x): + if self.config.use_timemix_timemix: + xx = self.time_shift( + x + ) # Mix x with the previous timestep to produce xk, xv, xr + xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) + xv = x * self.time_mix_v + xx * (1 - self.time_mix_v) + xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) + else: + xk = x + xv = x + xr = x + k = self.key(xk) + v = self.value(xv) + r = self.receptance(xr) + sr = torch.sigmoid(r) + return sr, k, v + + def forward(self, x): + B, T, C = x.size() # x = (batch_size, seq_len, embedding_dim) + attention_dim = self.config.attention_dim + sr, k, v = self._compute_rkv( + x + ) # sr, k, v = (batch_size, seq_len, attention_dim) + # wkv cuda/torch kernel + rwkv = sr * self.wkv( + B, T, attention_dim, self.time_decay, self.time_first, k, v + ) + return self.output(rwkv) + + def _init_time_mix_constants( + self, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + num_blocks = self.config.num_layers + embedding_dim = self.config.embedding_dim + + ratio_0_to_1 = self.block_id / max(1, num_blocks - 1) # 0 to 1 + ratio_1_to_almost0 = 1.0 - (self.block_id / num_blocks) # 1 to ~0 + + # TODO does this make sense? + # different time mix constants for each block and each embedding dim + embed_dim_val = torch.ones(1, 1, embedding_dim) + for i in range(embedding_dim): + embed_dim_val[0, 0, i] = i / embedding_dim + + # TODO check constants 0.3 and 0.5 + time_mix_k = torch.pow(embed_dim_val, ratio_1_to_almost0) + time_mix_v = ( + torch.pow(embed_dim_val, ratio_1_to_almost0) + 0.3 * ratio_0_to_1 + ) + time_mix_r = torch.pow(embed_dim_val, 0.5 * ratio_1_to_almost0) + + return time_mix_k, time_mix_v, time_mix_r + + def _init_time_decay_constants(self) -> Tuple[torch.Tensor, torch.Tensor]: + num_blocks = self.config.num_layers + attention_dim = self.config.attention_dim + ratio_0_to_1 = self.block_id / max(1, num_blocks - 1) # 0 to 1 + + # time decay + # this encourages the model to decay the information in different memory cells (channel dimensions) + # at different speeds + decay_speed = torch.ones(attention_dim) + for h in range(attention_dim): + decay_speed[h] = -5 + 8 * (h / (attention_dim - 1)) ** ( + 0.7 + 1.3 * ratio_0_to_1 + ) + time_decay = decay_speed + + # time first + # The alternating zigzag pattern initially creates subtle variations in the tensor elements, + # which are intended to help the model treat different dimensions of the embedding differently + zigzag = ( + torch.tensor([(i + 1) % 3 - 1 for i in range(attention_dim)]) * 0.5 + ) + time_first = ( + torch.ones(attention_dim) * torch.log(torch.tensor(0.3)) + zigzag + ) + + return time_decay, time_first + + +class RWKVChannelMix(ResettableParametersModule): + def __init__(self, config: RWKVConfig, block_id: int): + super().__init__() + self.config = config + self.block_id = block_id + + self._act_fn = _get_activation_fn(self.config.channelmix_act_fn) + + embedding_dim = self.config.embedding_dim + ffn_dim = self.config.ffn_dim + # init time mix constants + req_grad = True + self.time_mix_k = nn.Parameter( + torch.empty((1, 1, embedding_dim)), requires_grad=req_grad + ) + self.time_mix_r = nn.Parameter( + torch.empty((1, 1, embedding_dim)), requires_grad=req_grad + ) + + # init layers / parameters + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + self.key = nn.Linear(embedding_dim, ffn_dim, bias=False) + self.receptance = nn.Linear(embedding_dim, embedding_dim, bias=False) + self.value = nn.Linear(ffn_dim, embedding_dim, bias=False) + self.reset_parameters() + + def reset_parameters(self): + # init time mix constants + time_mix_k, time_mix_r = self._init_time_mix_constants() + req_grad = True + self.time_mix_k = nn.Parameter(time_mix_k, requires_grad=req_grad) + self.time_mix_r = nn.Parameter(time_mix_r, requires_grad=req_grad) + # init layers / parameters + if self.config.init == "paper_init": + # ZERO INIT + nn.init.zeros_(self.receptance.weight) + nn.init.zeros_(self.key.weight) + # NORMAL INIT + nn.init.normal_( + self.value.weight, + std=math.sqrt(self.config.ffn_dim / self.config.embedding_dim), + ) + elif self.config.init == "reproduce_init": + # ZERO INIT + nn.init.zeros_(self.receptance.weight) + nn.init.zeros_(self.value.weight) + # ORTHOGONAL INIT + nn.init.orthogonal_( + self.key.weight, gain=_calc_gain(self.key.weight) + ) + else: + raise ValueError(f"Unknown init method {self.config.init}") + + def _init_time_mix_constants(self) -> Tuple[torch.Tensor, torch.Tensor]: + num_blocks = self.config.num_layers + embedding_dim = self.config.embedding_dim + + ratio_1_to_almost0 = 1.0 - (self.block_id / num_blocks) # 1 to ~0 + embed_dim_val = torch.ones(1, 1, embedding_dim) + for i in range(embedding_dim): + embed_dim_val[0, 0, i] = i / embedding_dim + + time_mix_k = torch.pow(embed_dim_val, ratio_1_to_almost0) + time_mix_r = torch.pow(embed_dim_val, ratio_1_to_almost0) + + return time_mix_k, time_mix_r + + def forward(self, x): + if self.config.use_timemix_channelmix: + xx = self.time_shift(x) + xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) + xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) + else: + xk = x + xr = x + k = self.key(xk) + k = self._act_fn(k) + kv = self.value(k) + y = torch.sigmoid(self.receptance(xr)) * kv + return y diff --git a/models/rwkv/sequence_rwkv.py b/models/rwkv/sequence_rwkv.py new file mode 100644 index 0000000..f5c8348 --- /dev/null +++ b/models/rwkv/sequence_rwkv.py @@ -0,0 +1,92 @@ +import logging +from dataclasses import dataclass +from typing import Callable, Sequence + +import torch +from torch import nn + +from ...ml_utils.config import NameAndKwargs +from ..base import BaseSequenceModelTrain +from ..seq_enc_dec import create_decoder, create_encoder +from .rwkv_model import RWKVBlock, RWKVConfig + +LOGGER = logging.getLogger(__name__) + + +@dataclass +class SequenceRWKVConfig(RWKVConfig): + encoder: NameAndKwargs = None + decoder: NameAndKwargs = None + + +class SequenceRWKV(BaseSequenceModelTrain): + config_class = SequenceRWKVConfig + + def __init__(self, config: SequenceRWKVConfig, **kwargs): + super().__init__(**kwargs) + self.config = config + if config.wkv_config is not None: + LOGGER.info("Using WKV cuda kernel.") + else: + LOGGER.info("Using WKV torch kernel.") + + self.encoder = create_encoder(config=config) + self.decoder = create_decoder(config=config) + + self.blocks = nn.ModuleList([RWKVBlock(config, block_idx=i) for i in range(config.num_layers)]) + self.blocks_ln = nn.LayerNorm(config.embedding_dim) + + self.reset_parameters() + + @property + def context_length(self) -> int: + return self.config.context_length + + @property + def vocab_size(self) -> int: + return self.config.vocab_size + + @property + def input_dim(self) -> Sequence[int]: + return self.config.input_dim + + @property + def output_dim(self) -> Sequence[int]: + return self.config.output_dim + + def reset_parameters(self): + for block in self.blocks: + block.reset_parameters() + self.blocks_ln.reset_parameters() + self.encoder.reset_parameters() + self.decoder.reset_parameters() + + def _create_optim_groups(self, config: RWKVConfig): + optim_groups = [{"params": [p for p in self.parameters()], "weight_decay": 0.0}] + return optim_groups + + def get_loss_func(self) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]: + import torch.nn.functional as F + + def loss_fn(logits, targets): + assert not torch.any(torch.isnan(logits.view(-1))) + assert not torch.any(torch.isnan(targets.view(-1))) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) + return loss + + return loss_fn + + def forward(self, x: torch.Tensor) -> torch.Tensor: + assert ( + x.size(1) <= self.config.context_length + ), f"Forward input sequence length {x.size(1)} is longer than context length {self.config.context_length}" + + y = self.encoder(x) + + for block in self.blocks: + y = block(y) + y = self.blocks_ln(y) + + y = self.decoder(y) + + return y diff --git a/models/rwkv/wkv_kernel.py b/models/rwkv/wkv_kernel.py new file mode 100644 index 0000000..606852f --- /dev/null +++ b/models/rwkv/wkv_kernel.py @@ -0,0 +1,259 @@ +import os +from dataclasses import dataclass, field +from pathlib import Path +from typing import List, Union + +import torch +from einops import rearrange, repeat +from torch import nn +from torch.utils.cpp_extension import load + +"""This module is a wrapper for the C++/CUDA implementation of the WKV kernel. +""" + + +@dataclass +class WKVConfig: + T_max: int = 1024 # max sequence length within cuda operations + cpp_ext_name: str = "wkv" + device: Union[str, torch.device] = "cuda" + float_mode: str = "fp32" # options: fp32, fp16, bfloat16 + + def __post_init__(self): + self.device = torch.device(self.device) + + def float_mode_to_dtype(self): + if self.float_mode == "fp32" or "32" in str(self.float_mode): + return torch.float32 + elif self.float_mode == "fp16" or "16" in str(self.float_mode): + return torch.float16 + elif self.float_mode == "bfloat16": + return torch.bfloat16 + else: + raise ValueError(f"Unknown float_mode: {self.float_mode}") + + +class WKV(nn.Module): + _instance = None # for singleton + + class _WKV(torch.autograd.Function): + wkv_cuda = None + wkv_config: WKVConfig = None + + @classmethod + def forward( + cls, + ctx, + batch_size, + seq_len, + embedding_dim, + time_decay, # TODO replace embedding_dim with attention_dim + time_first, + k, + v, + ): + wkv_cuda = cls.wkv_cuda + wkv_config = cls.wkv_config + # setup context # TODO for PyTorch 2.0 use extra setup_context() function + ctx.batch_size = batch_size + ctx.seq_len = seq_len + ctx.embedding_dim = embedding_dim + ctx.wkv_cuda = wkv_cuda + ctx.wkv_config = wkv_config + assert ( + seq_len <= wkv_config.T_max + ), f"Sequence length {seq_len} exceeds the maximum allowed T_max={wkv_config.T_max}" + # TODO what does this assert do? Why necessary? + assert ( + batch_size * embedding_dim % min(embedding_dim, wkv_config.T_max) == 0 + ), "batch_size * embedding_dim must be divisible by min(embedding_dim, T_max)" + # + dtype = torch.float32 # convert all tensors to float32 (for cuda kernel) + device = wkv_config.device + # convert input tensors + time_decay = time_decay.to(dtype=dtype, device=device, memory_format=torch.contiguous_format) + time_first = time_first.to(dtype=dtype, device=device, memory_format=torch.contiguous_format) + k = k.to(dtype=dtype, device=device, memory_format=torch.contiguous_format) + v = v.to(dtype=dtype, device=device, memory_format=torch.contiguous_format) + + # allocate output tensor + y = torch.empty( + batch_size, seq_len, embedding_dim, dtype=dtype, device=device, memory_format=torch.contiguous_format + ) + + # call cuda kernel + time_decay = -torch.exp(time_decay) # TODO why is this necessary? + wkv_cuda.forward(batch_size, seq_len, embedding_dim, time_decay, time_first, k, v, y) + ctx.save_for_backward(time_decay, time_first, k, v, y) + + # convert output tensor to correct dtype + y = y.to(dtype=wkv_config.float_mode_to_dtype()) + return y + + @staticmethod + def backward(ctx, gy): + batch_size = ctx.batch_size + seq_len = ctx.seq_len + embedding_dim = ctx.embedding_dim + assert ( + seq_len <= ctx.wkv_config.T_max + ), f"Sequence length {seq_len} exceeds the maximum allowed T_max={ctx.wkv_config.T_max}" + assert ( + batch_size * embedding_dim % min(embedding_dim, ctx.wkv_config.T_max) == 0 + ), "batch_size * embedding_dim must be divisible by min(embedding_dim, T_max)" + + time_decay, time_first, k, v, y = ctx.saved_tensors + + device = ctx.wkv_config.device + # allocate gradient tensors + gtime_decay = torch.zeros((batch_size, embedding_dim), device=device, dtype=torch.float32).contiguous() + gtime_first = torch.zeros((batch_size, embedding_dim), device=device, dtype=torch.float32).contiguous() + gk = torch.zeros((batch_size, seq_len, embedding_dim), device=device, dtype=torch.float32).contiguous() + gv = torch.zeros((batch_size, seq_len, embedding_dim), device=device, dtype=torch.float32).contiguous() + + # call cuda kernel + gy = gy.to(dtype=torch.float32, memory_format=torch.contiguous_format) + # arg0: int, arg1: int, arg2: int, arg3: at::Tensor, arg4: at::Tensor, arg5: at::Tensor, arg6: at::Tensor, + # arg7: at::Tensor, arg8: at::Tensor, arg9: at::Tensor, arg10: at::Tensor, arg11: at::Tensor, arg12: at::Tensor + ctx.wkv_cuda.backward( + batch_size, + seq_len, + embedding_dim, + time_decay, + time_first, + k, + v, + y, + gy, + gtime_decay, + gtime_first, + gk, + gv, + ) + + gtime_decay = gtime_decay.sum(dim=0) + gtime_first = gtime_first.sum(dim=0) + + # convert gradient tensors to correct dtype + out_dtype = ctx.wkv_config.float_mode_to_dtype() + + return ( + None, + None, + None, + gtime_decay.to(dtype=out_dtype), + gtime_first.to(dtype=out_dtype), + gk.to(dtype=out_dtype), + gv.to(dtype=out_dtype), + ) + + def __new__(cls, config: WKVConfig = WKVConfig()): + if cls._instance is None: + cls._instance = super(WKV, cls).__new__(cls) + cls._instance._setup(config) + return cls._instance + + def __init__(self, *args, **kwargs): + # Dummy to avoid multiple calls to self._load_cuda() + pass + + def _setup(self, config: WKVConfig = WKVConfig()): + """Setup the WKV module. This is called by __new__ as constructor.""" + super().__init__() + self.cfg = config + self.wkv_cuda = self._load_cuda() + self.device = self.cfg.device + + def _load_cuda(self): + cfg = self.cfg + os.environ["CUDA_LIB"] = os.path.join( + os.path.split(torch.utils.cpp_extension.include_paths(cuda=True)[-1])[0], "lib" + ) + print(os.environ.get("LD_LIBRARY_PATH", "")) + print(os.environ["CUDA_LIB"]) + + cpp_ext_sources_float32: List[str] = [ + str(Path(__file__).parent / "cuda/wkv_op.cpp"), + str(Path(__file__).parent / "cuda/wkv_cuda.cu"), + ] + cpp_ext_sources_bfloat16: List[str] = [ + str(Path(__file__).parent / "cuda/wkv_op_bf16.cpp"), + str(Path(__file__).parent / "cuda/wkv_cuda_bf16.cu"), + ] + if cfg.float_mode_to_dtype() == torch.float32: + cpp_ext_sources = cpp_ext_sources_float32 + elif cfg.float_mode_to_dtype() == torch.bfloat16: + cpp_ext_sources = cpp_ext_sources_bfloat16 + else: + raise ValueError(f"Unsupported float mode {cfg.float_mode}") + + myargs = { + "verbose": True, + "with_cuda": True, + "extra_ldflags": [f"-L{os.environ['CUDA_LIB']}", "-lcublas"], + "extra_cuda_cflags": [ + # "-gencode", + # "arch=compute_70,code=compute_70", + "-gencode", + "arch=compute_80,code=compute_80", + "-res-usage", + "--use_fast_math", + "-O3", + "-Xptxas -O3", + "--extra-device-vectorization", + f"-DTmax={cfg.T_max}", + ], + } + + cuda_module = load(name=cfg.cpp_ext_name, sources=cpp_ext_sources, **myargs) + + return cuda_module + + def to(self, **kwargs): + device = kwargs.get("device", None) + if device is not None: + self.device = self.cfg.device = torch.device(device) + return super().to(**kwargs) + + def forward(self, batch_size, seq_len, embeding_dim, time_decay, time_first, k, v): + self.device = self.cfg.device = v.device + assert self.device != torch.device("cpu"), "WKV is not implemented for CPU" + self._WKV.wkv_cuda = self.wkv_cuda + self._WKV.wkv_config = self.cfg + return self._WKV.apply(batch_size, seq_len, embeding_dim, time_decay, time_first, k, v) + + +class WKVTorch(nn.Module): + def __init__(self): + super().__init__() + # did not add time_decay = -torch.exp(time_decay) here, it is not necessary + # this is done in the forward function of WKV cuda kernel + # if it is added here, the outputs will be different from the cuda kernel + + def forward(self, batch_size, seq_len, embedding_dim, time_decay, time_first, k, v): + dtype = k.dtype + device = k.device + y = torch.zeros(batch_size, seq_len, embedding_dim, dtype=dtype, device=device) + MIN_VAL = 0.0 # -1e38 + # reshape inputs + k_ = torch.permute(k, (1, 0, 2)) # rearrange(k, 'b s e -> s b e') + v_ = torch.permute(v, (1, 0, 2)) # rearrange(v, 'b s e -> s b e') + y_ = torch.permute(y, (1, 0, 2)) # rearrange(y, 'b s e -> s b e') + tf = time_first.repeat(batch_size, 1) # repeat(time_first, 'e -> b e', b=batch_size) + td = time_decay.repeat(batch_size, 1) # repeat(time_decay, 'e -> b e', b=batch_size) + # running sums + aa = torch.zeros(batch_size, embedding_dim, dtype=dtype, device=device) + bb = torch.zeros(batch_size, embedding_dim, dtype=dtype, device=device) + eps = torch.full((batch_size, embedding_dim), MIN_VAL, dtype=dtype, device=device) + for t in range(seq_len): + e_tf_k = torch.exp(tf + k_[t] - eps) + y_[t] = (aa + v_[t] * e_tf_k) / (bb + e_tf_k) + eps_next = torch.max(td + eps, k_[t]) + + e_td_k = torch.exp(td + eps - eps_next) + e_k = torch.exp(k_[t] - eps_next) + aa = aa * e_td_k + v_[t] * e_k + bb = bb * e_td_k + e_k + eps = eps_next + y = rearrange(y_, "s b e -> b s e") + return y diff --git a/safari b/safari new file mode 160000 index 0000000..23a1d81 --- /dev/null +++ b/safari @@ -0,0 +1 @@ +Subproject commit 23a1d81b2fce85f12d78812c427c4e8e9cb0ac42