diff --git a/equinox/nn/_selective_state_space_models.py b/equinox/nn/_selective_state_space_models.py index 98c05a8e..f09e7474 100644 --- a/equinox/nn/_selective_state_space_models.py +++ b/equinox/nn/_selective_state_space_models.py @@ -3,13 +3,41 @@ import jax import jax.numpy as jnp -from jaxtyping import Array, PRNGKeyArray +from jaxtyping import Array, Float, PRNGKeyArray from .._module import field, Module from ._conv import Conv1d from ._linear import Linear +def _selective_scan( + u: Float[Array, "seq_len d_inner"], + delta: Float[Array, "seq_len d_inner"], + A: Float[Array, "d_inner state_space_dims"], + B: Float[Array, "seq_len state_space_dims"], + C: Float[Array, "seq_len state_space_dims"], + D: Float[Array, "d_inner"], # noqa +): + seq_len, _ = u.shape + d_inner, state_space_dims = A.shape + + delta_A = jnp.exp(jnp.einsum("l d,d n -> l d n", delta, A)) + delta_B_u = jnp.einsum("l d,l n,l d -> l d n", delta, B, u) + + x_res = jnp.zeros(shape=(d_inner, state_space_dims)) + + def step(x, i): + x = delta_A[i] * x + delta_B_u[i] + + y = jnp.einsum("d n,n -> d", x, C[i, :]) + return x, y + + _, ys = jax.lax.scan(step, x_res, jnp.arange(seq_len)) + + ys = ys + u * D + return ys + + class SelectiveStateSpaceModel(Module, strict=True): """ State Space Model with Selective Scan. This is the implementation of the @@ -129,5 +157,38 @@ def __init__( ) @jax.named_scope("eqx.nn.StateSpaceModel") - def __call__(self) -> Array: - raise NotImplementedError + def __call__(self, x: Float[Array, "seq_len n_input_dims"]) -> Array: + seq_len, d = x.shape + if d != self.n_input_dims: + raise ValueError( + f"Input dimension mismatch: expected {self.n_input_dims}, got {d}" + ) + x_and_res = jax.vmap(self.in_proj)(x) + (x, res) = jnp.split(x_and_res, 2, axis=-1) + + x = jnp.transpose(x) + x = self.conv1d(x)[:, :seq_len] + x = jnp.transpose(x) + x = jax.nn.silu(x) + + y = self._ssm(x) + y = y * jax.nn.silu(res) + + output = jax.vmap(self.out_proj)(y) + return output + + def _ssm(self, x: Float[Array, "seq_len d_inner"]) -> Array: + A = -jnp.exp(self.A_log) + D = self.D + + x_delta_b_c = jax.vmap(self.x_proj)(x) + + split_indices = [ + self.dt_rank, + self.dt_rank + self.state_space_dims, + ] + delta, B, C = jnp.split(x_delta_b_c, split_indices, axis=-1) + delta = jax.nn.softplus(jax.vmap(self.dt_proj)(delta)) + + y = _selective_scan(x, delta, A, B, C, D) + return y