-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adds an option to support forward-mode automatic differentiation in all minimisers #114
Adds an option to support forward-mode automatic differentiation in all minimisers #114
Conversation
… all minimisers. Moves gradient computation, handling and related documentation into lin_to_grad. Also removes an old jacobian function that used size-based heuristics to toggle between modes.
This addresses #112 (comment). |
Summary:
|
Tests were breaking because pyright now complains about the return type of |
if isinstance(solver, optx.OptaxMinimiser): # No support for forward option | ||
return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: it'd be better to do
@pytest.mark.parametrize("solver", [solver for solver in minimisers if not isinstance(solver, optx.OptaxMinimiser)]
Not important enough to block merging, just noting in passing!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Noted for the future :)
Awesome! Merged :D |
Minimisers now accept a
"fwd"
and a"bwd"
option, andlin_to_grad
will use transposition to get the gradient with reverse-mode, while calling a custom_jacfwd
that evaluates the linearised function with unit-pytrees when using forward-mode. Where relevant, gradient-computation related documentation has been moved to these two functions.I have parameterised the tests for
minimise
andminimise_jvp
to try both options, let me know if there are other tests I should also parameterise. I'm assuming everything else will also work if these two pass.I have also removed a function
jacobian
from the_misc
module, that used size-based heuristics to toggle between forward- and reverse-mode automatic differentiation. Since we now support explicit options in all solvers that compute derivatives, this function no longer fits our current way of doing things. (I don't think it was used anywhere anyway.)