Skip to content

Commit

Permalink
better mem, adhering to strict jax config
Browse files Browse the repository at this point in the history
  • Loading branch information
Artur-Galstyan committed Feb 19, 2024
1 parent 0bd6026 commit 8832c30
Showing 1 changed file with 54 additions and 46 deletions.
100 changes: 54 additions & 46 deletions equinox/nn/_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from .._module import field, Module


internal_rope_embedding_cache = {}
internal_sinusoidal_positional_encoding_cache = {}
internal_rope_embedding_cache: dict[int, Array] = {}
internal_sinusoidal_positional_encoding_cache: dict[tuple[int, float], Array] = {}
cache_clears.append(internal_rope_embedding_cache.clear)
cache_clears.append(internal_sinusoidal_positional_encoding_cache.clear)

Expand Down Expand Up @@ -176,7 +176,6 @@ def __init__(
**kwargs,
):
"""**Arguments:**
`RotaryPositionalEmbedding` requires:
- `embedding_size`: Size of the token embeddings. Must be non-negative.
- `key`: Not used; provided for compatibility with the rest of the Equinox API.
Expand All @@ -198,9 +197,14 @@ def precompute_freqs_cis(
freqs = 1.0 / (
theta ** (jnp.arange(0, embedding_size, 2)[jnp.newaxis, :] / embedding_size)
)
t = jnp.arange(end)

t = jnp.arange(end, dtype=float)
freqs_outer = jnp.outer(t, freqs)
freqs_cis = jnp.cos(freqs_outer) + jnp.sin(freqs_outer) * 1j
freqs_cis = (
jnp.array(jnp.cos(freqs_outer), dtype=jnp.complex64)
+ jnp.array(jnp.sin(freqs_outer), dtype=jnp.complex64) * 1j
)

return freqs_cis

@jax.named_scope("eqx.nn.RotaryPositionalEmbedding")
Expand Down Expand Up @@ -236,18 +240,24 @@ def __call__(
neg_half_x = self.negate_half(x)

with jax.ensure_compile_time_eval():
if (embedding_size, seq_len) in internal_rope_embedding_cache:
freqs_cis = internal_rope_embedding_cache[(embedding_size, seq_len)]
if embedding_size in internal_rope_embedding_cache:
freqs_cis = internal_rope_embedding_cache[embedding_size]
freqs_cis_seq_len, _ = freqs_cis.shape
if seq_len > freqs_cis_seq_len:
freqs_cis = self.precompute_freqs_cis(embedding_size, seq_len)
internal_rope_embedding_cache[embedding_size] = freqs_cis
else:
freqs_cis = freqs_cis[:seq_len]
else:
freqs_cis = self.precompute_freqs_cis(embedding_size, seq_len)
internal_rope_embedding_cache[(embedding_size, seq_len)] = freqs_cis
internal_rope_embedding_cache[embedding_size] = freqs_cis

assert freqs_cis is not None, "freqs_cis must not be None."
freqs_real = jnp.tile(freqs_cis.real, (1, 2))
freqs_imag = jnp.tile(freqs_cis.imag, (1, 2))

x_rope = (x * freqs_real) + (neg_half_x * freqs_imag)
return jax.lax.stop_gradient(x_rope)
return x_rope


class SinusoidalPositionalEmbedding(Module):
Expand All @@ -259,42 +269,31 @@ class SinusoidalPositionalEmbedding(Module):
!!! example
The following example demonstrates how to use `SinusoidalPositionalEmbedding` in
a simple transformer model.
a simple transformer model. Note that you should apply the positional encoding
on the input directly - before applying any projections (as opposed to RoPE
embeddings, which should be applied after the projections).
```python
class TransformerBlock(eqx.Module):
class Transformer(eqx.Module):
...
key_embeddings: SinusoidalPositionalEmbedding
query_embeddings: SinusoidalPositionalEmbedding
sinusoidal_embeddings: SinusoidalPositionalEmbedding
def __init__(...):
...
self.query_embeddings = SinusoidalPositionalEmbedding(
embedding_size=n_embd
)
self.key_embeddings = SinusoidalPositionalEmbedding(
self.sinusoidal_embeddings = SinusoidalPositionalEmbedding(
embedding_size=n_embd
)
...
def __call__(...):
def process_heads(query_heads, key_heads, value_heads):
query_heads = jax.vmap(self.query_embeddings,
in_axes=1,
out_axes=1)(query_heads)
key_heads = jax.vmap(self.key_embeddings,
in_axes=1,
out_axes=1)(key_heads)
def __call__(x: Float[Array, "seq_len n_input_dims"]):
x = self.embedding(x) # Apply the input embedding first
# this maps x to shape (seq_len, n_embd)
x = self.sinusoidal_embeddings(x) # Apply the positional encoding
return query_heads, key_heads, value_heads
# x = self.mha_attention(...) # Apply the attention mechanism
mha_output = self.mha_attention(
process_heads=process_heads,
query=jax.vmap(self.rms_norm)(x),
key_=jax.vmap(self.rms_norm)(x),
value=jax.vmap(self.rms_norm)(x),
mask=mask,
)
```
??? cite
Expand Down Expand Up @@ -327,7 +326,6 @@ def __init__(
**kwargs,
):
"""**Arguments:**
`SinusoidalPositionalEmbedding` requires:
- `embedding_size`: Size of the token embeddings. Must be non-negative.
- `theta`: The frequency of the sinusoidal positional encoding.
Expand All @@ -353,16 +351,17 @@ def __init__(
def get_positional_encoding(
embedding_size: int, seq_len: int, theta: float = 10000.0
) -> Float[Array, "seq_len embedding_size"]:
pos = jnp.arange(seq_len)[:, jnp.newaxis]
pos = jnp.arange(seq_len, dtype=float)[:, jnp.newaxis]

div_term = jnp.exp(
jnp.arange(0, embedding_size, 2) * -(jnp.log(theta) / embedding_size)
jnp.arange(0, embedding_size, 2, dtype=float)
* -(jnp.log(theta) / embedding_size)
)
# the following expression is closer to the actual notation they used.
# div_term = 1 / 10000 ** (jnp.arange(0, embedding_size, 2) / embedding_size)
pos_enc = jnp.zeros((seq_len, embedding_size))
pos_enc = pos_enc.at[:, 0::2].set(jnp.sin(pos * div_term))
pos_enc = pos_enc.at[:, 1::2].set(jnp.cos(pos * div_term))
return pos_enc

mult = pos * div_term.reshape(1, -1)
sines = jnp.sin(mult)
cosines = jnp.cos(mult)
return jnp.stack([sines, cosines], axis=2).reshape((seq_len, embedding_size))

@jax.named_scope("eqx.nn.SinusoidalPositionalEmbedding")
def __call__(
Expand Down Expand Up @@ -395,18 +394,27 @@ def __call__(
with jax.ensure_compile_time_eval():
if (
embedding_size,
seq_len,
self.theta,
) in internal_sinusoidal_positional_encoding_cache:
freqs_cis = internal_sinusoidal_positional_encoding_cache[
(embedding_size, seq_len, self.theta)
(embedding_size, self.theta)
]
freqs_cis_seq_len, _ = freqs_cis.shape
if seq_len > freqs_cis_seq_len:
freqs_cis = self.get_positional_encoding(
embedding_size, seq_len, self.theta
)
internal_sinusoidal_positional_encoding_cache[
(embedding_size, self.theta)
] = freqs_cis
else:
freqs_cis = freqs_cis[:seq_len]
else:
freqs_cis = self.get_positional_encoding(
embedding_size, seq_len, self.theta
)
internal_rope_embedding_cache[
(embedding_size, seq_len, self.theta)
internal_sinusoidal_positional_encoding_cache[
(embedding_size, self.theta)
] = freqs_cis

assert freqs_cis is not None, "freqs_cis must not be None."
Expand Down

0 comments on commit 8832c30

Please sign in to comment.