Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes crash on symbolic zero tangents for kwargs. #749

Merged
merged 1 commit into from
Jun 9, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions equinox/_ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,9 +724,18 @@ def call(x, y, *, fn):
def call_jvp(primals, tangents, *, fn):
x, y = primals
tx, ty = tangents
# `y` is not differentiated below, so it has a symbolic zero tangent,
# represented as a `None`.
assert ty is None
primal_out = call(x, y, fn=fn)
tangent_out = tx**2 + ty
tangent_out = tx**2
return primal_out, tangent_out

x = jnp.array(2.0)
y = jnp.array(2.0)
fn = lambda a, b: a + b
# This only computes gradients for the first argument `x`.
equinox.filter_grad(call)(x, y, fn=fn)
```
"""

Expand Down Expand Up @@ -759,11 +768,9 @@ def def_jvp(self, fn_jvp):
def fn_jvp_wrapper(static, dynamic, tangents):
(dynamic,) = dynamic
(tangents,) = tangents
d_args, _ = dynamic
t_args, t_kwargs = tangents
t_args, t_kwargs = jtu.tree_map(_drop_nondiff, tangents, dynamic)
if len(jtu.tree_leaves(t_kwargs)) > 0:
raise ValueError("Received keyword tangent")
t_args = jtu.tree_map(_drop_nondiff, t_args, d_args)
args, kwargs = combine(dynamic, static)
out, t_out = fn_jvp(args, t_args, **kwargs)
t_out = jtu.tree_map(_none_to_zero, t_out, out, is_leaf=_is_none)
Expand Down
Loading