-
Notifications
You must be signed in to change notification settings - Fork 3k
xla_gpu_deterministic_ops=true breaks simple indexing in jax #27796
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
By using Additional for loop we can solve this . I have looked into it got this : Please correct me if I am wrong . import os arr = jnp.array([1, 2, 0, 0]) for i, pos in enumerate(positions): print(new_arr) |
That seems to work! Do you know if there is a guide somewhere that lists all such changes in behavior (and their fixes) that may arise with --xla_gpu_deterministic_ops=true flag? I have a medium-sized codebase which I want to run with this flag, and I am not sure a-priori what changes in the entire codebase I should make to avoid any errors. |
I’m not sure if there’s a comprehensive guide that covers all the changes and fixes required when using the --xla_gpu_deterministic_ops=true flag, but I’d be happy to help. Let me know how if I can support you further—whether it’s debugging any issues or reviewing parts of your code. |
I am also facing the same issue. For me the deterministic computation is important and I can not remove it. Any suggestions from the JAX team on how to handle this? |
Any news on this? Or can we push this somehow? |
I apologize for the late response, I will look more into it and try to solve this . |
Description
In order to make JAX behavior deterministic on GPUs, the recommend solution is to use
XLA_FLAGS=--xla_gpu_deterministic_ops=true
. However, with jax 0.5.3, jaxlib 0.5.3, this breaks simple indexing behavior.For example,
prints the expected
[1 2 0 0]
, while setting the xla_gpu_deterministic_ops flag to true asreturns
[0, 2, 0, 0]
.This issue is potentially related to the issue # 26836.
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: