diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index 08514717..66e405ca 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -144,24 +144,23 @@ def asarray( See the corresponding documentation in the array library and/or the array API specification for more details. """ + if isinstance(obj, da.Array): + if dtype is not None and dtype != obj.dtype: + if copy is False: + raise ValueError("Unable to avoid copy when changing dtype") + obj = obj.astype(dtype) + return obj.copy() if copy else obj + if copy is False: - # copy=False is not yet implemented in dask - raise NotImplementedError("copy=False is not yet implemented") - elif copy is True: - if isinstance(obj, da.Array) and dtype is None: - return obj.copy() - # Go through numpy, since dask copy is no-op by default - obj = np.array(obj, dtype=dtype, copy=True) - return da.array(obj, dtype=dtype) - else: - if not isinstance(obj, da.Array) or dtype is not None and obj.dtype != dtype: - # copy=True to be uniform across dask < 2024.12 and >= 2024.12 - # see https://github.com/dask/dask/pull/11524/ - obj = np.array(obj, dtype=dtype, copy=True) - return da.from_array(obj) - return obj - - return da.asarray(obj, dtype=dtype, **kwargs) + raise NotImplementedError( + "Unable to avoid copy when converting a non-dask object to dask" + ) + + # copy=None to be uniform across dask < 2024.12 and >= 2024.12 + # see https://github.com/dask/dask/pull/11524/ + obj = np.array(obj, dtype=dtype, copy=True) + return da.from_array(obj) + from dask.array import ( # Element wise aliases diff --git a/tests/test_all.py b/tests/test_all.py index 969d5cfb..081bb82b 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -40,5 +40,7 @@ def test_all(library): all_names = module.__all__ if set(dir_names) != set(all_names): - assert set(dir_names) - set(all_names) == set(), f"Some dir() names not included in __all__ for {mod_name}" - assert set(all_names) - set(dir_names) == set(), f"Some __all__ names not in dir() for {mod_name}" + extra_dir = set(dir_names) - set(all_names) + extra_all = set(all_names) - set(dir_names) + assert not extra_dir, f"Some dir() names not included in __all__ for {mod_name}: {extra_dir}" + assert not extra_all, f"Some __all__ names not in dir() for {mod_name}: {extra_all}" diff --git a/tests/test_common.py b/tests/test_common.py index 7887d4de..07afaddb 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -226,11 +226,17 @@ def test_asarray_copy(library): all = xp.all if library != 'dask.array' else lambda x: xp.all(x).compute() if library == 'numpy' and xp.__version__[0] < '2' and not hasattr(xp, '_CopyMode') : - supports_copy_false = False - elif library in ['cupy', 'dask.array']: - supports_copy_false = False + supports_copy_false_other_ns = False + supports_copy_false_same_ns = False + elif library == 'cupy': + supports_copy_false_other_ns = False + supports_copy_false_same_ns = False + elif library == 'dask.array': + supports_copy_false_other_ns = False + supports_copy_false_same_ns = True else: - supports_copy_false = True + supports_copy_false_other_ns = True + supports_copy_false_same_ns = True a = asarray([1]) b = asarray(a, copy=True) @@ -240,7 +246,7 @@ def test_asarray_copy(library): assert all(a[0] == 0) a = asarray([1]) - if supports_copy_false: + if supports_copy_false_same_ns: b = asarray(a, copy=False) assert is_lib_func(b) a[0] = 0 @@ -249,7 +255,7 @@ def test_asarray_copy(library): pytest.raises(NotImplementedError, lambda: asarray(a, copy=False)) a = asarray([1]) - if supports_copy_false: + if supports_copy_false_same_ns: pytest.raises(ValueError, lambda: asarray(a, copy=False, dtype=xp.float64)) else: @@ -281,7 +287,7 @@ def test_asarray_copy(library): for obj in [True, 0, 0.0, 0j, [0], [[0]]]: asarray(obj, copy=True) # No error asarray(obj, copy=None) # No error - if supports_copy_false: + if supports_copy_false_other_ns: pytest.raises(ValueError, lambda: asarray(obj, copy=False)) else: pytest.raises(NotImplementedError, lambda: asarray(obj, copy=False)) @@ -294,7 +300,7 @@ def test_asarray_copy(library): assert all(b[0] == 1.0) a = array.array('f', [1.0]) - if supports_copy_false: + if supports_copy_false_other_ns: b = asarray(a, copy=False) assert is_lib_func(b) a[0] = 0.0 diff --git a/tests/test_dask.py b/tests/test_dask.py new file mode 100644 index 00000000..2983e696 --- /dev/null +++ b/tests/test_dask.py @@ -0,0 +1,108 @@ +from contextlib import contextmanager + +import dask +import numpy as np +import pytest +import dask.array as da + +from array_api_compat import array_namespace + + +@pytest.fixture +def xp(): + """Fixture returning the wrapped dask namespace""" + return array_namespace(da.empty(0)) + + +@contextmanager +def assert_no_compute(): + """ + Context manager that raises if at any point inside it anything calls compute() + or persist(), e.g. as it can be triggered implicitly by __bool__, __array__, etc. + """ + def get(dsk, *args, **kwargs): + raise AssertionError("Called compute() or persist()") + + with dask.config.set(scheduler=get): + yield + + +def test_assert_no_compute(): + """Test the assert_no_compute context manager""" + a = da.asarray(True) + with pytest.raises(AssertionError, match="Called compute"): + with assert_no_compute(): + bool(a) + + # Exiting the context manager restores the original scheduler + assert bool(a) is True + + +# Test no_compute for functions that use generic _aliases with xp=np + +def test_unary_ops_no_compute(xp): + with assert_no_compute(): + a = xp.asarray([1.5, -1.5]) + xp.ceil(a) + xp.floor(a) + xp.trunc(a) + xp.sign(a) + + +def test_matmul_tensordot_no_compute(xp): + A = da.ones((4, 4), chunks=2) + B = da.zeros((4, 4), chunks=2) + with assert_no_compute(): + xp.matmul(A, B) + xp.tensordot(A, B) + + +# Test no_compute for functions that are fully bespoke for dask + +def test_asarray_no_compute(xp): + with assert_no_compute(): + a = xp.arange(10) + xp.asarray(a) + xp.asarray(a, dtype=np.int16) + xp.asarray(a, dtype=a.dtype) + xp.asarray(a, copy=True) + xp.asarray(a, copy=True, dtype=np.int16) + xp.asarray(a, copy=True, dtype=a.dtype) + xp.asarray(a, copy=False) + xp.asarray(a, copy=False, dtype=a.dtype) + + +@pytest.mark.parametrize("copy", [True, False]) +def test_astype_no_compute(xp, copy): + with assert_no_compute(): + a = xp.arange(10) + xp.astype(a, np.int16, copy=copy) + xp.astype(a, a.dtype, copy=copy) + + +def test_clip_no_compute(xp): + with assert_no_compute(): + a = xp.arange(10) + xp.clip(a) + xp.clip(a, 1) + xp.clip(a, 1, 8) + + +def test_generators_are_lazy(xp): + """ + Test that generator functions are fully lazy, e.g. that + da.ones(n) is not implemented as da.asarray(np.ones(n)) + """ + size = 100_000_000_000 # 800 GB + chunks = size // 10 # 10x 80 GB chunks + + with assert_no_compute(): + xp.zeros(size, chunks=chunks) + xp.ones(size, chunks=chunks) + xp.empty(size, chunks=chunks) + xp.full(size, fill_value=123, chunks=chunks) + a = xp.arange(size, chunks=chunks) + xp.zeros_like(a) + xp.ones_like(a) + xp.empty_like(a) + xp.full_like(a, fill_value=123)