Skip to content

Commit

Permalink
da.asarray should not materialize the graph
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Jan 15, 2025
1 parent 5ef0e18 commit a46f8a8
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 17 deletions.
38 changes: 21 additions & 17 deletions array_api_compat/dask/array/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,24 +129,28 @@ def asarray(
See the corresponding documentation in the array library and/or the array API
specification for more details.
"""
if isinstance(obj, da.Array):
if dtype is not None:
# Note: at the moment of writing, dask ignores the copy parameter
# and always behaves with copy=False. We pass the parameter anyway
# for the sake of forward compatibility.
res = obj.astype(dtype, copy=True if copy is True else False)
if copy is False and res is not obj:
raise ValueError("Unable to avoid copy")
else:
res = obj
return obj.copy() if copy else obj

if copy is False:
# copy=False is not yet implemented in dask
raise NotImplementedError("copy=False is not yet implemented")
elif copy is True:
if isinstance(obj, da.Array) and dtype is None:
return obj.copy()
# Go through numpy, since dask copy is no-op by default
obj = np.array(obj, dtype=dtype, copy=True)
return da.array(obj, dtype=dtype)
else:
if not isinstance(obj, da.Array) or dtype is not None and obj.dtype != dtype:
# copy=True to be uniform across dask < 2024.12 and >= 2024.12
# see https://github.com/dask/dask/pull/11524/
obj = np.array(obj, dtype=dtype, copy=True)
return da.from_array(obj)
return obj

return da.asarray(obj, dtype=dtype, **kwargs)
raise NotImplementedError(
"copy=False is not possible when converting a non-dask object to dask"
)

# copy=None to be uniform across dask < 2024.12 and >= 2024.12
# see https://github.com/dask/dask/pull/11524/
obj = np.asarray(obj, dtype=dtype, copy=True)
return da.from_array(obj)


from dask.array import (
# Element wise aliases
Expand Down
40 changes: 40 additions & 0 deletions tests/test_dask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import dask
import numpy as np
import pytest
import array_api_compat.dask.array as xp

@pytest.fixture
def no_compute():
"""
Cause the test to raise if at any point anything calls compute() or persist(),
e.g. as it can be triggered implicitly by __bool__, __array__, etc.
"""
def get(dsk, *args, **kwargs):
raise AssertionError("Called compute() or persist()")

with dask.config.set(scheduler=get):
yield


def test_no_compute(no_compute):
"""Test the no_compute_fixture"""
a = xp.asarray(True)
with pytest.raises(AssertionError, match="Called compute"):
bool(a)


def test_asarray_no_compute(no_compute):
a = xp.arange(10)
xp.asarray(a)
xp.asarray(a, dtype=np.int16)
xp.asarray(a, dtype=a.dtype)
xp.asarray(a, copy=True)
xp.asarray(a, copy=True, dtype=np.int16)
xp.asarray(a, copy=True, dtype=a.dtype)


def test_clip_no_compute(no_compute):
a = xp.arange(10)
xp.clip(a)
xp.clip(a, 1)
xp.clip(a, 1, 8)

0 comments on commit a46f8a8

Please sign in to comment.