260 lines
10 KiB
Python
260 lines
10 KiB
Python
import os
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import List, Union
|
|
|
|
import torch
|
|
from einops import rearrange, repeat
|
|
from torch import nn
|
|
from torch.utils.cpp_extension import load
|
|
|
|
"""This module is a wrapper for the C++/CUDA implementation of the WKV kernel.
|
|
"""
|
|
|
|
|
|
@dataclass
|
|
class WKVConfig:
|
|
T_max: int = 1024 # max sequence length within cuda operations
|
|
cpp_ext_name: str = "wkv"
|
|
device: Union[str, torch.device] = "cuda"
|
|
float_mode: str = "fp32" # options: fp32, fp16, bfloat16
|
|
|
|
def __post_init__(self):
|
|
self.device = torch.device(self.device)
|
|
|
|
def float_mode_to_dtype(self):
|
|
if self.float_mode == "fp32" or "32" in str(self.float_mode):
|
|
return torch.float32
|
|
elif self.float_mode == "fp16" or "16" in str(self.float_mode):
|
|
return torch.float16
|
|
elif self.float_mode == "bfloat16":
|
|
return torch.bfloat16
|
|
else:
|
|
raise ValueError(f"Unknown float_mode: {self.float_mode}")
|
|
|
|
|
|
class WKV(nn.Module):
|
|
_instance = None # for singleton
|
|
|
|
class _WKV(torch.autograd.Function):
|
|
wkv_cuda = None
|
|
wkv_config: WKVConfig = None
|
|
|
|
@classmethod
|
|
def forward(
|
|
cls,
|
|
ctx,
|
|
batch_size,
|
|
seq_len,
|
|
embedding_dim,
|
|
time_decay, # TODO replace embedding_dim with attention_dim
|
|
time_first,
|
|
k,
|
|
v,
|
|
):
|
|
wkv_cuda = cls.wkv_cuda
|
|
wkv_config = cls.wkv_config
|
|
# setup context # TODO for PyTorch 2.0 use extra setup_context() function
|
|
ctx.batch_size = batch_size
|
|
ctx.seq_len = seq_len
|
|
ctx.embedding_dim = embedding_dim
|
|
ctx.wkv_cuda = wkv_cuda
|
|
ctx.wkv_config = wkv_config
|
|
assert (
|
|
seq_len <= wkv_config.T_max
|
|
), f"Sequence length {seq_len} exceeds the maximum allowed T_max={wkv_config.T_max}"
|
|
# TODO what does this assert do? Why necessary?
|
|
assert (
|
|
batch_size * embedding_dim % min(embedding_dim, wkv_config.T_max) == 0
|
|
), "batch_size * embedding_dim must be divisible by min(embedding_dim, T_max)"
|
|
#
|
|
dtype = torch.float32 # convert all tensors to float32 (for cuda kernel)
|
|
device = wkv_config.device
|
|
# convert input tensors
|
|
time_decay = time_decay.to(dtype=dtype, device=device, memory_format=torch.contiguous_format)
|
|
time_first = time_first.to(dtype=dtype, device=device, memory_format=torch.contiguous_format)
|
|
k = k.to(dtype=dtype, device=device, memory_format=torch.contiguous_format)
|
|
v = v.to(dtype=dtype, device=device, memory_format=torch.contiguous_format)
|
|
|
|
# allocate output tensor
|
|
y = torch.empty(
|
|
batch_size, seq_len, embedding_dim, dtype=dtype, device=device, memory_format=torch.contiguous_format
|
|
)
|
|
|
|
# call cuda kernel
|
|
time_decay = -torch.exp(time_decay) # TODO why is this necessary?
|
|
wkv_cuda.forward(batch_size, seq_len, embedding_dim, time_decay, time_first, k, v, y)
|
|
ctx.save_for_backward(time_decay, time_first, k, v, y)
|
|
|
|
# convert output tensor to correct dtype
|
|
y = y.to(dtype=wkv_config.float_mode_to_dtype())
|
|
return y
|
|
|
|
@staticmethod
|
|
def backward(ctx, gy):
|
|
batch_size = ctx.batch_size
|
|
seq_len = ctx.seq_len
|
|
embedding_dim = ctx.embedding_dim
|
|
assert (
|
|
seq_len <= ctx.wkv_config.T_max
|
|
), f"Sequence length {seq_len} exceeds the maximum allowed T_max={ctx.wkv_config.T_max}"
|
|
assert (
|
|
batch_size * embedding_dim % min(embedding_dim, ctx.wkv_config.T_max) == 0
|
|
), "batch_size * embedding_dim must be divisible by min(embedding_dim, T_max)"
|
|
|
|
time_decay, time_first, k, v, y = ctx.saved_tensors
|
|
|
|
device = ctx.wkv_config.device
|
|
# allocate gradient tensors
|
|
gtime_decay = torch.zeros((batch_size, embedding_dim), device=device, dtype=torch.float32).contiguous()
|
|
gtime_first = torch.zeros((batch_size, embedding_dim), device=device, dtype=torch.float32).contiguous()
|
|
gk = torch.zeros((batch_size, seq_len, embedding_dim), device=device, dtype=torch.float32).contiguous()
|
|
gv = torch.zeros((batch_size, seq_len, embedding_dim), device=device, dtype=torch.float32).contiguous()
|
|
|
|
# call cuda kernel
|
|
gy = gy.to(dtype=torch.float32, memory_format=torch.contiguous_format)
|
|
# arg0: int, arg1: int, arg2: int, arg3: at::Tensor, arg4: at::Tensor, arg5: at::Tensor, arg6: at::Tensor,
|
|
# arg7: at::Tensor, arg8: at::Tensor, arg9: at::Tensor, arg10: at::Tensor, arg11: at::Tensor, arg12: at::Tensor
|
|
ctx.wkv_cuda.backward(
|
|
batch_size,
|
|
seq_len,
|
|
embedding_dim,
|
|
time_decay,
|
|
time_first,
|
|
k,
|
|
v,
|
|
y,
|
|
gy,
|
|
gtime_decay,
|
|
gtime_first,
|
|
gk,
|
|
gv,
|
|
)
|
|
|
|
gtime_decay = gtime_decay.sum(dim=0)
|
|
gtime_first = gtime_first.sum(dim=0)
|
|
|
|
# convert gradient tensors to correct dtype
|
|
out_dtype = ctx.wkv_config.float_mode_to_dtype()
|
|
|
|
return (
|
|
None,
|
|
None,
|
|
None,
|
|
gtime_decay.to(dtype=out_dtype),
|
|
gtime_first.to(dtype=out_dtype),
|
|
gk.to(dtype=out_dtype),
|
|
gv.to(dtype=out_dtype),
|
|
)
|
|
|
|
def __new__(cls, config: WKVConfig = WKVConfig()):
|
|
if cls._instance is None:
|
|
cls._instance = super(WKV, cls).__new__(cls)
|
|
cls._instance._setup(config)
|
|
return cls._instance
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
# Dummy to avoid multiple calls to self._load_cuda()
|
|
pass
|
|
|
|
def _setup(self, config: WKVConfig = WKVConfig()):
|
|
"""Setup the WKV module. This is called by __new__ as constructor."""
|
|
super().__init__()
|
|
self.cfg = config
|
|
self.wkv_cuda = self._load_cuda()
|
|
self.device = self.cfg.device
|
|
|
|
def _load_cuda(self):
|
|
cfg = self.cfg
|
|
os.environ["CUDA_LIB"] = os.path.join(
|
|
os.path.split(torch.utils.cpp_extension.include_paths(cuda=True)[-1])[0], "lib"
|
|
)
|
|
print(os.environ.get("LD_LIBRARY_PATH", ""))
|
|
print(os.environ["CUDA_LIB"])
|
|
|
|
cpp_ext_sources_float32: List[str] = [
|
|
str(Path(__file__).parent / "cuda/wkv_op.cpp"),
|
|
str(Path(__file__).parent / "cuda/wkv_cuda.cu"),
|
|
]
|
|
cpp_ext_sources_bfloat16: List[str] = [
|
|
str(Path(__file__).parent / "cuda/wkv_op_bf16.cpp"),
|
|
str(Path(__file__).parent / "cuda/wkv_cuda_bf16.cu"),
|
|
]
|
|
if cfg.float_mode_to_dtype() == torch.float32:
|
|
cpp_ext_sources = cpp_ext_sources_float32
|
|
elif cfg.float_mode_to_dtype() == torch.bfloat16:
|
|
cpp_ext_sources = cpp_ext_sources_bfloat16
|
|
else:
|
|
raise ValueError(f"Unsupported float mode {cfg.float_mode}")
|
|
|
|
myargs = {
|
|
"verbose": True,
|
|
"with_cuda": True,
|
|
"extra_ldflags": [f"-L{os.environ['CUDA_LIB']}", "-lcublas"],
|
|
"extra_cuda_cflags": [
|
|
# "-gencode",
|
|
# "arch=compute_70,code=compute_70",
|
|
"-gencode",
|
|
"arch=compute_80,code=compute_80",
|
|
"-res-usage",
|
|
"--use_fast_math",
|
|
"-O3",
|
|
"-Xptxas -O3",
|
|
"--extra-device-vectorization",
|
|
f"-DTmax={cfg.T_max}",
|
|
],
|
|
}
|
|
|
|
cuda_module = load(name=cfg.cpp_ext_name, sources=cpp_ext_sources, **myargs)
|
|
|
|
return cuda_module
|
|
|
|
def to(self, **kwargs):
|
|
device = kwargs.get("device", None)
|
|
if device is not None:
|
|
self.device = self.cfg.device = torch.device(device)
|
|
return super().to(**kwargs)
|
|
|
|
def forward(self, batch_size, seq_len, embeding_dim, time_decay, time_first, k, v):
|
|
self.device = self.cfg.device = v.device
|
|
assert self.device != torch.device("cpu"), "WKV is not implemented for CPU"
|
|
self._WKV.wkv_cuda = self.wkv_cuda
|
|
self._WKV.wkv_config = self.cfg
|
|
return self._WKV.apply(batch_size, seq_len, embeding_dim, time_decay, time_first, k, v)
|
|
|
|
|
|
class WKVTorch(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
# did not add time_decay = -torch.exp(time_decay) here, it is not necessary
|
|
# this is done in the forward function of WKV cuda kernel
|
|
# if it is added here, the outputs will be different from the cuda kernel
|
|
|
|
def forward(self, batch_size, seq_len, embedding_dim, time_decay, time_first, k, v):
|
|
dtype = k.dtype
|
|
device = k.device
|
|
y = torch.zeros(batch_size, seq_len, embedding_dim, dtype=dtype, device=device)
|
|
MIN_VAL = 0.0 # -1e38
|
|
# reshape inputs
|
|
k_ = torch.permute(k, (1, 0, 2)) # rearrange(k, 'b s e -> s b e')
|
|
v_ = torch.permute(v, (1, 0, 2)) # rearrange(v, 'b s e -> s b e')
|
|
y_ = torch.permute(y, (1, 0, 2)) # rearrange(y, 'b s e -> s b e')
|
|
tf = time_first.repeat(batch_size, 1) # repeat(time_first, 'e -> b e', b=batch_size)
|
|
td = time_decay.repeat(batch_size, 1) # repeat(time_decay, 'e -> b e', b=batch_size)
|
|
# running sums
|
|
aa = torch.zeros(batch_size, embedding_dim, dtype=dtype, device=device)
|
|
bb = torch.zeros(batch_size, embedding_dim, dtype=dtype, device=device)
|
|
eps = torch.full((batch_size, embedding_dim), MIN_VAL, dtype=dtype, device=device)
|
|
for t in range(seq_len):
|
|
e_tf_k = torch.exp(tf + k_[t] - eps)
|
|
y_[t] = (aa + v_[t] * e_tf_k) / (bb + e_tf_k)
|
|
eps_next = torch.max(td + eps, k_[t])
|
|
|
|
e_td_k = torch.exp(td + eps - eps_next)
|
|
e_k = torch.exp(k_[t] - eps_next)
|
|
aa = aa * e_td_k + v_[t] * e_k
|
|
bb = bb * e_td_k + e_k
|
|
eps = eps_next
|
|
y = rearrange(y_, "s b e -> b s e")
|
|
return y
|