Files
state_interpretation/TestModels.ipynb
2023-08-05 17:35:11 +02:00

138 lines
3.4 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%reload_ext autoreload\n",
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import importlib\n",
"import os\n",
"import sys\n",
"\n",
"# def import_from_file(module_name, file_path):\n",
"# spec = importlib.util.spec_from_file_location(module_name, file_path)\n",
"# module = importlib.util.module_from_spec(spec)\n",
"# sys.modules[module_name] = module\n",
"# spec.loader.exec_module(module)\n",
"# return module\n",
"\n",
"# os.environ['RWKV_JIT_ON'] = '1'\n",
"# os.environ['RWKV_T_MAX'] = '16384'\n",
"# os.environ['RWKV_FLOAT_MODE'] = 'fp32'\n",
"# os.environ['CUDA_HOME']\n",
"# path = os.path.abspath('RWKV-LM/RWKV-v4neo/src/model.py')\n",
"# rkwv_mod = import_from_file('rwkv_mod', path)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Falling back on slow Vandermonde kernel. Install pykeops for improved memory efficiency.\n"
]
}
],
"source": [
"from models.rwkv import RWKV, RWKVConfig\n",
"sys.path.append('./safari')\n",
"from safari.src.models.sequence.model import SequenceModel\n",
"from safari.src.models.sequence.h3 import H3\n",
"from safari.src.models.sequence.hyena import HyenaOperator\n",
"from omegaconf import OmegaConf\n",
"os.environ['CUDA_HOME'] = \".\"\n",
"\n",
"NUM_LAYERS = 3\n",
"EMBEDDING_DIM = 128\n",
"SEQUENCE_LENGTH = 1024\n",
"\n",
"def load_config(filename):\n",
" with open(filename) as fp:\n",
" cfg_yaml = fp.read()\n",
" return OmegaConf.create(cfg_yaml)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"rwkv_full = RWKV(RWKVConfig(context_length=SEQUENCE_LENGTH, embedding_dim=EMBEDDING_DIM, num_layers=NUM_LAYERS, wkv_config=None))"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"hyena_full = SequenceModel(layer=load_config('config/models/hyena.yaml'), d_model=EMBEDDING_DIM, n_layers=NUM_LAYERS)\n",
"h3_full = SequenceModel(layer=load_config('config/models/h3.yaml'), d_model=EMBEDDING_DIM, n_layers=NUM_LAYERS)\n",
"\n",
"\n",
"h3 = H3(d_model=EMBEDDING_DIM)\n",
"hyena = HyenaOperator(d_model=EMBEDDING_DIM, l_max=16384)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}