Add RWKV, H3, Hyena

This commit is contained in:
2023-08-05 17:33:32 +02:00
parent a71030547c
commit 7b15a413d4
22 changed files with 1794 additions and 0 deletions

165
models/base_model.py Normal file
View File

@@ -0,0 +1,165 @@
import copy
from abc import ABC, abstractmethod
from dataclasses import asdict
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Union
import torch
from torch import nn
FN_MODEL_PREFIX = "model_"
FN_MODEL_FILE_EXT = ".p"
def get_device(device: Union[torch.device, str, int]) -> torch.device:
if device == "auto":
device = "cuda"
if isinstance(device, int):
if device < 0:
device = torch.device("cpu")
else:
device = torch.device(f"cuda:{device}")
else:
device = torch.device(device)
if (
device.type == torch.device("cuda").type
and not torch.cuda.is_available()
):
LOGGER.warn(f"Device '{str(device)}' is not available! Using cpu now.")
return torch.device("cpu")
return device
class BaseModel(nn.Module, ABC):
"""BaseModel class
Takes care of easy saving and loading.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.config = None
@abstractmethod
def forward(self, *args, **kwargs):
pass
@abstractmethod
def get_loss_func(self, **kwargs):
pass
def _get_constructor_parameters(self) -> dict:
if isinstance(self.config, dict):
return self.config
return asdict(self.config)
def reset_parameters(self):
self.apply(self.get_init_fn())
def get_init_fn(self) -> Callable[[torch.Tensor], None]:
return None
@property
def num_parameters(self) -> int:
return (
torch.tensor([p.numel() for p in self.parameters()]).sum().item()
)
@property
def device(self) -> torch.device:
return next(iter(self.parameters())).device
def copy_to_cpu(self) -> "BaseModel":
"""Copy the model to CPU."""
return copy.deepcopy(self).to(torch.device("cpu"))
def get_checkpoint_data(
self, dict_key_prefix: str = "model_"
) -> Dict[str, Any]:
checkpoint_dict = {
f"{dict_key_prefix}state_dict": self.state_dict(),
f"{dict_key_prefix}data": self._get_constructor_parameters(),
f"{dict_key_prefix}name": self.__class__.__name__,
f"{dict_key_prefix}class": self.__class__,
}
return checkpoint_dict
def save(
self,
path: Union[str, Path],
model_name: str,
file_extension: Optional[str] = FN_MODEL_FILE_EXT,
dict_key_prefix: str = "model_",
) -> None:
if isinstance(path, str):
path = Path(path)
save_path = path / (model_name + file_extension)
torch.save(self.get_checkpoint_data(dict_key_prefix), save_path)
@staticmethod
def model_save_name(
idx: int, specifier: str = "epoch", num_digits: int = -1
) -> str:
"""Get a consistnet the model save name.
Args:
epoch (int): Epoch / iteration number.
specifier (str, optional): A specifier for the idx. Defaults to epoch.
num_digits (int, optional): The number of digits in the save name. Unused by default,
since this causes overrides when we have an overflow. Defaults to -1.
Returns:
str: Model save name.
"""
if num_digits == -1:
return f"{FN_MODEL_PREFIX}{specifier}_{idx}"
else:
return f"{FN_MODEL_PREFIX}{specifier}_{idx:0{num_digits}}"
@classmethod
def load(
cls,
path: Union[str, Path],
model_name: str = None,
file_extension: Optional[str] = ".p",
device: Union[torch.device, str, int] = "auto",
dict_key_prefix: str = "model_",
) -> "BaseModel":
device = get_device(device)
if isinstance(path, str):
path = Path(path)
if model_name is None:
save_path = path
else:
save_path = path / (model_name + file_extension)
checkpoint = torch.load(save_path, map_location=device)
return cls.params_from_checkpoint(
checkpoint=checkpoint, dict_key_prefix=dict_key_prefix
)
@classmethod
def params_from_checkpoint(
cls, checkpoint: Dict[str, Any], dict_key_prefix: str = "model_"
) -> "BaseModel":
if hasattr(cls, "config_class"):
from dacite import from_dict
config_cls = cls.config_class
model_cfg = from_dict(
data_class=config_cls,
data=checkpoint[f"{dict_key_prefix}data"],
)
model = cls(config=model_cfg)
else:
model = cls(**checkpoint[f"{dict_key_prefix}data"])
model.load_state_dict(checkpoint[f"{dict_key_prefix}state_dict"])
return model
# @staticmethod
# def class_and_params_from_checkpoint(checkpoint: Dict[str, Any], dict_key_prefix: str = 'model_') -> 'BaseModel':
# from . import get_model_class
# model_class = get_model_class(checkpoint[f"{dict_key_prefix}name"])
# model = model_class.params_from_checkpoint(checkpoint=checkpoint, dict_key_prefix=dict_key_prefix)
# return model