Skip to content

Commit

Permalink
BUG: .device attribute inside jax.jit
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Jan 21, 2025
1 parent 8a79994 commit 2aeb387
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 17 deletions.
18 changes: 11 additions & 7 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,14 +658,18 @@ def device(x: Array, /) -> Device:
pass
return _DASK_DEVICE
elif is_jax_array(x):
# JAX has .device() as a method, but it is being deprecated so that it
# can become a property, in accordance with the standard. In order for
# this function to not break when JAX makes the flip, we check for
# both here.
if inspect.ismethod(x.device):
return x.device()
# FIXME Jitted JAX arrays do not have a device attribute
# https://github.com/jax-ml/jax/issues/26000
# Return None in this case. Note that this workaround breaks
# the standard and will result in new arrays being created on the
# default device instead of the same device as the input array(s).
x_device = getattr(x, 'device', None)
# Older JAX releases had .device() as a method, which has been replaced
# with a property in accordance with the standard.
if inspect.ismethod(x_device):
return x_device()
else:
return x.device
return x_device
elif is_pydata_sparse_array(x):
# `sparse` will gain `.device`, so check for this first.
x_device = getattr(x, 'device', None)
Expand Down
12 changes: 5 additions & 7 deletions tests/test_array_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,10 @@ def test_array_namespace(library, api_version, use_compat):
if use_compat is True and library in {'array_api_strict', 'jax.numpy', 'sparse'}:
pytest.raises(ValueError, lambda: array_namespace(array, use_compat=use_compat))
return
namespace = array_api_compat.array_namespace(array, api_version=api_version, use_compat=use_compat)
namespace = array_namespace(array, api_version=api_version, use_compat=use_compat)

if use_compat is False or use_compat is None and library not in wrapped_libraries:
if library == "jax.numpy" and use_compat is None:
import jax.numpy
if hasattr(jax.numpy, "__array_api_version__"):

Check failure on line 29 in tests/test_array_namespace.py

View workflow job for this annotation

GitHub Actions / check-ruff

Ruff (F823)

tests/test_array_namespace.py:29:24: F823 Local variable `jax` referenced before assignment
# JAX v0.4.32 or later uses jax.numpy directly
assert namespace == jax.numpy
Expand All @@ -44,7 +43,7 @@ def test_array_namespace(library, api_version, use_compat):

if library == "numpy":
# check that the same namespace is returned for NumPy scalars
scalar_namespace = array_api_compat.array_namespace(
scalar_namespace = array_namespace(
xp.float64(0.0), api_version=api_version, use_compat=use_compat
)
assert scalar_namespace == namespace
Expand Down Expand Up @@ -75,8 +74,7 @@ def test_array_namespace(library, api_version, use_compat):
def test_jax_zero_gradient():
jx = jax.numpy.arange(4)
jax_zero = jax.vmap(jax.grad(jax.numpy.float32, allow_int=True))(jx)
assert (array_api_compat.get_namespace(jax_zero) is
array_api_compat.get_namespace(jx))
assert array_namespace(jax_zero) is array_namespace(jx)

def test_array_namespace_errors():
pytest.raises(TypeError, lambda: array_namespace([1]))
Expand All @@ -91,7 +89,7 @@ def test_array_namespace_errors_torch():
x = np.asarray([1, 2])
pytest.raises(TypeError, lambda: array_namespace(x, y))

def test_api_version():
def test_api_version_torch():
x = torch.asarray([1, 2])
torch_ = import_("torch", wrapper=True)
assert array_namespace(x, api_version="2023.12") == torch_
Expand All @@ -113,7 +111,7 @@ def test_api_version():

def test_get_namespace():
# Backwards compatible wrapper
assert array_api_compat.get_namespace is array_api_compat.array_namespace
assert array_api_compat.get_namespace is array_namespace

def test_python_scalars():
a = torch.asarray([1, 2])
Expand Down
9 changes: 6 additions & 3 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
import numpy as np
import array
from numpy.testing import assert_allclose
from numpy.testing import assert_equal

from array_api_compat import ( # noqa: F401
is_numpy_array, is_cupy_array, is_torch_array,
Expand Down Expand Up @@ -188,7 +188,10 @@ def test_device(library):
dev = device(x)

x2 = to_device(x, dev)
assert device(x) == device(x2)
assert device(x2) == device(x)

x3 = xp.asarray(x, device=dev)
assert device(x3) == device(x)


@pytest.mark.parametrize("library", wrapped_libraries)
Expand All @@ -207,7 +210,7 @@ def test_to_device_host(library):
# a `device(x)` query; however, what's really important
# here is that we can test portably after calling
# to_device(x, "cpu") to return to host
assert_allclose(x, expected)
assert_equal(x, expected)


@pytest.mark.parametrize("target_library", is_array_functions.keys())
Expand Down
20 changes: 20 additions & 0 deletions tests/test_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import jax
import jax.numpy as jnp
from numpy.testing import assert_equal

from array_api_compat import device, to_device


def test_device_jit():
# Test work around to https://github.com/jax-ml/jax/issues/26000
@jax.jit
def f(x):
return jnp.zeros(1, device=device(x))

@jax.jit
def g(x):
return to_device(jnp.zeros(1), device(x))

x = jnp.ones(1)
assert_equal(f(x), jnp.asarray(0))
assert_equal(g(x), jnp.asarray(0))

0 comments on commit 2aeb387

Please sign in to comment.