From c3c7f33cacee6eb481176a2e83bdc416318af5dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Korbinian=20P=C3=B6ppel?= Date: Sat, 5 Aug 2023 17:37:52 +0200 Subject: [PATCH] Add safari diff. --- ...d81b2fce85f12d78812c427c4e8e9cb0ac42.patch | 359 ++++++++++++++++++ 1 file changed, 359 insertions(+) create mode 100644 safari_diffs/23a1d81b2fce85f12d78812c427c4e8e9cb0ac42.patch diff --git a/safari_diffs/23a1d81b2fce85f12d78812c427c4e8e9cb0ac42.patch b/safari_diffs/23a1d81b2fce85f12d78812c427c4e8e9cb0ac42.patch new file mode 100644 index 0000000..54a0870 --- /dev/null +++ b/safari_diffs/23a1d81b2fce85f12d78812c427c4e8e9cb0ac42.patch @@ -0,0 +1,359 @@ +diff --git a/src/models/sequence/ssm/ss_kernel_diag.py b/src/models/sequence/ssm/ss_kernel_diag.py +index 6cb279b..554725c 100644 +--- a/src/models/sequence/ssm/ss_kernel_diag.py ++++ b/src/models/sequence/ssm/ss_kernel_diag.py +@@ -20,15 +20,20 @@ log = get_logger(__name__) + + # This could be None if the CUDA import fails + from src.ops.vandermonde import log_vandermonde_fast ++ + try: + import pykeops + from src.ops.vandermonde import log_vandermonde, log_vandermonde_transpose ++ + has_pykeops = True + log.info("Pykeops installation found.") + except ImportError: + has_pykeops = False + from src.ops.vandermonde import log_vandermonde_naive as log_vandermonde +- from src.ops.vandermonde import log_vandermonde_transpose_naive as log_vandermonde_transpose ++ from src.ops.vandermonde import ( ++ log_vandermonde_transpose_naive as log_vandermonde_transpose, ++ ) ++ + log.warning( + "Falling back on slow Vandermonde kernel. Install pykeops for improved memory efficiency." + ) +@@ -37,7 +42,7 @@ except ImportError: + _c2r = torch.view_as_real + _r2c = torch.view_as_complex + +-if tuple(map(int, torch.__version__.split('.')[:2])) >= (1, 10): ++if tuple(map(int, torch.__version__.split(".")[:2])) >= (1, 10): + _resolve_conj = lambda x: x.conj().resolve_conj() + else: + _resolve_conj = lambda x: x.conj() +@@ -48,15 +53,18 @@ class SSKernelDiag(OptimModule): + + def __init__( + self, +- A, B, C, log_dt, ++ A, ++ B, ++ C, ++ log_dt, + L=None, +- disc='bilinear', +- real_type='exp', ++ disc="bilinear", ++ real_type="exp", + lr=None, + bandlimit=None, + force_real=False, ++ **kwargs, + ): +- + super().__init__() + self.L = L + self.disc = disc +@@ -68,7 +76,7 @@ class SSKernelDiag(OptimModule): + assert A.size(-1) == C.size(-1) + self.H = log_dt.size(-1) + self.N = A.size(-1) +- assert A.size(-2) == B.size(-2) # Number of independent SSMs trained ++ assert A.size(-2) == B.size(-2) # Number of independent SSMs trained + assert self.H % A.size(-2) == 0 + self.n_ssm = A.size(-2) + self.repeat = self.H // A.size(0) +@@ -77,42 +85,48 @@ class SSKernelDiag(OptimModule): + self.C = nn.Parameter(_c2r(_resolve_conj(C))) + + # Register parameters +- if lr is None or isinstance(lr, float): lr_dict = {} +- else: lr_dict, lr = lr, None ++ if lr is None or isinstance(lr, float): ++ lr_dict = {} ++ else: ++ lr_dict, lr = lr, None + +- self.register("log_dt", log_dt, lr_dict.get('dt', lr)) +- self.register("B", _c2r(B), lr_dict.get('B', lr)) +- self.register("inv_A_real", self._A_init(A.real), lr_dict.get('A', lr)) +- self.register("A_imag", A.imag, lr_dict.get('A', lr)) ++ self.register("log_dt", log_dt, lr_dict.get("dt", lr)) ++ self.register("B", _c2r(B), lr_dict.get("B", lr)) ++ self.register("inv_A_real", self._A_init(A.real), lr_dict.get("A", lr)) ++ self.register("A_imag", A.imag, lr_dict.get("A", lr)) + + def _A_init(self, A_real): + A_real = torch.clamp(A_real, max=-1e-4) +- if self.real_type == 'none': ++ if self.real_type == "none": + return -A_real +- elif self.real_type == 'exp': +- return torch.log(-A_real) # Some of the HiPPO methods have real part 0 +- elif self.real_type == 'relu': ++ elif self.real_type == "exp": ++ return torch.log( ++ -A_real ++ ) # Some of the HiPPO methods have real part 0 ++ elif self.real_type == "relu": + return -A_real +- elif self.real_type == 'sigmoid': ++ elif self.real_type == "sigmoid": + return torch.logit(-A_real) +- elif self.real_type == 'softplus': +- return torch.log(torch.exp(-A_real)-1) +- else: raise NotImplementedError ++ elif self.real_type == "softplus": ++ return torch.log(torch.exp(-A_real) - 1) ++ else: ++ raise NotImplementedError + + def _A(self): + # Get the internal A (diagonal) parameter +- if self.real_type == 'none': ++ if self.real_type == "none": + A_real = -self.inv_A_real +- elif self.real_type == 'exp': ++ elif self.real_type == "exp": + A_real = -torch.exp(self.inv_A_real) +- elif self.real_type == 'relu': ++ elif self.real_type == "relu": + # JAX version seems to NaN if you alloA 0's, although this code Aas fine Aithout it +- A_real = -F.relu(self.inv_A_real)-1e-4 +- elif self.real_type == 'sigmoid': ++ A_real = -F.relu(self.inv_A_real) - 1e-4 ++ elif self.real_type == "sigmoid": + A_real = -F.sigmoid(self.inv_A_real) +- elif self.real_type == 'softplus': ++ elif self.real_type == "softplus": + A_real = -F.softplus(self.inv_A_real) +- else: raise NotImplementedError ++ else: ++ raise NotImplementedError + A = A_real + 1j * self.A_imag + return A + +@@ -126,121 +140,134 @@ class SSKernelDiag(OptimModule): + (B, H, L) output from initial state + """ + +- dt = torch.exp(self.log_dt) * rate # (H) +- C = _r2c(self.C) # (C H N) +- A = self._A() # (H N) ++ dt = torch.exp(self.log_dt) * rate # (H) ++ C = _r2c(self.C) # (C H N) ++ A = self._A() # (H N) + + B = _r2c(self.B) +- B = repeat(B, 't n -> 1 (v t) n', v=self.repeat) ++ B = repeat(B, "t n -> 1 (v t) n", v=self.repeat) + + # Force A to be real valued, so the whole kernel can be interpreted as a "multi-head EMA" + if self.force_real: + A = A.real + 0j + + if self.bandlimit is not None: +- freqs = dt[:, None] / rate * A.imag.abs() / (2*math.pi) # (H, N) +- mask = torch.where(freqs < self.bandlimit * .5, 1, 0) ++ freqs = dt[:, None] / rate * A.imag.abs() / (2 * math.pi) # (H, N) ++ mask = torch.where(freqs < self.bandlimit * 0.5, 1, 0) + C = C * mask + + # Incorporate dt into A +- A = repeat(A, 't n -> (v t) n', v=self.repeat) ++ A = repeat(A, "t n -> (v t) n", v=self.repeat) + dtA = A * dt.unsqueeze(-1) # (H N) + +- + # Augment B with state + if state is not None: + s = state / dt.unsqueeze(-1) +- if self.disc == 'bilinear': +- s = s * (1. + dtA/2) +- elif self.disc == 'zoh': +- s = s * dtA * dtA.exp() / (dtA.exp() - 1.) +- B = torch.cat([s, B], dim=-3) # (1+B H N) ++ if self.disc == "bilinear": ++ s = s * (1.0 + dtA / 2) ++ elif self.disc == "zoh": ++ s = s * dtA * dtA.exp() / (dtA.exp() - 1.0) ++ B = torch.cat([s, B], dim=-3) # (1+B H N) + + C = (B[:, None, :, :] * C).view(-1, self.H, self.N) +- if self.disc == 'zoh': ++ if self.disc == "zoh": + # Power up +- C = C * (torch.exp(dtA)-1.) / A ++ C = C * (torch.exp(dtA) - 1.0) / A + # TODO (TD): make it work for C.shape[0] > 1 + if log_vandermonde_fast is not None and C.shape[0] == 1: +- K = log_vandermonde_fast(C.squeeze(0), dtA, L).unsqueeze(0) # (H L) ++ K = log_vandermonde_fast(C.squeeze(0), dtA, L).unsqueeze( ++ 0 ++ ) # (H L) + else: +- K = log_vandermonde(C, dtA, L) # (H L) +- elif self.disc == 'bilinear': +- C = C * (1. - dtA/2).reciprocal() * dt.unsqueeze(-1) # or * dtA / A +- dA = (1. + dtA/2) / (1. - dtA/2) ++ K = log_vandermonde(C, dtA, L) # (H L) ++ elif self.disc == "bilinear": ++ C = ( ++ C * (1.0 - dtA / 2).reciprocal() * dt.unsqueeze(-1) ++ ) # or * dtA / A ++ dA = (1.0 + dtA / 2) / (1.0 - dtA / 2) + if log_vandermonde_fast is not None: +- dA_log = repeat(dA.log(), 'h d -> (c h) d', c=C.shape[0]) +- K = rearrange(log_vandermonde_fast(rearrange(C, 'c h d -> (c h) d'), dA_log, L), +- '(c h) d -> c h d', c=C.shape[0]) ++ dA_log = repeat(dA.log(), "h d -> (c h) d", c=C.shape[0]) ++ K = rearrange( ++ log_vandermonde_fast( ++ rearrange(C, "c h d -> (c h) d"), dA_log, L ++ ), ++ "(c h) d -> c h d", ++ c=C.shape[0], ++ ) + else: + K = log_vandermonde(C, dA.log(), L) +- elif self.disc == 'dss': ++ elif self.disc == "dss": + # Implementation from DSS meant for case when real eigenvalues can be positive +- P = dtA.unsqueeze(-1) * torch.arange(L, device=C.device) # [H N L] +- A_gt_0 = A.real > 0 # [N] ++ P = dtA.unsqueeze(-1) * torch.arange(L, device=C.device) # [H N L] ++ A_gt_0 = A.real > 0 # [N] + if A_gt_0.any(): + with torch.no_grad(): +- P_max = dtA * (A_gt_0 * (L-1)) # [H N] +- P = P - P_max.unsqueeze(-1) # [H N L] +- S = P.exp() # [H N L] ++ P_max = dtA * (A_gt_0 * (L - 1)) # [H N] ++ P = P - P_max.unsqueeze(-1) # [H N L] ++ S = P.exp() # [H N L] + +- dtA_neg = dtA * (1 - 2*A_gt_0) # [H N] +- num = dtA_neg.exp() - 1 # [H N] +- den = (dtA_neg * L).exp() - 1 # [H N] ++ dtA_neg = dtA * (1 - 2 * A_gt_0) # [H N] ++ num = dtA_neg.exp() - 1 # [H N] ++ den = (dtA_neg * L).exp() - 1 # [H N] + + # Inline reciprocal function for DSS logic + x = den * A + x_conj = _resolve_conj(x) +- r = x_conj / (x*x_conj + 1e-7) ++ r = x_conj / (x * x_conj + 1e-7) + +- C = C * num * r # [C H N] +- K = contract('chn,hnl->chl', C, S).float() +- else: assert False, f"{self.disc} not supported" ++ C = C * num * r # [C H N] ++ K = contract("chn,hnl->chl", C, S).float() ++ else: ++ assert False, f"{self.disc} not supported" + +- K = K.view(-1, self.channels, self.H, L) # (1+B C H L) ++ K = K.view(-1, self.channels, self.H, L) # (1+B C H L) + if state is not None: +- K_state = K[:-1, :, :, :] # (B C H L) ++ K_state = K[:-1, :, :, :] # (B C H L) + else: + K_state = None +- K = K[-1, :, :, :] # (C H L) ++ K = K[-1, :, :, :] # (C H L) + return K, K_state + + def _setup_step(self): + # These methods are organized like this to be compatible with the NPLR kernel interface +- dt = torch.exp(self.log_dt) # (H) +- B = _r2c(self.B) # (H N) +- C = _r2c(self.C) # (C H N) ++ dt = torch.exp(self.log_dt) # (H) ++ B = _r2c(self.B) # (H N) ++ C = _r2c(self.C) # (C H N) + self.dC = C +- A = self._A() # (H N) ++ A = self._A() # (H N) + +- A = repeat(A, 't n -> (v t) n', v=self.repeat) +- B = repeat(B, 't n -> (v t) n', v=self.repeat) ++ A = repeat(A, "t n -> (v t) n", v=self.repeat) ++ B = repeat(B, "t n -> (v t) n", v=self.repeat) + + # Incorporate dt into A + dtA = A * dt.unsqueeze(-1) # (H N) +- if self.disc == 'zoh': +- self.dA = torch.exp(dtA) # (H N) +- self.dB = B * (torch.exp(dtA)-1.) / A # (C H N) +- elif self.disc == 'bilinear': +- self.dA = (1. + dtA/2) / (1. - dtA/2) +- self.dB = B * (1. - dtA/2).reciprocal() * dt.unsqueeze(-1) # or * dtA / A +- ++ if self.disc == "zoh": ++ self.dA = torch.exp(dtA) # (H N) ++ self.dB = B * (torch.exp(dtA) - 1.0) / A # (C H N) ++ elif self.disc == "bilinear": ++ self.dA = (1.0 + dtA / 2) / (1.0 - dtA / 2) ++ self.dB = ( ++ B * (1.0 - dtA / 2).reciprocal() * dt.unsqueeze(-1) ++ ) # or * dtA / A + + def default_state(self, *batch_shape): + C = _r2c(self.C) +- state = torch.zeros(*batch_shape, self.H, self.N, dtype=C.dtype, device=C.device) ++ state = torch.zeros( ++ *batch_shape, self.H, self.N, dtype=C.dtype, device=C.device ++ ) + return state + + def step(self, u, state): +- next_state = contract("h n, b h n -> b h n", self.dA, state) \ +- + contract("h n, b h -> b h n", self.dB, u) ++ next_state = contract( ++ "h n, b h n -> b h n", self.dA, state ++ ) + contract("h n, b h -> b h n", self.dB, u) + y = contract("c h n, b h n -> b c h", self.dC, next_state) +- return 2*y.real, next_state ++ return 2 * y.real, next_state + + def forward_state(self, u, state): + self._setup_step() + AL = self.dA ** u.size(-1) +- u = u.flip(-1).to(self.dA).contiguous() # (B H L) ++ u = u.flip(-1).to(self.dA).contiguous() # (B H L) + v = log_vandermonde_transpose(u, self.dB, self.dA.log(), u.size(-1)) + next_state = AL * state + v + return next_state +@@ -307,7 +334,7 @@ class EMAKernel(OptimModule): + nn.init.normal_(self.gamma, mean=0.0, std=1.0) + # nn.init.normal_(self.omega, mean=0.0, std=1.0) + +- def coeffs(self): # Same as discretize ++ def coeffs(self): # Same as discretize + p = torch.sigmoid(self.delta) # (H N 1) + alpha = torch.sigmoid(self.alpha) + q = 1.0 - p * alpha +@@ -319,12 +346,16 @@ class EMAKernel(OptimModule): + vander = torch.arange(L).to(p).view(1, 1, L) * torch.log(q) # (H N L) + kernel = (p * self.beta) * torch.exp(vander) + if self.efficient_bidirectional: +- C = rearrange(self.gamma * self.scale, '(c h) n -> c h n', c=self.channels) +- kernel = torch.einsum('dnl,cdn->cdl', kernel, C) ++ C = rearrange( ++ self.gamma * self.scale, "(c h) n -> c h n", c=self.channels ++ ) ++ kernel = torch.einsum("dnl,cdn->cdl", kernel, C) + # kernel = rearrange(kernel, 'c d l -> (c d) l') + else: +- kernel = torch.einsum('dnl,dn->dl', kernel, self.gamma * self.scale) +- kernel = rearrange(kernel, '(c h) l -> c h l', c=self.channels) ++ kernel = torch.einsum( ++ "dnl,dn->dl", kernel, self.gamma * self.scale ++ ) ++ kernel = rearrange(kernel, "(c h) l -> c h l", c=self.channels) + + kernel = kernel[..., :L] + # kernel = rearrange(kernel, '(c h) l -> c h l', c=self.channels)