From a46f8a895d84e2419d534028bc8a9c01def45964 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 15 Jan 2025 19:42:15 +0000 Subject: [PATCH] `da.asarray` should not materialize the graph --- array_api_compat/dask/array/_aliases.py | 38 ++++++++++++----------- tests/test_dask.py | 40 +++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 17 deletions(-) create mode 100644 tests/test_dask.py diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index df8fede8..f04334df 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -129,24 +129,28 @@ def asarray( See the corresponding documentation in the array library and/or the array API specification for more details. """ + if isinstance(obj, da.Array): + if dtype is not None: + # Note: at the moment of writing, dask ignores the copy parameter + # and always behaves with copy=False. We pass the parameter anyway + # for the sake of forward compatibility. + res = obj.astype(dtype, copy=True if copy is True else False) + if copy is False and res is not obj: + raise ValueError("Unable to avoid copy") + else: + res = obj + return obj.copy() if copy else obj + if copy is False: - # copy=False is not yet implemented in dask - raise NotImplementedError("copy=False is not yet implemented") - elif copy is True: - if isinstance(obj, da.Array) and dtype is None: - return obj.copy() - # Go through numpy, since dask copy is no-op by default - obj = np.array(obj, dtype=dtype, copy=True) - return da.array(obj, dtype=dtype) - else: - if not isinstance(obj, da.Array) or dtype is not None and obj.dtype != dtype: - # copy=True to be uniform across dask < 2024.12 and >= 2024.12 - # see https://github.com/dask/dask/pull/11524/ - obj = np.array(obj, dtype=dtype, copy=True) - return da.from_array(obj) - return obj - - return da.asarray(obj, dtype=dtype, **kwargs) + raise NotImplementedError( + "copy=False is not possible when converting a non-dask object to dask" + ) + + # copy=None to be uniform across dask < 2024.12 and >= 2024.12 + # see https://github.com/dask/dask/pull/11524/ + obj = np.asarray(obj, dtype=dtype, copy=True) + return da.from_array(obj) + from dask.array import ( # Element wise aliases diff --git a/tests/test_dask.py b/tests/test_dask.py new file mode 100644 index 00000000..05716d36 --- /dev/null +++ b/tests/test_dask.py @@ -0,0 +1,40 @@ +import dask +import numpy as np +import pytest +import array_api_compat.dask.array as xp + +@pytest.fixture +def no_compute(): + """ + Cause the test to raise if at any point anything calls compute() or persist(), + e.g. as it can be triggered implicitly by __bool__, __array__, etc. + """ + def get(dsk, *args, **kwargs): + raise AssertionError("Called compute() or persist()") + + with dask.config.set(scheduler=get): + yield + + +def test_no_compute(no_compute): + """Test the no_compute_fixture""" + a = xp.asarray(True) + with pytest.raises(AssertionError, match="Called compute"): + bool(a) + + +def test_asarray_no_compute(no_compute): + a = xp.arange(10) + xp.asarray(a) + xp.asarray(a, dtype=np.int16) + xp.asarray(a, dtype=a.dtype) + xp.asarray(a, copy=True) + xp.asarray(a, copy=True, dtype=np.int16) + xp.asarray(a, copy=True, dtype=a.dtype) + + +def test_clip_no_compute(no_compute): + a = xp.arange(10) + xp.clip(a) + xp.clip(a, 1) + xp.clip(a, 1, 8)