Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds an option to support forward-mode automatic differentiation in all minimisers #114

Merged
merged 7 commits into from
Feb 8, 2025

Conversation

johannahaffner
Copy link
Contributor

Minimisers now accept a "fwd" and a "bwd" option, and lin_to_grad will use transposition to get the gradient with reverse-mode, while calling a custom _jacfwd that evaluates the linearised function with unit-pytrees when using forward-mode. Where relevant, gradient-computation related documentation has been moved to these two functions.

I have parameterised the tests for minimise and minimise_jvp to try both options, let me know if there are other tests I should also parameterise. I'm assuming everything else will also work if these two pass.

I have also removed a function jacobian from the _misc module, that used size-based heuristics to toggle between forward- and reverse-mode automatic differentiation. Since we now support explicit options in all solvers that compute derivatives, this function no longer fits our current way of doing things. (I don't think it was used anywhere anyway.)

Johanna Haffner added 2 commits February 2, 2025 13:14
… all minimisers.

Moves gradient computation, handling and related documentation into lin_to_grad.

Also removes an old jacobian function that used size-based heuristics to toggle between modes.
@johannahaffner
Copy link
Contributor Author

This addresses #112 (comment).

optimistix/_solver/gradient_methods.py Outdated Show resolved Hide resolved
optimistix/_misc.py Outdated Show resolved Hide resolved
@johannahaffner
Copy link
Contributor Author

Summary:

  • After some extra checks, I realised that the compilation-time overhead from using jax.jacfwd is additive (20 ms), and apparently independent of problem size. We now use jax.jacfwd.
  • the option is now called autodiff_mode. Overall, I think that is clearest.
  • I added a test to confirm that the forward-branch is entered correctly and that we get the expected result for a trivial case involving dfx.ForwardMode. This adds diffrax as a test-time dependency.

@johannahaffner
Copy link
Contributor Author

Tests were breaking because pyright now complains about the return type of sol.ys potentially being None. Since this is not a concern in the Levenberg-Marquardt benchmarks, I opted to ignore this complaint.

Comment on lines +207 to +208
if isinstance(solver, optx.OptaxMinimiser): # No support for forward option
return
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: it'd be better to do

@pytest.mark.parametrize("solver", [solver for solver in minimisers if not isinstance(solver, optx.OptaxMinimiser)]

Not important enough to block merging, just noting in passing!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noted for the future :)

@patrick-kidger patrick-kidger merged commit 61d68b1 into patrick-kidger:main Feb 8, 2025
2 checks passed
@patrick-kidger
Copy link
Owner

Awesome! Merged :D

@johannahaffner johannahaffner deleted the forward-fix branch February 8, 2025 18:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants