Skip to content

Commit

Permalink
Merge pull request #235 from crusaderky/dask_asarray
Browse files Browse the repository at this point in the history
BUG: dask: `asarray` should not materialize the graph
  • Loading branch information
ev-br authored Jan 16, 2025
2 parents adbb6ef + c94ec0b commit 9442237
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 27 deletions.
33 changes: 16 additions & 17 deletions array_api_compat/dask/array/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,24 +144,23 @@ def asarray(
See the corresponding documentation in the array library and/or the array API
specification for more details.
"""
if isinstance(obj, da.Array):
if dtype is not None and dtype != obj.dtype:
if copy is False:
raise ValueError("Unable to avoid copy when changing dtype")
obj = obj.astype(dtype)
return obj.copy() if copy else obj

if copy is False:
# copy=False is not yet implemented in dask
raise NotImplementedError("copy=False is not yet implemented")
elif copy is True:
if isinstance(obj, da.Array) and dtype is None:
return obj.copy()
# Go through numpy, since dask copy is no-op by default
obj = np.array(obj, dtype=dtype, copy=True)
return da.array(obj, dtype=dtype)
else:
if not isinstance(obj, da.Array) or dtype is not None and obj.dtype != dtype:
# copy=True to be uniform across dask < 2024.12 and >= 2024.12
# see https://github.com/dask/dask/pull/11524/
obj = np.array(obj, dtype=dtype, copy=True)
return da.from_array(obj)
return obj

return da.asarray(obj, dtype=dtype, **kwargs)
raise NotImplementedError(
"Unable to avoid copy when converting a non-dask object to dask"
)

# copy=None to be uniform across dask < 2024.12 and >= 2024.12
# see https://github.com/dask/dask/pull/11524/
obj = np.array(obj, dtype=dtype, copy=True)
return da.from_array(obj)


from dask.array import (
# Element wise aliases
Expand Down
6 changes: 4 additions & 2 deletions tests/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,7 @@ def test_all(library):
all_names = module.__all__

if set(dir_names) != set(all_names):
assert set(dir_names) - set(all_names) == set(), f"Some dir() names not included in __all__ for {mod_name}"
assert set(all_names) - set(dir_names) == set(), f"Some __all__ names not in dir() for {mod_name}"
extra_dir = set(dir_names) - set(all_names)
extra_all = set(all_names) - set(dir_names)
assert not extra_dir, f"Some dir() names not included in __all__ for {mod_name}: {extra_dir}"
assert not extra_all, f"Some __all__ names not in dir() for {mod_name}: {extra_all}"
22 changes: 14 additions & 8 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,11 +226,17 @@ def test_asarray_copy(library):
all = xp.all if library != 'dask.array' else lambda x: xp.all(x).compute()

if library == 'numpy' and xp.__version__[0] < '2' and not hasattr(xp, '_CopyMode') :
supports_copy_false = False
elif library in ['cupy', 'dask.array']:
supports_copy_false = False
supports_copy_false_other_ns = False
supports_copy_false_same_ns = False
elif library == 'cupy':
supports_copy_false_other_ns = False
supports_copy_false_same_ns = False
elif library == 'dask.array':
supports_copy_false_other_ns = False
supports_copy_false_same_ns = True
else:
supports_copy_false = True
supports_copy_false_other_ns = True
supports_copy_false_same_ns = True

a = asarray([1])
b = asarray(a, copy=True)
Expand All @@ -240,7 +246,7 @@ def test_asarray_copy(library):
assert all(a[0] == 0)

a = asarray([1])
if supports_copy_false:
if supports_copy_false_same_ns:
b = asarray(a, copy=False)
assert is_lib_func(b)
a[0] = 0
Expand All @@ -249,7 +255,7 @@ def test_asarray_copy(library):
pytest.raises(NotImplementedError, lambda: asarray(a, copy=False))

a = asarray([1])
if supports_copy_false:
if supports_copy_false_same_ns:
pytest.raises(ValueError, lambda: asarray(a, copy=False,
dtype=xp.float64))
else:
Expand Down Expand Up @@ -281,7 +287,7 @@ def test_asarray_copy(library):
for obj in [True, 0, 0.0, 0j, [0], [[0]]]:
asarray(obj, copy=True) # No error
asarray(obj, copy=None) # No error
if supports_copy_false:
if supports_copy_false_other_ns:
pytest.raises(ValueError, lambda: asarray(obj, copy=False))
else:
pytest.raises(NotImplementedError, lambda: asarray(obj, copy=False))
Expand All @@ -294,7 +300,7 @@ def test_asarray_copy(library):
assert all(b[0] == 1.0)

a = array.array('f', [1.0])
if supports_copy_false:
if supports_copy_false_other_ns:
b = asarray(a, copy=False)
assert is_lib_func(b)
a[0] = 0.0
Expand Down
108 changes: 108 additions & 0 deletions tests/test_dask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from contextlib import contextmanager

import dask
import numpy as np
import pytest
import dask.array as da

from array_api_compat import array_namespace


@pytest.fixture
def xp():
"""Fixture returning the wrapped dask namespace"""
return array_namespace(da.empty(0))


@contextmanager
def assert_no_compute():
"""
Context manager that raises if at any point inside it anything calls compute()
or persist(), e.g. as it can be triggered implicitly by __bool__, __array__, etc.
"""
def get(dsk, *args, **kwargs):
raise AssertionError("Called compute() or persist()")

with dask.config.set(scheduler=get):
yield


def test_assert_no_compute():
"""Test the assert_no_compute context manager"""
a = da.asarray(True)
with pytest.raises(AssertionError, match="Called compute"):
with assert_no_compute():
bool(a)

# Exiting the context manager restores the original scheduler
assert bool(a) is True


# Test no_compute for functions that use generic _aliases with xp=np

def test_unary_ops_no_compute(xp):
with assert_no_compute():
a = xp.asarray([1.5, -1.5])
xp.ceil(a)
xp.floor(a)
xp.trunc(a)
xp.sign(a)


def test_matmul_tensordot_no_compute(xp):
A = da.ones((4, 4), chunks=2)
B = da.zeros((4, 4), chunks=2)
with assert_no_compute():
xp.matmul(A, B)
xp.tensordot(A, B)


# Test no_compute for functions that are fully bespoke for dask

def test_asarray_no_compute(xp):
with assert_no_compute():
a = xp.arange(10)
xp.asarray(a)
xp.asarray(a, dtype=np.int16)
xp.asarray(a, dtype=a.dtype)
xp.asarray(a, copy=True)
xp.asarray(a, copy=True, dtype=np.int16)
xp.asarray(a, copy=True, dtype=a.dtype)
xp.asarray(a, copy=False)
xp.asarray(a, copy=False, dtype=a.dtype)


@pytest.mark.parametrize("copy", [True, False])
def test_astype_no_compute(xp, copy):
with assert_no_compute():
a = xp.arange(10)
xp.astype(a, np.int16, copy=copy)
xp.astype(a, a.dtype, copy=copy)


def test_clip_no_compute(xp):
with assert_no_compute():
a = xp.arange(10)
xp.clip(a)
xp.clip(a, 1)
xp.clip(a, 1, 8)


def test_generators_are_lazy(xp):
"""
Test that generator functions are fully lazy, e.g. that
da.ones(n) is not implemented as da.asarray(np.ones(n))
"""
size = 100_000_000_000 # 800 GB
chunks = size // 10 # 10x 80 GB chunks

with assert_no_compute():
xp.zeros(size, chunks=chunks)
xp.ones(size, chunks=chunks)
xp.empty(size, chunks=chunks)
xp.full(size, fill_value=123, chunks=chunks)
a = xp.arange(size, chunks=chunks)
xp.zeros_like(a)
xp.ones_like(a)
xp.empty_like(a)
xp.full_like(a, fill_value=123)

0 comments on commit 9442237

Please sign in to comment.