Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds an option to support forward-mode automatic differentiation in all minimisers #114

Merged
merged 7 commits into from
Feb 8, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 39 additions & 18 deletions optimistix/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import jax.lax as lax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
from equinox.internal import ω
from jaxtyping import Array, ArrayLike, Bool, PyTree, Scalar
from lineax.internal import (
Expand Down Expand Up @@ -103,26 +104,46 @@ def __call__(self, *args, **kwargs):
return out, aux


def jacobian(fn, in_size, out_size, has_aux=False):
"""Compute the Jacobian of a function using forward or backward mode AD.

`jacobian` chooses between forward and backwards autodiff depending on the input
and output dimension of `fn`, as specified in `in_size` and `out_size`.
def _jacfwd(lin_fn, pytree):
johannahaffner marked this conversation as resolved.
Show resolved Hide resolved
"""Custom version of jax.jacfwd that operates on a linearized function. Enables
obtaining a gradient without requiring transposition if the function that was
linearized returns a scalar, as in this case the Jacobian is equivalent to the
gradient. (Untested for other use cases.)
Tailored to the optimistix use-case, where all optimisation variables are pytrees of
arrays, sharing the same dtype.
"""

# Heuristic for which is better in each case
# These could probably be tuned a lot more.
if (in_size < 100) or (in_size <= 1.5 * out_size):
return jax.jacfwd(fn, has_aux=has_aux)
else:
return jax.jacrev(fn, has_aux=has_aux)


def lin_to_grad(lin_fn, *primals):
# Only the shape and dtype of primals is evaluated, not the value itself. We convert
# to grad after linearising to avoid recompilation. (1.0 is a scaling factor.)
leaves, treedef = jtu.tree_flatten(pytree)
static_sizes = [int(jnp.size(leaf)) for leaf in leaves]
indices = np.cumsum(static_sizes)[:-1] # Define boundaries between elements
elements = sum(static_sizes)
dtype = jax.dtypes.result_type(*leaves)

def values_to_tree(values):
parts = jnp.split(values, indices)
reshaped = jtu.tree_map(lambda a, b: jnp.reshape(a, b.shape), parts, leaves)
return jtu.tree_unflatten(treedef, reshaped)

def directional_derivative(index):
values = jnp.zeros(elements, dtype=dtype).at[index].set(1.0)
unit_tree = values_to_tree(values)
return lin_fn(unit_tree)

derivatives = jax.vmap(directional_derivative)(jnp.arange(elements))
return values_to_tree(derivatives)


def lin_to_grad(lin_fn, y_eval, mode=None):
# Only the shape and dtype of y_eval is evaluated, not the value itself. (lin_fn
# was linearized at y_eval, and the values were stored.)
# We convert to grad after linearising for efficiency:
# https://github.com/patrick-kidger/optimistix/issues/89#issuecomment-2447669714
return jax.linear_transpose(lin_fn, *primals)(1.0)
if mode == "bwd":
(grad,) = jax.linear_transpose(lin_fn, y_eval)(1.0) # (1.0 is a scaling factor)
return grad
if mode == "fwd":
return _jacfwd(lin_fn, y_eval)
else:
raise ValueError("Only `mode='fwd'` or `mode='bwd'` are valid.")


def _asarray(dtype, x):
Expand Down
20 changes: 16 additions & 4 deletions optimistix/_solver/bfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,13 @@ class AbstractBFGS(
information.

This abstract version may be subclassed to choose alternative descent and searches.

Supports the following `options`:

- `mode`: whether to use forward- or reverse-mode autodifferentiation to compute the
gradient. Can be either `"fwd"` or `"bwd"`. Defaults to `"bwd"`, which is
usually more efficient. Changing this can be useful when the target function
does not support reverse-mode automatic differentiation.
"""

rtol: AbstractVar[float]
Expand Down Expand Up @@ -205,6 +212,7 @@ def step(
state: _BFGSState,
tags: frozenset[object],
) -> tuple[Y, _BFGSState, Aux]:
mode = options.get("mode", "bwd")
f_eval, lin_fn, aux_eval = jax.linearize(
lambda _y: fn(_y, args), state.y_eval, has_aux=True
)
Expand All @@ -218,10 +226,7 @@ def step(
)

def accepted(descent_state):
# We have linearised the function (i.e. computed its Jacobian at y_eval)
# above, here we convert it to a gradient of the same shape as y. y_eval is
# actually a dummy value here, since lin_fn already has the required info.
(grad,) = lin_to_grad(lin_fn, state.y_eval)
grad = lin_to_grad(lin_fn, state.y_eval, mode=mode)

y_diff = (state.y_eval**ω - y**ω).ω
if self.use_inverse:
Expand Down Expand Up @@ -318,6 +323,13 @@ class BFGS(AbstractBFGS[Y, Aux, _Hessian], strict=True):
This is a quasi-Newton optimisation algorithm, whose defining feature is the way
it progressively builds up a Hessian approximation using multiple steps of gradient
information.

Supports the following `options`:

- `mode`: whether to use forward- or reverse-mode autodifferentiation to compute the
gradient. Can be either `"fwd"` or `"bwd"`. Defaults to `"bwd"`, which is
usually more efficient. Changing this can be useful when the target function
does not support reverse-mode automatic differentiation.
"""

rtol: float
Expand Down
23 changes: 18 additions & 5 deletions optimistix/_solver/gradient_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,13 @@ class AbstractGradientDescent(
- `norm: Callable[[PyTree], Scalar]`
- `descent: AbstractDescent`
- `search: AbstractSearch`

Supports the following `options`:

- `mode`: whether to use forward- or reverse-mode autodifferentiation to compute the
johannahaffner marked this conversation as resolved.
Show resolved Hide resolved
gradient. Can be either `"fwd"` or `"bwd"`. Defaults to `"bwd"`, which is
usually more efficient. Changing this can be useful when the target function
does not support reverse-mode automatic differentiation.
"""

rtol: AbstractVar[float]
Expand Down Expand Up @@ -162,6 +169,7 @@ def step(
state: _GradientDescentState,
tags: frozenset[object],
) -> tuple[Y, _GradientDescentState, Aux]:
mode = options.get("mode", "bwd")
f_eval, lin_fn, aux_eval = jax.linearize(
lambda _y: fn(_y, args), state.y_eval, has_aux=True
)
Expand All @@ -175,10 +183,7 @@ def step(
)

def accepted(descent_state):
# We have linearised the function (i.e. computed its Jacobian at y_eval)
# above, here we convert it to a gradient of the same shape as y. y_eval is
# actually a dummy value here, since lin_fn already has the required info.
(grad,) = lin_to_grad(lin_fn, state.y_eval)
grad = lin_to_grad(lin_fn, state.y_eval, mode=mode)

f_eval_info = FunctionInfo.EvalGrad(f_eval, grad)
descent_state = self.descent.query(state.y_eval, f_eval_info, descent_state)
Expand Down Expand Up @@ -243,7 +248,15 @@ def postprocess(


class GradientDescent(AbstractGradientDescent[Y, Aux], strict=True):
"""Classic gradient descent with a learning rate `learning_rate`."""
"""Classic gradient descent with a learning rate `learning_rate`.

Supports the following `options`:

- `mode`: whether to use forward- or reverse-mode autodifferentiation to compute the
gradient. Can be either `"fwd"` or `"bwd"`. Defaults to `"bwd"`, which is
usually more efficient. Changing this can be useful when the target function
does not support reverse-mode automatic differentiation.
"""

rtol: float
atol: float
Expand Down
10 changes: 9 additions & 1 deletion optimistix/_solver/nonlinear_cg.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,15 @@ def step(


class NonlinearCG(AbstractGradientDescent[Y, Aux], strict=True):
"""The nonlinear conjugate gradient method."""
"""The nonlinear conjugate gradient method.

Supports the following `options`:

- `mode`: whether to use forward- or reverse-mode autodifferentiation to compute the
gradient. Can be either `"fwd"` or `"bwd"`. Defaults to `"bwd"`, which is
usually more efficient. Changing this can be useful when the target function
does not support reverse-mode automatic differentiation.
"""

rtol: float
atol: float
Expand Down
8 changes: 6 additions & 2 deletions tests/test_minimise.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@
smoke_aux = (jnp.ones((2, 3)), {"smoke_aux": jnp.ones(2)})


@pytest.mark.parametrize("options", (dict(mode="fwd"), dict(mode="bwd")))
@pytest.mark.parametrize("solver", minimisers)
@pytest.mark.parametrize("_fn, minimum, init, args", minimisation_fn_minima_init_args)
def test_minimise(solver, _fn, minimum, init, args):
def test_minimise(solver, _fn, minimum, init, args, options):
if isinstance(solver, optx.GradientDescent):
max_steps = 100_000
else:
Expand All @@ -49,16 +50,18 @@ def test_minimise(solver, _fn, minimum, init, args):
init,
has_aux=has_aux,
args=args,
options=options,
max_steps=max_steps,
throw=False,
).value
optx_min = _fn(optx_argmin, args)
assert tree_allclose(optx_min, minimum, atol=atol, rtol=rtol)


@pytest.mark.parametrize("options", (dict(mode="fwd"), dict(mode="bwd")))
@pytest.mark.parametrize("solver", minimisers)
@pytest.mark.parametrize("_fn, minimum, init, args", minimisation_fn_minima_init_args)
def test_minimise_jvp(getkey, solver, _fn, minimum, init, args):
def test_minimise_jvp(getkey, solver, _fn, minimum, init, args, options):
if isinstance(solver, (optx.GradientDescent, optx.NonlinearCG)):
max_steps = 100_000
atol = rtol = 1e-2
Expand Down Expand Up @@ -88,6 +91,7 @@ def minimise(x, dynamic_args, *, adjoint):
x,
has_aux=has_aux,
args=args,
options=options,
max_steps=max_steps,
adjoint=adjoint,
throw=False,
Expand Down