-
-
Notifications
You must be signed in to change notification settings - Fork 150
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Realistic scenario in which jax array is static #798
Comments
We actually do have a few cases in which we do this internal to Equinox. For example, the Line 209 in 68cc26a
That said, scrutinising them a little more carefully, it's not obvious to me that this is actually required. For example in the above case I think it should still be fine to pass them across the boundary with Generally speaking static is used to make a tree-map ignore a particular operation. This isn't technically the same as requiring hashability. It just gets conflated since one major use case is the tree-map'ing when JIT'ing, and for that all non-traced arguments must be hashable. Overall I think if we wanted to make this an error then I'd be happy to do so. We'd need to be sure that all of the internal use cases of this can be expressed in different ways though. If you decide to take a look at this, then as always, happy to receive a PR! |
"This isn't technically the same as requiring hashability." I thought static values get made static by being set as auxiliary variables to the pytree, which jax indicates should be hashable? "Overall I think if we wanted to make this an error then I'd be happy to do so" sg I will look into it |
Yup, but JAX doesn't actually require hashability of these unless you attempt to hash the |
I've seen several cases recently in which users set jax arrays as a static field (incorrectly) which results in hard to debug results. My question is this, is there ever a use case in which setting a jax array as a static field is not an error/incorrect thing to do (I couldn't think of any since the static values should be hashable and jax arrays aren't)? If not, can we add an error to the module parsing if someone tries to set a jax array as static?
The text was updated successfully, but these errors were encountered: