Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Realistic scenario in which jax array is static #798

Closed
lockwo opened this issue Aug 12, 2024 · 4 comments
Closed

Realistic scenario in which jax array is static #798

lockwo opened this issue Aug 12, 2024 · 4 comments

Comments

@lockwo
Copy link
Contributor

lockwo commented Aug 12, 2024

I've seen several cases recently in which users set jax arrays as a static field (incorrectly) which results in hard to debug results. My question is this, is there ever a use case in which setting a jax array as a static field is not an error/incorrect thing to do (I couldn't think of any since the static values should be hashable and jax arrays aren't)? If not, can we add an error to the module parsing if someone tries to set a jax array as static?

@patrick-kidger
Copy link
Owner

We actually do have a few cases in which we do this internal to Equinox. For example, the _nonvmapd here:

return _vmapd, Static((_nonvmapd, _out_axes))

That said, scrutinising them a little more carefully, it's not obvious to me that this is actually required. For example in the above case I think it should still be fine to pass them across the boundary with out_axes=None for that argument, and the above implementation may just be an artifact of how things were originally set up.


Generally speaking static is used to make a tree-map ignore a particular operation. This isn't technically the same as requiring hashability. It just gets conflated since one major use case is the tree-map'ing when JIT'ing, and for that all non-traced arguments must be hashable.


Overall I think if we wanted to make this an error then I'd be happy to do so. We'd need to be sure that all of the internal use cases of this can be expressed in different ways though. If you decide to take a look at this, then as always, happy to receive a PR!

@lockwo
Copy link
Contributor Author

lockwo commented Aug 12, 2024

"This isn't technically the same as requiring hashability." I thought static values get made static by being set as auxiliary variables to the pytree, which jax indicates should be hashable?

"Overall I think if we wanted to make this an error then I'd be happy to do so" sg I will look into it

@patrick-kidger
Copy link
Owner

by being set as auxiliary variables to the pytree, which jax indicates should be hashable?

Yup, but JAX doesn't actually require hashability of these unless you attempt to hash the jax.tree_util.tree_stucture(pytree) of the pytree, I don't think.

@lockwo
Copy link
Contributor Author

lockwo commented Aug 13, 2024

#800 (comment)

@lockwo lockwo closed this as completed Aug 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants