Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support the copy keyword in asarray #119

Merged
merged 41 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
e75ba03
Fix some internal documentation
asmeurer Mar 19, 2024
41719af
Factor out the list of wrapped libraries in the tests
asmeurer Mar 19, 2024
f1068a3
Fix typo
asmeurer Mar 19, 2024
73047a9
Add a test for the copy flag in asarray
asmeurer Mar 20, 2024
fa36c20
Properly test copy=None with the dtype argument
asmeurer Mar 20, 2024
558e4ed
Move the asarray numpy implementation to numpy/_aliases
asmeurer Mar 20, 2024
440e1c1
Update xfails
asmeurer Mar 21, 2024
f57ac54
Add a CuPy specific implementation for asarray
asmeurer Mar 21, 2024
166c650
Update test to test that CuPy does not handle copy=False
asmeurer Mar 21, 2024
a194344
Update cupy buffer protocol copy test
asmeurer Mar 21, 2024
a1eea09
Merge branch 'main' into asarray-copy
asmeurer Mar 21, 2024
d84983a
Structure the copy flag in cupy.asarray better
asmeurer Mar 22, 2024
11d27dd
Remove no longer correct note from docstring
asmeurer Mar 22, 2024
f6b5ea2
Add dask.array specific implementation of asarray()
asmeurer Mar 22, 2024
7c0116c
Run the normal tests against different versions of numpy
asmeurer Mar 22, 2024
4d54461
Fix workflow synatx
asmeurer Mar 22, 2024
d7807e1
Install everything in one pip command
asmeurer Mar 22, 2024
354e007
Drop support for Python 3.8
asmeurer Mar 22, 2024
dfac540
Update extras_require in setup.py
asmeurer Mar 22, 2024
c7b5780
Fix ruff errors
asmeurer Mar 22, 2024
3e1f24c
Only run numpy tests for numpy 1.21
asmeurer Mar 22, 2024
cee1696
Don't include "jax.numpy" in the numpy-only tests
asmeurer Mar 25, 2024
c58fbec
Test Python 3.12 on CI
asmeurer Mar 25, 2024
2e5c759
Skip NumPy 1.21 in Python 3.12
asmeurer Mar 25, 2024
04551ed
Fix bash syntax
asmeurer Mar 25, 2024
f7fb29f
Only run numpy specific tests for numpy=dev
asmeurer Mar 25, 2024
689366a
Fix bash syntax
asmeurer Mar 25, 2024
c060cee
Run tests with -v
asmeurer Mar 25, 2024
4112eaf
Disable other libraries too for the numpy-only tests
asmeurer Mar 25, 2024
b171583
Add requirements-dev.txt
asmeurer Mar 25, 2024
8aa76b7
Add setuptools to requirements-dev.txt
asmeurer Mar 25, 2024
7105866
Use pip for the test install
asmeurer Mar 25, 2024
aeb1cc4
Fix workflow syntax
asmeurer Mar 25, 2024
f7e724e
Try again to fix workflow syntax
asmeurer Mar 25, 2024
d7e8532
Keep trying to fix the workflow syntax
asmeurer Mar 25, 2024
fbd6e1b
Try more syntax fixes
asmeurer Mar 25, 2024
6397108
Fix workflow syntax
asmeurer Mar 25, 2024
d0d068d
Merge branch 'main' into asarray-copy
asmeurer Mar 27, 2024
764d5ab
Better job names for docs build and deploy
asmeurer Mar 27, 2024
d5a7cc6
Merge branch 'main' into asarray-copy
asmeurer Mar 27, 2024
2dcd864
Merge branch 'asarray-copy' of github.com:asmeurer/array-api-compat i…
asmeurer Mar 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion array_api_compat/common/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,9 +291,11 @@ def _asarray(
"""
Array API compatibility wrapper for asarray().

See the corresponding documentation in NumPy/CuPy and/or the array API
See the corresponding documentation in the array library and/or the array API
specification for more details.

'namespace' may be an array module namespace. This is needed to support
conversion of sequences of Python scalars.
"""
if namespace is None:
try:
Expand Down
47 changes: 41 additions & 6 deletions array_api_compat/cupy/_aliases.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from __future__ import annotations

from functools import partial

import cupy as cp

from ..common import _aliases
from .._internal import get_xp

asarray = asarray_cupy = partial(_aliases._asarray, namespace='cupy')
asarray.__doc__ = _aliases._asarray.__doc__
del partial
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Optional, Union
from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol

bool = cp.bool_

Expand Down Expand Up @@ -62,6 +61,42 @@
matrix_transpose = get_xp(cp)(_aliases.matrix_transpose)
tensordot = get_xp(cp)(_aliases.tensordot)

# asarray also adds the copy keyword, which is not present in numpy 1.0.
def asarray(
obj: Union[
ndarray,
bool,
int,
float,
NestedSequence[bool | int | float],
SupportsBufferProtocol,
],
/,
*,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
copy: Optional[bool] = None,
**kwargs,
) -> ndarray:
"""
Array API compatibility wrapper for asarray().

See the corresponding documentation in the array library and/or the array API
specification for more details.

'namespace' may be an array module namespace. This is needed to support
conversion of sequences of Python scalars.
"""
with cp.cuda.Device(device):
# cupy is like NumPy 1.26 (except without _CopyMode). See the comments
# in asarray in numpy/_aliases.py.
if copy is None:
copy = False
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@leofang did you say at the meeting today that cupy 14 would definitely support NumPy 2.0? The meaning of copy=False changed between NumPy 1.26 and 2.0. Is it safe to add a if cupy.__version__ >= 14 check here? If we just leave this wrapper code as it is, it will break once cupy makes the change, because the default would become the array API copy=False semantics.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was TBD so here's my question: I would think like dpctl once CuPy is compliant we'd no longer need array-api-compat coverage? If so, I don't think we need a version check here, as it'd become moot anyway? WDYT?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way the compat library currently works is unconditionally wraps cupy. We could remove the wrapping once cupy is compliant but we'd need a version check to do that too.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean, the version check could be added to, say, __init__.py instead of here, once we know better after which CuPy version all array-api-compat wrappers can just fall back to native CuPy without extra action (=no-ops), right? Maybe I oversimplified it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If cupy is completely compliant and needs no wrappers we could do that. We're not doing that yet for numpy 2.0 but we probably should.

At any rate, I've thought of a better way to structure this code so that it won't actually break when cupy changes the meaning of copy=False, so this isn't actually important. I'll still need to update this to remove the NotImplementedError for copy=False, but is isn't pressing to try to do it ahead of time now because I've made it so that the default behavior won't break, and that would be a new feature for cupy anyways.

elif copy is False:
raise NotImplementedError("asarray(copy=False) is not yet supported in cupy")

return cp.array(obj, copy=copy, dtype=dtype, **kwargs)

# These functions are completely new here. If the library already has them
# (i.e., numpy 2.0), use the library version instead of our wrapper.
if hasattr(cp, 'vecdot'):
Expand All @@ -73,7 +108,7 @@
else:
isdtype = get_xp(cp)(_aliases.isdtype)

__all__ = _aliases.__all__ + ['asarray', 'asarray_cupy', 'bool', 'acos',
__all__ = _aliases.__all__ + ['asarray', 'bool', 'acos',
'acosh', 'asin', 'asinh', 'atan', 'atan2',
'atanh', 'bitwise_left_shift', 'bitwise_invert',
'bitwise_right_shift', 'concat', 'pow']
Expand Down
6 changes: 6 additions & 0 deletions array_api_compat/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,10 @@

from ..common._helpers import * # noqa: F403

try:
# Used in asarray(). Not present in older versions.
from numpy import _CopyMode

Check failure on line 26 in array_api_compat/numpy/__init__.py

View workflow job for this annotation

GitHub Actions / check-ruff

Ruff (F401)

array_api_compat/numpy/__init__.py:26:23: F401 `numpy._CopyMode` imported but unused; consider using `importlib.util.find_spec` to test for availability
except ImportError:
pass

__array_api_version__ = '2022.12'
70 changes: 64 additions & 6 deletions array_api_compat/numpy/_aliases.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from __future__ import annotations

from functools import partial

from ..common import _aliases

from .._internal import get_xp

asarray = asarray_numpy = partial(_aliases._asarray, namespace='numpy')
asarray.__doc__ = _aliases._asarray.__doc__
del partial
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Optional, Union
from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol

import numpy as np
bool = np.bool_
Expand Down Expand Up @@ -62,6 +61,65 @@
matrix_transpose = get_xp(np)(_aliases.matrix_transpose)
tensordot = get_xp(np)(_aliases.tensordot)

def _supports_buffer_protocol(obj):
try:
memoryview(obj)
except TypeError:
return False
return True

# asarray also adds the copy keyword, which is not present in numpy 1.0.
# asarray() is different enough between numpy, cupy, and dask, the logic
# complicated enough that it's easier to define it separately for each module
# rather than trying to combine everything into one function in common/
def asarray(
obj: Union[
ndarray,
bool,
int,
float,
NestedSequence[bool | int | float],
SupportsBufferProtocol,
],
/,
*,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
copy: "Optional[Union[bool, np._CopyMode]]" = None,
**kwargs,
) -> ndarray:
"""
Array API compatibility wrapper for asarray().

See the corresponding documentation in the array library and/or the array API
specification for more details.

'namespace' may be an array module namespace. This is needed to support
conversion of sequences of Python scalars.
"""
if np.__version__[0] >= '2':
# NumPy 2.0 asarray() is completely array API compatible. No need for
# the complicated logic below
return np.asarray(obj, dtype=dtype, device=device, copy=copy, **kwargs)

if device not in ["cpu", None]:
raise ValueError(f"Unsupported device for NumPy: {device!r}")

if hasattr(np, '_CopyMode'):
if copy is None:
copy = np._CopyMode.IF_NEEDED
elif copy is False:
copy = np._CopyMode.NEVER
elif copy is True:
copy = np._CopyMode.ALWAYS
else:
# Not present in older NumPys. In this case, we cannot really support
# copy=False.
if copy is False:
raise NotImplementedError("asarray(copy=False) requires a newer version of NumPy.")

return np.array(obj, copy=copy, dtype=dtype, **kwargs)

# These functions are completely new here. If the library already has them
# (i.e., numpy 2.0), use the library version instead of our wrapper.
if hasattr(np, 'vecdot'):
Expand All @@ -73,7 +131,7 @@
else:
isdtype = get_xp(np)(_aliases.isdtype)

__all__ = _aliases.__all__ + ['asarray', 'asarray_numpy', 'bool', 'acos',
__all__ = _aliases.__all__ + ['asarray', 'bool', 'acos',
'acosh', 'asin', 'asinh', 'atan', 'atan2',
'atanh', 'bitwise_left_shift', 'bitwise_invert',
'bitwise_right_shift', 'concat', 'pow']
Expand Down
3 changes: 3 additions & 0 deletions tests/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import pytest


wrapped_libraries = ["numpy", "cupy", "torch", "dask.array"]
all_libraries = wrapped_libraries + ["jax.numpy"]

def import_(library, wrapper=False):
if 'jax' in library and sys.version_info < (3, 9):
pytest.skip('JAX array API support does not support Python 3.8')
Expand Down
4 changes: 2 additions & 2 deletions tests/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@

import sys

from ._helpers import import_
from ._helpers import import_, wrapped_libraries

import pytest

@pytest.mark.parametrize("library", ["common", "cupy", "numpy", "torch", "dask.array"])
@pytest.mark.parametrize("library", ["common"] + wrapped_libraries)
def test_all(library):
import_(library, wrapper=True)

Expand Down
4 changes: 2 additions & 2 deletions tests/test_array_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
import array_api_compat
from array_api_compat import array_namespace

from ._helpers import import_
from ._helpers import import_, all_libraries

@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array", "jax.numpy"])
@pytest.mark.parametrize("library", all_libraries)
@pytest.mark.parametrize("api_version", [None, "2021.12"])
def test_array_namespace(library, api_version):
xp = import_(library)
Expand Down
103 changes: 99 additions & 4 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@

from array_api_compat import is_array_api_obj, device, to_device

from ._helpers import import_
from ._helpers import import_, wrapped_libraries, all_libraries

import pytest
import numpy as np
import array
from numpy.testing import assert_allclose

is_functions = {
Expand All @@ -29,7 +30,7 @@ def test_is_xp_array(library, func):

assert is_array_api_obj(x)

@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array", "jax.numpy"])
@pytest.mark.parametrize("library", all_libraries)
def test_device(library):
xp = import_(library, wrapper=True)

Expand All @@ -43,7 +44,7 @@ def test_device(library):
assert device(x) == device(x2)


@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"])
@pytest.mark.parametrize("library", wrapped_libraries)
def test_to_device_host(library):
# different libraries have different semantics
# for DtoH transfers; ensure that we support a portable
Expand All @@ -64,7 +65,7 @@ def test_to_device_host(library):

@pytest.mark.parametrize("target_library", is_functions.keys())
@pytest.mark.parametrize("source_library", is_functions.keys())
def test_asarray(source_library, target_library, request):
def test_asarray_cross_library(source_library, target_library, request):
if source_library == "dask.array" and target_library == "torch":
# Allow rest of test to execute instead of immediately xfailing
# xref https://github.com/pandas-dev/pandas/issues/38902
Expand All @@ -83,3 +84,97 @@ def test_asarray(source_library, target_library, request):
b = tgt_lib.asarray(a)

assert is_tgt_type(b), f"Expected {b} to be a {tgt_lib.ndarray}, but was {type(b)}"

@pytest.mark.parametrize("library", wrapped_libraries)
def test_asarray_copy(library):
# Note, we have this test here because the test suite currently doesn't
# test the copy flag to asarray() very rigorously. Once
# https://github.com/data-apis/array-api-tests/issues/241 is fixed we
# should be able to delete this.
xp = import_(library, wrapper=True)
asarray = xp.asarray
is_lib_func = globals()[is_functions[library]]
all = xp.all

if library == 'numpy' and xp.__version__[0] < '2' and not hasattr(xp, '_CopyMode') :
supports_copy_false = False
elif library == 'cupy':
supports_copy_false = False
else:
supports_copy_false = True

a = asarray([1])
b = asarray(a, copy=True)
assert is_lib_func(b)
a[0] = 0
assert all(b[0] == 1)
assert all(a[0] == 0)

a = asarray([1])
if supports_copy_false:
b = asarray(a, copy=False)
assert is_lib_func(b)
a[0] = 0
assert all(b[0] == 0)
else:
pytest.raises(NotImplementedError, lambda: asarray(a, copy=False))

a = asarray([1])
if supports_copy_false:
pytest.raises(ValueError, lambda: asarray(a, copy=False,
dtype=xp.float64))
else:
pytest.raises(NotImplementedError, lambda: asarray(a, copy=False, dtype=xp.float64))

a = asarray([1])
b = asarray(a, copy=None)
assert is_lib_func(b)
a[0] = 0
assert all(b[0] == 0)

a = asarray([1.0], dtype=xp.float32)
b = asarray(a, dtype=xp.float64, copy=None)
assert is_lib_func(b)
a[0] = 0.0
assert all(b[0] == 1.0)

a = asarray([1.0], dtype=xp.float64)
b = asarray(a, dtype=xp.float64, copy=None)
assert is_lib_func(b)
a[0] = 0.0
assert all(b[0] == 0.0)

# Python built-in types
for obj in [True, 0, 0.0, 0j, [0], [[0]]]:
asarray(obj, copy=True) # No error
asarray(obj, copy=None) # No error
if supports_copy_false:
pytest.raises(ValueError, lambda: asarray(obj, copy=False))
else:
pytest.raises(NotImplementedError, lambda: asarray(obj, copy=False))

# Use the standard library array to test the buffer protocol
a = array.array('f', [1.0])
b = asarray(a, copy=True)
assert is_lib_func(b)
a[0] = 0.0
assert all(b[0] == 1.0)

a = array.array('f', [1.0])
if supports_copy_false:
b = asarray(a, copy=False)
assert is_lib_func(b)
a[0] = 0.0
assert all(b[0] == 0.0)
else:
pytest.raises(NotImplementedError, lambda: asarray(a, copy=False))

a = array.array('f', [1.0])
b = asarray(a, copy=None)
assert is_lib_func(b)
a[0] = 0.0
if library == 'cupy':
# A copy is required for libraries where the default device is not CPU
assert all(b[0] == 1.0)
else:
assert all(b[0] == 0.0)
6 changes: 3 additions & 3 deletions tests/test_isdtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import pytest

from ._helpers import import_
from ._helpers import import_, wrapped_libraries

# Check the known dtypes by their string names

Expand Down Expand Up @@ -64,7 +64,7 @@ def isdtype_(dtype_, kind):
assert type(res) is bool # noqa: E721
return res

@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array", "jax.numpy"])
@pytest.mark.parametrize("library", wrapped_libraries)
def test_isdtype_spec_dtypes(library):
xp = import_(library, wrapper=True)

Expand Down Expand Up @@ -98,7 +98,7 @@ def test_isdtype_spec_dtypes(library):
'bfloat16',
]

@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array", "jax.numpy"])
@pytest.mark.parametrize("library", wrapped_libraries)
@pytest.mark.parametrize("dtype_", additional_dtypes)
def test_isdtype_additional_dtypes(library, dtype_):
xp = import_(library, wrapper=True)
Expand Down
Loading