Add RWKV, H3, Hyena

This commit is contained in:
2023-08-05 17:33:32 +02:00
parent a71030547c
commit 7b15a413d4
22 changed files with 1794 additions and 0 deletions

View File

@@ -0,0 +1,25 @@
#include <torch/extension.h>
// #include <c10/cuda/CUDAGuard.h>
void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y);
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv);
void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
// TODO add this line const at::cuda::OptionalCUDAGuard device_guard(device_of(w));
// (see https://discord.com/channels/992359628979568762/1084547464452907088/1091680712178016326)
// see https://github.com/BlinkDL/ChatRWKV/blob/main/rwkv_pip_package/src/rwkv/cuda/wrapper.cpp
cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>());
}
void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>(), gy.data_ptr<float>(), gw.data_ptr<float>(), gu.data_ptr<float>(), gk.data_ptr<float>(), gv.data_ptr<float>());
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "wkv forward");
m.def("backward", &backward, "wkv backward");
}
TORCH_LIBRARY(wkv, m) {
m.def("forward", forward);
m.def("backward", backward);
}