Skip to content

Commit

Permalink
Merge pull request #119 from asmeurer/asarray-copy
Browse files Browse the repository at this point in the history
Support the copy keyword in asarray
  • Loading branch information
asmeurer authored Mar 27, 2024
2 parents ecb4c57 + 2dcd864 commit 311d0aa
Show file tree
Hide file tree
Showing 18 changed files with 316 additions and 136 deletions.
11 changes: 5 additions & 6 deletions .github/workflows/array-api-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.8', '3.9', '3.10', '3.11']
python-version: ['3.9', '3.10', '3.11', '3.12']

steps:
- name: Checkout array-api-compat
Expand All @@ -55,16 +55,15 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
# NumPy 1.21 doesn't support Python 3.11. NumPy 2.0 doesn't support
# Python 3.8. There doesn't seem to be a way to put this in the numpy
# 1.21 config file.
if: "! ((matrix.python-version == '3.11' && inputs.package-name == 'numpy' && contains(inputs.package-version, '1.21')) || (matrix.python-version == '3.8' && inputs.package-name == 'numpy' && contains(inputs.xfails-file-extra, 'dev')))"
# NumPy 1.21 doesn't support Python 3.11. There doesn't seem to be a way
# to put this in the numpy 1.21 config file.
if: "! ((matrix.python-version == '3.11' || matrix.python-version == '3.12') && inputs.package-name == 'numpy' && contains(inputs.package-version, '1.21'))"
run: |
python -m pip install --upgrade pip
python -m pip install '${{ inputs.package-name }} ${{ inputs.package-version }}' ${{ inputs.extra-requires }}
python -m pip install -r ${GITHUB_WORKSPACE}/array-api-tests/requirements.txt
- name: Run the array API testsuite (${{ inputs.package-name }})
if: "! ((matrix.python-version == '3.11' && inputs.package-name == 'numpy' && contains(inputs.package-version, '1.21')) || (matrix.python-version == '3.8' && inputs.package-name == 'numpy' && contains(inputs.xfails-file-extra, 'dev')))"
if: "! ((matrix.python-version == '3.11' || matrix.python-version == '3.12') && inputs.package-name == 'numpy' && contains(inputs.package-version, '1.21'))"
env:
ARRAY_API_TESTS_MODULE: array_api_compat.${{ inputs.module-name || inputs.package-name }}
# This enables the NEP 50 type promotion behavior (without it a lot of
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/docs-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name: Docs Build
on: [push, pull_request]

jobs:
build:
docs-build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/docs-deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ on:
- main

jobs:
deploy:
docs-deploy:
runs-on: ubuntu-latest
environment:
name: docs-deploy
Expand Down
24 changes: 20 additions & 4 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.8', '3.9', '3.10', '3.11']
python-version: ['3.9', '3.10', '3.11', '3.12']
numpy-version: ['1.21', '1.26', 'dev']
exclude:
- python-version: '3.11'
numpy-version: '1.21'
- python-version: '3.12'
numpy-version: '1.21'
fail-fast: true
steps:
- uses: actions/checkout@v4
Expand All @@ -15,11 +21,21 @@ jobs:
- name: Install Dependencies
run: |
python -m pip install --upgrade pip
python -m pip install pytest numpy torch dask[array] jax[cpu]
if [ "${{ matrix.numpy-version }}" == "dev" ]; then
PIP_EXTRA='numpy --pre --extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple'
elif [ "${{ matrix.numpy-version }}" == "1.21" ]; then
PIP_EXTRA='numpy==1.21.*'
else
PIP_EXTRA='numpy==1.26.*'
fi
python -m pip install -r requirements-dev.txt $PIP_EXTRA
- name: Run Tests
run: |
pytest
if [[ "${{ matrix.numpy-version }}" == "1.21" || "${{ matrix.numpy-version }}" == "dev" ]]; then
PYTEST_EXTRA=(-k "numpy and not jax and not torch and not dask")
fi
pytest -v "${PYTEST_EXTRA[@]}"
# Make sure it installs
python setup.py install
python -m pip install .
92 changes: 4 additions & 88 deletions array_api_compat/common/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,18 @@

from typing import TYPE_CHECKING
if TYPE_CHECKING:
import numpy as np
from typing import Optional, Sequence, Tuple, Union
from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol
from ._typing import ndarray, Device, Dtype

from typing import NamedTuple
from types import ModuleType
import inspect

from ._helpers import _check_device, is_numpy_array, array_namespace
from ._helpers import _check_device

# These functions are modified from the NumPy versions.

# Creation functions add the device keyword (which does nothing for NumPy)

def arange(
start: Union[int, float],
/,
Expand Down Expand Up @@ -268,90 +268,6 @@ def var(
def permute_dims(x: ndarray, /, axes: Tuple[int, ...], xp) -> ndarray:
return xp.transpose(x, axes)

# Creation functions add the device keyword (which does nothing for NumPy)

# asarray also adds the copy keyword
def _asarray(
obj: Union[
ndarray,
bool,
int,
float,
NestedSequence[bool | int | float],
SupportsBufferProtocol,
],
/,
*,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
copy: "Optional[Union[bool, np._CopyMode]]" = None,
namespace = None,
**kwargs,
) -> ndarray:
"""
Array API compatibility wrapper for asarray().
See the corresponding documentation in NumPy/CuPy and/or the array API
specification for more details.
"""
if namespace is None:
try:
xp = array_namespace(obj, _use_compat=False)
except ValueError:
# TODO: What about lists of arrays?
raise ValueError("A namespace must be specified for asarray() with non-array input")
elif isinstance(namespace, ModuleType):
xp = namespace
elif namespace == 'numpy':
import numpy as xp
elif namespace == 'cupy':
import cupy as xp
elif namespace == 'dask.array':
import dask.array as xp
else:
raise ValueError("Unrecognized namespace argument to asarray()")

_check_device(xp, device)
if is_numpy_array(obj):
import numpy as np
if hasattr(np, '_CopyMode'):
# Not present in older NumPys
COPY_FALSE = (False, np._CopyMode.IF_NEEDED)
COPY_TRUE = (True, np._CopyMode.ALWAYS)
else:
COPY_FALSE = (False,)
COPY_TRUE = (True,)
else:
COPY_FALSE = (False,)
COPY_TRUE = (True,)
if copy in COPY_FALSE and namespace != "dask.array":
# copy=False is not yet implemented in xp.asarray
raise NotImplementedError("copy=False is not yet implemented")
if (hasattr(xp, "ndarray") and isinstance(obj, xp.ndarray)):
if dtype is not None and obj.dtype != dtype:
copy = True
if copy in COPY_TRUE:
return xp.array(obj, copy=True, dtype=dtype)
return obj
elif namespace == "dask.array":
if copy in COPY_TRUE:
if dtype is None:
return obj.copy()
# Go through numpy, since dask copy is no-op by default
import numpy as np
obj = np.array(obj, dtype=dtype, copy=True)
return xp.array(obj, dtype=dtype)
else:
import dask.array as da
import numpy as np
if not isinstance(obj, da.Array):
obj = np.asarray(obj, dtype=dtype)
return da.from_array(obj)
return obj

return xp.asarray(obj, dtype=dtype, **kwargs)

# np.reshape calls the keyword argument 'newshape' instead of 'shape'
def reshape(x: ndarray,
/,
Expand Down
57 changes: 51 additions & 6 deletions array_api_compat/cupy/_aliases.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from __future__ import annotations

from functools import partial

import cupy as cp

from ..common import _aliases
from .._internal import get_xp

asarray = asarray_cupy = partial(_aliases._asarray, namespace='cupy')
asarray.__doc__ = _aliases._asarray.__doc__
del partial
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Optional, Union
from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol

bool = cp.bool_

Expand Down Expand Up @@ -62,6 +61,52 @@
matrix_transpose = get_xp(cp)(_aliases.matrix_transpose)
tensordot = get_xp(cp)(_aliases.tensordot)

_copy_default = object()

# asarray also adds the copy keyword, which is not present in numpy 1.0.
def asarray(
obj: Union[
ndarray,
bool,
int,
float,
NestedSequence[bool | int | float],
SupportsBufferProtocol,
],
/,
*,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
copy: Optional[bool] = _copy_default,
**kwargs,
) -> ndarray:
"""
Array API compatibility wrapper for asarray().
See the corresponding documentation in the array library and/or the array API
specification for more details.
"""
with cp.cuda.Device(device):
# cupy is like NumPy 1.26 (except without _CopyMode). See the comments
# in asarray in numpy/_aliases.py.
if copy is not _copy_default:
# A future version of CuPy will change the meaning of copy=False
# to mean no-copy. We don't know for certain what version it will
# be yet, so to avoid breaking that version, we use a different
# default value for copy so asarray(obj) with no copy kwarg will
# always do the copy-if-needed behavior.

# This will still need to be updated to remove the
# NotImplementedError for copy=False, but at least this won't
# break the default or existing behavior.
if copy is None:
copy = False
elif copy is False:
raise NotImplementedError("asarray(copy=False) is not yet supported in cupy")
kwargs['copy'] = copy

return cp.array(obj, dtype=dtype, **kwargs)

# These functions are completely new here. If the library already has them
# (i.e., numpy 2.0), use the library version instead of our wrapper.
if hasattr(cp, 'vecdot'):
Expand All @@ -73,7 +118,7 @@
else:
isdtype = get_xp(cp)(_aliases.isdtype)

__all__ = _aliases.__all__ + ['asarray', 'asarray_cupy', 'bool', 'acos',
__all__ = _aliases.__all__ + ['asarray', 'bool', 'acos',
'acosh', 'asin', 'asinh', 'atan', 'atan2',
'atanh', 'bitwise_left_shift', 'bitwise_invert',
'bitwise_right_shift', 'concat', 'pow']
Expand Down
47 changes: 42 additions & 5 deletions array_api_compat/dask/array/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
if TYPE_CHECKING:
from typing import Optional, Union

from ...common._typing import Device, Dtype, Array
from ...common._typing import Device, Dtype, Array, NestedSequence, SupportsBufferProtocol

import dask.array as da

Expand Down Expand Up @@ -76,10 +76,6 @@ def _dask_arange(
arange = get_xp(da)(_dask_arange)
eye = get_xp(da)(_aliases.eye)

from functools import partial
asarray = partial(_aliases._asarray, namespace='dask.array')
asarray.__doc__ = _aliases._asarray.__doc__

linspace = get_xp(da)(_aliases.linspace)
eye = get_xp(da)(_aliases.eye)
UniqueAllResult = get_xp(da)(_aliases.UniqueAllResult)
Expand Down Expand Up @@ -113,6 +109,47 @@ def _dask_arange(
matmul = get_xp(np)(_aliases.matmul)
tensordot = get_xp(np)(_aliases.tensordot)


# asarray also adds the copy keyword, which is not present in numpy 1.0.
def asarray(
obj: Union[
Array,
bool,
int,
float,
NestedSequence[bool | int | float],
SupportsBufferProtocol,
],
/,
*,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
copy: "Optional[Union[bool, np._CopyMode]]" = None,
**kwargs,
) -> Array:
"""
Array API compatibility wrapper for asarray().
See the corresponding documentation in the array library and/or the array API
specification for more details.
"""
if copy is False:
# copy=False is not yet implemented in dask
raise NotImplementedError("copy=False is not yet implemented")
elif copy is True:
if isinstance(obj, da.Array) and dtype is None:
return obj.copy()
# Go through numpy, since dask copy is no-op by default
obj = np.array(obj, dtype=dtype, copy=True)
return da.array(obj, dtype=dtype)
else:
if not isinstance(obj, da.Array) or dtype is not None and obj.dtype != dtype:
obj = np.asarray(obj, dtype=dtype)
return da.from_array(obj)
return obj

return da.asarray(obj, dtype=dtype, **kwargs)

from dask.array import (
# Element wise aliases
arccos as acos,
Expand Down
6 changes: 6 additions & 0 deletions array_api_compat/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,10 @@

from ..common._helpers import * # noqa: F403

try:
# Used in asarray(). Not present in older versions.
from numpy import _CopyMode # noqa: F401
except ImportError:
pass

__array_api_version__ = '2022.12'
Loading

0 comments on commit 311d0aa

Please sign in to comment.