import logging from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import Any, List import torch from torch import nn from .interfaces import ( ModelTrainInterface, SequenceConfigInterface, SequenceInterface, ) from .base_model import BaseModel LOGGER = logging.getLogger(__name__) class LayerConfigInterface(ABC): @abstractmethod def assign_model_config_params(self, model_config): pass @dataclass class OptimizerConfig: lr: float = 1e-3 betas: List = field(default_factory=lambda: [0.9, 0.95]) weight_decay: float = 0.0 device_type: str = "cuda" # TODO is this necessary? fused: bool = True eps: float = 1e-6 @dataclass class BaseSequenceModelConfig(SequenceConfigInterface): optimizer: OptimizerConfig = field(default_factory=OptimizerConfig) shortname: str = "" # needed to give a model a more distinctive name used in configurations etc., temporary filled by hydra or OmegaConf class ResettableParametersModule(nn.Module, ABC): def __init__(self, **kwargs): super().__init__(**kwargs) @abstractmethod def reset_parameters(self, **kwargs): pass class BaseSequenceModelTrain( BaseModel, ModelTrainInterface, SequenceInterface ): def __init__(self, **kwargs): super().__init__(**kwargs) @abstractmethod def _create_optim_groups(self, **kwargs) -> list[dict[str, Any]]: # TODO think of a nice way to separate functionality from child classes # TODO move this into BaseModel and make it a separate interface too pass def configure_optimizer(self) -> torch.optim.Optimizer: optim_cfg = self.config.optimizer optim_groups = self._create_optim_groups(self.config) use_fused = optim_cfg.device_type == "cuda" and optim_cfg.fused LOGGER.info(f"Using fused optimizer: {use_fused}") extra_args = dict(fused=True) if use_fused else dict() extra_args["eps"] = optim_cfg.eps optimizer = torch.optim.AdamW( optim_groups, lr=optim_cfg.lr, betas=optim_cfg.betas, **extra_args ) return optimizer