Skip to content

Commit

Permalink
Merge pull request #238 from crusaderky/jax_jit_device
Browse files Browse the repository at this point in the history
  • Loading branch information
lucascolley authored Jan 26, 2025
2 parents 73f6426 + e7e71fb commit 0d66f80
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 18 deletions.
25 changes: 16 additions & 9 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,20 +648,24 @@ def device(x: Array, /) -> Device:
if is_numpy_array(x):
return "cpu"
elif is_dask_array(x):
# Peek at the metadata of the jax array to determine type
# Peek at the metadata of the Dask array to determine type
if is_numpy_array(x._meta):
# Must be on CPU since backed by numpy
return "cpu"
return _DASK_DEVICE
elif is_jax_array(x):
# JAX has .device() as a method, but it is being deprecated so that it
# can become a property, in accordance with the standard. In order for
# this function to not break when JAX makes the flip, we check for
# both here.
if inspect.ismethod(x.device):
return x.device()
# FIXME Jitted JAX arrays do not have a device attribute
# https://github.com/jax-ml/jax/issues/26000
# Return None in this case. Note that this workaround breaks
# the standard and will result in new arrays being created on the
# default device instead of the same device as the input array(s).
x_device = getattr(x, 'device', None)
# Older JAX releases had .device() as a method, which has been replaced
# with a property in accordance with the standard.
if inspect.ismethod(x_device):
return x_device()
else:
return x.device
return x_device
elif is_pydata_sparse_array(x):
# `sparse` will gain `.device`, so check for this first.
x_device = getattr(x, 'device', None)
Expand Down Expand Up @@ -792,8 +796,11 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
raise ValueError(f"Unsupported device {device!r}")
elif is_jax_array(x):
if not hasattr(x, "__array_namespace__"):
# In JAX v0.4.31 and older, this import adds to_device method to x.
# In JAX v0.4.31 and older, this import adds to_device method to x...
import jax.experimental.array_api # noqa: F401
# ... but only on eager JAX. It won't work inside jax.jit.
if not hasattr(x, "to_device"):
return x
return x.to_device(device, stream=stream)
elif is_pydata_sparse_array(x) and device == _device(x):
# Perform trivial check to return the same array if
Expand Down
11 changes: 5 additions & 6 deletions tests/test_array_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_array_namespace(library, api_version, use_compat):
if use_compat and library not in wrapped_libraries:
pytest.raises(ValueError, lambda: array_namespace(array, use_compat=use_compat))
return
namespace = array_api_compat.array_namespace(array, api_version=api_version, use_compat=use_compat)
namespace = array_namespace(array, api_version=api_version, use_compat=use_compat)

if use_compat is False or use_compat is None and library not in wrapped_libraries:
if library == "jax.numpy" and use_compat is None:
Expand All @@ -44,7 +44,7 @@ def test_array_namespace(library, api_version, use_compat):

if library == "numpy":
# check that the same namespace is returned for NumPy scalars
scalar_namespace = array_api_compat.array_namespace(
scalar_namespace = array_namespace(
xp.float64(0.0), api_version=api_version, use_compat=use_compat
)
assert scalar_namespace == namespace
Expand Down Expand Up @@ -75,8 +75,7 @@ def test_array_namespace(library, api_version, use_compat):
def test_jax_zero_gradient():
jx = jax.numpy.arange(4)
jax_zero = jax.vmap(jax.grad(jax.numpy.float32, allow_int=True))(jx)
assert (array_api_compat.get_namespace(jax_zero) is
array_api_compat.get_namespace(jx))
assert array_namespace(jax_zero) is array_namespace(jx)

def test_array_namespace_errors():
pytest.raises(TypeError, lambda: array_namespace([1]))
Expand All @@ -91,7 +90,7 @@ def test_array_namespace_errors_torch():
x = np.asarray([1, 2])
pytest.raises(TypeError, lambda: array_namespace(x, y))

def test_api_version():
def test_api_version_torch():
x = torch.asarray([1, 2])
torch_ = import_("torch", wrapper=True)
assert array_namespace(x, api_version="2023.12") == torch_
Expand All @@ -113,7 +112,7 @@ def test_api_version():

def test_get_namespace():
# Backwards compatible wrapper
assert array_api_compat.get_namespace is array_api_compat.array_namespace
assert array_api_compat.get_namespace is array_namespace

def test_python_scalars():
a = torch.asarray([1, 2])
Expand Down
9 changes: 6 additions & 3 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
import numpy as np
import array
from numpy.testing import assert_allclose
from numpy.testing import assert_equal

from array_api_compat import ( # noqa: F401
is_numpy_array, is_cupy_array, is_torch_array,
Expand Down Expand Up @@ -195,7 +195,10 @@ def test_device(library):
dev = device(x)

x2 = to_device(x, dev)
assert device(x) == device(x2)
assert device(x2) == device(x)

x3 = xp.asarray(x, device=dev)
assert device(x3) == device(x)


@pytest.mark.parametrize("library", wrapped_libraries)
Expand All @@ -214,7 +217,7 @@ def test_to_device_host(library):
# a `device(x)` query; however, what's really important
# here is that we can test portably after calling
# to_device(x, "cpu") to return to host
assert_allclose(x, expected)
assert_equal(x, expected)


@pytest.mark.parametrize("target_library", is_array_functions.keys())
Expand Down
34 changes: 34 additions & 0 deletions tests/test_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import jax
import jax.numpy as jnp
from numpy.testing import assert_equal
import pytest

from array_api_compat import device, to_device

HAS_JAX_0_4_31 = jax.__version__ >= "0.4.31"


@pytest.mark.parametrize(
"func",
[
lambda x: jnp.zeros(1, device=device(x)),
lambda x: jnp.zeros_like(jnp.ones(1, device=device(x))),
lambda x: jnp.zeros_like(jnp.empty(1, device=device(x))),
lambda x: jnp.full(1, fill_value=0, device=device(x)),
pytest.param(
lambda x: jnp.asarray([0], device=device(x)),
marks=pytest.mark.skipif(
not HAS_JAX_0_4_31, reason="asarray() has no device= parameter"
),
),
lambda x: to_device(jnp.zeros(1), device(x)),
]
)
def test_device_jit(func):
# Test work around to https://github.com/jax-ml/jax/issues/26000
# Also test missing to_device() method in JAX < 0.4.31
# when inside jax.jit, even after importing jax.experimental.array_api

x = jnp.ones(1)
assert_equal(func(x), jnp.asarray([0]))
assert_equal(jax.jit(func)(x), jnp.asarray([0]))

0 comments on commit 0d66f80

Please sign in to comment.