diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index f9f39230..95b2c5ad 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -649,13 +649,9 @@ def device(x: Array, /) -> Device: return "cpu" elif is_dask_array(x): # Peek at the metadata of the jax array to determine type - try: - import numpy as np - if isinstance(x._meta, np.ndarray): - # Must be on CPU since backed by numpy - return "cpu" - except ImportError: - pass + if is_numpy_array(x._meta): + # Must be on CPU since backed by numpy + return "cpu" return _DASK_DEVICE elif is_jax_array(x): # JAX has .device() as a method, but it is being deprecated so that it diff --git a/docs/supported-array-libraries.md b/docs/supported-array-libraries.md index a016a636..1af9f3dc 100644 --- a/docs/supported-array-libraries.md +++ b/docs/supported-array-libraries.md @@ -137,3 +137,8 @@ The minimum supported Dask version is 2023.12.0. ## [Sparse](https://sparse.pydata.org/en/stable/) Similar to JAX, `sparse` Array API support is contained directly in `sparse`. + +(array-api-strict-support)= +## [array-api-strict](https://data-apis.org/array-api-strict/) + +array-api-strict exists only to test support for the Array API, so it does not need any wrappers. diff --git a/tests/_helpers.py b/tests/_helpers.py index e2a7e1d1..5b79aa46 100644 --- a/tests/_helpers.py +++ b/tests/_helpers.py @@ -1,14 +1,10 @@ from importlib import import_module -import sys import pytest wrapped_libraries = ["numpy", "cupy", "torch", "dask.array"] -all_libraries = wrapped_libraries + ["jax.numpy"] +all_libraries = wrapped_libraries + ["array_api_strict", "jax.numpy", "sparse"] -# `sparse` added array API support as of Python 3.10. -if sys.version_info >= (3, 10): - all_libraries.append('sparse') def import_(library, wrapper=False): if library == 'cupy': @@ -20,9 +16,7 @@ def import_(library, wrapper=False): jax_numpy = import_module("jax.numpy") if not hasattr(jax_numpy, "__array_api_version__"): library = 'jax.experimental.array_api' - elif library.startswith('sparse'): - library = 'sparse' - else: + elif library in wrapped_libraries: library = 'array_api_compat.' + library return import_module(library) diff --git a/tests/test_all.py b/tests/test_all.py index 081bb82b..10a2a95d 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -18,7 +18,10 @@ @pytest.mark.parametrize("library", ["common"] + wrapped_libraries) def test_all(library): - import_(library, wrapper=True) + if library == "common": + import array_api_compat.common # noqa: F401 + else: + import_(library, wrapper=True) for mod_name in sys.modules: if not mod_name.startswith('array_api_compat.' + library): diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py index 9c26371c..b19ee1bf 100644 --- a/tests/test_array_namespace.py +++ b/tests/test_array_namespace.py @@ -14,12 +14,12 @@ @pytest.mark.parametrize("use_compat", [True, False, None]) @pytest.mark.parametrize("api_version", [None, "2021.12", "2022.12", "2023.12"]) -@pytest.mark.parametrize("library", all_libraries + ['array_api_strict']) +@pytest.mark.parametrize("library", all_libraries) def test_array_namespace(library, api_version, use_compat): xp = import_(library) array = xp.asarray([1.0, 2.0, 3.0]) - if use_compat is True and library in {'array_api_strict', 'jax.numpy', 'sparse'}: + if use_compat and library not in wrapped_libraries: pytest.raises(ValueError, lambda: array_namespace(array, use_compat=use_compat)) return namespace = array_api_compat.array_namespace(array, api_version=api_version, use_compat=use_compat) diff --git a/tests/test_common.py b/tests/test_common.py index e702e4a9..95c916ba 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -10,6 +10,7 @@ is_dask_array, is_jax_array, is_pydata_sparse_array, is_numpy_namespace, is_cupy_namespace, is_torch_namespace, is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace, + is_array_api_strict_namespace, ) from array_api_compat import ( @@ -33,6 +34,7 @@ 'dask.array': 'is_dask_namespace', 'jax.numpy': 'is_jax_namespace', 'sparse': 'is_pydata_sparse_namespace', + 'array_api_strict': 'is_array_api_strict_namespace', } @@ -74,7 +76,12 @@ def test_xp_is_array_generics(library): is_func = globals()[func] if is_func(x0): matches.append(library2) - assert matches in ([library], ["numpy"]) + + if library == "array_api_strict": + # There is no is_array_api_strict_array() function + assert matches == [] + else: + assert matches in ([library], ["numpy"]) @pytest.mark.parametrize("library", all_libraries) @@ -213,26 +220,33 @@ def test_to_device_host(library): @pytest.mark.parametrize("target_library", is_array_functions.keys()) @pytest.mark.parametrize("source_library", is_array_functions.keys()) def test_asarray_cross_library(source_library, target_library, request): - if source_library == "dask.array" and target_library == "torch": + def _xfail(reason: str) -> None: # Allow rest of test to execute instead of immediately xfailing # xref https://github.com/pandas-dev/pandas/issues/38902 + request.node.add_marker(pytest.mark.xfail(reason=reason)) + if source_library == "dask.array" and target_library == "torch": # TODO: remove xfail once # https://github.com/dask/dask/issues/8260 is resolved - request.node.add_marker(pytest.mark.xfail(reason="Bug in dask raising error on conversion")) - if source_library == "cupy" and target_library != "cupy": + _xfail(reason="Bug in dask raising error on conversion") + elif source_library == "jax.numpy" and target_library == "torch": + _xfail(reason="casts int to float") + elif source_library == "cupy" and target_library != "cupy": # cupy explicitly disallows implicit conversions to CPU pytest.skip(reason="cupy does not support implicit conversion to CPU") elif source_library == "sparse" and target_library != "sparse": pytest.skip(reason="`sparse` does not allow implicit densification") + src_lib = import_(source_library, wrapper=True) tgt_lib = import_(target_library, wrapper=True) is_tgt_type = globals()[is_array_functions[target_library]] - a = src_lib.asarray([1, 2, 3]) + a = src_lib.asarray([1, 2, 3], dtype=src_lib.int32) b = tgt_lib.asarray(a) assert is_tgt_type(b), f"Expected {b} to be a {tgt_lib.ndarray}, but was {type(b)}" + assert b.dtype == tgt_lib.int32 + @pytest.mark.parametrize("library", wrapped_libraries)