diff --git a/equinox/nn/_embedding.py b/equinox/nn/_embedding.py index 1fdd4135..a579199f 100644 --- a/equinox/nn/_embedding.py +++ b/equinox/nn/_embedding.py @@ -10,8 +10,8 @@ from .._module import field, Module -internal_rope_embedding_cache = {} -internal_sinusoidal_positional_encoding_cache = {} +internal_rope_embedding_cache: dict[int, Array] = {} +internal_sinusoidal_positional_encoding_cache: dict[tuple[int, float], Array] = {} cache_clears.append(internal_rope_embedding_cache.clear) cache_clears.append(internal_sinusoidal_positional_encoding_cache.clear) @@ -176,7 +176,6 @@ def __init__( **kwargs, ): """**Arguments:** - `RotaryPositionalEmbedding` requires: - `embedding_size`: Size of the token embeddings. Must be non-negative. - `key`: Not used; provided for compatibility with the rest of the Equinox API. @@ -198,9 +197,14 @@ def precompute_freqs_cis( freqs = 1.0 / ( theta ** (jnp.arange(0, embedding_size, 2)[jnp.newaxis, :] / embedding_size) ) - t = jnp.arange(end) + + t = jnp.arange(end, dtype=float) freqs_outer = jnp.outer(t, freqs) - freqs_cis = jnp.cos(freqs_outer) + jnp.sin(freqs_outer) * 1j + freqs_cis = ( + jnp.array(jnp.cos(freqs_outer), dtype=jnp.complex64) + + jnp.array(jnp.sin(freqs_outer), dtype=jnp.complex64) * 1j + ) + return freqs_cis @jax.named_scope("eqx.nn.RotaryPositionalEmbedding") @@ -236,18 +240,24 @@ def __call__( neg_half_x = self.negate_half(x) with jax.ensure_compile_time_eval(): - if (embedding_size, seq_len) in internal_rope_embedding_cache: - freqs_cis = internal_rope_embedding_cache[(embedding_size, seq_len)] + if embedding_size in internal_rope_embedding_cache: + freqs_cis = internal_rope_embedding_cache[embedding_size] + freqs_cis_seq_len, _ = freqs_cis.shape + if seq_len > freqs_cis_seq_len: + freqs_cis = self.precompute_freqs_cis(embedding_size, seq_len) + internal_rope_embedding_cache[embedding_size] = freqs_cis + else: + freqs_cis = freqs_cis[:seq_len] else: freqs_cis = self.precompute_freqs_cis(embedding_size, seq_len) - internal_rope_embedding_cache[(embedding_size, seq_len)] = freqs_cis + internal_rope_embedding_cache[embedding_size] = freqs_cis assert freqs_cis is not None, "freqs_cis must not be None." freqs_real = jnp.tile(freqs_cis.real, (1, 2)) freqs_imag = jnp.tile(freqs_cis.imag, (1, 2)) x_rope = (x * freqs_real) + (neg_half_x * freqs_imag) - return jax.lax.stop_gradient(x_rope) + return x_rope class SinusoidalPositionalEmbedding(Module): @@ -259,42 +269,31 @@ class SinusoidalPositionalEmbedding(Module): !!! example The following example demonstrates how to use `SinusoidalPositionalEmbedding` in - a simple transformer model. + a simple transformer model. Note that you should apply the positional encoding + on the input directly - before applying any projections (as opposed to RoPE + embeddings, which should be applied after the projections). + + ```python - class TransformerBlock(eqx.Module): + class Transformer(eqx.Module): ... - key_embeddings: SinusoidalPositionalEmbedding - query_embeddings: SinusoidalPositionalEmbedding + sinusoidal_embeddings: SinusoidalPositionalEmbedding def __init__(...): ... - self.query_embeddings = SinusoidalPositionalEmbedding( - embedding_size=n_embd - ) - self.key_embeddings = SinusoidalPositionalEmbedding( + self.sinusoidal_embeddings = SinusoidalPositionalEmbedding( embedding_size=n_embd ) ... - def __call__(...): - def process_heads(query_heads, key_heads, value_heads): - query_heads = jax.vmap(self.query_embeddings, - in_axes=1, - out_axes=1)(query_heads) - key_heads = jax.vmap(self.key_embeddings, - in_axes=1, - out_axes=1)(key_heads) + def __call__(x: Float[Array, "seq_len n_input_dims"]): + x = self.embedding(x) # Apply the input embedding first + # this maps x to shape (seq_len, n_embd) + x = self.sinusoidal_embeddings(x) # Apply the positional encoding - return query_heads, key_heads, value_heads + # x = self.mha_attention(...) # Apply the attention mechanism - mha_output = self.mha_attention( - process_heads=process_heads, - query=jax.vmap(self.rms_norm)(x), - key_=jax.vmap(self.rms_norm)(x), - value=jax.vmap(self.rms_norm)(x), - mask=mask, - ) ``` ??? cite @@ -327,7 +326,6 @@ def __init__( **kwargs, ): """**Arguments:** - `SinusoidalPositionalEmbedding` requires: - `embedding_size`: Size of the token embeddings. Must be non-negative. - `theta`: The frequency of the sinusoidal positional encoding. @@ -353,16 +351,17 @@ def __init__( def get_positional_encoding( embedding_size: int, seq_len: int, theta: float = 10000.0 ) -> Float[Array, "seq_len embedding_size"]: - pos = jnp.arange(seq_len)[:, jnp.newaxis] + pos = jnp.arange(seq_len, dtype=float)[:, jnp.newaxis] + div_term = jnp.exp( - jnp.arange(0, embedding_size, 2) * -(jnp.log(theta) / embedding_size) + jnp.arange(0, embedding_size, 2, dtype=float) + * -(jnp.log(theta) / embedding_size) ) - # the following expression is closer to the actual notation they used. - # div_term = 1 / 10000 ** (jnp.arange(0, embedding_size, 2) / embedding_size) - pos_enc = jnp.zeros((seq_len, embedding_size)) - pos_enc = pos_enc.at[:, 0::2].set(jnp.sin(pos * div_term)) - pos_enc = pos_enc.at[:, 1::2].set(jnp.cos(pos * div_term)) - return pos_enc + + mult = pos * div_term.reshape(1, -1) + sines = jnp.sin(mult) + cosines = jnp.cos(mult) + return jnp.stack([sines, cosines], axis=2).reshape((seq_len, embedding_size)) @jax.named_scope("eqx.nn.SinusoidalPositionalEmbedding") def __call__( @@ -395,18 +394,27 @@ def __call__( with jax.ensure_compile_time_eval(): if ( embedding_size, - seq_len, self.theta, ) in internal_sinusoidal_positional_encoding_cache: freqs_cis = internal_sinusoidal_positional_encoding_cache[ - (embedding_size, seq_len, self.theta) + (embedding_size, self.theta) ] + freqs_cis_seq_len, _ = freqs_cis.shape + if seq_len > freqs_cis_seq_len: + freqs_cis = self.get_positional_encoding( + embedding_size, seq_len, self.theta + ) + internal_sinusoidal_positional_encoding_cache[ + (embedding_size, self.theta) + ] = freqs_cis + else: + freqs_cis = freqs_cis[:seq_len] else: freqs_cis = self.get_positional_encoding( embedding_size, seq_len, self.theta ) - internal_rope_embedding_cache[ - (embedding_size, seq_len, self.theta) + internal_sinusoidal_positional_encoding_cache[ + (embedding_size, self.theta) ] = freqs_cis assert freqs_cis is not None, "freqs_cis must not be None."