Add safari diff.

This commit is contained in:
2023-08-05 17:37:52 +02:00
parent 7b15a413d4
commit c3c7f33cac

View File

@@ -0,0 +1,359 @@
diff --git a/src/models/sequence/ssm/ss_kernel_diag.py b/src/models/sequence/ssm/ss_kernel_diag.py
index 6cb279b..554725c 100644
--- a/src/models/sequence/ssm/ss_kernel_diag.py
+++ b/src/models/sequence/ssm/ss_kernel_diag.py
@@ -20,15 +20,20 @@ log = get_logger(__name__)
# This could be None if the CUDA import fails
from src.ops.vandermonde import log_vandermonde_fast
+
try:
import pykeops
from src.ops.vandermonde import log_vandermonde, log_vandermonde_transpose
+
has_pykeops = True
log.info("Pykeops installation found.")
except ImportError:
has_pykeops = False
from src.ops.vandermonde import log_vandermonde_naive as log_vandermonde
- from src.ops.vandermonde import log_vandermonde_transpose_naive as log_vandermonde_transpose
+ from src.ops.vandermonde import (
+ log_vandermonde_transpose_naive as log_vandermonde_transpose,
+ )
+
log.warning(
"Falling back on slow Vandermonde kernel. Install pykeops for improved memory efficiency."
)
@@ -37,7 +42,7 @@ except ImportError:
_c2r = torch.view_as_real
_r2c = torch.view_as_complex
-if tuple(map(int, torch.__version__.split('.')[:2])) >= (1, 10):
+if tuple(map(int, torch.__version__.split(".")[:2])) >= (1, 10):
_resolve_conj = lambda x: x.conj().resolve_conj()
else:
_resolve_conj = lambda x: x.conj()
@@ -48,15 +53,18 @@ class SSKernelDiag(OptimModule):
def __init__(
self,
- A, B, C, log_dt,
+ A,
+ B,
+ C,
+ log_dt,
L=None,
- disc='bilinear',
- real_type='exp',
+ disc="bilinear",
+ real_type="exp",
lr=None,
bandlimit=None,
force_real=False,
+ **kwargs,
):
-
super().__init__()
self.L = L
self.disc = disc
@@ -68,7 +76,7 @@ class SSKernelDiag(OptimModule):
assert A.size(-1) == C.size(-1)
self.H = log_dt.size(-1)
self.N = A.size(-1)
- assert A.size(-2) == B.size(-2) # Number of independent SSMs trained
+ assert A.size(-2) == B.size(-2) # Number of independent SSMs trained
assert self.H % A.size(-2) == 0
self.n_ssm = A.size(-2)
self.repeat = self.H // A.size(0)
@@ -77,42 +85,48 @@ class SSKernelDiag(OptimModule):
self.C = nn.Parameter(_c2r(_resolve_conj(C)))
# Register parameters
- if lr is None or isinstance(lr, float): lr_dict = {}
- else: lr_dict, lr = lr, None
+ if lr is None or isinstance(lr, float):
+ lr_dict = {}
+ else:
+ lr_dict, lr = lr, None
- self.register("log_dt", log_dt, lr_dict.get('dt', lr))
- self.register("B", _c2r(B), lr_dict.get('B', lr))
- self.register("inv_A_real", self._A_init(A.real), lr_dict.get('A', lr))
- self.register("A_imag", A.imag, lr_dict.get('A', lr))
+ self.register("log_dt", log_dt, lr_dict.get("dt", lr))
+ self.register("B", _c2r(B), lr_dict.get("B", lr))
+ self.register("inv_A_real", self._A_init(A.real), lr_dict.get("A", lr))
+ self.register("A_imag", A.imag, lr_dict.get("A", lr))
def _A_init(self, A_real):
A_real = torch.clamp(A_real, max=-1e-4)
- if self.real_type == 'none':
+ if self.real_type == "none":
return -A_real
- elif self.real_type == 'exp':
- return torch.log(-A_real) # Some of the HiPPO methods have real part 0
- elif self.real_type == 'relu':
+ elif self.real_type == "exp":
+ return torch.log(
+ -A_real
+ ) # Some of the HiPPO methods have real part 0
+ elif self.real_type == "relu":
return -A_real
- elif self.real_type == 'sigmoid':
+ elif self.real_type == "sigmoid":
return torch.logit(-A_real)
- elif self.real_type == 'softplus':
- return torch.log(torch.exp(-A_real)-1)
- else: raise NotImplementedError
+ elif self.real_type == "softplus":
+ return torch.log(torch.exp(-A_real) - 1)
+ else:
+ raise NotImplementedError
def _A(self):
# Get the internal A (diagonal) parameter
- if self.real_type == 'none':
+ if self.real_type == "none":
A_real = -self.inv_A_real
- elif self.real_type == 'exp':
+ elif self.real_type == "exp":
A_real = -torch.exp(self.inv_A_real)
- elif self.real_type == 'relu':
+ elif self.real_type == "relu":
# JAX version seems to NaN if you alloA 0's, although this code Aas fine Aithout it
- A_real = -F.relu(self.inv_A_real)-1e-4
- elif self.real_type == 'sigmoid':
+ A_real = -F.relu(self.inv_A_real) - 1e-4
+ elif self.real_type == "sigmoid":
A_real = -F.sigmoid(self.inv_A_real)
- elif self.real_type == 'softplus':
+ elif self.real_type == "softplus":
A_real = -F.softplus(self.inv_A_real)
- else: raise NotImplementedError
+ else:
+ raise NotImplementedError
A = A_real + 1j * self.A_imag
return A
@@ -126,121 +140,134 @@ class SSKernelDiag(OptimModule):
(B, H, L) output from initial state
"""
- dt = torch.exp(self.log_dt) * rate # (H)
- C = _r2c(self.C) # (C H N)
- A = self._A() # (H N)
+ dt = torch.exp(self.log_dt) * rate # (H)
+ C = _r2c(self.C) # (C H N)
+ A = self._A() # (H N)
B = _r2c(self.B)
- B = repeat(B, 't n -> 1 (v t) n', v=self.repeat)
+ B = repeat(B, "t n -> 1 (v t) n", v=self.repeat)
# Force A to be real valued, so the whole kernel can be interpreted as a "multi-head EMA"
if self.force_real:
A = A.real + 0j
if self.bandlimit is not None:
- freqs = dt[:, None] / rate * A.imag.abs() / (2*math.pi) # (H, N)
- mask = torch.where(freqs < self.bandlimit * .5, 1, 0)
+ freqs = dt[:, None] / rate * A.imag.abs() / (2 * math.pi) # (H, N)
+ mask = torch.where(freqs < self.bandlimit * 0.5, 1, 0)
C = C * mask
# Incorporate dt into A
- A = repeat(A, 't n -> (v t) n', v=self.repeat)
+ A = repeat(A, "t n -> (v t) n", v=self.repeat)
dtA = A * dt.unsqueeze(-1) # (H N)
-
# Augment B with state
if state is not None:
s = state / dt.unsqueeze(-1)
- if self.disc == 'bilinear':
- s = s * (1. + dtA/2)
- elif self.disc == 'zoh':
- s = s * dtA * dtA.exp() / (dtA.exp() - 1.)
- B = torch.cat([s, B], dim=-3) # (1+B H N)
+ if self.disc == "bilinear":
+ s = s * (1.0 + dtA / 2)
+ elif self.disc == "zoh":
+ s = s * dtA * dtA.exp() / (dtA.exp() - 1.0)
+ B = torch.cat([s, B], dim=-3) # (1+B H N)
C = (B[:, None, :, :] * C).view(-1, self.H, self.N)
- if self.disc == 'zoh':
+ if self.disc == "zoh":
# Power up
- C = C * (torch.exp(dtA)-1.) / A
+ C = C * (torch.exp(dtA) - 1.0) / A
# TODO (TD): make it work for C.shape[0] > 1
if log_vandermonde_fast is not None and C.shape[0] == 1:
- K = log_vandermonde_fast(C.squeeze(0), dtA, L).unsqueeze(0) # (H L)
+ K = log_vandermonde_fast(C.squeeze(0), dtA, L).unsqueeze(
+ 0
+ ) # (H L)
else:
- K = log_vandermonde(C, dtA, L) # (H L)
- elif self.disc == 'bilinear':
- C = C * (1. - dtA/2).reciprocal() * dt.unsqueeze(-1) # or * dtA / A
- dA = (1. + dtA/2) / (1. - dtA/2)
+ K = log_vandermonde(C, dtA, L) # (H L)
+ elif self.disc == "bilinear":
+ C = (
+ C * (1.0 - dtA / 2).reciprocal() * dt.unsqueeze(-1)
+ ) # or * dtA / A
+ dA = (1.0 + dtA / 2) / (1.0 - dtA / 2)
if log_vandermonde_fast is not None:
- dA_log = repeat(dA.log(), 'h d -> (c h) d', c=C.shape[0])
- K = rearrange(log_vandermonde_fast(rearrange(C, 'c h d -> (c h) d'), dA_log, L),
- '(c h) d -> c h d', c=C.shape[0])
+ dA_log = repeat(dA.log(), "h d -> (c h) d", c=C.shape[0])
+ K = rearrange(
+ log_vandermonde_fast(
+ rearrange(C, "c h d -> (c h) d"), dA_log, L
+ ),
+ "(c h) d -> c h d",
+ c=C.shape[0],
+ )
else:
K = log_vandermonde(C, dA.log(), L)
- elif self.disc == 'dss':
+ elif self.disc == "dss":
# Implementation from DSS meant for case when real eigenvalues can be positive
- P = dtA.unsqueeze(-1) * torch.arange(L, device=C.device) # [H N L]
- A_gt_0 = A.real > 0 # [N]
+ P = dtA.unsqueeze(-1) * torch.arange(L, device=C.device) # [H N L]
+ A_gt_0 = A.real > 0 # [N]
if A_gt_0.any():
with torch.no_grad():
- P_max = dtA * (A_gt_0 * (L-1)) # [H N]
- P = P - P_max.unsqueeze(-1) # [H N L]
- S = P.exp() # [H N L]
+ P_max = dtA * (A_gt_0 * (L - 1)) # [H N]
+ P = P - P_max.unsqueeze(-1) # [H N L]
+ S = P.exp() # [H N L]
- dtA_neg = dtA * (1 - 2*A_gt_0) # [H N]
- num = dtA_neg.exp() - 1 # [H N]
- den = (dtA_neg * L).exp() - 1 # [H N]
+ dtA_neg = dtA * (1 - 2 * A_gt_0) # [H N]
+ num = dtA_neg.exp() - 1 # [H N]
+ den = (dtA_neg * L).exp() - 1 # [H N]
# Inline reciprocal function for DSS logic
x = den * A
x_conj = _resolve_conj(x)
- r = x_conj / (x*x_conj + 1e-7)
+ r = x_conj / (x * x_conj + 1e-7)
- C = C * num * r # [C H N]
- K = contract('chn,hnl->chl', C, S).float()
- else: assert False, f"{self.disc} not supported"
+ C = C * num * r # [C H N]
+ K = contract("chn,hnl->chl", C, S).float()
+ else:
+ assert False, f"{self.disc} not supported"
- K = K.view(-1, self.channels, self.H, L) # (1+B C H L)
+ K = K.view(-1, self.channels, self.H, L) # (1+B C H L)
if state is not None:
- K_state = K[:-1, :, :, :] # (B C H L)
+ K_state = K[:-1, :, :, :] # (B C H L)
else:
K_state = None
- K = K[-1, :, :, :] # (C H L)
+ K = K[-1, :, :, :] # (C H L)
return K, K_state
def _setup_step(self):
# These methods are organized like this to be compatible with the NPLR kernel interface
- dt = torch.exp(self.log_dt) # (H)
- B = _r2c(self.B) # (H N)
- C = _r2c(self.C) # (C H N)
+ dt = torch.exp(self.log_dt) # (H)
+ B = _r2c(self.B) # (H N)
+ C = _r2c(self.C) # (C H N)
self.dC = C
- A = self._A() # (H N)
+ A = self._A() # (H N)
- A = repeat(A, 't n -> (v t) n', v=self.repeat)
- B = repeat(B, 't n -> (v t) n', v=self.repeat)
+ A = repeat(A, "t n -> (v t) n", v=self.repeat)
+ B = repeat(B, "t n -> (v t) n", v=self.repeat)
# Incorporate dt into A
dtA = A * dt.unsqueeze(-1) # (H N)
- if self.disc == 'zoh':
- self.dA = torch.exp(dtA) # (H N)
- self.dB = B * (torch.exp(dtA)-1.) / A # (C H N)
- elif self.disc == 'bilinear':
- self.dA = (1. + dtA/2) / (1. - dtA/2)
- self.dB = B * (1. - dtA/2).reciprocal() * dt.unsqueeze(-1) # or * dtA / A
-
+ if self.disc == "zoh":
+ self.dA = torch.exp(dtA) # (H N)
+ self.dB = B * (torch.exp(dtA) - 1.0) / A # (C H N)
+ elif self.disc == "bilinear":
+ self.dA = (1.0 + dtA / 2) / (1.0 - dtA / 2)
+ self.dB = (
+ B * (1.0 - dtA / 2).reciprocal() * dt.unsqueeze(-1)
+ ) # or * dtA / A
def default_state(self, *batch_shape):
C = _r2c(self.C)
- state = torch.zeros(*batch_shape, self.H, self.N, dtype=C.dtype, device=C.device)
+ state = torch.zeros(
+ *batch_shape, self.H, self.N, dtype=C.dtype, device=C.device
+ )
return state
def step(self, u, state):
- next_state = contract("h n, b h n -> b h n", self.dA, state) \
- + contract("h n, b h -> b h n", self.dB, u)
+ next_state = contract(
+ "h n, b h n -> b h n", self.dA, state
+ ) + contract("h n, b h -> b h n", self.dB, u)
y = contract("c h n, b h n -> b c h", self.dC, next_state)
- return 2*y.real, next_state
+ return 2 * y.real, next_state
def forward_state(self, u, state):
self._setup_step()
AL = self.dA ** u.size(-1)
- u = u.flip(-1).to(self.dA).contiguous() # (B H L)
+ u = u.flip(-1).to(self.dA).contiguous() # (B H L)
v = log_vandermonde_transpose(u, self.dB, self.dA.log(), u.size(-1))
next_state = AL * state + v
return next_state
@@ -307,7 +334,7 @@ class EMAKernel(OptimModule):
nn.init.normal_(self.gamma, mean=0.0, std=1.0)
# nn.init.normal_(self.omega, mean=0.0, std=1.0)
- def coeffs(self): # Same as discretize
+ def coeffs(self): # Same as discretize
p = torch.sigmoid(self.delta) # (H N 1)
alpha = torch.sigmoid(self.alpha)
q = 1.0 - p * alpha
@@ -319,12 +346,16 @@ class EMAKernel(OptimModule):
vander = torch.arange(L).to(p).view(1, 1, L) * torch.log(q) # (H N L)
kernel = (p * self.beta) * torch.exp(vander)
if self.efficient_bidirectional:
- C = rearrange(self.gamma * self.scale, '(c h) n -> c h n', c=self.channels)
- kernel = torch.einsum('dnl,cdn->cdl', kernel, C)
+ C = rearrange(
+ self.gamma * self.scale, "(c h) n -> c h n", c=self.channels
+ )
+ kernel = torch.einsum("dnl,cdn->cdl", kernel, C)
# kernel = rearrange(kernel, 'c d l -> (c d) l')
else:
- kernel = torch.einsum('dnl,dn->dl', kernel, self.gamma * self.scale)
- kernel = rearrange(kernel, '(c h) l -> c h l', c=self.channels)
+ kernel = torch.einsum(
+ "dnl,dn->dl", kernel, self.gamma * self.scale
+ )
+ kernel = rearrange(kernel, "(c h) l -> c h l", c=self.channels)
kernel = kernel[..., :L]
# kernel = rearrange(kernel, '(c h) l -> c h l', c=self.channels)