Skip to content

Commit

Permalink
Fixed a case in which we're hashing tracers.
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Jun 24, 2024
1 parent 5a5bf28 commit 86ca654
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions equinox/internal/_loop/checkpointed.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,9 +265,10 @@ def _checkpointed_while_loop(vjp_arg, cond_fun, checkpoints, buffers, max_steps)
"""Uncheckpointed forward used when not differentiating."""
del checkpoints, buffers, max_steps
init_val, body_fun = vjp_arg
_body_fun = lambda x: body_fun(x) # hashable wrapper; JAX issue #13554
while_loop = jax.named_call(lax.while_loop, name="checkpointed-no-vjp")
return while_loop(cond_fun, _body_fun, init_val)
# Hashable wrapper; JAX issue #13554 and
# https://github.com/patrick-kidger/equinox/issues/768
return while_loop(lambda x: cond_fun(x), lambda x: body_fun(x), init_val)


def _scalar_index(i, x):
Expand Down

0 comments on commit 86ca654

Please sign in to comment.