diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 25419c01..e7a868c9 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -178,7 +178,7 @@ def _check_api_version(api_version): elif api_version is not None and api_version != '2022.12': raise ValueError("Only the 2022.12 version of the array API specification is currently supported") -def array_namespace(*xs, api_version=None, _use_compat=True): +def array_namespace(*xs, api_version=None, use_compat=None): """ Get the array API compatible namespace for the arrays `xs`. @@ -191,6 +191,12 @@ def array_namespace(*xs, api_version=None, _use_compat=True): The newest version of the spec that you need support for (currently the compat library wrapped APIs support v2022.12). + use_compat: bool or None + If None (the default), the native namespace will be returned if it is + already array API compatible, otherwise a compat wrapper is used. If + True, the compat library wrapped library will be returned. If False, + the native library namespace is returned. + Returns ------- @@ -234,46 +240,66 @@ def your_function(x, y): is_jax_array """ + if use_compat not in [None, True, False]: + raise ValueError("use_compat must be None, True, or False") + + _use_compat = use_compat in [None, True] + namespaces = set() for x in xs: if is_numpy_array(x): - _check_api_version(api_version) - if _use_compat: - from .. import numpy as numpy_namespace + from .. import numpy as numpy_namespace + import numpy as np + if use_compat is True: + _check_api_version(api_version) namespaces.add(numpy_namespace) - else: - import numpy as np + elif use_compat is False: namespaces.add(np) + else: + # numpy 2.0 has __array_namespace__ and is fully array API + # compatible. + if hasattr(x, '__array_namespace__'): + namespaces.add(x.__array_namespace__(api_version=api_version)) + else: + namespaces.add(numpy_namespace) elif is_cupy_array(x): - _check_api_version(api_version) if _use_compat: + _check_api_version(api_version) from .. import cupy as cupy_namespace namespaces.add(cupy_namespace) else: import cupy as cp namespaces.add(cp) elif is_torch_array(x): - _check_api_version(api_version) if _use_compat: + _check_api_version(api_version) from .. import torch as torch_namespace namespaces.add(torch_namespace) else: import torch namespaces.add(torch) elif is_dask_array(x): - _check_api_version(api_version) if _use_compat: + _check_api_version(api_version) from ..dask import array as dask_namespace namespaces.add(dask_namespace) else: - raise TypeError("_use_compat cannot be False if input array is a dask array!") + import dask.array as da + namespaces.add(da) elif is_jax_array(x): - _check_api_version(api_version) - # jax.experimental.array_api is already an array namespace. We do - # not have a wrapper submodule for it. - import jax.experimental.array_api as jnp + if use_compat is True: + _check_api_version(api_version) + raise ValueError("JAX does not have an array-api-compat wrapper") + elif use_compat is False: + import jax.numpy as jnp + else: + # jax.experimental.array_api is already an array namespace. We do + # not have a wrapper submodule for it. + import jax.experimental.array_api as jnp namespaces.add(jnp) elif hasattr(x, '__array_namespace__'): + if use_compat is True: + raise ValueError("The given array does not have an array-api-compat wrapper") namespaces.add(x.__array_namespace__(api_version=api_version)) else: # TODO: Support Python scalars? diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 28f02f36..70378716 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -94,11 +94,6 @@ def asarray( See the corresponding documentation in the array library and/or the array API specification for more details. """ - if np.__version__[0] >= '2': - # NumPy 2.0 asarray() is completely array API compatible. No need for - # the complicated logic below - return np.asarray(obj, dtype=dtype, device=device, copy=copy, **kwargs) - if device not in ["cpu", None]: raise ValueError(f"Unsupported device for NumPy: {device!r}") diff --git a/numpy-dev-xfails.txt b/numpy-dev-xfails.txt index 51ff34ad..54c6b963 100644 --- a/numpy-dev-xfails.txt +++ b/numpy-dev-xfails.txt @@ -1,17 +1,6 @@ -# asarray(copy=False) is not yet implemented -array_api_tests/test_creation_functions.py::test_asarray_arrays - # finfo(float32).eps returns float32 but should return float array_api_tests/test_data_type_functions.py::test_finfo[float32] -# Array methods and attributes not already on np.ndarray cannot be wrapped -array_api_tests/test_has_names.py::test_has_names[array_method-to_device] -array_api_tests/test_has_names.py::test_has_names[array_attribute-device] - -# linalg tests require cleanups -# https://github.com/data-apis/array-api-tests/pull/101 -array_api_tests/test_linalg.py::test_solve - # NumPy deviates in some special cases for floordiv array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] diff --git a/tests/_helpers.py b/tests/_helpers.py index fc6b3e04..ffa2171e 100644 --- a/tests/_helpers.py +++ b/tests/_helpers.py @@ -2,8 +2,11 @@ import pytest -wrapped_libraries = ["numpy", "cupy", "torch", "dask.array"] -all_libraries = wrapped_libraries + ["jax.numpy"] +wrapped_libraries = ["cupy", "torch", "dask.array"] +all_libraries = wrapped_libraries + ["numpy", "jax.numpy"] +import numpy as np +if np.__version__[0] == '1': + wrapped_libraries.append("numpy") def import_(library, wrapper=False): if library == 'cupy': diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py index 18880597..78705189 100644 --- a/tests/test_array_namespace.py +++ b/tests/test_array_namespace.py @@ -9,24 +9,29 @@ import array_api_compat from array_api_compat import array_namespace -from ._helpers import import_, all_libraries +from ._helpers import import_, all_libraries, wrapped_libraries -@pytest.mark.parametrize("library", all_libraries) -@pytest.mark.parametrize("api_version", [None, "2021.12"]) -def test_array_namespace(library, api_version): +@pytest.mark.parametrize("use_compat", [True, False, None]) +@pytest.mark.parametrize("api_version", [None, "2021.12", "2022.12"]) +@pytest.mark.parametrize("library", all_libraries + ['array_api_strict']) +def test_array_namespace(library, api_version, use_compat): xp = import_(library) array = xp.asarray([1.0, 2.0, 3.0]) - namespace = array_api_compat.array_namespace(array, api_version=api_version) + if use_compat is True and library in ['array_api_strict', 'jax.numpy']: + pytest.raises(ValueError, lambda: array_namespace(array, use_compat=use_compat)) + return + namespace = array_api_compat.array_namespace(array, api_version=api_version, use_compat=use_compat) - if "array_api" in library: - assert namespace == xp + if use_compat is False or use_compat is None and library not in wrapped_libraries: + if library == "jax.numpy" and use_compat is None: + import jax.experimental.array_api + assert namespace == jax.experimental.array_api + else: + assert namespace == xp else: if library == "dask.array": assert namespace == array_api_compat.dask.array - elif library == "jax.numpy": - import jax.experimental.array_api - assert namespace == jax.experimental.array_api else: assert namespace == getattr(array_api_compat, library) @@ -64,14 +69,14 @@ def test_array_namespace_errors_torch(): pytest.raises(TypeError, lambda: array_namespace(x, y)) def test_api_version(): - x = np.asarray([1, 2]) - np_ = import_("numpy", wrapper=True) - assert array_namespace(x, api_version="2022.12") == np_ - assert array_namespace(x, api_version=None) == np_ - assert array_namespace(x) == np_ + x = torch.asarray([1, 2]) + torch_ = import_("torch", wrapper=True) + assert array_namespace(x, api_version="2022.12") == torch_ + assert array_namespace(x, api_version=None) == torch_ + assert array_namespace(x) == torch_ # Should issue a warning with warnings.catch_warnings(record=True) as w: - assert array_namespace(x, api_version="2021.12") == np_ + assert array_namespace(x, api_version="2021.12") == torch_ assert len(w) == 1 assert "2021.12" in str(w[0].message) diff --git a/tests/test_no_dependencies.py b/tests/test_no_dependencies.py index 391b55a0..8ad71a3c 100644 --- a/tests/test_no_dependencies.py +++ b/tests/test_no_dependencies.py @@ -11,8 +11,6 @@ import sys import subprocess -from ._helpers import import_ - import pytest class Array: @@ -54,6 +52,9 @@ def _test_dependency(mod): @pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array", "jax.numpy", "array_api_strict"]) def test_numpy_dependency(library): + # This import is here because it imports numpy + from ._helpers import import_ + # This unfortunately won't go through any of the pytest machinery. We # reraise the exception as an AssertionError so that pytest will show it # in a semi-reasonable way diff --git a/torch-xfails.txt b/torch-xfails.txt index a9106fae..577f4640 100644 --- a/torch-xfails.txt +++ b/torch-xfails.txt @@ -183,6 +183,7 @@ array_api_tests/test_statistical_functions.py::test_var # The test suite is incorrectly checking sums that have loss of significance # (https://github.com/data-apis/array-api-tests/issues/168) array_api_tests/test_statistical_functions.py::test_sum +array_api_tests/test_statistical_functions.py::test_prod # These functions do not yet support complex numbers array_api_tests/test_operators_and_elementwise_functions.py::test_sign