Add RWKV, H3, Hyena
This commit is contained in:
73
models/base.py
Normal file
73
models/base.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, List
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .interfaces import (
|
||||
ModelTrainInterface,
|
||||
SequenceConfigInterface,
|
||||
SequenceInterface,
|
||||
)
|
||||
from .base_model import BaseModel
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LayerConfigInterface(ABC):
|
||||
@abstractmethod
|
||||
def assign_model_config_params(self, model_config):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class OptimizerConfig:
|
||||
lr: float = 1e-3
|
||||
betas: List = field(default_factory=lambda: [0.9, 0.95])
|
||||
weight_decay: float = 0.0
|
||||
device_type: str = "cuda" # TODO is this necessary?
|
||||
fused: bool = True
|
||||
eps: float = 1e-6
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseSequenceModelConfig(SequenceConfigInterface):
|
||||
optimizer: OptimizerConfig = field(default_factory=OptimizerConfig)
|
||||
|
||||
shortname: str = "" # needed to give a model a more distinctive name used in configurations etc., temporary filled by hydra or OmegaConf
|
||||
|
||||
|
||||
class ResettableParametersModule(nn.Module, ABC):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def reset_parameters(self, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class BaseSequenceModelTrain(
|
||||
BaseModel, ModelTrainInterface, SequenceInterface
|
||||
):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def _create_optim_groups(self, **kwargs) -> list[dict[str, Any]]:
|
||||
# TODO think of a nice way to separate functionality from child classes
|
||||
# TODO move this into BaseModel and make it a separate interface too
|
||||
pass
|
||||
|
||||
def configure_optimizer(self) -> torch.optim.Optimizer:
|
||||
optim_cfg = self.config.optimizer
|
||||
optim_groups = self._create_optim_groups(self.config)
|
||||
use_fused = optim_cfg.device_type == "cuda" and optim_cfg.fused
|
||||
LOGGER.info(f"Using fused optimizer: {use_fused}")
|
||||
extra_args = dict(fused=True) if use_fused else dict()
|
||||
extra_args["eps"] = optim_cfg.eps
|
||||
optimizer = torch.optim.AdamW(
|
||||
optim_groups, lr=optim_cfg.lr, betas=optim_cfg.betas, **extra_args
|
||||
)
|
||||
return optimizer
|
||||
Reference in New Issue
Block a user