Add safari diff.
This commit is contained in:
359
safari_diffs/23a1d81b2fce85f12d78812c427c4e8e9cb0ac42.patch
Normal file
359
safari_diffs/23a1d81b2fce85f12d78812c427c4e8e9cb0ac42.patch
Normal file
@@ -0,0 +1,359 @@
|
||||
diff --git a/src/models/sequence/ssm/ss_kernel_diag.py b/src/models/sequence/ssm/ss_kernel_diag.py
|
||||
index 6cb279b..554725c 100644
|
||||
--- a/src/models/sequence/ssm/ss_kernel_diag.py
|
||||
+++ b/src/models/sequence/ssm/ss_kernel_diag.py
|
||||
@@ -20,15 +20,20 @@ log = get_logger(__name__)
|
||||
|
||||
# This could be None if the CUDA import fails
|
||||
from src.ops.vandermonde import log_vandermonde_fast
|
||||
+
|
||||
try:
|
||||
import pykeops
|
||||
from src.ops.vandermonde import log_vandermonde, log_vandermonde_transpose
|
||||
+
|
||||
has_pykeops = True
|
||||
log.info("Pykeops installation found.")
|
||||
except ImportError:
|
||||
has_pykeops = False
|
||||
from src.ops.vandermonde import log_vandermonde_naive as log_vandermonde
|
||||
- from src.ops.vandermonde import log_vandermonde_transpose_naive as log_vandermonde_transpose
|
||||
+ from src.ops.vandermonde import (
|
||||
+ log_vandermonde_transpose_naive as log_vandermonde_transpose,
|
||||
+ )
|
||||
+
|
||||
log.warning(
|
||||
"Falling back on slow Vandermonde kernel. Install pykeops for improved memory efficiency."
|
||||
)
|
||||
@@ -37,7 +42,7 @@ except ImportError:
|
||||
_c2r = torch.view_as_real
|
||||
_r2c = torch.view_as_complex
|
||||
|
||||
-if tuple(map(int, torch.__version__.split('.')[:2])) >= (1, 10):
|
||||
+if tuple(map(int, torch.__version__.split(".")[:2])) >= (1, 10):
|
||||
_resolve_conj = lambda x: x.conj().resolve_conj()
|
||||
else:
|
||||
_resolve_conj = lambda x: x.conj()
|
||||
@@ -48,15 +53,18 @@ class SSKernelDiag(OptimModule):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
- A, B, C, log_dt,
|
||||
+ A,
|
||||
+ B,
|
||||
+ C,
|
||||
+ log_dt,
|
||||
L=None,
|
||||
- disc='bilinear',
|
||||
- real_type='exp',
|
||||
+ disc="bilinear",
|
||||
+ real_type="exp",
|
||||
lr=None,
|
||||
bandlimit=None,
|
||||
force_real=False,
|
||||
+ **kwargs,
|
||||
):
|
||||
-
|
||||
super().__init__()
|
||||
self.L = L
|
||||
self.disc = disc
|
||||
@@ -68,7 +76,7 @@ class SSKernelDiag(OptimModule):
|
||||
assert A.size(-1) == C.size(-1)
|
||||
self.H = log_dt.size(-1)
|
||||
self.N = A.size(-1)
|
||||
- assert A.size(-2) == B.size(-2) # Number of independent SSMs trained
|
||||
+ assert A.size(-2) == B.size(-2) # Number of independent SSMs trained
|
||||
assert self.H % A.size(-2) == 0
|
||||
self.n_ssm = A.size(-2)
|
||||
self.repeat = self.H // A.size(0)
|
||||
@@ -77,42 +85,48 @@ class SSKernelDiag(OptimModule):
|
||||
self.C = nn.Parameter(_c2r(_resolve_conj(C)))
|
||||
|
||||
# Register parameters
|
||||
- if lr is None or isinstance(lr, float): lr_dict = {}
|
||||
- else: lr_dict, lr = lr, None
|
||||
+ if lr is None or isinstance(lr, float):
|
||||
+ lr_dict = {}
|
||||
+ else:
|
||||
+ lr_dict, lr = lr, None
|
||||
|
||||
- self.register("log_dt", log_dt, lr_dict.get('dt', lr))
|
||||
- self.register("B", _c2r(B), lr_dict.get('B', lr))
|
||||
- self.register("inv_A_real", self._A_init(A.real), lr_dict.get('A', lr))
|
||||
- self.register("A_imag", A.imag, lr_dict.get('A', lr))
|
||||
+ self.register("log_dt", log_dt, lr_dict.get("dt", lr))
|
||||
+ self.register("B", _c2r(B), lr_dict.get("B", lr))
|
||||
+ self.register("inv_A_real", self._A_init(A.real), lr_dict.get("A", lr))
|
||||
+ self.register("A_imag", A.imag, lr_dict.get("A", lr))
|
||||
|
||||
def _A_init(self, A_real):
|
||||
A_real = torch.clamp(A_real, max=-1e-4)
|
||||
- if self.real_type == 'none':
|
||||
+ if self.real_type == "none":
|
||||
return -A_real
|
||||
- elif self.real_type == 'exp':
|
||||
- return torch.log(-A_real) # Some of the HiPPO methods have real part 0
|
||||
- elif self.real_type == 'relu':
|
||||
+ elif self.real_type == "exp":
|
||||
+ return torch.log(
|
||||
+ -A_real
|
||||
+ ) # Some of the HiPPO methods have real part 0
|
||||
+ elif self.real_type == "relu":
|
||||
return -A_real
|
||||
- elif self.real_type == 'sigmoid':
|
||||
+ elif self.real_type == "sigmoid":
|
||||
return torch.logit(-A_real)
|
||||
- elif self.real_type == 'softplus':
|
||||
- return torch.log(torch.exp(-A_real)-1)
|
||||
- else: raise NotImplementedError
|
||||
+ elif self.real_type == "softplus":
|
||||
+ return torch.log(torch.exp(-A_real) - 1)
|
||||
+ else:
|
||||
+ raise NotImplementedError
|
||||
|
||||
def _A(self):
|
||||
# Get the internal A (diagonal) parameter
|
||||
- if self.real_type == 'none':
|
||||
+ if self.real_type == "none":
|
||||
A_real = -self.inv_A_real
|
||||
- elif self.real_type == 'exp':
|
||||
+ elif self.real_type == "exp":
|
||||
A_real = -torch.exp(self.inv_A_real)
|
||||
- elif self.real_type == 'relu':
|
||||
+ elif self.real_type == "relu":
|
||||
# JAX version seems to NaN if you alloA 0's, although this code Aas fine Aithout it
|
||||
- A_real = -F.relu(self.inv_A_real)-1e-4
|
||||
- elif self.real_type == 'sigmoid':
|
||||
+ A_real = -F.relu(self.inv_A_real) - 1e-4
|
||||
+ elif self.real_type == "sigmoid":
|
||||
A_real = -F.sigmoid(self.inv_A_real)
|
||||
- elif self.real_type == 'softplus':
|
||||
+ elif self.real_type == "softplus":
|
||||
A_real = -F.softplus(self.inv_A_real)
|
||||
- else: raise NotImplementedError
|
||||
+ else:
|
||||
+ raise NotImplementedError
|
||||
A = A_real + 1j * self.A_imag
|
||||
return A
|
||||
|
||||
@@ -126,121 +140,134 @@ class SSKernelDiag(OptimModule):
|
||||
(B, H, L) output from initial state
|
||||
"""
|
||||
|
||||
- dt = torch.exp(self.log_dt) * rate # (H)
|
||||
- C = _r2c(self.C) # (C H N)
|
||||
- A = self._A() # (H N)
|
||||
+ dt = torch.exp(self.log_dt) * rate # (H)
|
||||
+ C = _r2c(self.C) # (C H N)
|
||||
+ A = self._A() # (H N)
|
||||
|
||||
B = _r2c(self.B)
|
||||
- B = repeat(B, 't n -> 1 (v t) n', v=self.repeat)
|
||||
+ B = repeat(B, "t n -> 1 (v t) n", v=self.repeat)
|
||||
|
||||
# Force A to be real valued, so the whole kernel can be interpreted as a "multi-head EMA"
|
||||
if self.force_real:
|
||||
A = A.real + 0j
|
||||
|
||||
if self.bandlimit is not None:
|
||||
- freqs = dt[:, None] / rate * A.imag.abs() / (2*math.pi) # (H, N)
|
||||
- mask = torch.where(freqs < self.bandlimit * .5, 1, 0)
|
||||
+ freqs = dt[:, None] / rate * A.imag.abs() / (2 * math.pi) # (H, N)
|
||||
+ mask = torch.where(freqs < self.bandlimit * 0.5, 1, 0)
|
||||
C = C * mask
|
||||
|
||||
# Incorporate dt into A
|
||||
- A = repeat(A, 't n -> (v t) n', v=self.repeat)
|
||||
+ A = repeat(A, "t n -> (v t) n", v=self.repeat)
|
||||
dtA = A * dt.unsqueeze(-1) # (H N)
|
||||
|
||||
-
|
||||
# Augment B with state
|
||||
if state is not None:
|
||||
s = state / dt.unsqueeze(-1)
|
||||
- if self.disc == 'bilinear':
|
||||
- s = s * (1. + dtA/2)
|
||||
- elif self.disc == 'zoh':
|
||||
- s = s * dtA * dtA.exp() / (dtA.exp() - 1.)
|
||||
- B = torch.cat([s, B], dim=-3) # (1+B H N)
|
||||
+ if self.disc == "bilinear":
|
||||
+ s = s * (1.0 + dtA / 2)
|
||||
+ elif self.disc == "zoh":
|
||||
+ s = s * dtA * dtA.exp() / (dtA.exp() - 1.0)
|
||||
+ B = torch.cat([s, B], dim=-3) # (1+B H N)
|
||||
|
||||
C = (B[:, None, :, :] * C).view(-1, self.H, self.N)
|
||||
- if self.disc == 'zoh':
|
||||
+ if self.disc == "zoh":
|
||||
# Power up
|
||||
- C = C * (torch.exp(dtA)-1.) / A
|
||||
+ C = C * (torch.exp(dtA) - 1.0) / A
|
||||
# TODO (TD): make it work for C.shape[0] > 1
|
||||
if log_vandermonde_fast is not None and C.shape[0] == 1:
|
||||
- K = log_vandermonde_fast(C.squeeze(0), dtA, L).unsqueeze(0) # (H L)
|
||||
+ K = log_vandermonde_fast(C.squeeze(0), dtA, L).unsqueeze(
|
||||
+ 0
|
||||
+ ) # (H L)
|
||||
else:
|
||||
- K = log_vandermonde(C, dtA, L) # (H L)
|
||||
- elif self.disc == 'bilinear':
|
||||
- C = C * (1. - dtA/2).reciprocal() * dt.unsqueeze(-1) # or * dtA / A
|
||||
- dA = (1. + dtA/2) / (1. - dtA/2)
|
||||
+ K = log_vandermonde(C, dtA, L) # (H L)
|
||||
+ elif self.disc == "bilinear":
|
||||
+ C = (
|
||||
+ C * (1.0 - dtA / 2).reciprocal() * dt.unsqueeze(-1)
|
||||
+ ) # or * dtA / A
|
||||
+ dA = (1.0 + dtA / 2) / (1.0 - dtA / 2)
|
||||
if log_vandermonde_fast is not None:
|
||||
- dA_log = repeat(dA.log(), 'h d -> (c h) d', c=C.shape[0])
|
||||
- K = rearrange(log_vandermonde_fast(rearrange(C, 'c h d -> (c h) d'), dA_log, L),
|
||||
- '(c h) d -> c h d', c=C.shape[0])
|
||||
+ dA_log = repeat(dA.log(), "h d -> (c h) d", c=C.shape[0])
|
||||
+ K = rearrange(
|
||||
+ log_vandermonde_fast(
|
||||
+ rearrange(C, "c h d -> (c h) d"), dA_log, L
|
||||
+ ),
|
||||
+ "(c h) d -> c h d",
|
||||
+ c=C.shape[0],
|
||||
+ )
|
||||
else:
|
||||
K = log_vandermonde(C, dA.log(), L)
|
||||
- elif self.disc == 'dss':
|
||||
+ elif self.disc == "dss":
|
||||
# Implementation from DSS meant for case when real eigenvalues can be positive
|
||||
- P = dtA.unsqueeze(-1) * torch.arange(L, device=C.device) # [H N L]
|
||||
- A_gt_0 = A.real > 0 # [N]
|
||||
+ P = dtA.unsqueeze(-1) * torch.arange(L, device=C.device) # [H N L]
|
||||
+ A_gt_0 = A.real > 0 # [N]
|
||||
if A_gt_0.any():
|
||||
with torch.no_grad():
|
||||
- P_max = dtA * (A_gt_0 * (L-1)) # [H N]
|
||||
- P = P - P_max.unsqueeze(-1) # [H N L]
|
||||
- S = P.exp() # [H N L]
|
||||
+ P_max = dtA * (A_gt_0 * (L - 1)) # [H N]
|
||||
+ P = P - P_max.unsqueeze(-1) # [H N L]
|
||||
+ S = P.exp() # [H N L]
|
||||
|
||||
- dtA_neg = dtA * (1 - 2*A_gt_0) # [H N]
|
||||
- num = dtA_neg.exp() - 1 # [H N]
|
||||
- den = (dtA_neg * L).exp() - 1 # [H N]
|
||||
+ dtA_neg = dtA * (1 - 2 * A_gt_0) # [H N]
|
||||
+ num = dtA_neg.exp() - 1 # [H N]
|
||||
+ den = (dtA_neg * L).exp() - 1 # [H N]
|
||||
|
||||
# Inline reciprocal function for DSS logic
|
||||
x = den * A
|
||||
x_conj = _resolve_conj(x)
|
||||
- r = x_conj / (x*x_conj + 1e-7)
|
||||
+ r = x_conj / (x * x_conj + 1e-7)
|
||||
|
||||
- C = C * num * r # [C H N]
|
||||
- K = contract('chn,hnl->chl', C, S).float()
|
||||
- else: assert False, f"{self.disc} not supported"
|
||||
+ C = C * num * r # [C H N]
|
||||
+ K = contract("chn,hnl->chl", C, S).float()
|
||||
+ else:
|
||||
+ assert False, f"{self.disc} not supported"
|
||||
|
||||
- K = K.view(-1, self.channels, self.H, L) # (1+B C H L)
|
||||
+ K = K.view(-1, self.channels, self.H, L) # (1+B C H L)
|
||||
if state is not None:
|
||||
- K_state = K[:-1, :, :, :] # (B C H L)
|
||||
+ K_state = K[:-1, :, :, :] # (B C H L)
|
||||
else:
|
||||
K_state = None
|
||||
- K = K[-1, :, :, :] # (C H L)
|
||||
+ K = K[-1, :, :, :] # (C H L)
|
||||
return K, K_state
|
||||
|
||||
def _setup_step(self):
|
||||
# These methods are organized like this to be compatible with the NPLR kernel interface
|
||||
- dt = torch.exp(self.log_dt) # (H)
|
||||
- B = _r2c(self.B) # (H N)
|
||||
- C = _r2c(self.C) # (C H N)
|
||||
+ dt = torch.exp(self.log_dt) # (H)
|
||||
+ B = _r2c(self.B) # (H N)
|
||||
+ C = _r2c(self.C) # (C H N)
|
||||
self.dC = C
|
||||
- A = self._A() # (H N)
|
||||
+ A = self._A() # (H N)
|
||||
|
||||
- A = repeat(A, 't n -> (v t) n', v=self.repeat)
|
||||
- B = repeat(B, 't n -> (v t) n', v=self.repeat)
|
||||
+ A = repeat(A, "t n -> (v t) n", v=self.repeat)
|
||||
+ B = repeat(B, "t n -> (v t) n", v=self.repeat)
|
||||
|
||||
# Incorporate dt into A
|
||||
dtA = A * dt.unsqueeze(-1) # (H N)
|
||||
- if self.disc == 'zoh':
|
||||
- self.dA = torch.exp(dtA) # (H N)
|
||||
- self.dB = B * (torch.exp(dtA)-1.) / A # (C H N)
|
||||
- elif self.disc == 'bilinear':
|
||||
- self.dA = (1. + dtA/2) / (1. - dtA/2)
|
||||
- self.dB = B * (1. - dtA/2).reciprocal() * dt.unsqueeze(-1) # or * dtA / A
|
||||
-
|
||||
+ if self.disc == "zoh":
|
||||
+ self.dA = torch.exp(dtA) # (H N)
|
||||
+ self.dB = B * (torch.exp(dtA) - 1.0) / A # (C H N)
|
||||
+ elif self.disc == "bilinear":
|
||||
+ self.dA = (1.0 + dtA / 2) / (1.0 - dtA / 2)
|
||||
+ self.dB = (
|
||||
+ B * (1.0 - dtA / 2).reciprocal() * dt.unsqueeze(-1)
|
||||
+ ) # or * dtA / A
|
||||
|
||||
def default_state(self, *batch_shape):
|
||||
C = _r2c(self.C)
|
||||
- state = torch.zeros(*batch_shape, self.H, self.N, dtype=C.dtype, device=C.device)
|
||||
+ state = torch.zeros(
|
||||
+ *batch_shape, self.H, self.N, dtype=C.dtype, device=C.device
|
||||
+ )
|
||||
return state
|
||||
|
||||
def step(self, u, state):
|
||||
- next_state = contract("h n, b h n -> b h n", self.dA, state) \
|
||||
- + contract("h n, b h -> b h n", self.dB, u)
|
||||
+ next_state = contract(
|
||||
+ "h n, b h n -> b h n", self.dA, state
|
||||
+ ) + contract("h n, b h -> b h n", self.dB, u)
|
||||
y = contract("c h n, b h n -> b c h", self.dC, next_state)
|
||||
- return 2*y.real, next_state
|
||||
+ return 2 * y.real, next_state
|
||||
|
||||
def forward_state(self, u, state):
|
||||
self._setup_step()
|
||||
AL = self.dA ** u.size(-1)
|
||||
- u = u.flip(-1).to(self.dA).contiguous() # (B H L)
|
||||
+ u = u.flip(-1).to(self.dA).contiguous() # (B H L)
|
||||
v = log_vandermonde_transpose(u, self.dB, self.dA.log(), u.size(-1))
|
||||
next_state = AL * state + v
|
||||
return next_state
|
||||
@@ -307,7 +334,7 @@ class EMAKernel(OptimModule):
|
||||
nn.init.normal_(self.gamma, mean=0.0, std=1.0)
|
||||
# nn.init.normal_(self.omega, mean=0.0, std=1.0)
|
||||
|
||||
- def coeffs(self): # Same as discretize
|
||||
+ def coeffs(self): # Same as discretize
|
||||
p = torch.sigmoid(self.delta) # (H N 1)
|
||||
alpha = torch.sigmoid(self.alpha)
|
||||
q = 1.0 - p * alpha
|
||||
@@ -319,12 +346,16 @@ class EMAKernel(OptimModule):
|
||||
vander = torch.arange(L).to(p).view(1, 1, L) * torch.log(q) # (H N L)
|
||||
kernel = (p * self.beta) * torch.exp(vander)
|
||||
if self.efficient_bidirectional:
|
||||
- C = rearrange(self.gamma * self.scale, '(c h) n -> c h n', c=self.channels)
|
||||
- kernel = torch.einsum('dnl,cdn->cdl', kernel, C)
|
||||
+ C = rearrange(
|
||||
+ self.gamma * self.scale, "(c h) n -> c h n", c=self.channels
|
||||
+ )
|
||||
+ kernel = torch.einsum("dnl,cdn->cdl", kernel, C)
|
||||
# kernel = rearrange(kernel, 'c d l -> (c d) l')
|
||||
else:
|
||||
- kernel = torch.einsum('dnl,dn->dl', kernel, self.gamma * self.scale)
|
||||
- kernel = rearrange(kernel, '(c h) l -> c h l', c=self.channels)
|
||||
+ kernel = torch.einsum(
|
||||
+ "dnl,dn->dl", kernel, self.gamma * self.scale
|
||||
+ )
|
||||
+ kernel = rearrange(kernel, "(c h) l -> c h l", c=self.channels)
|
||||
|
||||
kernel = kernel[..., :L]
|
||||
# kernel = rearrange(kernel, '(c h) l -> c h l', c=self.channels)
|
||||
Reference in New Issue
Block a user