Add RWKV, H3, Hyena
This commit is contained in:
137
TestModels.ipynb
Normal file
137
TestModels.ipynb
Normal file
@@ -0,0 +1,137 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%reload_ext autoreload\n",
|
||||
"import torch"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import importlib\n",
|
||||
"import os\n",
|
||||
"import sys\n",
|
||||
"\n",
|
||||
"# def import_from_file(module_name, file_path):\n",
|
||||
"# spec = importlib.util.spec_from_file_location(module_name, file_path)\n",
|
||||
"# module = importlib.util.module_from_spec(spec)\n",
|
||||
"# sys.modules[module_name] = module\n",
|
||||
"# spec.loader.exec_module(module)\n",
|
||||
"# return module\n",
|
||||
"\n",
|
||||
"# os.environ['RWKV_JIT_ON'] = '1'\n",
|
||||
"# os.environ['RWKV_T_MAX'] = '16384'\n",
|
||||
"# os.environ['RWKV_FLOAT_MODE'] = 'fp32'\n",
|
||||
"# os.environ['CUDA_HOME']\n",
|
||||
"# path = os.path.abspath('RWKV-LM/RWKV-v4neo/src/model.py')\n",
|
||||
"# rkwv_mod = import_from_file('rwkv_mod', path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Falling back on slow Vandermonde kernel. Install pykeops for improved memory efficiency.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from models.rwkv import RWKV, RWKVConfig\n",
|
||||
"sys.path.append('./safari')\n",
|
||||
"from safari.src.models.sequence.model import SequenceModel\n",
|
||||
"from safari.src.models.sequence.h3 import H3\n",
|
||||
"from safari.src.models.sequence.hyena import HyenaOperator\n",
|
||||
"from omegaconf import OmegaConf\n",
|
||||
"os.environ['CUDA_HOME'] = \".\"\n",
|
||||
"\n",
|
||||
"NUM_LAYERS = 3\n",
|
||||
"EMBEDDING_DIM = 128\n",
|
||||
"SEQUENCE_LENGTH = 1024\n",
|
||||
"\n",
|
||||
"def load_config(filename):\n",
|
||||
" with open(filename) as fp:\n",
|
||||
" cfg_yaml = fp.read()\n",
|
||||
" return OmegaConf.create(cfg_yaml)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"rwkv_full = RWKV(RWKVConfig(context_length=SEQUENCE_LENGTH, embedding_dim=EMBEDDING_DIM, num_layers=NUM_LAYERS, wkv_config=None))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"hyena_full = SequenceModel(layer=load_config('config/models/hyena.yaml'), d_model=EMBEDDING_DIM, n_layers=NUM_LAYERS)\n",
|
||||
"h3_full = SequenceModel(layer=load_config('config/models/h3.yaml'), d_model=EMBEDDING_DIM, n_layers=NUM_LAYERS)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"h3 = H3(d_model=EMBEDDING_DIM)\n",
|
||||
"hyena = HyenaOperator(d_model=EMBEDDING_DIM, l_max=16384)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "base",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.11"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
Reference in New Issue
Block a user