Add RWKV, H3, Hyena
This commit is contained in:
77
models/interfaces.py
Normal file
77
models/interfaces.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional, Protocol, Sequence
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class ModelTrainInterface(ABC):
|
||||
def configure_optimizer(self) -> Optional[torch.optim.Optimizer]:
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def get_loss_func(self, **kwargs) -> Callable[[Any], torch.Tensor]:
|
||||
pass
|
||||
|
||||
|
||||
class SequenceInterface(ABC):
|
||||
"""This is a generic interface for a sequence.
|
||||
In our case a sequence also includes its label. Therefore, the label (aka output_dim)
|
||||
is also part of the sequence interface.
|
||||
A sequence always has a length (=context_length).
|
||||
|
||||
A sequence can have one of the following flavors:
|
||||
Input sequence:
|
||||
- sequence of tokens (e.g. words): vocab_size must be specified
|
||||
- sequence of vectors: input_dim must be specified
|
||||
|
||||
Output sequence:
|
||||
- next token (e.g. word): `vocab_size` must be specified (e.g. Causal Language Modeling).
|
||||
- (sequence of) vectors: output_dim must be specified (e.g. Forecasting)
|
||||
- label: output_dim must be specified (e.g. Sequence Classification)
|
||||
|
||||
Examples:
|
||||
- Causal Language Modeling: input_dim = None, output_dim = None, vocab_size = int
|
||||
- Forecasting: input_dim = int, output_dim = int, vocab_size = None
|
||||
- Sequence Classification (General Sequence): input_dim = int, output_dim = int, vocab_size = None
|
||||
- Sequence Classification (Text): input_dim = None, output_dim = int, vocab_size = int
|
||||
"""
|
||||
|
||||
@property
|
||||
def input_dim(self) -> Optional[Sequence[int]]:
|
||||
return None
|
||||
|
||||
@property
|
||||
def output_dim(self) -> Optional[Sequence[int]]:
|
||||
return None
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> Optional[int]:
|
||||
return None
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def context_length(self) -> int:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class SequenceConfigInterface:
|
||||
context_length: int
|
||||
# vocab_size: Optional[int] = None
|
||||
input_dim: Optional[Sequence[int]] = None
|
||||
output_dim: Optional[Sequence[int]] = None
|
||||
|
||||
|
||||
class Tokenizer(Protocol):
|
||||
def __call__(self, **kwargs) -> Any:
|
||||
...
|
||||
|
||||
def __len__(self) -> int:
|
||||
...
|
||||
|
||||
|
||||
class TokenizerInterface(ABC):
|
||||
@property
|
||||
def tokenizer(self) -> Optional[Tokenizer]:
|
||||
return None
|
||||
Reference in New Issue
Block a user