Skip to content

Commit

Permalink
Merge pull request #228 from crusaderky/is_lazy_array
Browse files Browse the repository at this point in the history
ENH: is_lazy_array()
  • Loading branch information
ev-br authored Jan 15, 2025
2 parents e5dd419 + 7950eaa commit 5ef0e18
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 6 deletions.
58 changes: 58 additions & 0 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,63 @@ def is_writeable_array(x) -> bool:
return True


def is_lazy_array(x) -> bool:
"""Return True if x is potentially a future or it may be otherwise impossible or
expensive to eagerly read its contents, regardless of their size, e.g. by
calling ``bool(x)`` or ``float(x)``.
Return False otherwise; e.g. ``bool(x)`` etc. is guaranteed to succeed and to be
cheap as long as the array has the right dtype and size.
Note
----
This function errs on the side of caution for array types that may or may not be
lazy, e.g. JAX arrays, by always returning True for them.
"""
if (
is_numpy_array(x)
or is_cupy_array(x)
or is_torch_array(x)
or is_pydata_sparse_array(x)
):
return False

# **JAX note:** while it is possible to determine if you're inside or outside
# jax.jit by testing the subclass of a jax.Array object, as well as testing bool()
# as we do below for unknown arrays, this is not recommended by JAX best practices.

# **Dask note:** Dask eagerly computes the graph on __bool__, __float__, and so on.
# This behaviour, while impossible to change without breaking backwards
# compatibility, is highly detrimental to performance as the whole graph will end
# up being computed multiple times.

if is_jax_array(x) or is_dask_array(x) or is_ndonnx_array(x):
return True

# Unknown Array API compatible object. Note that this test may have dire consequences
# in terms of performance, e.g. for a lazy object that eagerly computes the graph
# on __bool__ (dask is one such example, which however is special-cased above).

# Select a single point of the array
s = size(x)
if s is None:
return True
xp = array_namespace(x)
if s > 1:
x = xp.reshape(x, (-1,))[0]
# Cast to dtype=bool and deal with size 0 arrays
x = xp.any(x)

try:
bool(x)
return False
# The Array API standard dictactes that __bool__ should raise TypeError if the
# output cannot be defined.
# Here we allow for it to raise arbitrary exceptions, e.g. like Dask does.
except Exception:
return True


__all__ = [
"array_namespace",
"device",
Expand All @@ -845,6 +902,7 @@ def is_writeable_array(x) -> bool:
"is_pydata_sparse_array",
"is_pydata_sparse_namespace",
"is_writeable_array",
"is_lazy_array",
"size",
"to_device",
]
Expand Down
1 change: 1 addition & 0 deletions docs/helper-functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ yet.
.. autofunction:: is_pydata_sparse_array
.. autofunction:: is_ndonnx_array
.. autofunction:: is_writeable_array
.. autofunction:: is_lazy_array
.. autofunction:: is_numpy_namespace
.. autofunction:: is_cupy_namespace
.. autofunction:: is_torch_namespace
Expand Down
54 changes: 48 additions & 6 deletions tests/test_common.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
import math

import pytest
import numpy as np
import array
from numpy.testing import assert_allclose

from array_api_compat import ( # noqa: F401
is_numpy_array, is_cupy_array, is_torch_array,
is_dask_array, is_jax_array, is_pydata_sparse_array,
Expand All @@ -6,15 +13,10 @@
)

from array_api_compat import (
device, is_array_api_obj, is_writeable_array, size, to_device
device, is_array_api_obj, is_lazy_array, is_writeable_array, size, to_device
)
from ._helpers import import_, wrapped_libraries, all_libraries

import pytest
import numpy as np
import array
from numpy.testing import assert_allclose

is_array_functions = {
'numpy': 'is_numpy_array',
'cupy': 'is_cupy_array',
Expand Down Expand Up @@ -115,6 +117,45 @@ def test_size_none(library):
assert size(x) in (None, 5)


@pytest.mark.parametrize("library", all_libraries)
def test_is_lazy_array(library):
lib = import_(library)
x = lib.asarray([1, 2, 3])
assert isinstance(is_lazy_array(x), bool)


@pytest.mark.parametrize("shape", [(math.nan,), (1, math.nan), (None, ), (1, None)])
def test_is_lazy_array_nan_size(shape, monkeypatch):
"""Test is_lazy_array() on an unknown Array API compliant object
with NaN (like Dask) or None (like ndonnx) in its shape
"""
xp = import_("array_api_strict")
x = xp.asarray(1)
assert not is_lazy_array(x)
monkeypatch.setattr(type(x), "shape", shape)
assert is_lazy_array(x)


@pytest.mark.parametrize("exc", [TypeError, AssertionError])
def test_is_lazy_array_bool_raises(exc, monkeypatch):
"""Test is_lazy_array() on an unknown Array API compliant object
where calling bool() raises:
- TypeError: e.g. like jitted JAX. This is the proper exception which
lazy arrays should raise as per the Array API specification
- something else: e.g. like Dask, where bool() triggers compute()
which can result in any kind of exception to be raised
"""
xp = import_("array_api_strict")
x = xp.asarray(1)
assert not is_lazy_array(x)

def __bool__(self):
raise exc("Hello world")

monkeypatch.setattr(type(x), "__bool__", __bool__)
assert is_lazy_array(x)


@pytest.mark.parametrize("library", all_libraries)
def test_device(library):
xp = import_(library, wrapper=True)
Expand Down Expand Up @@ -172,6 +213,7 @@ def test_asarray_cross_library(source_library, target_library, request):

assert is_tgt_type(b), f"Expected {b} to be a {tgt_lib.ndarray}, but was {type(b)}"


@pytest.mark.parametrize("library", wrapped_libraries)
def test_asarray_copy(library):
# Note, we have this test here because the test suite currently doesn't
Expand Down

0 comments on commit 5ef0e18

Please sign in to comment.