Files
state_interpretation/models/interfaces.py
2023-08-05 17:35:11 +02:00

78 lines
2.2 KiB
Python

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Callable, Optional, Protocol, Sequence
import torch
class ModelTrainInterface(ABC):
def configure_optimizer(self) -> Optional[torch.optim.Optimizer]:
return None
@abstractmethod
def get_loss_func(self, **kwargs) -> Callable[[Any], torch.Tensor]:
pass
class SequenceInterface(ABC):
"""This is a generic interface for a sequence.
In our case a sequence also includes its label. Therefore, the label (aka output_dim)
is also part of the sequence interface.
A sequence always has a length (=context_length).
A sequence can have one of the following flavors:
Input sequence:
- sequence of tokens (e.g. words): vocab_size must be specified
- sequence of vectors: input_dim must be specified
Output sequence:
- next token (e.g. word): `vocab_size` must be specified (e.g. Causal Language Modeling).
- (sequence of) vectors: output_dim must be specified (e.g. Forecasting)
- label: output_dim must be specified (e.g. Sequence Classification)
Examples:
- Causal Language Modeling: input_dim = None, output_dim = None, vocab_size = int
- Forecasting: input_dim = int, output_dim = int, vocab_size = None
- Sequence Classification (General Sequence): input_dim = int, output_dim = int, vocab_size = None
- Sequence Classification (Text): input_dim = None, output_dim = int, vocab_size = int
"""
@property
def input_dim(self) -> Optional[Sequence[int]]:
return None
@property
def output_dim(self) -> Optional[Sequence[int]]:
return None
@property
def vocab_size(self) -> Optional[int]:
return None
@property
@abstractmethod
def context_length(self) -> int:
pass
@dataclass
class SequenceConfigInterface:
context_length: int
# vocab_size: Optional[int] = None
input_dim: Optional[Sequence[int]] = None
output_dim: Optional[Sequence[int]] = None
class Tokenizer(Protocol):
def __call__(self, **kwargs) -> Any:
...
def __len__(self) -> int:
...
class TokenizerInterface(ABC):
@property
def tokenizer(self) -> Optional[Tokenizer]:
return None