Add RWKV, H3, Hyena
This commit is contained in:
259
models/rwkv/wkv_kernel.py
Normal file
259
models/rwkv/wkv_kernel.py
Normal file
@@ -0,0 +1,259 @@
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
from torch import nn
|
||||
from torch.utils.cpp_extension import load
|
||||
|
||||
"""This module is a wrapper for the C++/CUDA implementation of the WKV kernel.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class WKVConfig:
|
||||
T_max: int = 1024 # max sequence length within cuda operations
|
||||
cpp_ext_name: str = "wkv"
|
||||
device: Union[str, torch.device] = "cuda"
|
||||
float_mode: str = "fp32" # options: fp32, fp16, bfloat16
|
||||
|
||||
def __post_init__(self):
|
||||
self.device = torch.device(self.device)
|
||||
|
||||
def float_mode_to_dtype(self):
|
||||
if self.float_mode == "fp32" or "32" in str(self.float_mode):
|
||||
return torch.float32
|
||||
elif self.float_mode == "fp16" or "16" in str(self.float_mode):
|
||||
return torch.float16
|
||||
elif self.float_mode == "bfloat16":
|
||||
return torch.bfloat16
|
||||
else:
|
||||
raise ValueError(f"Unknown float_mode: {self.float_mode}")
|
||||
|
||||
|
||||
class WKV(nn.Module):
|
||||
_instance = None # for singleton
|
||||
|
||||
class _WKV(torch.autograd.Function):
|
||||
wkv_cuda = None
|
||||
wkv_config: WKVConfig = None
|
||||
|
||||
@classmethod
|
||||
def forward(
|
||||
cls,
|
||||
ctx,
|
||||
batch_size,
|
||||
seq_len,
|
||||
embedding_dim,
|
||||
time_decay, # TODO replace embedding_dim with attention_dim
|
||||
time_first,
|
||||
k,
|
||||
v,
|
||||
):
|
||||
wkv_cuda = cls.wkv_cuda
|
||||
wkv_config = cls.wkv_config
|
||||
# setup context # TODO for PyTorch 2.0 use extra setup_context() function
|
||||
ctx.batch_size = batch_size
|
||||
ctx.seq_len = seq_len
|
||||
ctx.embedding_dim = embedding_dim
|
||||
ctx.wkv_cuda = wkv_cuda
|
||||
ctx.wkv_config = wkv_config
|
||||
assert (
|
||||
seq_len <= wkv_config.T_max
|
||||
), f"Sequence length {seq_len} exceeds the maximum allowed T_max={wkv_config.T_max}"
|
||||
# TODO what does this assert do? Why necessary?
|
||||
assert (
|
||||
batch_size * embedding_dim % min(embedding_dim, wkv_config.T_max) == 0
|
||||
), "batch_size * embedding_dim must be divisible by min(embedding_dim, T_max)"
|
||||
#
|
||||
dtype = torch.float32 # convert all tensors to float32 (for cuda kernel)
|
||||
device = wkv_config.device
|
||||
# convert input tensors
|
||||
time_decay = time_decay.to(dtype=dtype, device=device, memory_format=torch.contiguous_format)
|
||||
time_first = time_first.to(dtype=dtype, device=device, memory_format=torch.contiguous_format)
|
||||
k = k.to(dtype=dtype, device=device, memory_format=torch.contiguous_format)
|
||||
v = v.to(dtype=dtype, device=device, memory_format=torch.contiguous_format)
|
||||
|
||||
# allocate output tensor
|
||||
y = torch.empty(
|
||||
batch_size, seq_len, embedding_dim, dtype=dtype, device=device, memory_format=torch.contiguous_format
|
||||
)
|
||||
|
||||
# call cuda kernel
|
||||
time_decay = -torch.exp(time_decay) # TODO why is this necessary?
|
||||
wkv_cuda.forward(batch_size, seq_len, embedding_dim, time_decay, time_first, k, v, y)
|
||||
ctx.save_for_backward(time_decay, time_first, k, v, y)
|
||||
|
||||
# convert output tensor to correct dtype
|
||||
y = y.to(dtype=wkv_config.float_mode_to_dtype())
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, gy):
|
||||
batch_size = ctx.batch_size
|
||||
seq_len = ctx.seq_len
|
||||
embedding_dim = ctx.embedding_dim
|
||||
assert (
|
||||
seq_len <= ctx.wkv_config.T_max
|
||||
), f"Sequence length {seq_len} exceeds the maximum allowed T_max={ctx.wkv_config.T_max}"
|
||||
assert (
|
||||
batch_size * embedding_dim % min(embedding_dim, ctx.wkv_config.T_max) == 0
|
||||
), "batch_size * embedding_dim must be divisible by min(embedding_dim, T_max)"
|
||||
|
||||
time_decay, time_first, k, v, y = ctx.saved_tensors
|
||||
|
||||
device = ctx.wkv_config.device
|
||||
# allocate gradient tensors
|
||||
gtime_decay = torch.zeros((batch_size, embedding_dim), device=device, dtype=torch.float32).contiguous()
|
||||
gtime_first = torch.zeros((batch_size, embedding_dim), device=device, dtype=torch.float32).contiguous()
|
||||
gk = torch.zeros((batch_size, seq_len, embedding_dim), device=device, dtype=torch.float32).contiguous()
|
||||
gv = torch.zeros((batch_size, seq_len, embedding_dim), device=device, dtype=torch.float32).contiguous()
|
||||
|
||||
# call cuda kernel
|
||||
gy = gy.to(dtype=torch.float32, memory_format=torch.contiguous_format)
|
||||
# arg0: int, arg1: int, arg2: int, arg3: at::Tensor, arg4: at::Tensor, arg5: at::Tensor, arg6: at::Tensor,
|
||||
# arg7: at::Tensor, arg8: at::Tensor, arg9: at::Tensor, arg10: at::Tensor, arg11: at::Tensor, arg12: at::Tensor
|
||||
ctx.wkv_cuda.backward(
|
||||
batch_size,
|
||||
seq_len,
|
||||
embedding_dim,
|
||||
time_decay,
|
||||
time_first,
|
||||
k,
|
||||
v,
|
||||
y,
|
||||
gy,
|
||||
gtime_decay,
|
||||
gtime_first,
|
||||
gk,
|
||||
gv,
|
||||
)
|
||||
|
||||
gtime_decay = gtime_decay.sum(dim=0)
|
||||
gtime_first = gtime_first.sum(dim=0)
|
||||
|
||||
# convert gradient tensors to correct dtype
|
||||
out_dtype = ctx.wkv_config.float_mode_to_dtype()
|
||||
|
||||
return (
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
gtime_decay.to(dtype=out_dtype),
|
||||
gtime_first.to(dtype=out_dtype),
|
||||
gk.to(dtype=out_dtype),
|
||||
gv.to(dtype=out_dtype),
|
||||
)
|
||||
|
||||
def __new__(cls, config: WKVConfig = WKVConfig()):
|
||||
if cls._instance is None:
|
||||
cls._instance = super(WKV, cls).__new__(cls)
|
||||
cls._instance._setup(config)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Dummy to avoid multiple calls to self._load_cuda()
|
||||
pass
|
||||
|
||||
def _setup(self, config: WKVConfig = WKVConfig()):
|
||||
"""Setup the WKV module. This is called by __new__ as constructor."""
|
||||
super().__init__()
|
||||
self.cfg = config
|
||||
self.wkv_cuda = self._load_cuda()
|
||||
self.device = self.cfg.device
|
||||
|
||||
def _load_cuda(self):
|
||||
cfg = self.cfg
|
||||
os.environ["CUDA_LIB"] = os.path.join(
|
||||
os.path.split(torch.utils.cpp_extension.include_paths(cuda=True)[-1])[0], "lib"
|
||||
)
|
||||
print(os.environ.get("LD_LIBRARY_PATH", ""))
|
||||
print(os.environ["CUDA_LIB"])
|
||||
|
||||
cpp_ext_sources_float32: List[str] = [
|
||||
str(Path(__file__).parent / "cuda/wkv_op.cpp"),
|
||||
str(Path(__file__).parent / "cuda/wkv_cuda.cu"),
|
||||
]
|
||||
cpp_ext_sources_bfloat16: List[str] = [
|
||||
str(Path(__file__).parent / "cuda/wkv_op_bf16.cpp"),
|
||||
str(Path(__file__).parent / "cuda/wkv_cuda_bf16.cu"),
|
||||
]
|
||||
if cfg.float_mode_to_dtype() == torch.float32:
|
||||
cpp_ext_sources = cpp_ext_sources_float32
|
||||
elif cfg.float_mode_to_dtype() == torch.bfloat16:
|
||||
cpp_ext_sources = cpp_ext_sources_bfloat16
|
||||
else:
|
||||
raise ValueError(f"Unsupported float mode {cfg.float_mode}")
|
||||
|
||||
myargs = {
|
||||
"verbose": True,
|
||||
"with_cuda": True,
|
||||
"extra_ldflags": [f"-L{os.environ['CUDA_LIB']}", "-lcublas"],
|
||||
"extra_cuda_cflags": [
|
||||
# "-gencode",
|
||||
# "arch=compute_70,code=compute_70",
|
||||
"-gencode",
|
||||
"arch=compute_80,code=compute_80",
|
||||
"-res-usage",
|
||||
"--use_fast_math",
|
||||
"-O3",
|
||||
"-Xptxas -O3",
|
||||
"--extra-device-vectorization",
|
||||
f"-DTmax={cfg.T_max}",
|
||||
],
|
||||
}
|
||||
|
||||
cuda_module = load(name=cfg.cpp_ext_name, sources=cpp_ext_sources, **myargs)
|
||||
|
||||
return cuda_module
|
||||
|
||||
def to(self, **kwargs):
|
||||
device = kwargs.get("device", None)
|
||||
if device is not None:
|
||||
self.device = self.cfg.device = torch.device(device)
|
||||
return super().to(**kwargs)
|
||||
|
||||
def forward(self, batch_size, seq_len, embeding_dim, time_decay, time_first, k, v):
|
||||
self.device = self.cfg.device = v.device
|
||||
assert self.device != torch.device("cpu"), "WKV is not implemented for CPU"
|
||||
self._WKV.wkv_cuda = self.wkv_cuda
|
||||
self._WKV.wkv_config = self.cfg
|
||||
return self._WKV.apply(batch_size, seq_len, embeding_dim, time_decay, time_first, k, v)
|
||||
|
||||
|
||||
class WKVTorch(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# did not add time_decay = -torch.exp(time_decay) here, it is not necessary
|
||||
# this is done in the forward function of WKV cuda kernel
|
||||
# if it is added here, the outputs will be different from the cuda kernel
|
||||
|
||||
def forward(self, batch_size, seq_len, embedding_dim, time_decay, time_first, k, v):
|
||||
dtype = k.dtype
|
||||
device = k.device
|
||||
y = torch.zeros(batch_size, seq_len, embedding_dim, dtype=dtype, device=device)
|
||||
MIN_VAL = 0.0 # -1e38
|
||||
# reshape inputs
|
||||
k_ = torch.permute(k, (1, 0, 2)) # rearrange(k, 'b s e -> s b e')
|
||||
v_ = torch.permute(v, (1, 0, 2)) # rearrange(v, 'b s e -> s b e')
|
||||
y_ = torch.permute(y, (1, 0, 2)) # rearrange(y, 'b s e -> s b e')
|
||||
tf = time_first.repeat(batch_size, 1) # repeat(time_first, 'e -> b e', b=batch_size)
|
||||
td = time_decay.repeat(batch_size, 1) # repeat(time_decay, 'e -> b e', b=batch_size)
|
||||
# running sums
|
||||
aa = torch.zeros(batch_size, embedding_dim, dtype=dtype, device=device)
|
||||
bb = torch.zeros(batch_size, embedding_dim, dtype=dtype, device=device)
|
||||
eps = torch.full((batch_size, embedding_dim), MIN_VAL, dtype=dtype, device=device)
|
||||
for t in range(seq_len):
|
||||
e_tf_k = torch.exp(tf + k_[t] - eps)
|
||||
y_[t] = (aa + v_[t] * e_tf_k) / (bb + e_tf_k)
|
||||
eps_next = torch.max(td + eps, k_[t])
|
||||
|
||||
e_td_k = torch.exp(td + eps - eps_next)
|
||||
e_k = torch.exp(k_[t] - eps_next)
|
||||
aa = aa * e_td_k + v_[t] * e_k
|
||||
bb = bb * e_td_k + e_k
|
||||
eps = eps_next
|
||||
y = rearrange(y_, "s b e -> b s e")
|
||||
return y
|
||||
Reference in New Issue
Block a user