Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't wrap NumPy 2.0 at all #126

Merged
merged 7 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 40 additions & 14 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def _check_api_version(api_version):
elif api_version is not None and api_version != '2022.12':
raise ValueError("Only the 2022.12 version of the array API specification is currently supported")

def array_namespace(*xs, api_version=None, _use_compat=True):
def array_namespace(*xs, api_version=None, use_compat=None):
"""
Get the array API compatible namespace for the arrays `xs`.

Expand All @@ -191,6 +191,12 @@ def array_namespace(*xs, api_version=None, _use_compat=True):
The newest version of the spec that you need support for (currently
the compat library wrapped APIs support v2022.12).

use_compat: bool or None
If None (the default), the native namespace will be returned if it is
already array API compatible, otherwise a compat wrapper is used. If
True, the compat library wrapped library will be returned. If False,
the native library namespace is returned.

Returns
-------

Expand Down Expand Up @@ -234,46 +240,66 @@ def your_function(x, y):
is_jax_array

"""
if use_compat not in [None, True, False]:
raise ValueError("use_compat must be None, True, or False")

_use_compat = use_compat in [None, True]

namespaces = set()
for x in xs:
if is_numpy_array(x):
_check_api_version(api_version)
if _use_compat:
from .. import numpy as numpy_namespace
from .. import numpy as numpy_namespace
import numpy as np
if use_compat is True:
_check_api_version(api_version)
namespaces.add(numpy_namespace)
else:
import numpy as np
elif use_compat is False:
namespaces.add(np)
else:
# numpy 2.0 has __array_namespace__ and is fully array API
# compatible.
if hasattr(x, '__array_namespace__'):
namespaces.add(x.__array_namespace__(api_version=api_version))
else:
namespaces.add(numpy_namespace)
elif is_cupy_array(x):
_check_api_version(api_version)
if _use_compat:
_check_api_version(api_version)
from .. import cupy as cupy_namespace
namespaces.add(cupy_namespace)
else:
import cupy as cp
namespaces.add(cp)
elif is_torch_array(x):
_check_api_version(api_version)
if _use_compat:
_check_api_version(api_version)
from .. import torch as torch_namespace
namespaces.add(torch_namespace)
else:
import torch
namespaces.add(torch)
elif is_dask_array(x):
_check_api_version(api_version)
if _use_compat:
_check_api_version(api_version)
from ..dask import array as dask_namespace
namespaces.add(dask_namespace)
else:
raise TypeError("_use_compat cannot be False if input array is a dask array!")
import dask.array as da
namespaces.add(da)
elif is_jax_array(x):
_check_api_version(api_version)
# jax.experimental.array_api is already an array namespace. We do
# not have a wrapper submodule for it.
import jax.experimental.array_api as jnp
if use_compat is True:
_check_api_version(api_version)
raise ValueError("JAX does not have an array-api-compat wrapper")
elif use_compat is False:
import jax.numpy as jnp
else:
# jax.experimental.array_api is already an array namespace. We do
# not have a wrapper submodule for it.
import jax.experimental.array_api as jnp
namespaces.add(jnp)
elif hasattr(x, '__array_namespace__'):
if use_compat is True:
raise ValueError("The given array does not have an array-api-compat wrapper")
namespaces.add(x.__array_namespace__(api_version=api_version))
else:
# TODO: Support Python scalars?
Expand Down
5 changes: 0 additions & 5 deletions array_api_compat/numpy/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,6 @@ def asarray(
See the corresponding documentation in the array library and/or the array API
specification for more details.
"""
if np.__version__[0] >= '2':
# NumPy 2.0 asarray() is completely array API compatible. No need for
# the complicated logic below
return np.asarray(obj, dtype=dtype, device=device, copy=copy, **kwargs)

if device not in ["cpu", None]:
raise ValueError(f"Unsupported device for NumPy: {device!r}")

Expand Down
11 changes: 0 additions & 11 deletions numpy-dev-xfails.txt
Original file line number Diff line number Diff line change
@@ -1,17 +1,6 @@
# asarray(copy=False) is not yet implemented
array_api_tests/test_creation_functions.py::test_asarray_arrays

# finfo(float32).eps returns float32 but should return float
array_api_tests/test_data_type_functions.py::test_finfo[float32]

# Array methods and attributes not already on np.ndarray cannot be wrapped
array_api_tests/test_has_names.py::test_has_names[array_method-to_device]
array_api_tests/test_has_names.py::test_has_names[array_attribute-device]

# linalg tests require cleanups
# https://github.com/data-apis/array-api-tests/pull/101
array_api_tests/test_linalg.py::test_solve

# NumPy deviates in some special cases for floordiv
array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity]
array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity]
Expand Down
7 changes: 5 additions & 2 deletions tests/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@

import pytest

wrapped_libraries = ["numpy", "cupy", "torch", "dask.array"]
all_libraries = wrapped_libraries + ["jax.numpy"]
wrapped_libraries = ["cupy", "torch", "dask.array"]
all_libraries = wrapped_libraries + ["numpy", "jax.numpy"]
import numpy as np
if np.__version__[0] == '1':
wrapped_libraries.append("numpy")

def import_(library, wrapper=False):
if library == 'cupy':
Expand Down
37 changes: 21 additions & 16 deletions tests/test_array_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,29 @@
import array_api_compat
from array_api_compat import array_namespace

from ._helpers import import_, all_libraries
from ._helpers import import_, all_libraries, wrapped_libraries

@pytest.mark.parametrize("library", all_libraries)
@pytest.mark.parametrize("api_version", [None, "2021.12"])
def test_array_namespace(library, api_version):
@pytest.mark.parametrize("use_compat", [True, False, None])
@pytest.mark.parametrize("api_version", [None, "2021.12", "2022.12"])
@pytest.mark.parametrize("library", all_libraries + ['array_api_strict'])
def test_array_namespace(library, api_version, use_compat):
xp = import_(library)

array = xp.asarray([1.0, 2.0, 3.0])
namespace = array_api_compat.array_namespace(array, api_version=api_version)
if use_compat is True and library in ['array_api_strict', 'jax.numpy']:
pytest.raises(ValueError, lambda: array_namespace(array, use_compat=use_compat))
return
namespace = array_api_compat.array_namespace(array, api_version=api_version, use_compat=use_compat)

if "array_api" in library:
assert namespace == xp
if use_compat is False or use_compat is None and library not in wrapped_libraries:
if library == "jax.numpy" and use_compat is None:
import jax.experimental.array_api
assert namespace == jax.experimental.array_api
else:
assert namespace == xp
else:
if library == "dask.array":
assert namespace == array_api_compat.dask.array
elif library == "jax.numpy":
import jax.experimental.array_api
assert namespace == jax.experimental.array_api
else:
assert namespace == getattr(array_api_compat, library)

Expand Down Expand Up @@ -64,14 +69,14 @@ def test_array_namespace_errors_torch():
pytest.raises(TypeError, lambda: array_namespace(x, y))

def test_api_version():
x = np.asarray([1, 2])
np_ = import_("numpy", wrapper=True)
assert array_namespace(x, api_version="2022.12") == np_
assert array_namespace(x, api_version=None) == np_
assert array_namespace(x) == np_
x = torch.asarray([1, 2])
torch_ = import_("torch", wrapper=True)
assert array_namespace(x, api_version="2022.12") == torch_
assert array_namespace(x, api_version=None) == torch_
assert array_namespace(x) == torch_
# Should issue a warning
with warnings.catch_warnings(record=True) as w:
assert array_namespace(x, api_version="2021.12") == np_
assert array_namespace(x, api_version="2021.12") == torch_
assert len(w) == 1
assert "2021.12" in str(w[0].message)

Expand Down
5 changes: 3 additions & 2 deletions tests/test_no_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
import sys
import subprocess

from ._helpers import import_

import pytest

class Array:
Expand Down Expand Up @@ -54,6 +52,9 @@ def _test_dependency(mod):
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array",
"jax.numpy", "array_api_strict"])
def test_numpy_dependency(library):
# This import is here because it imports numpy
from ._helpers import import_

# This unfortunately won't go through any of the pytest machinery. We
# reraise the exception as an AssertionError so that pytest will show it
# in a semi-reasonable way
Expand Down
1 change: 1 addition & 0 deletions torch-xfails.txt
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ array_api_tests/test_statistical_functions.py::test_var
# The test suite is incorrectly checking sums that have loss of significance
# (https://github.com/data-apis/array-api-tests/issues/168)
array_api_tests/test_statistical_functions.py::test_sum
array_api_tests/test_statistical_functions.py::test_prod

# These functions do not yet support complex numbers
array_api_tests/test_operators_and_elementwise_functions.py::test_sign
Expand Down
Loading