diff --git a/src/models/sequence/ssm/ss_kernel_diag.py b/src/models/sequence/ssm/ss_kernel_diag.py index 6cb279b..554725c 100644 --- a/src/models/sequence/ssm/ss_kernel_diag.py +++ b/src/models/sequence/ssm/ss_kernel_diag.py @@ -20,15 +20,20 @@ log = get_logger(__name__) # This could be None if the CUDA import fails from src.ops.vandermonde import log_vandermonde_fast + try: import pykeops from src.ops.vandermonde import log_vandermonde, log_vandermonde_transpose + has_pykeops = True log.info("Pykeops installation found.") except ImportError: has_pykeops = False from src.ops.vandermonde import log_vandermonde_naive as log_vandermonde - from src.ops.vandermonde import log_vandermonde_transpose_naive as log_vandermonde_transpose + from src.ops.vandermonde import ( + log_vandermonde_transpose_naive as log_vandermonde_transpose, + ) + log.warning( "Falling back on slow Vandermonde kernel. Install pykeops for improved memory efficiency." ) @@ -37,7 +42,7 @@ except ImportError: _c2r = torch.view_as_real _r2c = torch.view_as_complex -if tuple(map(int, torch.__version__.split('.')[:2])) >= (1, 10): +if tuple(map(int, torch.__version__.split(".")[:2])) >= (1, 10): _resolve_conj = lambda x: x.conj().resolve_conj() else: _resolve_conj = lambda x: x.conj() @@ -48,15 +53,18 @@ class SSKernelDiag(OptimModule): def __init__( self, - A, B, C, log_dt, + A, + B, + C, + log_dt, L=None, - disc='bilinear', - real_type='exp', + disc="bilinear", + real_type="exp", lr=None, bandlimit=None, force_real=False, + **kwargs, ): - super().__init__() self.L = L self.disc = disc @@ -68,7 +76,7 @@ class SSKernelDiag(OptimModule): assert A.size(-1) == C.size(-1) self.H = log_dt.size(-1) self.N = A.size(-1) - assert A.size(-2) == B.size(-2) # Number of independent SSMs trained + assert A.size(-2) == B.size(-2) # Number of independent SSMs trained assert self.H % A.size(-2) == 0 self.n_ssm = A.size(-2) self.repeat = self.H // A.size(0) @@ -77,42 +85,48 @@ class SSKernelDiag(OptimModule): self.C = nn.Parameter(_c2r(_resolve_conj(C))) # Register parameters - if lr is None or isinstance(lr, float): lr_dict = {} - else: lr_dict, lr = lr, None + if lr is None or isinstance(lr, float): + lr_dict = {} + else: + lr_dict, lr = lr, None - self.register("log_dt", log_dt, lr_dict.get('dt', lr)) - self.register("B", _c2r(B), lr_dict.get('B', lr)) - self.register("inv_A_real", self._A_init(A.real), lr_dict.get('A', lr)) - self.register("A_imag", A.imag, lr_dict.get('A', lr)) + self.register("log_dt", log_dt, lr_dict.get("dt", lr)) + self.register("B", _c2r(B), lr_dict.get("B", lr)) + self.register("inv_A_real", self._A_init(A.real), lr_dict.get("A", lr)) + self.register("A_imag", A.imag, lr_dict.get("A", lr)) def _A_init(self, A_real): A_real = torch.clamp(A_real, max=-1e-4) - if self.real_type == 'none': + if self.real_type == "none": return -A_real - elif self.real_type == 'exp': - return torch.log(-A_real) # Some of the HiPPO methods have real part 0 - elif self.real_type == 'relu': + elif self.real_type == "exp": + return torch.log( + -A_real + ) # Some of the HiPPO methods have real part 0 + elif self.real_type == "relu": return -A_real - elif self.real_type == 'sigmoid': + elif self.real_type == "sigmoid": return torch.logit(-A_real) - elif self.real_type == 'softplus': - return torch.log(torch.exp(-A_real)-1) - else: raise NotImplementedError + elif self.real_type == "softplus": + return torch.log(torch.exp(-A_real) - 1) + else: + raise NotImplementedError def _A(self): # Get the internal A (diagonal) parameter - if self.real_type == 'none': + if self.real_type == "none": A_real = -self.inv_A_real - elif self.real_type == 'exp': + elif self.real_type == "exp": A_real = -torch.exp(self.inv_A_real) - elif self.real_type == 'relu': + elif self.real_type == "relu": # JAX version seems to NaN if you alloA 0's, although this code Aas fine Aithout it - A_real = -F.relu(self.inv_A_real)-1e-4 - elif self.real_type == 'sigmoid': + A_real = -F.relu(self.inv_A_real) - 1e-4 + elif self.real_type == "sigmoid": A_real = -F.sigmoid(self.inv_A_real) - elif self.real_type == 'softplus': + elif self.real_type == "softplus": A_real = -F.softplus(self.inv_A_real) - else: raise NotImplementedError + else: + raise NotImplementedError A = A_real + 1j * self.A_imag return A @@ -126,121 +140,134 @@ class SSKernelDiag(OptimModule): (B, H, L) output from initial state """ - dt = torch.exp(self.log_dt) * rate # (H) - C = _r2c(self.C) # (C H N) - A = self._A() # (H N) + dt = torch.exp(self.log_dt) * rate # (H) + C = _r2c(self.C) # (C H N) + A = self._A() # (H N) B = _r2c(self.B) - B = repeat(B, 't n -> 1 (v t) n', v=self.repeat) + B = repeat(B, "t n -> 1 (v t) n", v=self.repeat) # Force A to be real valued, so the whole kernel can be interpreted as a "multi-head EMA" if self.force_real: A = A.real + 0j if self.bandlimit is not None: - freqs = dt[:, None] / rate * A.imag.abs() / (2*math.pi) # (H, N) - mask = torch.where(freqs < self.bandlimit * .5, 1, 0) + freqs = dt[:, None] / rate * A.imag.abs() / (2 * math.pi) # (H, N) + mask = torch.where(freqs < self.bandlimit * 0.5, 1, 0) C = C * mask # Incorporate dt into A - A = repeat(A, 't n -> (v t) n', v=self.repeat) + A = repeat(A, "t n -> (v t) n", v=self.repeat) dtA = A * dt.unsqueeze(-1) # (H N) - # Augment B with state if state is not None: s = state / dt.unsqueeze(-1) - if self.disc == 'bilinear': - s = s * (1. + dtA/2) - elif self.disc == 'zoh': - s = s * dtA * dtA.exp() / (dtA.exp() - 1.) - B = torch.cat([s, B], dim=-3) # (1+B H N) + if self.disc == "bilinear": + s = s * (1.0 + dtA / 2) + elif self.disc == "zoh": + s = s * dtA * dtA.exp() / (dtA.exp() - 1.0) + B = torch.cat([s, B], dim=-3) # (1+B H N) C = (B[:, None, :, :] * C).view(-1, self.H, self.N) - if self.disc == 'zoh': + if self.disc == "zoh": # Power up - C = C * (torch.exp(dtA)-1.) / A + C = C * (torch.exp(dtA) - 1.0) / A # TODO (TD): make it work for C.shape[0] > 1 if log_vandermonde_fast is not None and C.shape[0] == 1: - K = log_vandermonde_fast(C.squeeze(0), dtA, L).unsqueeze(0) # (H L) + K = log_vandermonde_fast(C.squeeze(0), dtA, L).unsqueeze( + 0 + ) # (H L) else: - K = log_vandermonde(C, dtA, L) # (H L) - elif self.disc == 'bilinear': - C = C * (1. - dtA/2).reciprocal() * dt.unsqueeze(-1) # or * dtA / A - dA = (1. + dtA/2) / (1. - dtA/2) + K = log_vandermonde(C, dtA, L) # (H L) + elif self.disc == "bilinear": + C = ( + C * (1.0 - dtA / 2).reciprocal() * dt.unsqueeze(-1) + ) # or * dtA / A + dA = (1.0 + dtA / 2) / (1.0 - dtA / 2) if log_vandermonde_fast is not None: - dA_log = repeat(dA.log(), 'h d -> (c h) d', c=C.shape[0]) - K = rearrange(log_vandermonde_fast(rearrange(C, 'c h d -> (c h) d'), dA_log, L), - '(c h) d -> c h d', c=C.shape[0]) + dA_log = repeat(dA.log(), "h d -> (c h) d", c=C.shape[0]) + K = rearrange( + log_vandermonde_fast( + rearrange(C, "c h d -> (c h) d"), dA_log, L + ), + "(c h) d -> c h d", + c=C.shape[0], + ) else: K = log_vandermonde(C, dA.log(), L) - elif self.disc == 'dss': + elif self.disc == "dss": # Implementation from DSS meant for case when real eigenvalues can be positive - P = dtA.unsqueeze(-1) * torch.arange(L, device=C.device) # [H N L] - A_gt_0 = A.real > 0 # [N] + P = dtA.unsqueeze(-1) * torch.arange(L, device=C.device) # [H N L] + A_gt_0 = A.real > 0 # [N] if A_gt_0.any(): with torch.no_grad(): - P_max = dtA * (A_gt_0 * (L-1)) # [H N] - P = P - P_max.unsqueeze(-1) # [H N L] - S = P.exp() # [H N L] + P_max = dtA * (A_gt_0 * (L - 1)) # [H N] + P = P - P_max.unsqueeze(-1) # [H N L] + S = P.exp() # [H N L] - dtA_neg = dtA * (1 - 2*A_gt_0) # [H N] - num = dtA_neg.exp() - 1 # [H N] - den = (dtA_neg * L).exp() - 1 # [H N] + dtA_neg = dtA * (1 - 2 * A_gt_0) # [H N] + num = dtA_neg.exp() - 1 # [H N] + den = (dtA_neg * L).exp() - 1 # [H N] # Inline reciprocal function for DSS logic x = den * A x_conj = _resolve_conj(x) - r = x_conj / (x*x_conj + 1e-7) + r = x_conj / (x * x_conj + 1e-7) - C = C * num * r # [C H N] - K = contract('chn,hnl->chl', C, S).float() - else: assert False, f"{self.disc} not supported" + C = C * num * r # [C H N] + K = contract("chn,hnl->chl", C, S).float() + else: + assert False, f"{self.disc} not supported" - K = K.view(-1, self.channels, self.H, L) # (1+B C H L) + K = K.view(-1, self.channels, self.H, L) # (1+B C H L) if state is not None: - K_state = K[:-1, :, :, :] # (B C H L) + K_state = K[:-1, :, :, :] # (B C H L) else: K_state = None - K = K[-1, :, :, :] # (C H L) + K = K[-1, :, :, :] # (C H L) return K, K_state def _setup_step(self): # These methods are organized like this to be compatible with the NPLR kernel interface - dt = torch.exp(self.log_dt) # (H) - B = _r2c(self.B) # (H N) - C = _r2c(self.C) # (C H N) + dt = torch.exp(self.log_dt) # (H) + B = _r2c(self.B) # (H N) + C = _r2c(self.C) # (C H N) self.dC = C - A = self._A() # (H N) + A = self._A() # (H N) - A = repeat(A, 't n -> (v t) n', v=self.repeat) - B = repeat(B, 't n -> (v t) n', v=self.repeat) + A = repeat(A, "t n -> (v t) n", v=self.repeat) + B = repeat(B, "t n -> (v t) n", v=self.repeat) # Incorporate dt into A dtA = A * dt.unsqueeze(-1) # (H N) - if self.disc == 'zoh': - self.dA = torch.exp(dtA) # (H N) - self.dB = B * (torch.exp(dtA)-1.) / A # (C H N) - elif self.disc == 'bilinear': - self.dA = (1. + dtA/2) / (1. - dtA/2) - self.dB = B * (1. - dtA/2).reciprocal() * dt.unsqueeze(-1) # or * dtA / A - + if self.disc == "zoh": + self.dA = torch.exp(dtA) # (H N) + self.dB = B * (torch.exp(dtA) - 1.0) / A # (C H N) + elif self.disc == "bilinear": + self.dA = (1.0 + dtA / 2) / (1.0 - dtA / 2) + self.dB = ( + B * (1.0 - dtA / 2).reciprocal() * dt.unsqueeze(-1) + ) # or * dtA / A def default_state(self, *batch_shape): C = _r2c(self.C) - state = torch.zeros(*batch_shape, self.H, self.N, dtype=C.dtype, device=C.device) + state = torch.zeros( + *batch_shape, self.H, self.N, dtype=C.dtype, device=C.device + ) return state def step(self, u, state): - next_state = contract("h n, b h n -> b h n", self.dA, state) \ - + contract("h n, b h -> b h n", self.dB, u) + next_state = contract( + "h n, b h n -> b h n", self.dA, state + ) + contract("h n, b h -> b h n", self.dB, u) y = contract("c h n, b h n -> b c h", self.dC, next_state) - return 2*y.real, next_state + return 2 * y.real, next_state def forward_state(self, u, state): self._setup_step() AL = self.dA ** u.size(-1) - u = u.flip(-1).to(self.dA).contiguous() # (B H L) + u = u.flip(-1).to(self.dA).contiguous() # (B H L) v = log_vandermonde_transpose(u, self.dB, self.dA.log(), u.size(-1)) next_state = AL * state + v return next_state @@ -307,7 +334,7 @@ class EMAKernel(OptimModule): nn.init.normal_(self.gamma, mean=0.0, std=1.0) # nn.init.normal_(self.omega, mean=0.0, std=1.0) - def coeffs(self): # Same as discretize + def coeffs(self): # Same as discretize p = torch.sigmoid(self.delta) # (H N 1) alpha = torch.sigmoid(self.alpha) q = 1.0 - p * alpha @@ -319,12 +346,16 @@ class EMAKernel(OptimModule): vander = torch.arange(L).to(p).view(1, 1, L) * torch.log(q) # (H N L) kernel = (p * self.beta) * torch.exp(vander) if self.efficient_bidirectional: - C = rearrange(self.gamma * self.scale, '(c h) n -> c h n', c=self.channels) - kernel = torch.einsum('dnl,cdn->cdl', kernel, C) + C = rearrange( + self.gamma * self.scale, "(c h) n -> c h n", c=self.channels + ) + kernel = torch.einsum("dnl,cdn->cdl", kernel, C) # kernel = rearrange(kernel, 'c d l -> (c d) l') else: - kernel = torch.einsum('dnl,dn->dl', kernel, self.gamma * self.scale) - kernel = rearrange(kernel, '(c h) l -> c h l', c=self.channels) + kernel = torch.einsum( + "dnl,dn->dl", kernel, self.gamma * self.scale + ) + kernel = rearrange(kernel, "(c h) l -> c h l", c=self.channels) kernel = kernel[..., :L] # kernel = rearrange(kernel, '(c h) l -> c h l', c=self.channels)