Skip to content

Commit

Permalink
added rope test + fixed kv cache + mini refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Artur-Galstyan committed Jul 13, 2024
1 parent 0b7bbf0 commit 699b93e
Show file tree
Hide file tree
Showing 4 changed files with 409 additions and 82 deletions.
153 changes: 89 additions & 64 deletions equinox/nn/_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@
import jax
import jax.numpy as jnp
import jax.random as jrandom
from jaxtyping import Array, Bool, Float, PRNGKeyArray
from jaxtyping import Array, Bool, Float, Int, PRNGKeyArray

from .._misc import default_floating_dtype
from .._misc import default_floating_dtype, default_int_dtype
from .._module import field, Module
from ._dropout import Dropout
from ._kv_cache import KVCacheCallable
from ._linear import Linear
from ._stateful import State
from ._stateful import State, StateIndex


def dot_product_attention_weights(
Expand Down Expand Up @@ -127,6 +127,7 @@ class MultiheadAttention(Module, strict=True):
dropout: Dropout

kv_cache: Optional[KVCacheCallable]
index: Optional[StateIndex]

num_heads: int = field(static=True)
query_size: int = field(static=True)
Expand Down Expand Up @@ -277,6 +278,12 @@ def __init__(
self.use_value_bias = use_value_bias
self.use_output_bias = use_output_bias
self.kv_cache = kv_cache
if self.kv_cache is not None:
self.index = StateIndex(
jnp.zeros((), default_int_dtype()),
)
else:
self.index = None

@jax.named_scope("eqx.nn.MultiheadAttention")
def __call__(
Expand All @@ -301,6 +308,7 @@ def __call__(
Float[Array, "seq_length num_heads qk_size"],
Float[Array, "seq_length num_heads qk_size"],
Float[Array, "seq_length num_heads vo_size"],
Int[Array, ""],
],
tuple[
Float[Array, "seq_length num_heads qk_size"],
Expand Down Expand Up @@ -332,84 +340,44 @@ def __call__(
- `inference`: As [`equinox.nn.Dropout.__call__`][]. (Keyword only
argument.)
- `deterministic`: (Deprecated in favour of `inference`.)
- `process_heads`: A function that takes in the query, key, and value heads and
returns new query, key, and value heads. For example, this can be
used to implement relative positional embeddings -
see e.g. `RotaryPositionalEmbedding`for an example. (Keyword only argument.)
- `process_heads`: A function that takes in the query, key, value heads as well
as the current autoregressive index and returns new query, key, and value
heads. For example, this can be used to implement relative positional
embeddings - see e.g. `RotaryPositionalEmbedding`for an example.
(Keyword only argument.)
**Returns:**
A JAX array of shape `(query_seq_length, output_size)`.
"""

if deterministic is not None:
inference = deterministic
warnings.warn(
"MultiheadAttention()(deterministic=...) is deprecated "
"in favour of MultiheadAttention()(inference=...)"
)

query_seq_length, _ = query.shape
kv_seq_length, _ = key_.shape
kv_seq_length2, _ = value.shape

if kv_seq_length != kv_seq_length2:
raise ValueError("key and value must both be sequences of equal length.")

query_heads = self._project(self.query_proj, query)
key_heads = self._project(self.key_proj, key_)
value_heads = self._project(self.value_proj, value)
query_seq_length, kv_seq_length = self._get_query_and_kv_seq_lengths(
query, key_, value
)

# TODO: Apply RoPE somehow, somewhere here
query_heads, key_heads, value_heads = self._project_heads(query, key_, value)
index = self._get_start_index(state)

if process_heads is not None:
q_shape, k_shape, v_shape = (
query_heads.shape,
key_heads.shape,
value_heads.shape,
query_heads, key_heads, value_heads = _process_heads(
process_heads, query_heads, key_heads, value_heads, index
)
query_heads, key_heads, value_heads = process_heads(
query_heads,
key_heads,
value_heads,
)

if (
query_heads.shape != q_shape
or key_heads.shape != k_shape
or value_heads.shape != v_shape
):
raise ValueError(
"process_heads must not change the shape of the heads."
)

if state is None:
state_length = None
index = None
causal_mask_offset = 0
else:
if self.kv_cache is None:
raise ValueError(
"State was provided, but cannot use autoregressive decoding "
"without specifying "
"`MultiheadAttention(..., kv_cache=...)`. "
"See `equinox.nn.StandardKVCache` for an example."
)

key_state, value_state, index, state = self.kv_cache(
key_heads, value_heads, state
if state:
key_heads, value_heads, kv_seq_length, state = self._handle_kv_cache(
key_heads, value_heads, index, query_seq_length, state
)
_check_kv_shapes(key_state, value_state, key_heads, value_heads)
state_length, _, _ = key_state.shape

causal_mask_offset = index
key_heads = key_state
value_heads = value_state
kv_seq_length = state_length
mask = _generate_mask(mask, query_seq_length, kv_seq_length, index)
if self.kv_cache is not None:
mask = _mask_unwritten_parts(kv_seq_length, query_seq_length, mask, index)

mask = _generate_mask(mask, query_seq_length, kv_seq_length, causal_mask_offset)
if state_length is not None:
mask = _mask_unwritten_parts(state_length, query_seq_length, mask, index)
attn_fn = partial(
dot_product_attention, dropout=self.dropout, inference=inference
)
Expand All @@ -432,6 +400,49 @@ def __call__(
else:
return out, state

def _get_query_and_kv_seq_lengths(self, query, key_, value):
query_seq_length, _ = query.shape
kv_seq_length, _ = key_.shape
kv_seq_length2, _ = value.shape

if kv_seq_length != kv_seq_length2:
raise ValueError("key and value must both be sequences of equal length.")

return query_seq_length, kv_seq_length

def _handle_kv_cache(self, key_heads, value_heads, index, query_seq_length, state):
if self.kv_cache is None:
raise ValueError(
"State was provided, but cannot use autoregressive decoding "
"without specifying "
"`MultiheadAttention(..., kv_cache=...)`. "
"See `equinox.nn.StandardKVCache` for an example."
)

key_state, value_state, state = self.kv_cache(
key_heads, value_heads, index, state
)
_check_kv_shapes(key_state, value_state, key_heads, value_heads)
kv_seq_length, _, _ = key_state.shape

assert self.index is not None
state = state.set(self.index, index + query_seq_length)
return key_state, value_state, kv_seq_length, state

def _get_start_index(self, state):
if state is not None and self.index is not None:
index = state.get(self.index)
else:
index = jnp.array(0, dtype=default_int_dtype())

return index

def _project_heads(self, query, key_, value):
query_heads = self._project(self.query_proj, query)
key_heads = self._project(self.key_proj, key_)
value_heads = self._project(self.value_proj, value)
return query_heads, key_heads, value_heads

def _project(self, proj, x):
seq_length, _ = x.shape
projection = jax.vmap(proj)(x)
Expand Down Expand Up @@ -489,7 +500,7 @@ def _generate_mask(


def _mask_unwritten_parts(
state_length: int,
kv_seq_length: int,
query_seq_length: int,
mask: Union[
Optional[Bool[Array, "q_seq kv_seq"]],
Expand All @@ -498,9 +509,23 @@ def _mask_unwritten_parts(
index: Optional[Array],
):
# Also mask out the latter parts of the state we haven't written into yet.
unwritten_mask = jnp.arange(state_length) < index # pyright: ignore
unwritten_mask = jnp.arange(kv_seq_length) < index # pyright: ignore
if mask is None:
mask = jnp.broadcast_to(unwritten_mask, (query_seq_length, state_length))
mask = jnp.broadcast_to(unwritten_mask, (query_seq_length, kv_seq_length))
else:
mask = mask & unwritten_mask.reshape(*mask.shape)
return mask


def _process_heads(process_heads, query_heads, key_heads, value_heads, index):
q_shape, k_shape, v_shape = (
query_heads.shape,
key_heads.shape,
value_heads.shape,
)
qs, ks, vs = process_heads(query_heads, key_heads, value_heads, index)

if qs.shape != q_shape or ks.shape != k_shape or vs.shape != v_shape:
raise ValueError("process_heads must not change the shape of the heads.")

return qs, ks, vs
19 changes: 11 additions & 8 deletions equinox/nn/_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,19 +127,21 @@ def process_heads(
query_heads: Float[Array, "seq_length num_heads qk_size"],
key_heads: Float[Array, "seq_length num_heads qk_size"],
value_heads: Float[Array, "seq_length num_heads vo_size"],
**kwargs
index: Int[Array, ""]
) -> tuple[
Float[Array, "seq_length num_heads qk_size"],
Float[Array, "seq_length num_heads qk_size"],
Float[Array, "seq_length num_heads vo_size"]
]:
offset = kwargs.get("offset", 0)
query_heads = jax.vmap(self.rope_embeddings,
in_axes=(1, None)
out_axes=1)(query_heads, offset)
key_heads = jax.vmap(self.rope_embeddings,
in_axes=(1, None),
out_axes=1)(key_heads, offset)
# index is the autoregressive index of the current token
rope_partial = functools.partial(
rope_embeddings,
offset=index
)
query_heads = jax.vmap(rope_partial, in_axes=1, out_axes=1)
(query_heads)
key_heads = jax.vmap(rope_partial, in_axes=1, out_axes=1)
(key_heads)
return query_heads, key_heads, value_heads
Expand Down Expand Up @@ -215,6 +217,7 @@ def __call__(
A JAX array of shape `(seq_length, embedding_size)`, with the rotary positional
encoding applied to the input.
"""

seq_len, embedding_size = x.shape
if embedding_size != self.embedding_size:
raise ValueError(
Expand Down
15 changes: 6 additions & 9 deletions equinox/nn/_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import jax.numpy as jnp
from jaxtyping import Array, Float, Int

from .._misc import default_floating_dtype, default_int_dtype
from .._misc import default_floating_dtype
from .._module import field, Module
from ._stateful import State, StateIndex

Expand All @@ -14,12 +14,12 @@
[
Float[Array, "seq_length num_heads qk_size"],
Float[Array, "seq_length num_heads vo_size"],
Int[Array, ""],
State,
],
tuple[
Float[Array, "state_length num_heads qk_size"],
Float[Array, "state_length num_heads vo_size"],
Int[Array, ""],
State,
],
]
Expand Down Expand Up @@ -61,7 +61,6 @@ def __init__(
(
jnp.empty(self.key_shape, dtype=dtype),
jnp.empty(self.value_shape, dtype=dtype),
jnp.zeros((), default_int_dtype()),
),
)

Expand All @@ -70,11 +69,11 @@ def __call__(
self,
key_heads: Float[Array, "seq_length num_heads qk_size"],
value_heads: Float[Array, "seq_length num_heads vo_size"],
index: Int[Array, ""],
state: State,
) -> tuple[
Float[Array, "state_length num_heads qk_size"],
Float[Array, "state_length num_heads vo_size"],
Int[Array, ""],
State,
]:
"""**Arguments:**
Expand All @@ -92,13 +91,11 @@ def __call__(
"""
kv_seq_length, _, _ = key_heads.shape
key_state, value_state, index = state.get(self.autoregressive_index)
key_state, value_state = state.get(self.autoregressive_index)
key_state = lax.dynamic_update_slice_in_dim(key_state, key_heads, index, axis=0)
value_state = lax.dynamic_update_slice_in_dim(
value_state, value_heads, index, axis=0
)
state = state.set(
self.autoregressive_index, (key_state, value_state, index + kv_seq_length)
)
state = state.set(self.autoregressive_index, (key_state, value_state))

return key_state, value_state, index, state
return key_state, value_state, state
Loading

0 comments on commit 699b93e

Please sign in to comment.