Skip to content

Commit

Permalink
Simplify an obsolete edge case check
Browse files Browse the repository at this point in the history
  • Loading branch information
knyazer committed Mar 13, 2024
1 parent 3c25190 commit 80761de
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 8 deletions.
2 changes: 0 additions & 2 deletions equinox/nn/_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,6 @@ def __call__(

if inference is None:
inference = self.inference
if isinstance(self.p, (int, float)) and self.p == 0:
inference = True
if inference:
return x
elif key is None:
Expand Down
12 changes: 8 additions & 4 deletions tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,11 +763,12 @@ def test_multihead_attention(getkey):
qk_size=13,
vo_size=17,
key=getkey(),
inference=True,
)
q = jrandom.uniform(getkey(), (19, 3))
k = jrandom.uniform(getkey(), (23, 5))
v = jrandom.uniform(getkey(), (23, 7))
assert attn(q, k, v).shape == (19, 11)
assert attn(q, k, v, key=jrandom.PRNGKey(1)).shape == (19, 11)

attn = eqx.nn.MultiheadAttention(num_heads=2, query_size=4, key=getkey())
attn = eqx.tree_at(
Expand All @@ -781,18 +782,21 @@ def test_multihead_attention(getkey):
[jnp.arange(16.0).reshape(4, 4) for _ in range(4)],
)
x = jnp.array([[1.0, 2.0, 3.0, 4.0]])
assert jnp.allclose(attn(x, x, x), jnp.array([[680.0, 1960.0, 3240.0, 4520.0]]))
assert jnp.allclose(
attn(x, x, x, key=jrandom.PRNGKey(2)),
jnp.array([[680.0, 1960.0, 3240.0, 4520.0]]),
)

x = jnp.arange(1, 13, dtype=jnp.float32).reshape(3, 4)
mask = jnp.broadcast_to(jnp.array([True, False, False]), (2, 3, 3))
assert jnp.allclose(
attn(x, x, x, mask),
attn(x, x, x, mask, key=jrandom.PRNGKey(3)),
jnp.broadcast_to(jnp.array([[680.0, 1960.0, 3240.0, 4520.0]]), (3, 4)),
)

mask = jnp.broadcast_to(jnp.array([True, False, False]), (3, 3))
assert jnp.allclose(
attn(x, x, x, mask),
attn(x, x, x, mask, key=jrandom.PRNGKey(4)),
jnp.broadcast_to(jnp.array([[680.0, 1960.0, 3240.0, 4520.0]]), (3, 4)),
)

Expand Down
4 changes: 2 additions & 2 deletions tests/test_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class MyModule(eqx.Module):

def __init__(self):
attention = eqx.nn.MultiheadAttention(
num_heads=3, query_size=12, key=getkey()
num_heads=3, query_size=12, key=getkey(), inference=True
)
my_proj = eqx.nn.Linear(12, 12, use_bias=False, key=getkey())
where = lambda pair: pair[1].key_proj
Expand Down Expand Up @@ -88,7 +88,7 @@ class MyModule(eqx.Module):
def __init__(self):
my_proj = eqx.nn.Linear(12, 12, use_bias=False, key=getkey())
attention = eqx.nn.MultiheadAttention(
num_heads=3, query_size=12, key=getkey()
num_heads=3, query_size=12, key=getkey(), inference=True
)
where = lambda pair: (pair[1].key_proj, pair[1].query_proj.weight)
get = lambda pair: (pair[0], pair[0].weight + 1)
Expand Down

0 comments on commit 80761de

Please sign in to comment.