Skip to content

Commit

Permalink
Merge pull request #237 from crusaderky/is_array_object
Browse files Browse the repository at this point in the history
ENH: is_lazy_array and is_writeable_array to return False on non-arrays
  • Loading branch information
ev-br authored Jan 17, 2025
2 parents 9442237 + 0c1b7dd commit 44e1eb3
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 18 deletions.
58 changes: 40 additions & 18 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import inspect
import warnings

def _is_jax_zero_gradient_array(x):
def _is_jax_zero_gradient_array(x: object) -> bool:
"""Return True if `x` is a zero-gradient array.
These arrays are a design quirk of Jax that may one day be removed.
Expand All @@ -32,7 +32,8 @@ def _is_jax_zero_gradient_array(x):

return isinstance(x, np.ndarray) and x.dtype == jax.float0

def is_numpy_array(x):

def is_numpy_array(x: object) -> bool:
"""
Return True if `x` is a NumPy array.
Expand Down Expand Up @@ -63,7 +64,8 @@ def is_numpy_array(x):
return (isinstance(x, (np.ndarray, np.generic))
and not _is_jax_zero_gradient_array(x))

def is_cupy_array(x):

def is_cupy_array(x: object) -> bool:
"""
Return True if `x` is a CuPy array.
Expand Down Expand Up @@ -93,7 +95,8 @@ def is_cupy_array(x):
# TODO: Should we reject ndarray subclasses?
return isinstance(x, cp.ndarray)

def is_torch_array(x):

def is_torch_array(x: object) -> bool:
"""
Return True if `x` is a PyTorch tensor.
Expand All @@ -120,7 +123,8 @@ def is_torch_array(x):
# TODO: Should we reject ndarray subclasses?
return isinstance(x, torch.Tensor)

def is_ndonnx_array(x):

def is_ndonnx_array(x: object) -> bool:
"""
Return True if `x` is a ndonnx Array.
Expand All @@ -147,7 +151,8 @@ def is_ndonnx_array(x):

return isinstance(x, ndx.Array)

def is_dask_array(x):

def is_dask_array(x: object) -> bool:
"""
Return True if `x` is a dask.array Array.
Expand All @@ -174,7 +179,8 @@ def is_dask_array(x):

return isinstance(x, dask.array.Array)

def is_jax_array(x):

def is_jax_array(x: object) -> bool:
"""
Return True if `x` is a JAX array.
Expand Down Expand Up @@ -202,6 +208,7 @@ def is_jax_array(x):

return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x)


def is_pydata_sparse_array(x) -> bool:
"""
Return True if `x` is an array from the `sparse` package.
Expand Down Expand Up @@ -231,7 +238,8 @@ def is_pydata_sparse_array(x) -> bool:
# TODO: Account for other backends.
return isinstance(x, sparse.SparseArray)

def is_array_api_obj(x):

def is_array_api_obj(x: object) -> bool:
"""
Return True if `x` is an array API compatible array object.
Expand All @@ -254,10 +262,12 @@ def is_array_api_obj(x):
or is_pydata_sparse_array(x) \
or hasattr(x, '__array_namespace__')

def _compat_module_name():

def _compat_module_name() -> str:
assert __name__.endswith('.common._helpers')
return __name__.removesuffix('.common._helpers')


def is_numpy_namespace(xp) -> bool:
"""
Returns True if `xp` is a NumPy namespace.
Expand All @@ -278,6 +288,7 @@ def is_numpy_namespace(xp) -> bool:
"""
return xp.__name__ in {'numpy', _compat_module_name() + '.numpy'}


def is_cupy_namespace(xp) -> bool:
"""
Returns True if `xp` is a CuPy namespace.
Expand All @@ -298,6 +309,7 @@ def is_cupy_namespace(xp) -> bool:
"""
return xp.__name__ in {'cupy', _compat_module_name() + '.cupy'}


def is_torch_namespace(xp) -> bool:
"""
Returns True if `xp` is a PyTorch namespace.
Expand All @@ -319,7 +331,7 @@ def is_torch_namespace(xp) -> bool:
return xp.__name__ in {'torch', _compat_module_name() + '.torch'}


def is_ndonnx_namespace(xp):
def is_ndonnx_namespace(xp) -> bool:
"""
Returns True if `xp` is an NDONNX namespace.
Expand All @@ -337,7 +349,8 @@ def is_ndonnx_namespace(xp):
"""
return xp.__name__ == 'ndonnx'

def is_dask_namespace(xp):

def is_dask_namespace(xp) -> bool:
"""
Returns True if `xp` is a Dask namespace.
Expand All @@ -357,7 +370,8 @@ def is_dask_namespace(xp):
"""
return xp.__name__ in {'dask.array', _compat_module_name() + '.dask.array'}

def is_jax_namespace(xp):

def is_jax_namespace(xp) -> bool:
"""
Returns True if `xp` is a JAX namespace.
Expand All @@ -378,7 +392,8 @@ def is_jax_namespace(xp):
"""
return xp.__name__ in {'jax.numpy', 'jax.experimental.array_api'}

def is_pydata_sparse_namespace(xp):

def is_pydata_sparse_namespace(xp) -> bool:
"""
Returns True if `xp` is a pydata/sparse namespace.
Expand All @@ -396,7 +411,8 @@ def is_pydata_sparse_namespace(xp):
"""
return xp.__name__ == 'sparse'

def is_array_api_strict_namespace(xp):

def is_array_api_strict_namespace(xp) -> bool:
"""
Returns True if `xp` is an array-api-strict namespace.
Expand All @@ -414,13 +430,15 @@ def is_array_api_strict_namespace(xp):
"""
return xp.__name__ == 'array_api_strict'

def _check_api_version(api_version):

def _check_api_version(api_version: str) -> None:
if api_version in ['2021.12', '2022.12']:
warnings.warn(f"The {api_version} version of the array API specification was requested but the returned namespace is actually version 2023.12")
elif api_version is not None and api_version not in ['2021.12', '2022.12',
'2023.12']:
raise ValueError("Only the 2023.12 version of the array API specification is currently supported")


def array_namespace(*xs, api_version=None, use_compat=None):
"""
Get the array API compatible namespace for the arrays `xs`.
Expand Down Expand Up @@ -808,9 +826,10 @@ def size(x: Array) -> int | None:
return None if math.isnan(out) else out


def is_writeable_array(x) -> bool:
def is_writeable_array(x: object) -> bool:
"""
Return False if ``x.__setitem__`` is expected to raise; True otherwise.
Return False if `x` is not an array API compatible object.
Warning
-------
Expand All @@ -821,10 +840,10 @@ def is_writeable_array(x) -> bool:
return x.flags.writeable
if is_jax_array(x) or is_pydata_sparse_array(x):
return False
return True
return is_array_api_obj(x)


def is_lazy_array(x) -> bool:
def is_lazy_array(x: object) -> bool:
"""Return True if x is potentially a future or it may be otherwise impossible or
expensive to eagerly read its contents, regardless of their size, e.g. by
calling ``bool(x)`` or ``float(x)``.
Expand Down Expand Up @@ -857,6 +876,9 @@ def is_lazy_array(x) -> bool:
if is_jax_array(x) or is_dask_array(x) or is_ndonnx_array(x):
return True

if not is_array_api_obj(x):
return False

# Unknown Array API compatible object. Note that this test may have dire consequences
# in terms of performance, e.g. for a lazy object that eagerly computes the graph
# on __bool__ (dask is one such example, which however is special-cased above).
Expand Down
21 changes: 21 additions & 0 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,27 @@ def __bool__(self):
assert is_lazy_array(x)


@pytest.mark.parametrize(
'func',
list(is_array_functions.values())
+ ["is_array_api_obj", "is_lazy_array", "is_writeable_array"]
)
def test_is_array_any_object(func):
"""Test that is_*_array functions return False and don't raise on non-array objects
"""
func = globals()[func]

# These objects are missing attributes such as __name__
assert not func(object())
assert not func(None)
assert not func(1)

class C:
pass

assert not func(C())


@pytest.mark.parametrize("library", all_libraries)
def test_device(library):
xp = import_(library, wrapper=True)
Expand Down

0 comments on commit 44e1eb3

Please sign in to comment.