-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
JAX backend fails for latent scan variables #6718
Comments
Possibly related to #6351 The issue is that there is not a 1-to-1 map between the Scan RV and the Scan value variable (due to the weird output of Scan actually being a Slice I think) |
It works with the default backend just fine, I was hoping it was something to do with how the |
Ah okay, that sounds different then |
Regardless of this issue, we should definitely clean up the |
How complex of a fix would that be? |
By the way this seems to be triggered by the |
I don't quite know, but worth a look. An option is to grab the user provided mode and exclude rewrites that are incompatible with JAX (since we know then that we are compiling to JAX. Otherwise we could have an optional kwarg to the dispatch function with the |
Yeah it is trying to feed a numpy generator as input. This would also fix #6697 which would be a big improvement. |
An immediate solution to your problem is to pass a valid model.register_rv(traj, name='traj', initval=np.zeros(100)) |
Good to know! I can try to have a look at the mode problem as well over the next couple days if you're busy with other stuff. |
The error happens because So far nothing terrible, but this Scan has RNG! which are not compatible with JAX. Usually we convert shared RNGs with a warning, but Scan does not show these as shared to the JAXLinker (they get converted to NominalVariables), so no special hackery is done, and JAX gets the numpy RNGs! The error is more understandable then:
In more recent version of JAX it's slightly different:
I think this would be fixed by pymc-devs/pytensor#278, as the RNG is an explicit input for the purposes of the inner function created by Scan. However, in general we shouldn't need to define custom modes, and specially custom modes with different linkers internally. Choosing which rewrites get triggered makes a bit more sense perhaps, but the backend? |
This wouldn't be fixed by pymc-devs/pytensor#278, because the outputs wouldn't be numpy Generators, and when the scan tried to set the shared variables it would fail. We should not allow Scan in the default backend to use a JAX/PyTorch linker (numba should be fine) |
We should just depreciate the mode argument in scan all together no? |
Still on the fence whether we want the control of the rewrites, but we can reassess later if the need shows up |
Describe the issue:
Not sure if this belongs here or in the pytensor repo. Putting it here because the minimal example I can come up with uses PyMC. If you make a scan variable, register it without observations, then use it for further computation, the graph will fail to compile.
Reproduceable code example:
Error message:
PyMC version information:
The text was updated successfully, but these errors were encountered: