diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 750fa40..e98d6f1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -25,4 +25,4 @@ repos: rev: v1.1.331 hooks: - id: pyright - additional_dependencies: ["equinox", "jax", "lineax", "pytest", "optax"] + additional_dependencies: ["equinox", "jax", "lineax", "pytest", "optax", "diffrax"] diff --git a/benchmarks/levenberg-marquardt.py b/benchmarks/levenberg-marquardt.py index 2181fdb..93ecbdf 100644 --- a/benchmarks/levenberg-marquardt.py +++ b/benchmarks/levenberg-marquardt.py @@ -81,7 +81,7 @@ def solve( # support forward-mode autodiff, which is used by Levenberg--Marquardt adjoint=dfx.DirectAdjoint(), ) - return sol.ys + return sol.ys # pyright: ignore def get_data() -> tuple[Float[Array, "3 2"], Float[Array, "3 50"]]: diff --git a/optimistix/_misc.py b/optimistix/_misc.py index bd58f79..bf7e5d7 100644 --- a/optimistix/_misc.py +++ b/optimistix/_misc.py @@ -103,26 +103,20 @@ def __call__(self, *args, **kwargs): return out, aux -def jacobian(fn, in_size, out_size, has_aux=False): - """Compute the Jacobian of a function using forward or backward mode AD. - - `jacobian` chooses between forward and backwards autodiff depending on the input - and output dimension of `fn`, as specified in `in_size` and `out_size`. - """ - - # Heuristic for which is better in each case - # These could probably be tuned a lot more. - if (in_size < 100) or (in_size <= 1.5 * out_size): - return jax.jacfwd(fn, has_aux=has_aux) - else: - return jax.jacrev(fn, has_aux=has_aux) - - -def lin_to_grad(lin_fn, *primals): - # Only the shape and dtype of primals is evaluated, not the value itself. We convert - # to grad after linearising to avoid recompilation. (1.0 is a scaling factor.) +def lin_to_grad(lin_fn, y_eval, autodiff_mode=None): + # Only the shape and dtype of y_eval is evaluated, not the value itself. (lin_fn + # was linearized at y_eval, and the values were stored.) + # We convert to grad after linearising for efficiency: # https://github.com/patrick-kidger/optimistix/issues/89#issuecomment-2447669714 - return jax.linear_transpose(lin_fn, *primals)(1.0) + if autodiff_mode == "bwd": + (grad,) = jax.linear_transpose(lin_fn, y_eval)(1.0) # (1.0 is a scaling factor) + return grad + if autodiff_mode == "fwd": + return jax.jacfwd(lin_fn)(y_eval) + else: + raise ValueError( + "Only `autodiff_mode='fwd'` or `autodiff_mode='bwd'` are valid." + ) def _asarray(dtype, x): diff --git a/optimistix/_solver/bfgs.py b/optimistix/_solver/bfgs.py index 19e7994..c2f6fbb 100644 --- a/optimistix/_solver/bfgs.py +++ b/optimistix/_solver/bfgs.py @@ -155,6 +155,13 @@ class AbstractBFGS( information. This abstract version may be subclassed to choose alternative descent and searches. + + Supports the following `options`: + + - `autodiff_mode`: whether to use forward- or reverse-mode autodifferentiation to + compute the gradient. Can be either `"fwd"` or `"bwd"`. Defaults to `"bwd"`, + which is usually more efficient. Changing this can be useful when the target + function does not support reverse-mode automatic differentiation. """ rtol: AbstractVar[float] @@ -205,6 +212,7 @@ def step( state: _BFGSState, tags: frozenset[object], ) -> tuple[Y, _BFGSState, Aux]: + autodiff_mode = options.get("autodiff_mode", "bwd") f_eval, lin_fn, aux_eval = jax.linearize( lambda _y: fn(_y, args), state.y_eval, has_aux=True ) @@ -218,10 +226,7 @@ def step( ) def accepted(descent_state): - # We have linearised the function (i.e. computed its Jacobian at y_eval) - # above, here we convert it to a gradient of the same shape as y. y_eval is - # actually a dummy value here, since lin_fn already has the required info. - (grad,) = lin_to_grad(lin_fn, state.y_eval) + grad = lin_to_grad(lin_fn, state.y_eval, autodiff_mode=autodiff_mode) y_diff = (state.y_eval**ω - y**ω).ω if self.use_inverse: @@ -318,6 +323,13 @@ class BFGS(AbstractBFGS[Y, Aux, _Hessian], strict=True): This is a quasi-Newton optimisation algorithm, whose defining feature is the way it progressively builds up a Hessian approximation using multiple steps of gradient information. + + Supports the following `options`: + + - `autodiff_mode`: whether to use forward- or reverse-mode autodifferentiation to + compute the gradient. Can be either `"fwd"` or `"bwd"`. Defaults to `"bwd"`, + which is usually more efficient. Changing this can be useful when the target + function does not support reverse-mode automatic differentiation. """ rtol: float diff --git a/optimistix/_solver/gradient_methods.py b/optimistix/_solver/gradient_methods.py index 726507d..6277aac 100644 --- a/optimistix/_solver/gradient_methods.py +++ b/optimistix/_solver/gradient_methods.py @@ -120,6 +120,13 @@ class AbstractGradientDescent( - `norm: Callable[[PyTree], Scalar]` - `descent: AbstractDescent` - `search: AbstractSearch` + + Supports the following `options`: + + - `autodiff_mode`: whether to use forward- or reverse-mode autodifferentiation to + compute the gradient. Can be either `"fwd"` or `"bwd"`. Defaults to `"bwd"`, + which is usually more efficient. Changing this can be useful when the target + function does not support reverse-mode automatic differentiation. """ rtol: AbstractVar[float] @@ -162,6 +169,7 @@ def step( state: _GradientDescentState, tags: frozenset[object], ) -> tuple[Y, _GradientDescentState, Aux]: + autodiff_mode = options.get("autodiff_mode", "bwd") f_eval, lin_fn, aux_eval = jax.linearize( lambda _y: fn(_y, args), state.y_eval, has_aux=True ) @@ -175,10 +183,7 @@ def step( ) def accepted(descent_state): - # We have linearised the function (i.e. computed its Jacobian at y_eval) - # above, here we convert it to a gradient of the same shape as y. y_eval is - # actually a dummy value here, since lin_fn already has the required info. - (grad,) = lin_to_grad(lin_fn, state.y_eval) + grad = lin_to_grad(lin_fn, state.y_eval, autodiff_mode=autodiff_mode) f_eval_info = FunctionInfo.EvalGrad(f_eval, grad) descent_state = self.descent.query(state.y_eval, f_eval_info, descent_state) @@ -243,7 +248,15 @@ def postprocess( class GradientDescent(AbstractGradientDescent[Y, Aux], strict=True): - """Classic gradient descent with a learning rate `learning_rate`.""" + """Classic gradient descent with a learning rate `learning_rate`. + + Supports the following `options`: + + - `autodiff_mode`: whether to use forward- or reverse-mode autodifferentiation to + compute the gradient. Can be either `"fwd"` or `"bwd"`. Defaults to `"bwd"`, + which is usually more efficient. Changing this can be useful when the target + function does not support reverse-mode automatic differentiation. + """ rtol: float atol: float diff --git a/optimistix/_solver/nonlinear_cg.py b/optimistix/_solver/nonlinear_cg.py index 8edac05..2479b15 100644 --- a/optimistix/_solver/nonlinear_cg.py +++ b/optimistix/_solver/nonlinear_cg.py @@ -175,7 +175,15 @@ def step( class NonlinearCG(AbstractGradientDescent[Y, Aux], strict=True): - """The nonlinear conjugate gradient method.""" + """The nonlinear conjugate gradient method. + + Supports the following `options`: + + - `autodiff_mode`: whether to use forward- or reverse-mode autodifferentiation to + compute the gradient. Can be either `"fwd"` or `"bwd"`. Defaults to `"bwd"`, + which is usually more efficient. Changing this can be useful when the target + function does not support reverse-mode automatic differentiation. + """ rtol: float atol: float diff --git a/tests/helpers.py b/tests/helpers.py index 4eef610..437b08a 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -2,6 +2,7 @@ from collections.abc import Callable from typing import Any, TypeVar +import diffrax as dfx import equinox as eqx import equinox.internal as eqxi import jax @@ -786,3 +787,37 @@ def apply(self, primal_fn, rewrite_fn, inputs, tags): del rewrite_fn, tags while_loop = ft.partial(eqxi.while_loop, kind="lax") return primal_fn(inputs + (while_loop,)) + + +def forward_only_ode(k, args): + # Test minimisers for use with dfx.ForwardMode. This test checks if the forward + # branch is entered as expected and that a (trivial) result is found. + # We're checking if trickier problems are solved correctly in the other tests. + del args + dy = lambda t, y, k: -k * y + + def solve(_k): + return dfx.diffeqsolve( + dfx.ODETerm(dy), + dfx.Tsit5(), + 0.0, + 10.0, + 0.1, + 10.0, + args=_k, + adjoint=dfx.ForwardMode(), + ) + + data = jnp.asarray(solve(jnp.array(0.5)).ys) # seems to make type checkers happy + fit = jnp.asarray(solve(k).ys) + return jnp.sum((data - fit) ** 2) + + +forward_only_fn_init_options_expected = ( + ( + forward_only_ode, + jnp.array(0.6), + dict(autodiff_mode="fwd"), + jnp.array(0.5), + ), +) diff --git a/tests/requirements.txt b/tests/requirements.txt index e9b55d0..ddfdbfa 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -4,4 +4,5 @@ lineax optax pytest pytest-xdist -jaxlib \ No newline at end of file +jaxlib +diffrax \ No newline at end of file diff --git a/tests/test_minimise.py b/tests/test_minimise.py index 2b6ab2c..657c62c 100644 --- a/tests/test_minimise.py +++ b/tests/test_minimise.py @@ -15,6 +15,7 @@ beale, bowl, finite_difference_jvp, + forward_only_fn_init_options_expected, matyas, minimisation_fn_minima_init_args, minimisers, @@ -25,9 +26,12 @@ smoke_aux = (jnp.ones((2, 3)), {"smoke_aux": jnp.ones(2)}) +@pytest.mark.parametrize( + "options", (dict(autodiff_mode="fwd"), dict(autodiff_mode="bwd")) +) @pytest.mark.parametrize("solver", minimisers) @pytest.mark.parametrize("_fn, minimum, init, args", minimisation_fn_minima_init_args) -def test_minimise(solver, _fn, minimum, init, args): +def test_minimise(solver, _fn, minimum, init, args, options): if isinstance(solver, optx.GradientDescent): max_steps = 100_000 else: @@ -49,6 +53,7 @@ def test_minimise(solver, _fn, minimum, init, args): init, has_aux=has_aux, args=args, + options=options, max_steps=max_steps, throw=False, ).value @@ -56,9 +61,12 @@ def test_minimise(solver, _fn, minimum, init, args): assert tree_allclose(optx_min, minimum, atol=atol, rtol=rtol) +@pytest.mark.parametrize( + "options", (dict(autodiff_mode="fwd"), dict(autodiff_mode="bwd")) +) @pytest.mark.parametrize("solver", minimisers) @pytest.mark.parametrize("_fn, minimum, init, args", minimisation_fn_minima_init_args) -def test_minimise_jvp(getkey, solver, _fn, minimum, init, args): +def test_minimise_jvp(getkey, solver, _fn, minimum, init, args, options): if isinstance(solver, (optx.GradientDescent, optx.NonlinearCG)): max_steps = 100_000 atol = rtol = 1e-2 @@ -88,6 +96,7 @@ def minimise(x, dynamic_args, *, adjoint): x, has_aux=has_aux, args=args, + options=options, max_steps=max_steps, adjoint=adjoint, throw=False, @@ -188,3 +197,17 @@ def f(x, _): num_called_so_far = num_called optx.minimise(f, solver2, 1.0) assert num_called_so_far == num_called + + +@pytest.mark.parametrize("solver", minimisers) +@pytest.mark.parametrize( + "fn, y0, options, expected", forward_only_fn_init_options_expected +) +def test_forward_minimisation(fn, y0, options, expected, solver): + if isinstance(solver, optx.OptaxMinimiser): # No support for forward option + return + else: + # Many steps because gradient descent takes ridiculously long + sol = optx.minimise(fn, solver, y0, options=options, max_steps=2**10) + assert sol.result == optx.RESULTS.successful + assert tree_allclose(sol.value, expected, atol=1e-4, rtol=1e-4)