diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 95b2c5ad..434e7d87 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -648,20 +648,24 @@ def device(x: Array, /) -> Device: if is_numpy_array(x): return "cpu" elif is_dask_array(x): - # Peek at the metadata of the jax array to determine type + # Peek at the metadata of the Dask array to determine type if is_numpy_array(x._meta): # Must be on CPU since backed by numpy return "cpu" return _DASK_DEVICE elif is_jax_array(x): - # JAX has .device() as a method, but it is being deprecated so that it - # can become a property, in accordance with the standard. In order for - # this function to not break when JAX makes the flip, we check for - # both here. - if inspect.ismethod(x.device): - return x.device() + # FIXME Jitted JAX arrays do not have a device attribute + # https://github.com/jax-ml/jax/issues/26000 + # Return None in this case. Note that this workaround breaks + # the standard and will result in new arrays being created on the + # default device instead of the same device as the input array(s). + x_device = getattr(x, 'device', None) + # Older JAX releases had .device() as a method, which has been replaced + # with a property in accordance with the standard. + if inspect.ismethod(x_device): + return x_device() else: - return x.device + return x_device elif is_pydata_sparse_array(x): # `sparse` will gain `.device`, so check for this first. x_device = getattr(x, 'device', None) @@ -792,8 +796,11 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] raise ValueError(f"Unsupported device {device!r}") elif is_jax_array(x): if not hasattr(x, "__array_namespace__"): - # In JAX v0.4.31 and older, this import adds to_device method to x. + # In JAX v0.4.31 and older, this import adds to_device method to x... import jax.experimental.array_api # noqa: F401 + # ... but only on eager JAX. It won't work inside jax.jit. + if not hasattr(x, "to_device"): + return x return x.to_device(device, stream=stream) elif is_pydata_sparse_array(x) and device == _device(x): # Perform trivial check to return the same array if diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py index b19ee1bf..a66a64d9 100644 --- a/tests/test_array_namespace.py +++ b/tests/test_array_namespace.py @@ -22,7 +22,7 @@ def test_array_namespace(library, api_version, use_compat): if use_compat and library not in wrapped_libraries: pytest.raises(ValueError, lambda: array_namespace(array, use_compat=use_compat)) return - namespace = array_api_compat.array_namespace(array, api_version=api_version, use_compat=use_compat) + namespace = array_namespace(array, api_version=api_version, use_compat=use_compat) if use_compat is False or use_compat is None and library not in wrapped_libraries: if library == "jax.numpy" and use_compat is None: @@ -44,7 +44,7 @@ def test_array_namespace(library, api_version, use_compat): if library == "numpy": # check that the same namespace is returned for NumPy scalars - scalar_namespace = array_api_compat.array_namespace( + scalar_namespace = array_namespace( xp.float64(0.0), api_version=api_version, use_compat=use_compat ) assert scalar_namespace == namespace @@ -75,8 +75,7 @@ def test_array_namespace(library, api_version, use_compat): def test_jax_zero_gradient(): jx = jax.numpy.arange(4) jax_zero = jax.vmap(jax.grad(jax.numpy.float32, allow_int=True))(jx) - assert (array_api_compat.get_namespace(jax_zero) is - array_api_compat.get_namespace(jx)) + assert array_namespace(jax_zero) is array_namespace(jx) def test_array_namespace_errors(): pytest.raises(TypeError, lambda: array_namespace([1])) @@ -91,7 +90,7 @@ def test_array_namespace_errors_torch(): x = np.asarray([1, 2]) pytest.raises(TypeError, lambda: array_namespace(x, y)) -def test_api_version(): +def test_api_version_torch(): x = torch.asarray([1, 2]) torch_ = import_("torch", wrapper=True) assert array_namespace(x, api_version="2023.12") == torch_ @@ -113,7 +112,7 @@ def test_api_version(): def test_get_namespace(): # Backwards compatible wrapper - assert array_api_compat.get_namespace is array_api_compat.array_namespace + assert array_api_compat.get_namespace is array_namespace def test_python_scalars(): a = torch.asarray([1, 2]) diff --git a/tests/test_common.py b/tests/test_common.py index 95c916ba..e95e305e 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -3,7 +3,7 @@ import pytest import numpy as np import array -from numpy.testing import assert_allclose +from numpy.testing import assert_equal from array_api_compat import ( # noqa: F401 is_numpy_array, is_cupy_array, is_torch_array, @@ -195,7 +195,10 @@ def test_device(library): dev = device(x) x2 = to_device(x, dev) - assert device(x) == device(x2) + assert device(x2) == device(x) + + x3 = xp.asarray(x, device=dev) + assert device(x3) == device(x) @pytest.mark.parametrize("library", wrapped_libraries) @@ -214,7 +217,7 @@ def test_to_device_host(library): # a `device(x)` query; however, what's really important # here is that we can test portably after calling # to_device(x, "cpu") to return to host - assert_allclose(x, expected) + assert_equal(x, expected) @pytest.mark.parametrize("target_library", is_array_functions.keys()) diff --git a/tests/test_jax.py b/tests/test_jax.py new file mode 100644 index 00000000..e33cec02 --- /dev/null +++ b/tests/test_jax.py @@ -0,0 +1,34 @@ +import jax +import jax.numpy as jnp +from numpy.testing import assert_equal +import pytest + +from array_api_compat import device, to_device + +HAS_JAX_0_4_31 = jax.__version__ >= "0.4.31" + + +@pytest.mark.parametrize( + "func", + [ + lambda x: jnp.zeros(1, device=device(x)), + lambda x: jnp.zeros_like(jnp.ones(1, device=device(x))), + lambda x: jnp.zeros_like(jnp.empty(1, device=device(x))), + lambda x: jnp.full(1, fill_value=0, device=device(x)), + pytest.param( + lambda x: jnp.asarray([0], device=device(x)), + marks=pytest.mark.skipif( + not HAS_JAX_0_4_31, reason="asarray() has no device= parameter" + ), + ), + lambda x: to_device(jnp.zeros(1), device(x)), + ] +) +def test_device_jit(func): + # Test work around to https://github.com/jax-ml/jax/issues/26000 + # Also test missing to_device() method in JAX < 0.4.31 + # when inside jax.jit, even after importing jax.experimental.array_api + + x = jnp.ones(1) + assert_equal(func(x), jnp.asarray([0])) + assert_equal(jax.jit(func)(x), jnp.asarray([0]))