Skip to content

Commit

Permalink
fixed dtype promotion
Browse files Browse the repository at this point in the history
  • Loading branch information
Artur-Galstyan committed Feb 20, 2024
1 parent 7877a94 commit a0022ac
Showing 1 changed file with 4 additions and 7 deletions.
11 changes: 4 additions & 7 deletions equinox/nn/_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,12 +198,10 @@ def precompute_freqs_cis(
theta ** (jnp.arange(0, embedding_size, 2)[jnp.newaxis, :] / embedding_size)
)

t = jnp.arange(end, dtype=float)
t = jnp.arange(float(end))
freqs_outer = jnp.outer(t, freqs)
freqs_cis = (
jnp.array(jnp.cos(freqs_outer), dtype=jnp.complex64)
+ jnp.array(jnp.sin(freqs_outer), dtype=jnp.complex64) * 1j
)
with jax.numpy_dtype_promotion("standard"):
freqs_cis = jnp.cos(freqs_outer) + jnp.sin(freqs_outer) * 1j

return freqs_cis

Expand Down Expand Up @@ -354,8 +352,7 @@ def get_positional_encoding(
pos = jnp.arange(seq_len, dtype=float)[:, jnp.newaxis]

div_term = jnp.exp(
jnp.arange(0, embedding_size, 2, dtype=float)
* -(jnp.log(theta) / embedding_size)
jnp.arange(0.0, embedding_size, 2) * -(jnp.log(theta) / embedding_size)
)

mult = pos * div_term.reshape(1, -1)
Expand Down

0 comments on commit a0022ac

Please sign in to comment.