-
-
Notifications
You must be signed in to change notification settings - Fork 150
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Passing a vmapped submodule as field #206
Comments
Summary: what's going on here is that whilst To expand on that. Take a look at the source code for This breaks the PyTree structure of your overall model. When you pass The fix is pretty simple: use This is completely unrelated to the "filtering" part of (In addition you should likely use |
Thanks for the detailed answer Patrick, very clear! The filter_* is a nice fix to return a PyTree indeed. Using the eqx.filter_vmap and the eqx.filter(module, eqx.is_inexact_array) works well in the above example :) However it is a bit confusing when I don't want to use the smart filtering and want to specify in_axes and out_axes to my vmap function. Can I define custom filter_vmap or use the in_axes and out_axes argument like in the original vmap? |
Oh I found an old issue #87 on that -> using the args=in_axes and out=out_axes seems to work! Thanks |
Hi, I just hit this issue while using the filtering:
but having vmapped submodules/fields, and not realizing to use This setup lead to all the vmapped submodules in the Maybe an error could be thrown somewhere, or it could be mentioned in the tutorials somewhere? If not, good luck to the next person who hits the same problem 😅 |
Agreed! I've just opened #665 to try and detect this case and emit a warning. |
Hello,
I'm a new user of the equinox library so maybe my problem is obvious to solve but I cannot figure out how to do that properly.
Basically I'm trying to pass a vmapped submodule as a field of another module class.
A minimal example would be something like this:
The above code returns the following error:
The forward pass is working but not when combined with optax.
Any ideas how to do that?
Thanks,
Maya
The text was updated successfully, but these errors were encountered: