Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replacement for equinox.internal.if_mapped? #916

Open
mjo22 opened this issue Dec 19, 2024 · 2 comments
Open

Replacement for equinox.internal.if_mapped? #916

mjo22 opened this issue Dec 19, 2024 · 2 comments
Labels
question User queries

Comments

@mjo22
Copy link

mjo22 commented Dec 19, 2024

Hi, I noticed that equinox.internal.if_mapped has been removed in a recent version of equinox rather than added to the main package. What is the reason for this / what would a reasonable replacement be to replicate its behavior?

I have highly nested pytrees in my package, so if_mapped turns out to be very useful for saving memory when moving across vmap boundaries.

@patrick-kidger
Copy link
Owner

So this hooked into some JAX internals that are getting removed in a future version of JAX. You could maybe replicate the same behavior by arranging for something updated and similar yourself, but I'm not sure exactly what.

The fact that this used such internals is why it was undocumented in equinox.internal instead of the main namespace, as I knew this might need to happen at some point!

FWIW I never really found a use-case for this in my own work as it meant that the output shape was hard to track.

Sorry that I can't give you better news!

@patrick-kidger patrick-kidger added the question User queries label Dec 19, 2024
@mjo22
Copy link
Author

mjo22 commented Dec 19, 2024

This is okay, makes sense! It is probably best practice to be explicit about the out_axes anyway. The best replacement for the behavior is probably to just create an out_axes pytree explicitly from a filter_spec, like the following

out_axes = jax.tree.map(lambda x: 0 if x else None, filter_spec)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

2 participants