Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

using tfp.substrates.jax.optimizer.lbfgs_minimize with Equinox #906

Open
raj-brown opened this issue Dec 5, 2024 · 6 comments
Open

using tfp.substrates.jax.optimizer.lbfgs_minimize with Equinox #906

raj-brown opened this issue Dec 5, 2024 · 6 comments
Labels
question User queries

Comments

@raj-brown
Copy link

Dear All-
I want to use LBFGS optimizer and was wondering If there is any example using Equinox Neural Network model with tfp.substrates.jax.optimizer.lbfgs_minimize optimizer.

Thanks!

@lockwo
Copy link
Contributor

lockwo commented Dec 6, 2024

I am not aware of any examples using TFP + equinox, but it should be possible given that equinox generally operates at the jax level, rather than as a wrapper. There are lots of examples using optax + equinox, and optax also has a LBFGS example (https://optax.readthedocs.io/en/stable/_collections/examples/lbfgs.html) so that might be a good starting place (other implementations of potential use are https://jaxopt.github.io/stable/unconstrained.html and https://docs.kidger.site/optimistix/api/minimise/#optimistix.BFGS)

@patrick-kidger patrick-kidger added the question User queries label Dec 6, 2024
@raj-brown
Copy link
Author

Thank you @lockwo and @patrick-kidger It worked for me. I had another question.
For the example https://docs.kidger.site/equinox/all-of-equinox/ If I want to maximize the loss function only for extra_bias and minimize it for the rest parameters? How do I selectively choose the parameters for min/max operation. I will really appreciate for your suggestions @patrick-kidger @lockwo. Thanks!

@lockwo
Copy link
Contributor

lockwo commented Dec 9, 2024

To do what you're describing there's a couple ways (since I assume you don't much care about that specific problem, but want to apply it to your, potentially much more complicated use case). 1) you could modify the loss function/the function being differentiated such that the gradient is naturally what you want (e.g. if the loss function had two terms, you could potentially change the sign on the one dependent on certain parameters), 2) you could make the parameters a custom layer and write a custom gradient rule to just negate it (not sure where this level of work would be needed, but its possible), 3) (what I would probably do in most situations), just compute the gradient and then add an extra step to multiply that component of the gradient by -1 before applying the gradient.

@raj-brown
Copy link
Author

Thanks @lockw. I will take the choice 3. I just had one doubt how to choose the gradient correspondint to a specific parameters in the pytree . For the same example how would I know the gradints corresponding to extra_bias pytree...Thank you very much

@lockwo
Copy link
Contributor

lockwo commented Dec 9, 2024

The gradients is a pytree of the same structure as the parameters, so wherever in the parameter pytree the bias is, same in the gradients

@raj-brown
Copy link
Author

thank you @lockwo

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

3 participants