Add RWKV, H3, Hyena
This commit is contained in:
165
models/base_model.py
Normal file
165
models/base_model.py
Normal file
@@ -0,0 +1,165 @@
|
||||
import copy
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import asdict
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
FN_MODEL_PREFIX = "model_"
|
||||
FN_MODEL_FILE_EXT = ".p"
|
||||
|
||||
|
||||
def get_device(device: Union[torch.device, str, int]) -> torch.device:
|
||||
if device == "auto":
|
||||
device = "cuda"
|
||||
if isinstance(device, int):
|
||||
if device < 0:
|
||||
device = torch.device("cpu")
|
||||
else:
|
||||
device = torch.device(f"cuda:{device}")
|
||||
else:
|
||||
device = torch.device(device)
|
||||
|
||||
if (
|
||||
device.type == torch.device("cuda").type
|
||||
and not torch.cuda.is_available()
|
||||
):
|
||||
LOGGER.warn(f"Device '{str(device)}' is not available! Using cpu now.")
|
||||
return torch.device("cpu")
|
||||
return device
|
||||
|
||||
|
||||
class BaseModel(nn.Module, ABC):
|
||||
"""BaseModel class
|
||||
Takes care of easy saving and loading.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.config = None
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_loss_func(self, **kwargs):
|
||||
pass
|
||||
|
||||
def _get_constructor_parameters(self) -> dict:
|
||||
if isinstance(self.config, dict):
|
||||
return self.config
|
||||
return asdict(self.config)
|
||||
|
||||
def reset_parameters(self):
|
||||
self.apply(self.get_init_fn())
|
||||
|
||||
def get_init_fn(self) -> Callable[[torch.Tensor], None]:
|
||||
return None
|
||||
|
||||
@property
|
||||
def num_parameters(self) -> int:
|
||||
return (
|
||||
torch.tensor([p.numel() for p in self.parameters()]).sum().item()
|
||||
)
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return next(iter(self.parameters())).device
|
||||
|
||||
def copy_to_cpu(self) -> "BaseModel":
|
||||
"""Copy the model to CPU."""
|
||||
return copy.deepcopy(self).to(torch.device("cpu"))
|
||||
|
||||
def get_checkpoint_data(
|
||||
self, dict_key_prefix: str = "model_"
|
||||
) -> Dict[str, Any]:
|
||||
checkpoint_dict = {
|
||||
f"{dict_key_prefix}state_dict": self.state_dict(),
|
||||
f"{dict_key_prefix}data": self._get_constructor_parameters(),
|
||||
f"{dict_key_prefix}name": self.__class__.__name__,
|
||||
f"{dict_key_prefix}class": self.__class__,
|
||||
}
|
||||
return checkpoint_dict
|
||||
|
||||
def save(
|
||||
self,
|
||||
path: Union[str, Path],
|
||||
model_name: str,
|
||||
file_extension: Optional[str] = FN_MODEL_FILE_EXT,
|
||||
dict_key_prefix: str = "model_",
|
||||
) -> None:
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
save_path = path / (model_name + file_extension)
|
||||
torch.save(self.get_checkpoint_data(dict_key_prefix), save_path)
|
||||
|
||||
@staticmethod
|
||||
def model_save_name(
|
||||
idx: int, specifier: str = "epoch", num_digits: int = -1
|
||||
) -> str:
|
||||
"""Get a consistnet the model save name.
|
||||
|
||||
Args:
|
||||
epoch (int): Epoch / iteration number.
|
||||
specifier (str, optional): A specifier for the idx. Defaults to epoch.
|
||||
num_digits (int, optional): The number of digits in the save name. Unused by default,
|
||||
since this causes overrides when we have an overflow. Defaults to -1.
|
||||
|
||||
Returns:
|
||||
str: Model save name.
|
||||
"""
|
||||
if num_digits == -1:
|
||||
return f"{FN_MODEL_PREFIX}{specifier}_{idx}"
|
||||
else:
|
||||
return f"{FN_MODEL_PREFIX}{specifier}_{idx:0{num_digits}}"
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
cls,
|
||||
path: Union[str, Path],
|
||||
model_name: str = None,
|
||||
file_extension: Optional[str] = ".p",
|
||||
device: Union[torch.device, str, int] = "auto",
|
||||
dict_key_prefix: str = "model_",
|
||||
) -> "BaseModel":
|
||||
device = get_device(device)
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
if model_name is None:
|
||||
save_path = path
|
||||
else:
|
||||
save_path = path / (model_name + file_extension)
|
||||
checkpoint = torch.load(save_path, map_location=device)
|
||||
|
||||
return cls.params_from_checkpoint(
|
||||
checkpoint=checkpoint, dict_key_prefix=dict_key_prefix
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def params_from_checkpoint(
|
||||
cls, checkpoint: Dict[str, Any], dict_key_prefix: str = "model_"
|
||||
) -> "BaseModel":
|
||||
if hasattr(cls, "config_class"):
|
||||
from dacite import from_dict
|
||||
|
||||
config_cls = cls.config_class
|
||||
|
||||
model_cfg = from_dict(
|
||||
data_class=config_cls,
|
||||
data=checkpoint[f"{dict_key_prefix}data"],
|
||||
)
|
||||
model = cls(config=model_cfg)
|
||||
else:
|
||||
model = cls(**checkpoint[f"{dict_key_prefix}data"])
|
||||
model.load_state_dict(checkpoint[f"{dict_key_prefix}state_dict"])
|
||||
return model
|
||||
|
||||
# @staticmethod
|
||||
# def class_and_params_from_checkpoint(checkpoint: Dict[str, Any], dict_key_prefix: str = 'model_') -> 'BaseModel':
|
||||
# from . import get_model_class
|
||||
# model_class = get_model_class(checkpoint[f"{dict_key_prefix}name"])
|
||||
# model = model_class.params_from_checkpoint(checkpoint=checkpoint, dict_key_prefix=dict_key_prefix)
|
||||
# return model
|
||||
Reference in New Issue
Block a user