Files
2023-08-05 17:35:11 +02:00

74 lines
2.1 KiB
Python

import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, List
import torch
from torch import nn
from .interfaces import (
ModelTrainInterface,
SequenceConfigInterface,
SequenceInterface,
)
from .base_model import BaseModel
LOGGER = logging.getLogger(__name__)
class LayerConfigInterface(ABC):
@abstractmethod
def assign_model_config_params(self, model_config):
pass
@dataclass
class OptimizerConfig:
lr: float = 1e-3
betas: List = field(default_factory=lambda: [0.9, 0.95])
weight_decay: float = 0.0
device_type: str = "cuda" # TODO is this necessary?
fused: bool = True
eps: float = 1e-6
@dataclass
class BaseSequenceModelConfig(SequenceConfigInterface):
optimizer: OptimizerConfig = field(default_factory=OptimizerConfig)
shortname: str = "" # needed to give a model a more distinctive name used in configurations etc., temporary filled by hydra or OmegaConf
class ResettableParametersModule(nn.Module, ABC):
def __init__(self, **kwargs):
super().__init__(**kwargs)
@abstractmethod
def reset_parameters(self, **kwargs):
pass
class BaseSequenceModelTrain(
BaseModel, ModelTrainInterface, SequenceInterface
):
def __init__(self, **kwargs):
super().__init__(**kwargs)
@abstractmethod
def _create_optim_groups(self, **kwargs) -> list[dict[str, Any]]:
# TODO think of a nice way to separate functionality from child classes
# TODO move this into BaseModel and make it a separate interface too
pass
def configure_optimizer(self) -> torch.optim.Optimizer:
optim_cfg = self.config.optimizer
optim_groups = self._create_optim_groups(self.config)
use_fused = optim_cfg.device_type == "cuda" and optim_cfg.fused
LOGGER.info(f"Using fused optimizer: {use_fused}")
extra_args = dict(fused=True) if use_fused else dict()
extra_args["eps"] = optim_cfg.eps
optimizer = torch.optim.AdamW(
optim_groups, lr=optim_cfg.lr, betas=optim_cfg.betas, **extra_args
)
return optimizer