78 lines
2.2 KiB
Python
78 lines
2.2 KiB
Python
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass
|
|
from typing import Any, Callable, Optional, Protocol, Sequence
|
|
|
|
import torch
|
|
|
|
|
|
class ModelTrainInterface(ABC):
|
|
def configure_optimizer(self) -> Optional[torch.optim.Optimizer]:
|
|
return None
|
|
|
|
@abstractmethod
|
|
def get_loss_func(self, **kwargs) -> Callable[[Any], torch.Tensor]:
|
|
pass
|
|
|
|
|
|
class SequenceInterface(ABC):
|
|
"""This is a generic interface for a sequence.
|
|
In our case a sequence also includes its label. Therefore, the label (aka output_dim)
|
|
is also part of the sequence interface.
|
|
A sequence always has a length (=context_length).
|
|
|
|
A sequence can have one of the following flavors:
|
|
Input sequence:
|
|
- sequence of tokens (e.g. words): vocab_size must be specified
|
|
- sequence of vectors: input_dim must be specified
|
|
|
|
Output sequence:
|
|
- next token (e.g. word): `vocab_size` must be specified (e.g. Causal Language Modeling).
|
|
- (sequence of) vectors: output_dim must be specified (e.g. Forecasting)
|
|
- label: output_dim must be specified (e.g. Sequence Classification)
|
|
|
|
Examples:
|
|
- Causal Language Modeling: input_dim = None, output_dim = None, vocab_size = int
|
|
- Forecasting: input_dim = int, output_dim = int, vocab_size = None
|
|
- Sequence Classification (General Sequence): input_dim = int, output_dim = int, vocab_size = None
|
|
- Sequence Classification (Text): input_dim = None, output_dim = int, vocab_size = int
|
|
"""
|
|
|
|
@property
|
|
def input_dim(self) -> Optional[Sequence[int]]:
|
|
return None
|
|
|
|
@property
|
|
def output_dim(self) -> Optional[Sequence[int]]:
|
|
return None
|
|
|
|
@property
|
|
def vocab_size(self) -> Optional[int]:
|
|
return None
|
|
|
|
@property
|
|
@abstractmethod
|
|
def context_length(self) -> int:
|
|
pass
|
|
|
|
|
|
@dataclass
|
|
class SequenceConfigInterface:
|
|
context_length: int
|
|
# vocab_size: Optional[int] = None
|
|
input_dim: Optional[Sequence[int]] = None
|
|
output_dim: Optional[Sequence[int]] = None
|
|
|
|
|
|
class Tokenizer(Protocol):
|
|
def __call__(self, **kwargs) -> Any:
|
|
...
|
|
|
|
def __len__(self) -> int:
|
|
...
|
|
|
|
|
|
class TokenizerInterface(ABC):
|
|
@property
|
|
def tokenizer(self) -> Optional[Tokenizer]:
|
|
return None
|