Skip to content

Commit

Permalink
docs and feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
Artur-Galstyan committed Jun 15, 2024
1 parent 6f747da commit bf41205
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 7 deletions.
8 changes: 8 additions & 0 deletions docs/api/nn/attention.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,11 @@
members:
- __init__
- __call__

---

::: equinox.nn.StandardKVCache
selection:
members:
- __init__
- __call__
1 change: 1 addition & 0 deletions equinox/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
RotaryPositionalEmbedding as RotaryPositionalEmbedding,
)
from ._inference import inference_mode as inference_mode
from ._kv_cache import StandardKVCache as StandardKVCache
from ._linear import Identity as Identity, Linear as Linear
from ._mlp import MLP as MLP
from ._normalisation import (
Expand Down
7 changes: 4 additions & 3 deletions equinox/nn/_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,16 +199,17 @@ def __init__(
import equinox as eqx
import jax
seq_len = 3
state_length = 8
num_heads = 1
query_size = 6
standard_kv_cache = eqx.nn.StandardKVCache(
key_shape=(state_length, num_heads, query_size),
value_shape=(state_length, num_heads, query_size),
state_length=state_length,
num_heads=num_heads,
key_size=query_size,
value_size=query_size
)
mha, state = eqx.nn.make_with_state(MultiheadAttention)(
Expand Down
39 changes: 35 additions & 4 deletions equinox/nn/_kv_cache.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from collections.abc import Callable

import jax
import jax.lax as lax
import jax.numpy as jnp
from jaxtyping import Array, Float, Int

from .._misc import default_int_dtype
from .._misc import default_floating_dtype, default_int_dtype
from .._module import field, Module
from ._stateful import State, StateIndex

Expand All @@ -25,6 +26,11 @@


class StandardKVCache(Module):
"""
A class to manage the key and value caches for a transformer
model with autoregressive decoding.
"""

key_shape: tuple[int, int, int] = field(static=True)
value_shape: tuple[int, int, int] = field(static=True)

Expand All @@ -36,18 +42,30 @@ def __init__(
num_heads: int,
key_size: int,
value_size: int,
dtype=None,
):
r"""**Arguments:**
- `state_length`: Refers to the maximum sequence length
- `num_heads`: Number of parallel attention heads $h$.
- `key_size`: Number of input channels for key $K$.
- `value_size`: Number of input channels for value $V$.
- `dtype` (optional): The data type of the KV caches.
"""
dtype = default_floating_dtype() if dtype is None else dtype
self.key_shape = state_length, num_heads, key_size
self.value_shape = state_length, num_heads, value_size

self.autoregressive_index = StateIndex(
(
lambda _: jnp.empty(self.key_shape),
jnp.empty(self.value_shape),
lambda: jnp.empty(self.key_shape, dtype=dtype),
jnp.empty(self.value_shape, dtype=dtype),
jnp.zeros((), default_int_dtype()),
),
)

@jax.named_scope("eqx.nn.StandardKVCache")
def __call__(
self,
key_heads: Float[Array, "seq_length num_heads qk_size"],
Expand All @@ -59,13 +77,26 @@ def __call__(
Int[Array, ""],
State,
]:
"""**Arguments:**
- `key_heads`: The new key heads to be added to the cache
- `value_heads`: The new value heads to be added to the cache
- `state`: The current state containing the index for autoregressive decoding
**Returns:**
A tuple (key_state, value_state, index, state) containing the updated keys
and values as well as the index and the new state.
The shape of `key_state` is `(state_length num_heads qk_size)`
and the shape of `value_state` is `(state_length num_heads vo_size)`.
"""
kv_seq_length, _, _ = key_heads.shape
key_state, value_state, index = state.get(self.autoregressive_index)
key_state = lax.dynamic_update_slice_in_dim(key_state, key_heads, index, axis=0)
value_state = lax.dynamic_update_slice_in_dim(
value_state, value_heads, index, axis=0
)
index = index + kv_seq_length
state = state.set(
self.autoregressive_index, (key_state, value_state, index + kv_seq_length)
)
Expand Down

0 comments on commit bf41205

Please sign in to comment.