Skip to content

Commit

Permalink
added pyright ignore to test - aren't we using that anymore?
Browse files Browse the repository at this point in the history
  • Loading branch information
Artur-Galstyan committed Aug 12, 2024
1 parent 68cc26a commit 7660d67
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 17 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,6 @@ examples/MNIST
examples/multipart_serialised.eqx
.python-version
.DS_Store
.ruff_cache
.pytest_cache
.venv
46 changes: 33 additions & 13 deletions equinox/nn/_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,18 +126,22 @@ def __call__(...):
def process_heads(
query_heads: Float[Array, "seq_length num_heads qk_size"],
key_heads: Float[Array, "seq_length num_heads qk_size"],
value_heads: Float[Array, "seq_length num_heads vo_size"]
value_heads: Float[Array, "seq_length num_heads vo_size"],
index: Int[Array, ""]
) -> tuple[
Float[Array, "seq_length num_heads qk_size"],
Float[Array, "seq_length num_heads qk_size"],
Float[Array, "seq_length num_heads vo_size"]
]:
query_heads = jax.vmap(self.rope_embeddings,
in_axes=1,
out_axes=1)(query_heads)
key_heads = jax.vmap(self.rope_embeddings,
in_axes=1,
out_axes=1)(key_heads)
# index is the autoregressive index of the current token
rope_partial = functools.partial(
rope_embeddings,
offset=index
)
query_heads = jax.vmap(rope_partial, in_axes=1, out_axes=1)
(query_heads)
key_heads = jax.vmap(rope_partial, in_axes=1, out_axes=1)
(key_heads)
return query_heads, key_heads, value_heads
Expand All @@ -161,13 +165,16 @@ def process_heads(
"""

embedding_size: int = field(static=True)
max_seq_length: int = field(static=True)
theta: float = field(static=True, default=10_000.0)

def __check_init__(self):
if self.embedding_size < 0:
raise ValueError("`embedding_size` must not be negative.")
if (self.embedding_size % 2) != 0:
raise ValueError("`embedding_size` must be even.")
if self.max_seq_length < 0:
raise ValueError("`max_seq_length` must not be negative.")

@staticmethod
def rotate_half(x: Float[Array, "seq_length embedding_size"]):
Expand All @@ -194,12 +201,14 @@ def precompute_freqs_cis(
def __call__(
self,
x: Float[Array, "seq_length embedding_size"],
offset: Int[Array, ""] = jnp.array(0),
*,
key: Optional[PRNGKeyArray] = None,
) -> Float[Array, "seq_length embedding_size"]:
"""**Arguments:**
- `x`: A JAX array of shape `(seq_length, embedding_size)`.
- `offset`: The offset to apply to the positional encoding.
- `key`: Ignored; provided for compatibility with the rest of the Equinox API.
(Keyword only argument.)
Expand All @@ -215,37 +224,48 @@ def __call__(
f"x.shape[-1] must match self.embedding_size, "
f"but {x.shape[-1]} != {self.embedding_size}"
)
if seq_len > self.max_seq_length:
raise ValueError(
f"seq_len must be less than or equal to self.max_seq_length, "
f"but {seq_len} > {self.max_seq_length}"
)

with jax.ensure_compile_time_eval():
if embedding_size in internal_rope_embedding_cache:
freqs_cis = internal_rope_embedding_cache[embedding_size]
freqs_cis_seq_len, _ = freqs_cis.shape
if seq_len > freqs_cis_seq_len:
if self.max_seq_length > freqs_cis_seq_len:
freqs_cis = self.precompute_freqs_cis(
embedding_size, seq_len, self.theta
embedding_size, self.max_seq_length, self.theta
)
internal_rope_embedding_cache[embedding_size] = freqs_cis
else:
freqs_cis = freqs_cis[:seq_len]
freqs_cis = freqs_cis[: self.max_seq_length]
else:
freqs_cis = self.precompute_freqs_cis(
embedding_size, seq_len, self.theta
embedding_size, self.max_seq_length, self.theta
)
internal_rope_embedding_cache[embedding_size] = freqs_cis

freqs_cis = jax.lax.dynamic_slice_in_dim(freqs_cis, offset, seq_len)

freqs_real = jnp.tile(freqs_cis.real, (1, 2))
freqs_imag = jnp.tile(freqs_cis.imag, (1, 2))

rotate_x = self.rotate_half(x)

x_rope = (x * freqs_real) + (rotate_x * freqs_imag)
return x_rope


RotaryPositionalEmbedding.__init__.__doc__ = """**Arguments:**
- `embedding_size`: Size of the token embeddings. Must be non-negative and even.
- `theta`: The base frequency for the sinusoidal functions. It defines the rate
of oscillation for the sine and cosine waves that encode positional information
- `theta`: The base frequency for the sinusoidal functions. It defines the rate
of oscillation for the sine and cosine waves that encode positional information
into the embeddings. The larger the theta value, the slower the oscillations
and vice versa. Defaults to 10_000.0
- `max_seq_length`: The maximum sequence length for which to precompute the
positional encodings. This is used to determine the size of the precomputed
positional encodings.
"""
16 changes: 12 additions & 4 deletions tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,8 @@ def test_mlp_learnt_activation():
key=jrandom.PRNGKey(5678),
)
x = jnp.array([0.5, 0.7])
assert mlp.activation.negative_slope.shape == (2, 8)
assert mlp.final_activation.negative_slope.shape == (5,)
assert mlp.activation.negative_slope.shape == (2, 8) # pyright: ignore
assert mlp.final_activation.negative_slope.shape == (5,) # pyright: ignore

@eqx.filter_jit
@eqx.filter_grad
Expand Down Expand Up @@ -1352,13 +1352,16 @@ def test_prelu(getkey):

def test_rope_embeddings_shapes(getkey):
embedding_size = 32
rope_embeddings = eqx.nn.RotaryPositionalEmbedding(embedding_size)

n_heads = 4
seq_length = 8
query_size = 32
key_size = 32

rope_embeddings = eqx.nn.RotaryPositionalEmbedding(
embedding_size, max_seq_length=seq_length
)

query_heads = jax.random.normal(
key=getkey(), shape=(seq_length, n_heads, query_size)
)
Expand Down Expand Up @@ -1435,7 +1438,12 @@ def test_rope_embeddings_values():
seq_length, embedding_size
)

rope_embeddings = eqx.nn.RotaryPositionalEmbedding(embedding_size)
rope_embeddings = eqx.nn.RotaryPositionalEmbedding(
embedding_size, max_seq_length=seq_length
)
res = rope_embeddings(x)

assert jnp.allclose(res, expected_values, atol=1e-6)
res = rope_embeddings(x)

assert jnp.allclose(res, expected_values, atol=1e-6)

0 comments on commit 7660d67

Please sign in to comment.