Skip to content

xla_gpu_deterministic_ops=true breaks simple indexing in jax #27796

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
shehper opened this issue Apr 7, 2025 · 6 comments
Open

xla_gpu_deterministic_ops=true breaks simple indexing in jax #27796

shehper opened this issue Apr 7, 2025 · 6 comments
Labels
bug Something isn't working

Comments

@shehper
Copy link

shehper commented Apr 7, 2025

Description

In order to make JAX behavior deterministic on GPUs, the recommend solution is to use XLA_FLAGS=--xla_gpu_deterministic_ops=true. However, with jax 0.5.3, jaxlib 0.5.3, this breaks simple indexing behavior.

For example,

import os
os.environ["XLA_FLAGS"] = ""
import jax.numpy as jnp

arr = jnp.array([1, 2, 0, 0])
positions = jnp.arange(arr.shape[0])
new_arr = jnp.zeros_like(arr)
new_arr = new_arr.at[positions].set(arr)
print(new_arr)

prints the expected [1 2 0 0], while setting the xla_gpu_deterministic_ops flag to true as

import os
os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true"
import jax.numpy as jnp

arr = jnp.array([1, 2, 0, 0])
positions = jnp.arange(arr.shape[0])
new_arr = jnp.zeros_like(arr)
new_arr = new_arr.at[positions].set(arr)
print(new_arr)

returns [0, 2, 0, 0].

This issue is potentially related to the issue # 26836.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.5.3
jaxlib: 0.5.3
numpy:  2.2.4
python: 3.10.15 | packaged by conda-forge | (main, Oct 16 2024, 01:24:24) [GCC 13.3.0]
device info: NVIDIA GeForce RTX 4090-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='apollo', release='6.8.0-57-generic', version='#59~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Wed Mar 19 17:07:41 UTC 2', machine='x86_64')


$ nvidia-smi
Mon Apr  7 14:40:36 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.144.03             Driver Version: 550.144.03     CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 4090        On  |   00000000:01:00.0 Off |                  Off |
|  0%   34C    P8             32W /  480W |   18558MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A    230881      C   ...Documents/AC-SolverX/env/bin/python      18550MiB |
+-----------------------------------------------------------------------------------------+
@shehper shehper added the bug Something isn't working label Apr 7, 2025
@Aniketsy
Copy link

Aniketsy commented Apr 9, 2025

By using Additional for loop we can solve this . I have looked into it got this :
"XLA's implementation of deterministic scatter operations is a behavior, not an issue, aimed at ensuring consistent results across different executions, but it can come with a performance cost "

Please correct me if I am wrong .

import os
os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true"
import jax.numpy as jnp

arr = jnp.array([1, 2, 0, 0])
positions = jnp.arange(arr.shape[0])
new_arr = jnp.zeros_like(arr)

for i, pos in enumerate(positions):
new_arr = new_arr.at[pos].set(arr[i])

print(new_arr)

@shehper
Copy link
Author

shehper commented Apr 9, 2025

That seems to work!

Do you know if there is a guide somewhere that lists all such changes in behavior (and their fixes) that may arise with --xla_gpu_deterministic_ops=true flag? I have a medium-sized codebase which I want to run with this flag, and I am not sure a-priori what changes in the entire codebase I should make to avoid any errors.

@Aniketsy
Copy link

Aniketsy commented Apr 9, 2025

I’m not sure if there’s a comprehensive guide that covers all the changes and fixes required when using the --xla_gpu_deterministic_ops=true flag, but I’d be happy to help. Let me know how if I can support you further—whether it’s debugging any issues or reviewing parts of your code.

@alessandrofasse
Copy link

I am also facing the same issue. For me the deterministic computation is important and I can not remove it. Any suggestions from the JAX team on how to handle this?

@alessandrofasse
Copy link

Any news on this? Or can we push this somehow?

@Aniketsy
Copy link

I apologize for the late response, I will look more into it and try to solve this .

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants