Skip to content

Commit

Permalink
Merge branch 'main' into 2024.12
Browse files Browse the repository at this point in the history
  • Loading branch information
ev-br authored Jan 24, 2025
2 parents a448710 + 73f6426 commit 528c172
Show file tree
Hide file tree
Showing 20 changed files with 472 additions and 123 deletions.
7 changes: 1 addition & 6 deletions array_api_compat/common/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,11 +233,6 @@ def unique_values(x: ndarray, /, xp) -> ndarray:
**kwargs,
)

def astype(x: ndarray, dtype: Dtype, /, *, copy: bool = True) -> ndarray:
if not copy and dtype == x.dtype:
return x
return x.astype(dtype=dtype, copy=copy)

# These functions have different keyword argument names

def std(
Expand Down Expand Up @@ -579,7 +574,7 @@ def sign(x: ndarray, /, xp, **kwargs) -> ndarray:
'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
'astype', 'std', 'var', 'cumulative_sum', 'cumulative_prod', 'clip', 'permute_dims',
'std', 'var', 'cumulative_sum', 'cumulative_prod', 'clip', 'permute_dims',
'reshape', 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc',
'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype',
'unstack', 'sign']
135 changes: 108 additions & 27 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import inspect
import warnings

def _is_jax_zero_gradient_array(x):
def _is_jax_zero_gradient_array(x: object) -> bool:
"""Return True if `x` is a zero-gradient array.
These arrays are a design quirk of Jax that may one day be removed.
Expand All @@ -32,7 +32,8 @@ def _is_jax_zero_gradient_array(x):

return isinstance(x, np.ndarray) and x.dtype == jax.float0

def is_numpy_array(x):

def is_numpy_array(x: object) -> bool:
"""
Return True if `x` is a NumPy array.
Expand Down Expand Up @@ -63,7 +64,8 @@ def is_numpy_array(x):
return (isinstance(x, (np.ndarray, np.generic))
and not _is_jax_zero_gradient_array(x))

def is_cupy_array(x):

def is_cupy_array(x: object) -> bool:
"""
Return True if `x` is a CuPy array.
Expand Down Expand Up @@ -93,7 +95,8 @@ def is_cupy_array(x):
# TODO: Should we reject ndarray subclasses?
return isinstance(x, cp.ndarray)

def is_torch_array(x):

def is_torch_array(x: object) -> bool:
"""
Return True if `x` is a PyTorch tensor.
Expand All @@ -120,7 +123,8 @@ def is_torch_array(x):
# TODO: Should we reject ndarray subclasses?
return isinstance(x, torch.Tensor)

def is_ndonnx_array(x):

def is_ndonnx_array(x: object) -> bool:
"""
Return True if `x` is a ndonnx Array.
Expand All @@ -147,7 +151,8 @@ def is_ndonnx_array(x):

return isinstance(x, ndx.Array)

def is_dask_array(x):

def is_dask_array(x: object) -> bool:
"""
Return True if `x` is a dask.array Array.
Expand All @@ -174,7 +179,8 @@ def is_dask_array(x):

return isinstance(x, dask.array.Array)

def is_jax_array(x):

def is_jax_array(x: object) -> bool:
"""
Return True if `x` is a JAX array.
Expand Down Expand Up @@ -202,6 +208,7 @@ def is_jax_array(x):

return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x)


def is_pydata_sparse_array(x) -> bool:
"""
Return True if `x` is an array from the `sparse` package.
Expand Down Expand Up @@ -231,7 +238,8 @@ def is_pydata_sparse_array(x) -> bool:
# TODO: Account for other backends.
return isinstance(x, sparse.SparseArray)

def is_array_api_obj(x):

def is_array_api_obj(x: object) -> bool:
"""
Return True if `x` is an array API compatible array object.
Expand All @@ -254,10 +262,12 @@ def is_array_api_obj(x):
or is_pydata_sparse_array(x) \
or hasattr(x, '__array_namespace__')

def _compat_module_name():

def _compat_module_name() -> str:
assert __name__.endswith('.common._helpers')
return __name__.removesuffix('.common._helpers')


def is_numpy_namespace(xp) -> bool:
"""
Returns True if `xp` is a NumPy namespace.
Expand All @@ -278,6 +288,7 @@ def is_numpy_namespace(xp) -> bool:
"""
return xp.__name__ in {'numpy', _compat_module_name() + '.numpy'}


def is_cupy_namespace(xp) -> bool:
"""
Returns True if `xp` is a CuPy namespace.
Expand All @@ -298,6 +309,7 @@ def is_cupy_namespace(xp) -> bool:
"""
return xp.__name__ in {'cupy', _compat_module_name() + '.cupy'}


def is_torch_namespace(xp) -> bool:
"""
Returns True if `xp` is a PyTorch namespace.
Expand All @@ -319,7 +331,7 @@ def is_torch_namespace(xp) -> bool:
return xp.__name__ in {'torch', _compat_module_name() + '.torch'}


def is_ndonnx_namespace(xp):
def is_ndonnx_namespace(xp) -> bool:
"""
Returns True if `xp` is an NDONNX namespace.
Expand All @@ -337,7 +349,8 @@ def is_ndonnx_namespace(xp):
"""
return xp.__name__ == 'ndonnx'

def is_dask_namespace(xp):

def is_dask_namespace(xp) -> bool:
"""
Returns True if `xp` is a Dask namespace.
Expand All @@ -357,7 +370,8 @@ def is_dask_namespace(xp):
"""
return xp.__name__ in {'dask.array', _compat_module_name() + '.dask.array'}

def is_jax_namespace(xp):

def is_jax_namespace(xp) -> bool:
"""
Returns True if `xp` is a JAX namespace.
Expand All @@ -378,7 +392,8 @@ def is_jax_namespace(xp):
"""
return xp.__name__ in {'jax.numpy', 'jax.experimental.array_api'}

def is_pydata_sparse_namespace(xp):

def is_pydata_sparse_namespace(xp) -> bool:
"""
Returns True if `xp` is a pydata/sparse namespace.
Expand All @@ -396,7 +411,8 @@ def is_pydata_sparse_namespace(xp):
"""
return xp.__name__ == 'sparse'

def is_array_api_strict_namespace(xp):

def is_array_api_strict_namespace(xp) -> bool:
"""
Returns True if `xp` is an array-api-strict namespace.
Expand All @@ -414,13 +430,15 @@ def is_array_api_strict_namespace(xp):
"""
return xp.__name__ == 'array_api_strict'

def _check_api_version(api_version):

def _check_api_version(api_version: str) -> None:
if api_version in ['2021.12', '2022.12']:
warnings.warn(f"The {api_version} version of the array API specification was requested but the returned namespace is actually version 2023.12")
elif api_version is not None and api_version not in ['2021.12', '2022.12',
'2023.12']:
raise ValueError("Only the 2023.12 version of the array API specification is currently supported")


def array_namespace(*xs, api_version=None, use_compat=None):
"""
Get the array API compatible namespace for the arrays `xs`.
Expand Down Expand Up @@ -631,13 +649,9 @@ def device(x: Array, /) -> Device:
return "cpu"
elif is_dask_array(x):
# Peek at the metadata of the jax array to determine type
try:
import numpy as np
if isinstance(x._meta, np.ndarray):
# Must be on CPU since backed by numpy
return "cpu"
except ImportError:
pass
if is_numpy_array(x._meta):
# Must be on CPU since backed by numpy
return "cpu"
return _DASK_DEVICE
elif is_jax_array(x):
# JAX has .device() as a method, but it is being deprecated so that it
Expand Down Expand Up @@ -788,24 +802,30 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
return x.to_device(device, stream=stream)


def size(x):
def size(x: Array) -> int | None:
"""
Return the total number of elements of x.
This is equivalent to `x.size` according to the `standard
<https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.size.html>`__.
This helper is included because PyTorch defines `size` in an
:external+torch:meth:`incompatible way <torch.Tensor.size>`.
It also fixes dask.array's behaviour which returns nan for unknown sizes, whereas
the standard requires None.
"""
# Lazy API compliant arrays, such as ndonnx, can contain None in their shape
if None in x.shape:
return None
return math.prod(x.shape)
out = math.prod(x.shape)
# dask.array.Array.shape can contain NaN
return None if math.isnan(out) else out


def is_writeable_array(x) -> bool:
def is_writeable_array(x: object) -> bool:
"""
Return False if ``x.__setitem__`` is expected to raise; True otherwise.
Return False if `x` is not an array API compatible object.
Warning
-------
Expand All @@ -816,7 +836,67 @@ def is_writeable_array(x) -> bool:
return x.flags.writeable
if is_jax_array(x) or is_pydata_sparse_array(x):
return False
return True
return is_array_api_obj(x)


def is_lazy_array(x: object) -> bool:
"""Return True if x is potentially a future or it may be otherwise impossible or
expensive to eagerly read its contents, regardless of their size, e.g. by
calling ``bool(x)`` or ``float(x)``.
Return False otherwise; e.g. ``bool(x)`` etc. is guaranteed to succeed and to be
cheap as long as the array has the right dtype and size.
Note
----
This function errs on the side of caution for array types that may or may not be
lazy, e.g. JAX arrays, by always returning True for them.
"""
if (
is_numpy_array(x)
or is_cupy_array(x)
or is_torch_array(x)
or is_pydata_sparse_array(x)
):
return False

# **JAX note:** while it is possible to determine if you're inside or outside
# jax.jit by testing the subclass of a jax.Array object, as well as testing bool()
# as we do below for unknown arrays, this is not recommended by JAX best practices.

# **Dask note:** Dask eagerly computes the graph on __bool__, __float__, and so on.
# This behaviour, while impossible to change without breaking backwards
# compatibility, is highly detrimental to performance as the whole graph will end
# up being computed multiple times.

if is_jax_array(x) or is_dask_array(x) or is_ndonnx_array(x):
return True

if not is_array_api_obj(x):
return False

# Unknown Array API compatible object. Note that this test may have dire consequences
# in terms of performance, e.g. for a lazy object that eagerly computes the graph
# on __bool__ (dask is one such example, which however is special-cased above).

# Select a single point of the array
s = size(x)
if s is None:
return True
xp = array_namespace(x)
if s > 1:
x = xp.reshape(x, (-1,))[0]
# Cast to dtype=bool and deal with size 0 arrays
x = xp.any(x)

try:
bool(x)
return False
# The Array API standard dictactes that __bool__ should raise TypeError if the
# output cannot be defined.
# Here we allow for it to raise arbitrary exceptions, e.g. like Dask does.
except Exception:
return True


__all__ = [
Expand All @@ -840,6 +920,7 @@ def is_writeable_array(x) -> bool:
"is_pydata_sparse_array",
"is_pydata_sparse_namespace",
"is_writeable_array",
"is_lazy_array",
"size",
"to_device",
]
Expand Down
22 changes: 18 additions & 4 deletions array_api_compat/cupy/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import cupy as cp

from ..common import _aliases
from ..common import _aliases, _helpers
from .._internal import get_xp

from ._info import __array_namespace_info__
Expand Down Expand Up @@ -46,7 +46,6 @@
unique_counts = get_xp(cp)(_aliases.unique_counts)
unique_inverse = get_xp(cp)(_aliases.unique_inverse)
unique_values = get_xp(cp)(_aliases.unique_values)
astype = _aliases.astype
std = get_xp(cp)(_aliases.std)
var = get_xp(cp)(_aliases.var)
cumulative_sum = get_xp(cp)(_aliases.cumulative_sum)
Expand Down Expand Up @@ -111,6 +110,21 @@ def asarray(

return cp.array(obj, dtype=dtype, **kwargs)


def astype(
x: ndarray,
dtype: Dtype,
/,
*,
copy: bool = True,
device: Optional[Device] = None,
) -> ndarray:
if device is None:
return x.astype(dtype=dtype, copy=copy)
out = _helpers.to_device(x.astype(dtype=dtype, copy=False), device)
return out.copy() if copy and out is x else out


# These functions are completely new here. If the library already has them
# (i.e., numpy 2.0), use the library version instead of our wrapper.
if hasattr(cp, 'vecdot'):
Expand All @@ -128,10 +142,10 @@ def asarray(
else:
unstack = get_xp(cp)(_aliases.unstack)

__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'bool',
__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'astype',
'acos', 'acosh', 'asin', 'asinh', 'atan',
'atan2', 'atanh', 'bitwise_left_shift',
'bitwise_invert', 'bitwise_right_shift',
'concat', 'pow', 'sign']
'bool', 'concat', 'pow', 'sign']

_all_ignore = ['cp', 'get_xp']
Loading

0 comments on commit 528c172

Please sign in to comment.