74 lines
2.1 KiB
Python
74 lines
2.1 KiB
Python
import logging
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, List
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from .interfaces import (
|
|
ModelTrainInterface,
|
|
SequenceConfigInterface,
|
|
SequenceInterface,
|
|
)
|
|
from .base_model import BaseModel
|
|
|
|
LOGGER = logging.getLogger(__name__)
|
|
|
|
|
|
class LayerConfigInterface(ABC):
|
|
@abstractmethod
|
|
def assign_model_config_params(self, model_config):
|
|
pass
|
|
|
|
|
|
@dataclass
|
|
class OptimizerConfig:
|
|
lr: float = 1e-3
|
|
betas: List = field(default_factory=lambda: [0.9, 0.95])
|
|
weight_decay: float = 0.0
|
|
device_type: str = "cuda" # TODO is this necessary?
|
|
fused: bool = True
|
|
eps: float = 1e-6
|
|
|
|
|
|
@dataclass
|
|
class BaseSequenceModelConfig(SequenceConfigInterface):
|
|
optimizer: OptimizerConfig = field(default_factory=OptimizerConfig)
|
|
|
|
shortname: str = "" # needed to give a model a more distinctive name used in configurations etc., temporary filled by hydra or OmegaConf
|
|
|
|
|
|
class ResettableParametersModule(nn.Module, ABC):
|
|
def __init__(self, **kwargs):
|
|
super().__init__(**kwargs)
|
|
|
|
@abstractmethod
|
|
def reset_parameters(self, **kwargs):
|
|
pass
|
|
|
|
|
|
class BaseSequenceModelTrain(
|
|
BaseModel, ModelTrainInterface, SequenceInterface
|
|
):
|
|
def __init__(self, **kwargs):
|
|
super().__init__(**kwargs)
|
|
|
|
@abstractmethod
|
|
def _create_optim_groups(self, **kwargs) -> list[dict[str, Any]]:
|
|
# TODO think of a nice way to separate functionality from child classes
|
|
# TODO move this into BaseModel and make it a separate interface too
|
|
pass
|
|
|
|
def configure_optimizer(self) -> torch.optim.Optimizer:
|
|
optim_cfg = self.config.optimizer
|
|
optim_groups = self._create_optim_groups(self.config)
|
|
use_fused = optim_cfg.device_type == "cuda" and optim_cfg.fused
|
|
LOGGER.info(f"Using fused optimizer: {use_fused}")
|
|
extra_args = dict(fused=True) if use_fused else dict()
|
|
extra_args["eps"] = optim_cfg.eps
|
|
optimizer = torch.optim.AdamW(
|
|
optim_groups, lr=optim_cfg.lr, betas=optim_cfg.betas, **extra_args
|
|
)
|
|
return optimizer
|