You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, I noticed that equinox.internal.if_mapped has been removed in a recent version of equinox rather than added to the main package. What is the reason for this / what would a reasonable replacement be to replicate its behavior?
I have highly nested pytrees in my package, so if_mapped turns out to be very useful for saving memory when moving across vmap boundaries.
The text was updated successfully, but these errors were encountered:
So this hooked into some JAX internals that are getting removed in a future version of JAX. You could maybe replicate the same behavior by arranging for something updated and similar yourself, but I'm not sure exactly what.
The fact that this used such internals is why it was undocumented in equinox.internal instead of the main namespace, as I knew this might need to happen at some point!
FWIW I never really found a use-case for this in my own work as it meant that the output shape was hard to track.
This is okay, makes sense! It is probably best practice to be explicit about the out_axes anyway. The best replacement for the behavior is probably to just create an out_axes pytree explicitly from a filter_spec, like the following
Hi, I noticed that
equinox.internal.if_mapped
has been removed in a recent version of equinox rather than added to the main package. What is the reason for this / what would a reasonable replacement be to replicate its behavior?I have highly nested pytrees in my package, so
if_mapped
turns out to be very useful for saving memory when moving across vmap boundaries.The text was updated successfully, but these errors were encountered: