diff --git a/equinox/_ad.py b/equinox/_ad.py index a15312e4..17b11aec 100644 --- a/equinox/_ad.py +++ b/equinox/_ad.py @@ -724,9 +724,18 @@ def call(x, y, *, fn): def call_jvp(primals, tangents, *, fn): x, y = primals tx, ty = tangents + # `y` is not differentiated below, so it has a symbolic zero tangent, + # represented as a `None`. + assert ty is None primal_out = call(x, y, fn=fn) - tangent_out = tx**2 + ty + tangent_out = tx**2 return primal_out, tangent_out + + x = jnp.array(2.0) + y = jnp.array(2.0) + fn = lambda a, b: a + b + # This only computes gradients for the first argument `x`. + equinox.filter_grad(call)(x, y, fn=fn) ``` """ @@ -759,11 +768,9 @@ def def_jvp(self, fn_jvp): def fn_jvp_wrapper(static, dynamic, tangents): (dynamic,) = dynamic (tangents,) = tangents - d_args, _ = dynamic - t_args, t_kwargs = tangents + t_args, t_kwargs = jtu.tree_map(_drop_nondiff, tangents, dynamic) if len(jtu.tree_leaves(t_kwargs)) > 0: raise ValueError("Received keyword tangent") - t_args = jtu.tree_map(_drop_nondiff, t_args, d_args) args, kwargs = combine(dynamic, static) out, t_out = fn_jvp(args, t_args, **kwargs) t_out = jtu.tree_map(_none_to_zero, t_out, out, is_leaf=_is_none)