Skip to content

Commit

Permalink
added mamba block, need to test
Browse files Browse the repository at this point in the history
  • Loading branch information
Artur-Galstyan committed Feb 15, 2024
1 parent 8aa2c72 commit 2e6d8f6
Showing 1 changed file with 64 additions and 3 deletions.
67 changes: 64 additions & 3 deletions equinox/nn/_selective_state_space_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,41 @@

import jax
import jax.numpy as jnp
from jaxtyping import Array, PRNGKeyArray
from jaxtyping import Array, Float, PRNGKeyArray

from .._module import field, Module
from ._conv import Conv1d
from ._linear import Linear


def _selective_scan(
u: Float[Array, "seq_len d_inner"],
delta: Float[Array, "seq_len d_inner"],
A: Float[Array, "d_inner state_space_dims"],
B: Float[Array, "seq_len state_space_dims"],
C: Float[Array, "seq_len state_space_dims"],
D: Float[Array, "d_inner"], # noqa
):
seq_len, _ = u.shape
d_inner, state_space_dims = A.shape

delta_A = jnp.exp(jnp.einsum("l d,d n -> l d n", delta, A))
delta_B_u = jnp.einsum("l d,l n,l d -> l d n", delta, B, u)

x_res = jnp.zeros(shape=(d_inner, state_space_dims))

def step(x, i):
x = delta_A[i] * x + delta_B_u[i]

y = jnp.einsum("d n,n -> d", x, C[i, :])
return x, y

_, ys = jax.lax.scan(step, x_res, jnp.arange(seq_len))

ys = ys + u * D
return ys


class SelectiveStateSpaceModel(Module, strict=True):
"""
State Space Model with Selective Scan. This is the implementation of the
Expand Down Expand Up @@ -129,5 +157,38 @@ def __init__(
)

@jax.named_scope("eqx.nn.StateSpaceModel")
def __call__(self) -> Array:
raise NotImplementedError
def __call__(self, x: Float[Array, "seq_len n_input_dims"]) -> Array:
seq_len, d = x.shape
if d != self.n_input_dims:
raise ValueError(
f"Input dimension mismatch: expected {self.n_input_dims}, got {d}"
)
x_and_res = jax.vmap(self.in_proj)(x)
(x, res) = jnp.split(x_and_res, 2, axis=-1)

x = jnp.transpose(x)
x = self.conv1d(x)[:, :seq_len]
x = jnp.transpose(x)
x = jax.nn.silu(x)

y = self._ssm(x)
y = y * jax.nn.silu(res)

output = jax.vmap(self.out_proj)(y)
return output

def _ssm(self, x: Float[Array, "seq_len d_inner"]) -> Array:
A = -jnp.exp(self.A_log)
D = self.D

x_delta_b_c = jax.vmap(self.x_proj)(x)

split_indices = [
self.dt_rank,
self.dt_rank + self.state_space_dims,
]
delta, B, C = jnp.split(x_delta_b_c, split_indices, axis=-1)
delta = jax.nn.softplus(jax.vmap(self.dt_proj)(delta))

y = _selective_scan(x, delta, A, B, C, D)
return y

0 comments on commit 2e6d8f6

Please sign in to comment.