Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Passing a vmapped submodule as field #206

Open
mayalenE opened this issue Oct 11, 2022 · 5 comments
Open

Passing a vmapped submodule as field #206

mayalenE opened this issue Oct 11, 2022 · 5 comments
Labels
question User queries

Comments

@mayalenE
Copy link

Hello,

I'm a new user of the equinox library so maybe my problem is obvious to solve but I cannot figure out how to do that properly.

Basically I'm trying to pass a vmapped submodule as a field of another module class.

A minimal example would be something like this:

import optax
class Submodule(eqx.Module):
    A: Array

    def __call__(self, x):  # x of shape N
        return self.A @ x


class Module(eqx.Module):
    submodule: Submodule
    
    def __init__(self, submodule):
        self.submodule = vmap(submodule)

    def __call__(self, x):  # x of shape (B,N)
        return self.submodule(x)


submodule = Submodule(np.zeros((4, 4)))
module = Module(submodule)
x = np.ones((10, 4, 1))
print(module(x).shape)

@eqx.filter_value_and_grad
@eqx.filter_jit
def grad_loss(module, x):
    x = module(x)
    return ((1 - x) ** 2).sum()


@eqx.filter_jit
def make_step(module, x, opt_state):
    loss, grads = grad_loss(module, x)
    updates, opt_state = optim.update(grads, opt_state)
    module = eqx.apply_updates(module, updates)
    return loss, module, opt_state


optim = optax.adam(0.2)
opt_state = optim.init(module)
for optim_step_idx in range(40):
    loss, module, opt_state = make_step(module, x, opt_state)
    print(loss, module.submodule.A.sum())

The above code returns the following error:
image

The forward pass is working but not when combined with optax.

Any ideas how to do that?
Thanks,

Maya

@patrick-kidger
Copy link
Owner

Summary: what's going on here is that whilst submodule is a PyTree, vmap(submodule) is just a function, this isn't a PyTree, and this isn't something Optax knows how to optimise.


To expand on that.

Take a look at the source code for vmap. You'll see that what it does is define a new function on-the-fly and then returns that.

This breaks the PyTree structure of your overall model. When you pass module into optim.init, then module.submodule is just an arbitrary Python function, and Optax has no idea how to treat that as a parameter.

The fix is pretty simple: use self.submodule = eqx.filter_vmap(submodule) instead. This does return a PyTree, so things should work as expected.

This is completely unrelated to the "filtering" part of filter_vmap, by the way. The filtered transformations also fix a lot of edge cases!


(In addition you should likely use optim.init(eqx.filter(module, eqx.is_inexact_array)); see the FAQ.)

@mayalenE
Copy link
Author

Thanks for the detailed answer Patrick, very clear!

The filter_* is a nice fix to return a PyTree indeed. Using the eqx.filter_vmap and the eqx.filter(module, eqx.is_inexact_array) works well in the above example :)

However it is a bit confusing when I don't want to use the smart filtering and want to specify in_axes and out_axes to my vmap function. Can I define custom filter_vmap or use the in_axes and out_axes argument like in the original vmap?

@mayalenE
Copy link
Author

mayalenE commented Oct 12, 2022

Oh I found an old issue #87 on that -> using the args=in_axes and out=out_axes seems to work!
I'm not sure to understand what these are doing I have to dig more into the way filters work.

Thanks

@patrick-kidger patrick-kidger added the question User queries label May 7, 2023
@Tsarpf
Copy link

Tsarpf commented Feb 21, 2024

Hi,

I just hit this issue while using the filtering:

opt_state = optim.init(eqx.filter(module, eqx.is_inexact_array));

but having vmapped submodules/fields, and not realizing to use eqx.filter_vmap, (changing to which solved the issue, thanks!)

This setup lead to all the vmapped submodules in the opt_state tree having their value set to None, and so these modules' weight updates were silently being skipped.

Maybe an error could be thrown somewhere, or it could be mentioned in the tutorials somewhere? If not, good luck to the next person who hits the same problem 😅

@patrick-kidger
Copy link
Owner

Agreed! I've just opened #665 to try and detect this case and emit a warning.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

3 participants