Files
state_interpretation/models/rwkv/wkv_kernel.py
2023-08-05 17:35:11 +02:00

260 lines
10 KiB
Python

import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Union
import torch
from einops import rearrange, repeat
from torch import nn
from torch.utils.cpp_extension import load
"""This module is a wrapper for the C++/CUDA implementation of the WKV kernel.
"""
@dataclass
class WKVConfig:
T_max: int = 1024 # max sequence length within cuda operations
cpp_ext_name: str = "wkv"
device: Union[str, torch.device] = "cuda"
float_mode: str = "fp32" # options: fp32, fp16, bfloat16
def __post_init__(self):
self.device = torch.device(self.device)
def float_mode_to_dtype(self):
if self.float_mode == "fp32" or "32" in str(self.float_mode):
return torch.float32
elif self.float_mode == "fp16" or "16" in str(self.float_mode):
return torch.float16
elif self.float_mode == "bfloat16":
return torch.bfloat16
else:
raise ValueError(f"Unknown float_mode: {self.float_mode}")
class WKV(nn.Module):
_instance = None # for singleton
class _WKV(torch.autograd.Function):
wkv_cuda = None
wkv_config: WKVConfig = None
@classmethod
def forward(
cls,
ctx,
batch_size,
seq_len,
embedding_dim,
time_decay, # TODO replace embedding_dim with attention_dim
time_first,
k,
v,
):
wkv_cuda = cls.wkv_cuda
wkv_config = cls.wkv_config
# setup context # TODO for PyTorch 2.0 use extra setup_context() function
ctx.batch_size = batch_size
ctx.seq_len = seq_len
ctx.embedding_dim = embedding_dim
ctx.wkv_cuda = wkv_cuda
ctx.wkv_config = wkv_config
assert (
seq_len <= wkv_config.T_max
), f"Sequence length {seq_len} exceeds the maximum allowed T_max={wkv_config.T_max}"
# TODO what does this assert do? Why necessary?
assert (
batch_size * embedding_dim % min(embedding_dim, wkv_config.T_max) == 0
), "batch_size * embedding_dim must be divisible by min(embedding_dim, T_max)"
#
dtype = torch.float32 # convert all tensors to float32 (for cuda kernel)
device = wkv_config.device
# convert input tensors
time_decay = time_decay.to(dtype=dtype, device=device, memory_format=torch.contiguous_format)
time_first = time_first.to(dtype=dtype, device=device, memory_format=torch.contiguous_format)
k = k.to(dtype=dtype, device=device, memory_format=torch.contiguous_format)
v = v.to(dtype=dtype, device=device, memory_format=torch.contiguous_format)
# allocate output tensor
y = torch.empty(
batch_size, seq_len, embedding_dim, dtype=dtype, device=device, memory_format=torch.contiguous_format
)
# call cuda kernel
time_decay = -torch.exp(time_decay) # TODO why is this necessary?
wkv_cuda.forward(batch_size, seq_len, embedding_dim, time_decay, time_first, k, v, y)
ctx.save_for_backward(time_decay, time_first, k, v, y)
# convert output tensor to correct dtype
y = y.to(dtype=wkv_config.float_mode_to_dtype())
return y
@staticmethod
def backward(ctx, gy):
batch_size = ctx.batch_size
seq_len = ctx.seq_len
embedding_dim = ctx.embedding_dim
assert (
seq_len <= ctx.wkv_config.T_max
), f"Sequence length {seq_len} exceeds the maximum allowed T_max={ctx.wkv_config.T_max}"
assert (
batch_size * embedding_dim % min(embedding_dim, ctx.wkv_config.T_max) == 0
), "batch_size * embedding_dim must be divisible by min(embedding_dim, T_max)"
time_decay, time_first, k, v, y = ctx.saved_tensors
device = ctx.wkv_config.device
# allocate gradient tensors
gtime_decay = torch.zeros((batch_size, embedding_dim), device=device, dtype=torch.float32).contiguous()
gtime_first = torch.zeros((batch_size, embedding_dim), device=device, dtype=torch.float32).contiguous()
gk = torch.zeros((batch_size, seq_len, embedding_dim), device=device, dtype=torch.float32).contiguous()
gv = torch.zeros((batch_size, seq_len, embedding_dim), device=device, dtype=torch.float32).contiguous()
# call cuda kernel
gy = gy.to(dtype=torch.float32, memory_format=torch.contiguous_format)
# arg0: int, arg1: int, arg2: int, arg3: at::Tensor, arg4: at::Tensor, arg5: at::Tensor, arg6: at::Tensor,
# arg7: at::Tensor, arg8: at::Tensor, arg9: at::Tensor, arg10: at::Tensor, arg11: at::Tensor, arg12: at::Tensor
ctx.wkv_cuda.backward(
batch_size,
seq_len,
embedding_dim,
time_decay,
time_first,
k,
v,
y,
gy,
gtime_decay,
gtime_first,
gk,
gv,
)
gtime_decay = gtime_decay.sum(dim=0)
gtime_first = gtime_first.sum(dim=0)
# convert gradient tensors to correct dtype
out_dtype = ctx.wkv_config.float_mode_to_dtype()
return (
None,
None,
None,
gtime_decay.to(dtype=out_dtype),
gtime_first.to(dtype=out_dtype),
gk.to(dtype=out_dtype),
gv.to(dtype=out_dtype),
)
def __new__(cls, config: WKVConfig = WKVConfig()):
if cls._instance is None:
cls._instance = super(WKV, cls).__new__(cls)
cls._instance._setup(config)
return cls._instance
def __init__(self, *args, **kwargs):
# Dummy to avoid multiple calls to self._load_cuda()
pass
def _setup(self, config: WKVConfig = WKVConfig()):
"""Setup the WKV module. This is called by __new__ as constructor."""
super().__init__()
self.cfg = config
self.wkv_cuda = self._load_cuda()
self.device = self.cfg.device
def _load_cuda(self):
cfg = self.cfg
os.environ["CUDA_LIB"] = os.path.join(
os.path.split(torch.utils.cpp_extension.include_paths(cuda=True)[-1])[0], "lib"
)
print(os.environ.get("LD_LIBRARY_PATH", ""))
print(os.environ["CUDA_LIB"])
cpp_ext_sources_float32: List[str] = [
str(Path(__file__).parent / "cuda/wkv_op.cpp"),
str(Path(__file__).parent / "cuda/wkv_cuda.cu"),
]
cpp_ext_sources_bfloat16: List[str] = [
str(Path(__file__).parent / "cuda/wkv_op_bf16.cpp"),
str(Path(__file__).parent / "cuda/wkv_cuda_bf16.cu"),
]
if cfg.float_mode_to_dtype() == torch.float32:
cpp_ext_sources = cpp_ext_sources_float32
elif cfg.float_mode_to_dtype() == torch.bfloat16:
cpp_ext_sources = cpp_ext_sources_bfloat16
else:
raise ValueError(f"Unsupported float mode {cfg.float_mode}")
myargs = {
"verbose": True,
"with_cuda": True,
"extra_ldflags": [f"-L{os.environ['CUDA_LIB']}", "-lcublas"],
"extra_cuda_cflags": [
# "-gencode",
# "arch=compute_70,code=compute_70",
"-gencode",
"arch=compute_80,code=compute_80",
"-res-usage",
"--use_fast_math",
"-O3",
"-Xptxas -O3",
"--extra-device-vectorization",
f"-DTmax={cfg.T_max}",
],
}
cuda_module = load(name=cfg.cpp_ext_name, sources=cpp_ext_sources, **myargs)
return cuda_module
def to(self, **kwargs):
device = kwargs.get("device", None)
if device is not None:
self.device = self.cfg.device = torch.device(device)
return super().to(**kwargs)
def forward(self, batch_size, seq_len, embeding_dim, time_decay, time_first, k, v):
self.device = self.cfg.device = v.device
assert self.device != torch.device("cpu"), "WKV is not implemented for CPU"
self._WKV.wkv_cuda = self.wkv_cuda
self._WKV.wkv_config = self.cfg
return self._WKV.apply(batch_size, seq_len, embeding_dim, time_decay, time_first, k, v)
class WKVTorch(nn.Module):
def __init__(self):
super().__init__()
# did not add time_decay = -torch.exp(time_decay) here, it is not necessary
# this is done in the forward function of WKV cuda kernel
# if it is added here, the outputs will be different from the cuda kernel
def forward(self, batch_size, seq_len, embedding_dim, time_decay, time_first, k, v):
dtype = k.dtype
device = k.device
y = torch.zeros(batch_size, seq_len, embedding_dim, dtype=dtype, device=device)
MIN_VAL = 0.0 # -1e38
# reshape inputs
k_ = torch.permute(k, (1, 0, 2)) # rearrange(k, 'b s e -> s b e')
v_ = torch.permute(v, (1, 0, 2)) # rearrange(v, 'b s e -> s b e')
y_ = torch.permute(y, (1, 0, 2)) # rearrange(y, 'b s e -> s b e')
tf = time_first.repeat(batch_size, 1) # repeat(time_first, 'e -> b e', b=batch_size)
td = time_decay.repeat(batch_size, 1) # repeat(time_decay, 'e -> b e', b=batch_size)
# running sums
aa = torch.zeros(batch_size, embedding_dim, dtype=dtype, device=device)
bb = torch.zeros(batch_size, embedding_dim, dtype=dtype, device=device)
eps = torch.full((batch_size, embedding_dim), MIN_VAL, dtype=dtype, device=device)
for t in range(seq_len):
e_tf_k = torch.exp(tf + k_[t] - eps)
y_[t] = (aa + v_[t] * e_tf_k) / (bb + e_tf_k)
eps_next = torch.max(td + eps, k_[t])
e_td_k = torch.exp(td + eps - eps_next)
e_k = torch.exp(k_[t] - eps_next)
aa = aa * e_td_k + v_[t] * e_k
bb = bb * e_td_k + e_k
eps = eps_next
y = rearrange(y_, "s b e -> b s e")
return y