From 4ca63eb56cf6b57e5bf09b5ea6b66dbac5c3739f Mon Sep 17 00:00:00 2001 From: Johanna Haffner Date: Mon, 27 Jan 2025 10:23:39 +0100 Subject: [PATCH 1/7] Adding an option to enable forward-mode automatic differentiation for all minimisers. Moves gradient computation, handling and related documentation into lin_to_grad. Also removes an old jacobian function that used size-based heuristics to toggle between modes. --- optimistix/_misc.py | 54 +++++++++++++++++--------- optimistix/_solver/bfgs.py | 20 ++++++++-- optimistix/_solver/gradient_methods.py | 23 ++++++++--- optimistix/_solver/nonlinear_cg.py | 10 ++++- tests/test_minimise.py | 8 +++- 5 files changed, 85 insertions(+), 30 deletions(-) diff --git a/optimistix/_misc.py b/optimistix/_misc.py index bd58f79..6eba372 100644 --- a/optimistix/_misc.py +++ b/optimistix/_misc.py @@ -8,6 +8,7 @@ import jax.lax as lax import jax.numpy as jnp import jax.tree_util as jtu +import numpy as np from equinox.internal import ω from jaxtyping import Array, ArrayLike, Bool, PyTree, Scalar from lineax.internal import ( @@ -103,26 +104,43 @@ def __call__(self, *args, **kwargs): return out, aux -def jacobian(fn, in_size, out_size, has_aux=False): - """Compute the Jacobian of a function using forward or backward mode AD. - - `jacobian` chooses between forward and backwards autodiff depending on the input - and output dimension of `fn`, as specified in `in_size` and `out_size`. +def _jacfwd(lin_fn, pytree): + """Custom version of jax.jacfwd that directly uses a linearized function. + Takes inspiration from jax.jacfwd, but simplifies some steps: we only ever treat + PyTrees of arrays, where all elements have the same dtype. """ - - # Heuristic for which is better in each case - # These could probably be tuned a lot more. - if (in_size < 100) or (in_size <= 1.5 * out_size): - return jax.jacfwd(fn, has_aux=has_aux) - else: - return jax.jacrev(fn, has_aux=has_aux) - - -def lin_to_grad(lin_fn, *primals): - # Only the shape and dtype of primals is evaluated, not the value itself. We convert - # to grad after linearising to avoid recompilation. (1.0 is a scaling factor.) + leaves, treedef = jtu.tree_flatten(pytree) + static_sizes = [int(jnp.size(leaf)) for leaf in leaves] + indices = np.cumsum(static_sizes)[:-1] # Define boundaries between elements + elements = sum(static_sizes) + dtype = jax.dtypes.result_type(*leaves) + + def values_to_tree(values): + parts = jnp.split(values, indices) + reshaped = jtu.tree_map(lambda a, b: jnp.reshape(a, b.shape), parts, leaves) + return jtu.tree_unflatten(treedef, reshaped) + + def unit_tree(index): + values = jnp.zeros(elements, dtype=dtype).at[index].set(1.0) + return values_to_tree(values) + + unit_pytrees = [unit_tree(i) for i in range(elements)] + derivatives = jnp.stack([lin_fn(t) for t in unit_pytrees]) + return values_to_tree(derivatives) + + +def lin_to_grad(lin_fn, y_eval, mode=None): + # Only the shape and dtype of y_eval is evaluated, not the value itself. (lin_fn + # was linearized at y_eval, and the values were stored.) + # We convert to grad after linearising for efficiency: # https://github.com/patrick-kidger/optimistix/issues/89#issuecomment-2447669714 - return jax.linear_transpose(lin_fn, *primals)(1.0) + if mode == "bwd": + (grad,) = jax.linear_transpose(lin_fn, y_eval)(1.0) # (1.0 is a scaling factor) + return grad + if mode == "fwd": + return _jacfwd(lin_fn, y_eval) + else: + raise ValueError("Only `mode='fwd'` or `mode='bwd'` are valid.") def _asarray(dtype, x): diff --git a/optimistix/_solver/bfgs.py b/optimistix/_solver/bfgs.py index 19e7994..0e26ed3 100644 --- a/optimistix/_solver/bfgs.py +++ b/optimistix/_solver/bfgs.py @@ -155,6 +155,13 @@ class AbstractBFGS( information. This abstract version may be subclassed to choose alternative descent and searches. + + Supports the following `options`: + + - `mode`: whether to use forward- or reverse-mode autodifferentiation to compute the + gradient. Can be either `"fwd"` or `"bwd"`. Defaults to `"bwd"`, which is + usually more efficient. Changing this can be useful when the target function + does not support reverse-mode automatic differentiation. """ rtol: AbstractVar[float] @@ -205,6 +212,7 @@ def step( state: _BFGSState, tags: frozenset[object], ) -> tuple[Y, _BFGSState, Aux]: + mode = options.get("mode", "bwd") f_eval, lin_fn, aux_eval = jax.linearize( lambda _y: fn(_y, args), state.y_eval, has_aux=True ) @@ -218,10 +226,7 @@ def step( ) def accepted(descent_state): - # We have linearised the function (i.e. computed its Jacobian at y_eval) - # above, here we convert it to a gradient of the same shape as y. y_eval is - # actually a dummy value here, since lin_fn already has the required info. - (grad,) = lin_to_grad(lin_fn, state.y_eval) + grad = lin_to_grad(lin_fn, state.y_eval, mode=mode) y_diff = (state.y_eval**ω - y**ω).ω if self.use_inverse: @@ -318,6 +323,13 @@ class BFGS(AbstractBFGS[Y, Aux, _Hessian], strict=True): This is a quasi-Newton optimisation algorithm, whose defining feature is the way it progressively builds up a Hessian approximation using multiple steps of gradient information. + + Supports the following `options`: + + - `mode`: whether to use forward- or reverse-mode autodifferentiation to compute the + gradient. Can be either `"fwd"` or `"bwd"`. Defaults to `"bwd"`, which is + usually more efficient. Changing this can be useful when the target function + does not support reverse-mode automatic differentiation. """ rtol: float diff --git a/optimistix/_solver/gradient_methods.py b/optimistix/_solver/gradient_methods.py index 726507d..5469ea6 100644 --- a/optimistix/_solver/gradient_methods.py +++ b/optimistix/_solver/gradient_methods.py @@ -120,6 +120,13 @@ class AbstractGradientDescent( - `norm: Callable[[PyTree], Scalar]` - `descent: AbstractDescent` - `search: AbstractSearch` + + Supports the following `options`: + + - `mode`: whether to use forward- or reverse-mode autodifferentiation to compute the + gradient. Can be either `"fwd"` or `"bwd"`. Defaults to `"bwd"`, which is + usually more efficient. Changing this can be useful when the target function + does not support reverse-mode automatic differentiation. """ rtol: AbstractVar[float] @@ -162,6 +169,7 @@ def step( state: _GradientDescentState, tags: frozenset[object], ) -> tuple[Y, _GradientDescentState, Aux]: + mode = options.get("mode", "bwd") f_eval, lin_fn, aux_eval = jax.linearize( lambda _y: fn(_y, args), state.y_eval, has_aux=True ) @@ -175,10 +183,7 @@ def step( ) def accepted(descent_state): - # We have linearised the function (i.e. computed its Jacobian at y_eval) - # above, here we convert it to a gradient of the same shape as y. y_eval is - # actually a dummy value here, since lin_fn already has the required info. - (grad,) = lin_to_grad(lin_fn, state.y_eval) + grad = lin_to_grad(lin_fn, state.y_eval, mode=mode) f_eval_info = FunctionInfo.EvalGrad(f_eval, grad) descent_state = self.descent.query(state.y_eval, f_eval_info, descent_state) @@ -243,7 +248,15 @@ def postprocess( class GradientDescent(AbstractGradientDescent[Y, Aux], strict=True): - """Classic gradient descent with a learning rate `learning_rate`.""" + """Classic gradient descent with a learning rate `learning_rate`. + + Supports the following `options`: + + - `mode`: whether to use forward- or reverse-mode autodifferentiation to compute the + gradient. Can be either `"fwd"` or `"bwd"`. Defaults to `"bwd"`, which is + usually more efficient. Changing this can be useful when the target function + does not support reverse-mode automatic differentiation. + """ rtol: float atol: float diff --git a/optimistix/_solver/nonlinear_cg.py b/optimistix/_solver/nonlinear_cg.py index 8edac05..4d00030 100644 --- a/optimistix/_solver/nonlinear_cg.py +++ b/optimistix/_solver/nonlinear_cg.py @@ -175,7 +175,15 @@ def step( class NonlinearCG(AbstractGradientDescent[Y, Aux], strict=True): - """The nonlinear conjugate gradient method.""" + """The nonlinear conjugate gradient method. + + Supports the following `options`: + + - `mode`: whether to use forward- or reverse-mode autodifferentiation to compute the + gradient. Can be either `"fwd"` or `"bwd"`. Defaults to `"bwd"`, which is + usually more efficient. Changing this can be useful when the target function + does not support reverse-mode automatic differentiation. + """ rtol: float atol: float diff --git a/tests/test_minimise.py b/tests/test_minimise.py index 2b6ab2c..a0aa8a3 100644 --- a/tests/test_minimise.py +++ b/tests/test_minimise.py @@ -25,9 +25,10 @@ smoke_aux = (jnp.ones((2, 3)), {"smoke_aux": jnp.ones(2)}) +@pytest.mark.parametrize("options", (dict(mode="fwd"), dict(mode="bwd"))) @pytest.mark.parametrize("solver", minimisers) @pytest.mark.parametrize("_fn, minimum, init, args", minimisation_fn_minima_init_args) -def test_minimise(solver, _fn, minimum, init, args): +def test_minimise(solver, _fn, minimum, init, args, options): if isinstance(solver, optx.GradientDescent): max_steps = 100_000 else: @@ -49,6 +50,7 @@ def test_minimise(solver, _fn, minimum, init, args): init, has_aux=has_aux, args=args, + options=options, max_steps=max_steps, throw=False, ).value @@ -56,9 +58,10 @@ def test_minimise(solver, _fn, minimum, init, args): assert tree_allclose(optx_min, minimum, atol=atol, rtol=rtol) +@pytest.mark.parametrize("options", (dict(mode="fwd"), dict(mode="bwd"))) @pytest.mark.parametrize("solver", minimisers) @pytest.mark.parametrize("_fn, minimum, init, args", minimisation_fn_minima_init_args) -def test_minimise_jvp(getkey, solver, _fn, minimum, init, args): +def test_minimise_jvp(getkey, solver, _fn, minimum, init, args, options): if isinstance(solver, (optx.GradientDescent, optx.NonlinearCG)): max_steps = 100_000 atol = rtol = 1e-2 @@ -88,6 +91,7 @@ def minimise(x, dynamic_args, *, adjoint): x, has_aux=has_aux, args=args, + options=options, max_steps=max_steps, adjoint=adjoint, throw=False, From 2b2907e09afdedac0bf047c24799bbfe5e87fbab Mon Sep 17 00:00:00 2001 From: Johanna Haffner Date: Mon, 3 Feb 2025 09:46:08 +0100 Subject: [PATCH 2/7] improve docstring --- optimistix/_misc.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/optimistix/_misc.py b/optimistix/_misc.py index 6eba372..eec63c3 100644 --- a/optimistix/_misc.py +++ b/optimistix/_misc.py @@ -105,9 +105,12 @@ def __call__(self, *args, **kwargs): def _jacfwd(lin_fn, pytree): - """Custom version of jax.jacfwd that directly uses a linearized function. - Takes inspiration from jax.jacfwd, but simplifies some steps: we only ever treat - PyTrees of arrays, where all elements have the same dtype. + """Custom version of jax.jacfwd that operates on a linearized function. Enables + obtaining a gradient without requiring transposition if the function that was + linearized returns a scalar, as in this case the Jacobian is equivalent to the + gradient. (Untested for other use cases.) + Tailored to the optimistix use-case, where all optimisation variables are pytrees of + arrays, sharing the same dtype. """ leaves, treedef = jtu.tree_flatten(pytree) static_sizes = [int(jnp.size(leaf)) for leaf in leaves] From b563c3e10c5a95ed9772021149f22eaeba0c4915 Mon Sep 17 00:00:00 2001 From: Johanna Haffner Date: Mon, 3 Feb 2025 13:16:21 +0100 Subject: [PATCH 3/7] move creation of unit pytrees to enable vmapping --- optimistix/_misc.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/optimistix/_misc.py b/optimistix/_misc.py index eec63c3..08ab04c 100644 --- a/optimistix/_misc.py +++ b/optimistix/_misc.py @@ -123,12 +123,12 @@ def values_to_tree(values): reshaped = jtu.tree_map(lambda a, b: jnp.reshape(a, b.shape), parts, leaves) return jtu.tree_unflatten(treedef, reshaped) - def unit_tree(index): + def directional_derivative(index): values = jnp.zeros(elements, dtype=dtype).at[index].set(1.0) - return values_to_tree(values) + unit_tree = values_to_tree(values) + return lin_fn(unit_tree) - unit_pytrees = [unit_tree(i) for i in range(elements)] - derivatives = jnp.stack([lin_fn(t) for t in unit_pytrees]) + derivatives = jax.vmap(directional_derivative)(jnp.arange(elements)) return values_to_tree(derivatives) From a7d23888ff87a720d88ad3510853d1ee9cfed2bd Mon Sep 17 00:00:00 2001 From: Johanna Haffner Date: Thu, 6 Feb 2025 18:54:31 +0100 Subject: [PATCH 4/7] remove custom _jacfwd, replace with eqx.filter_jacfwd --- optimistix/_misc.py | 31 +------------------------------ 1 file changed, 1 insertion(+), 30 deletions(-) diff --git a/optimistix/_misc.py b/optimistix/_misc.py index 08ab04c..5a08d54 100644 --- a/optimistix/_misc.py +++ b/optimistix/_misc.py @@ -8,7 +8,6 @@ import jax.lax as lax import jax.numpy as jnp import jax.tree_util as jtu -import numpy as np from equinox.internal import ω from jaxtyping import Array, ArrayLike, Bool, PyTree, Scalar from lineax.internal import ( @@ -104,34 +103,6 @@ def __call__(self, *args, **kwargs): return out, aux -def _jacfwd(lin_fn, pytree): - """Custom version of jax.jacfwd that operates on a linearized function. Enables - obtaining a gradient without requiring transposition if the function that was - linearized returns a scalar, as in this case the Jacobian is equivalent to the - gradient. (Untested for other use cases.) - Tailored to the optimistix use-case, where all optimisation variables are pytrees of - arrays, sharing the same dtype. - """ - leaves, treedef = jtu.tree_flatten(pytree) - static_sizes = [int(jnp.size(leaf)) for leaf in leaves] - indices = np.cumsum(static_sizes)[:-1] # Define boundaries between elements - elements = sum(static_sizes) - dtype = jax.dtypes.result_type(*leaves) - - def values_to_tree(values): - parts = jnp.split(values, indices) - reshaped = jtu.tree_map(lambda a, b: jnp.reshape(a, b.shape), parts, leaves) - return jtu.tree_unflatten(treedef, reshaped) - - def directional_derivative(index): - values = jnp.zeros(elements, dtype=dtype).at[index].set(1.0) - unit_tree = values_to_tree(values) - return lin_fn(unit_tree) - - derivatives = jax.vmap(directional_derivative)(jnp.arange(elements)) - return values_to_tree(derivatives) - - def lin_to_grad(lin_fn, y_eval, mode=None): # Only the shape and dtype of y_eval is evaluated, not the value itself. (lin_fn # was linearized at y_eval, and the values were stored.) @@ -141,7 +112,7 @@ def lin_to_grad(lin_fn, y_eval, mode=None): (grad,) = jax.linear_transpose(lin_fn, y_eval)(1.0) # (1.0 is a scaling factor) return grad if mode == "fwd": - return _jacfwd(lin_fn, y_eval) + return eqx.filter_jacfwd(lin_fn)(y_eval) else: raise ValueError("Only `mode='fwd'` or `mode='bwd'` are valid.") From dff7cb13099b30a20cdf22a0516ae7b234fc6a33 Mon Sep 17 00:00:00 2001 From: Johanna Haffner Date: Thu, 6 Feb 2025 19:04:26 +0100 Subject: [PATCH 5/7] eqx.filter_jacfwd -> jax.jacfwd --- optimistix/_misc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimistix/_misc.py b/optimistix/_misc.py index 5a08d54..71c0ab2 100644 --- a/optimistix/_misc.py +++ b/optimistix/_misc.py @@ -112,7 +112,7 @@ def lin_to_grad(lin_fn, y_eval, mode=None): (grad,) = jax.linear_transpose(lin_fn, y_eval)(1.0) # (1.0 is a scaling factor) return grad if mode == "fwd": - return eqx.filter_jacfwd(lin_fn)(y_eval) + return jax.jacfwd(lin_fn)(y_eval) else: raise ValueError("Only `mode='fwd'` or `mode='bwd'` are valid.") From f831466f908f83c9f8f58f925f1efdfa38df6c18 Mon Sep 17 00:00:00 2001 From: Johanna Haffner Date: Fri, 7 Feb 2025 13:30:46 +0100 Subject: [PATCH 6/7] mode -> autodiff_mode, adds forward-only test case --- .pre-commit-config.yaml | 2 +- optimistix/_misc.py | 10 +++++--- optimistix/_solver/bfgs.py | 20 +++++++-------- optimistix/_solver/gradient_methods.py | 20 +++++++-------- optimistix/_solver/nonlinear_cg.py | 8 +++--- tests/helpers.py | 35 ++++++++++++++++++++++++++ tests/requirements.txt | 3 ++- tests/test_minimise.py | 23 +++++++++++++++-- 8 files changed, 89 insertions(+), 32 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 750fa40..e98d6f1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -25,4 +25,4 @@ repos: rev: v1.1.331 hooks: - id: pyright - additional_dependencies: ["equinox", "jax", "lineax", "pytest", "optax"] + additional_dependencies: ["equinox", "jax", "lineax", "pytest", "optax", "diffrax"] diff --git a/optimistix/_misc.py b/optimistix/_misc.py index 71c0ab2..bf7e5d7 100644 --- a/optimistix/_misc.py +++ b/optimistix/_misc.py @@ -103,18 +103,20 @@ def __call__(self, *args, **kwargs): return out, aux -def lin_to_grad(lin_fn, y_eval, mode=None): +def lin_to_grad(lin_fn, y_eval, autodiff_mode=None): # Only the shape and dtype of y_eval is evaluated, not the value itself. (lin_fn # was linearized at y_eval, and the values were stored.) # We convert to grad after linearising for efficiency: # https://github.com/patrick-kidger/optimistix/issues/89#issuecomment-2447669714 - if mode == "bwd": + if autodiff_mode == "bwd": (grad,) = jax.linear_transpose(lin_fn, y_eval)(1.0) # (1.0 is a scaling factor) return grad - if mode == "fwd": + if autodiff_mode == "fwd": return jax.jacfwd(lin_fn)(y_eval) else: - raise ValueError("Only `mode='fwd'` or `mode='bwd'` are valid.") + raise ValueError( + "Only `autodiff_mode='fwd'` or `autodiff_mode='bwd'` are valid." + ) def _asarray(dtype, x): diff --git a/optimistix/_solver/bfgs.py b/optimistix/_solver/bfgs.py index 0e26ed3..c2f6fbb 100644 --- a/optimistix/_solver/bfgs.py +++ b/optimistix/_solver/bfgs.py @@ -158,10 +158,10 @@ class AbstractBFGS( Supports the following `options`: - - `mode`: whether to use forward- or reverse-mode autodifferentiation to compute the - gradient. Can be either `"fwd"` or `"bwd"`. Defaults to `"bwd"`, which is - usually more efficient. Changing this can be useful when the target function - does not support reverse-mode automatic differentiation. + - `autodiff_mode`: whether to use forward- or reverse-mode autodifferentiation to + compute the gradient. Can be either `"fwd"` or `"bwd"`. Defaults to `"bwd"`, + which is usually more efficient. Changing this can be useful when the target + function does not support reverse-mode automatic differentiation. """ rtol: AbstractVar[float] @@ -212,7 +212,7 @@ def step( state: _BFGSState, tags: frozenset[object], ) -> tuple[Y, _BFGSState, Aux]: - mode = options.get("mode", "bwd") + autodiff_mode = options.get("autodiff_mode", "bwd") f_eval, lin_fn, aux_eval = jax.linearize( lambda _y: fn(_y, args), state.y_eval, has_aux=True ) @@ -226,7 +226,7 @@ def step( ) def accepted(descent_state): - grad = lin_to_grad(lin_fn, state.y_eval, mode=mode) + grad = lin_to_grad(lin_fn, state.y_eval, autodiff_mode=autodiff_mode) y_diff = (state.y_eval**ω - y**ω).ω if self.use_inverse: @@ -326,10 +326,10 @@ class BFGS(AbstractBFGS[Y, Aux, _Hessian], strict=True): Supports the following `options`: - - `mode`: whether to use forward- or reverse-mode autodifferentiation to compute the - gradient. Can be either `"fwd"` or `"bwd"`. Defaults to `"bwd"`, which is - usually more efficient. Changing this can be useful when the target function - does not support reverse-mode automatic differentiation. + - `autodiff_mode`: whether to use forward- or reverse-mode autodifferentiation to + compute the gradient. Can be either `"fwd"` or `"bwd"`. Defaults to `"bwd"`, + which is usually more efficient. Changing this can be useful when the target + function does not support reverse-mode automatic differentiation. """ rtol: float diff --git a/optimistix/_solver/gradient_methods.py b/optimistix/_solver/gradient_methods.py index 5469ea6..6277aac 100644 --- a/optimistix/_solver/gradient_methods.py +++ b/optimistix/_solver/gradient_methods.py @@ -123,10 +123,10 @@ class AbstractGradientDescent( Supports the following `options`: - - `mode`: whether to use forward- or reverse-mode autodifferentiation to compute the - gradient. Can be either `"fwd"` or `"bwd"`. Defaults to `"bwd"`, which is - usually more efficient. Changing this can be useful when the target function - does not support reverse-mode automatic differentiation. + - `autodiff_mode`: whether to use forward- or reverse-mode autodifferentiation to + compute the gradient. Can be either `"fwd"` or `"bwd"`. Defaults to `"bwd"`, + which is usually more efficient. Changing this can be useful when the target + function does not support reverse-mode automatic differentiation. """ rtol: AbstractVar[float] @@ -169,7 +169,7 @@ def step( state: _GradientDescentState, tags: frozenset[object], ) -> tuple[Y, _GradientDescentState, Aux]: - mode = options.get("mode", "bwd") + autodiff_mode = options.get("autodiff_mode", "bwd") f_eval, lin_fn, aux_eval = jax.linearize( lambda _y: fn(_y, args), state.y_eval, has_aux=True ) @@ -183,7 +183,7 @@ def step( ) def accepted(descent_state): - grad = lin_to_grad(lin_fn, state.y_eval, mode=mode) + grad = lin_to_grad(lin_fn, state.y_eval, autodiff_mode=autodiff_mode) f_eval_info = FunctionInfo.EvalGrad(f_eval, grad) descent_state = self.descent.query(state.y_eval, f_eval_info, descent_state) @@ -252,10 +252,10 @@ class GradientDescent(AbstractGradientDescent[Y, Aux], strict=True): Supports the following `options`: - - `mode`: whether to use forward- or reverse-mode autodifferentiation to compute the - gradient. Can be either `"fwd"` or `"bwd"`. Defaults to `"bwd"`, which is - usually more efficient. Changing this can be useful when the target function - does not support reverse-mode automatic differentiation. + - `autodiff_mode`: whether to use forward- or reverse-mode autodifferentiation to + compute the gradient. Can be either `"fwd"` or `"bwd"`. Defaults to `"bwd"`, + which is usually more efficient. Changing this can be useful when the target + function does not support reverse-mode automatic differentiation. """ rtol: float diff --git a/optimistix/_solver/nonlinear_cg.py b/optimistix/_solver/nonlinear_cg.py index 4d00030..2479b15 100644 --- a/optimistix/_solver/nonlinear_cg.py +++ b/optimistix/_solver/nonlinear_cg.py @@ -179,10 +179,10 @@ class NonlinearCG(AbstractGradientDescent[Y, Aux], strict=True): Supports the following `options`: - - `mode`: whether to use forward- or reverse-mode autodifferentiation to compute the - gradient. Can be either `"fwd"` or `"bwd"`. Defaults to `"bwd"`, which is - usually more efficient. Changing this can be useful when the target function - does not support reverse-mode automatic differentiation. + - `autodiff_mode`: whether to use forward- or reverse-mode autodifferentiation to + compute the gradient. Can be either `"fwd"` or `"bwd"`. Defaults to `"bwd"`, + which is usually more efficient. Changing this can be useful when the target + function does not support reverse-mode automatic differentiation. """ rtol: float diff --git a/tests/helpers.py b/tests/helpers.py index 4eef610..437b08a 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -2,6 +2,7 @@ from collections.abc import Callable from typing import Any, TypeVar +import diffrax as dfx import equinox as eqx import equinox.internal as eqxi import jax @@ -786,3 +787,37 @@ def apply(self, primal_fn, rewrite_fn, inputs, tags): del rewrite_fn, tags while_loop = ft.partial(eqxi.while_loop, kind="lax") return primal_fn(inputs + (while_loop,)) + + +def forward_only_ode(k, args): + # Test minimisers for use with dfx.ForwardMode. This test checks if the forward + # branch is entered as expected and that a (trivial) result is found. + # We're checking if trickier problems are solved correctly in the other tests. + del args + dy = lambda t, y, k: -k * y + + def solve(_k): + return dfx.diffeqsolve( + dfx.ODETerm(dy), + dfx.Tsit5(), + 0.0, + 10.0, + 0.1, + 10.0, + args=_k, + adjoint=dfx.ForwardMode(), + ) + + data = jnp.asarray(solve(jnp.array(0.5)).ys) # seems to make type checkers happy + fit = jnp.asarray(solve(k).ys) + return jnp.sum((data - fit) ** 2) + + +forward_only_fn_init_options_expected = ( + ( + forward_only_ode, + jnp.array(0.6), + dict(autodiff_mode="fwd"), + jnp.array(0.5), + ), +) diff --git a/tests/requirements.txt b/tests/requirements.txt index e9b55d0..ddfdbfa 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -4,4 +4,5 @@ lineax optax pytest pytest-xdist -jaxlib \ No newline at end of file +jaxlib +diffrax \ No newline at end of file diff --git a/tests/test_minimise.py b/tests/test_minimise.py index a0aa8a3..657c62c 100644 --- a/tests/test_minimise.py +++ b/tests/test_minimise.py @@ -15,6 +15,7 @@ beale, bowl, finite_difference_jvp, + forward_only_fn_init_options_expected, matyas, minimisation_fn_minima_init_args, minimisers, @@ -25,7 +26,9 @@ smoke_aux = (jnp.ones((2, 3)), {"smoke_aux": jnp.ones(2)}) -@pytest.mark.parametrize("options", (dict(mode="fwd"), dict(mode="bwd"))) +@pytest.mark.parametrize( + "options", (dict(autodiff_mode="fwd"), dict(autodiff_mode="bwd")) +) @pytest.mark.parametrize("solver", minimisers) @pytest.mark.parametrize("_fn, minimum, init, args", minimisation_fn_minima_init_args) def test_minimise(solver, _fn, minimum, init, args, options): @@ -58,7 +61,9 @@ def test_minimise(solver, _fn, minimum, init, args, options): assert tree_allclose(optx_min, minimum, atol=atol, rtol=rtol) -@pytest.mark.parametrize("options", (dict(mode="fwd"), dict(mode="bwd"))) +@pytest.mark.parametrize( + "options", (dict(autodiff_mode="fwd"), dict(autodiff_mode="bwd")) +) @pytest.mark.parametrize("solver", minimisers) @pytest.mark.parametrize("_fn, minimum, init, args", minimisation_fn_minima_init_args) def test_minimise_jvp(getkey, solver, _fn, minimum, init, args, options): @@ -192,3 +197,17 @@ def f(x, _): num_called_so_far = num_called optx.minimise(f, solver2, 1.0) assert num_called_so_far == num_called + + +@pytest.mark.parametrize("solver", minimisers) +@pytest.mark.parametrize( + "fn, y0, options, expected", forward_only_fn_init_options_expected +) +def test_forward_minimisation(fn, y0, options, expected, solver): + if isinstance(solver, optx.OptaxMinimiser): # No support for forward option + return + else: + # Many steps because gradient descent takes ridiculously long + sol = optx.minimise(fn, solver, y0, options=options, max_steps=2**10) + assert sol.result == optx.RESULTS.successful + assert tree_allclose(sol.value, expected, atol=1e-4, rtol=1e-4) From 1e4c4123ae04d39c446cfe6e58bfd647f4e4e42f Mon Sep 17 00:00:00 2001 From: Johanna Haffner Date: Fri, 7 Feb 2025 13:41:38 +0100 Subject: [PATCH 7/7] add a pyright: ignore flag for return type (sol.ys) of dfx.diffeqsolve --- benchmarks/levenberg-marquardt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/levenberg-marquardt.py b/benchmarks/levenberg-marquardt.py index 2181fdb..93ecbdf 100644 --- a/benchmarks/levenberg-marquardt.py +++ b/benchmarks/levenberg-marquardt.py @@ -81,7 +81,7 @@ def solve( # support forward-mode autodiff, which is used by Levenberg--Marquardt adjoint=dfx.DirectAdjoint(), ) - return sol.ys + return sol.ys # pyright: ignore def get_data() -> tuple[Float[Array, "3 2"], Float[Array, "3 50"]]: