-
-
Notifications
You must be signed in to change notification settings - Fork 152
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
using tfp.substrates.jax.optimizer.lbfgs_minimize with Equinox #906
Comments
I am not aware of any examples using TFP + equinox, but it should be possible given that equinox generally operates at the jax level, rather than as a wrapper. There are lots of examples using optax + equinox, and optax also has a LBFGS example (https://optax.readthedocs.io/en/stable/_collections/examples/lbfgs.html) so that might be a good starting place (other implementations of potential use are https://jaxopt.github.io/stable/unconstrained.html and https://docs.kidger.site/optimistix/api/minimise/#optimistix.BFGS) |
Thank you @lockwo and @patrick-kidger It worked for me. I had another question. |
To do what you're describing there's a couple ways (since I assume you don't much care about that specific problem, but want to apply it to your, potentially much more complicated use case). 1) you could modify the loss function/the function being differentiated such that the gradient is naturally what you want (e.g. if the loss function had two terms, you could potentially change the sign on the one dependent on certain parameters), 2) you could make the parameters a custom layer and write a custom gradient rule to just negate it (not sure where this level of work would be needed, but its possible), 3) (what I would probably do in most situations), just compute the gradient and then add an extra step to multiply that component of the gradient by -1 before applying the gradient. |
Thanks @lockw. I will take the choice 3. I just had one doubt how to choose the gradient correspondint to a specific parameters in the pytree . For the same example how would I know the gradints corresponding to |
The gradients is a pytree of the same structure as the parameters, so wherever in the parameter pytree the bias is, same in the gradients |
thank you @lockwo |
Dear All-
I want to use LBFGS optimizer and was wondering If there is any example using Equinox Neural Network model with
tfp.substrates.jax.optimizer.lbfgs_minimize
optimizer.Thanks!
The text was updated successfully, but these errors were encountered: