From e6dbf3a4bb12a8d6eb9480cd08459848ca97e5d2 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Fri, 24 Jan 2025 14:31:42 +0000 Subject: [PATCH] ENH: Dask: `sort` and `argsort` --- array_api_compat/dask/array/_aliases.py | 118 ++++++++++++++++++++++-- dask-xfails.txt | 22 +---- tests/test_dask.py | 73 ++++++++++++++- 3 files changed, 187 insertions(+), 26 deletions(-) diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index a8ed6f26..ab18fd71 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -1,6 +1,8 @@ from __future__ import annotations -from ...common import _aliases +from typing import Callable + +from ...common import _aliases, array_namespace from ..._internal import get_xp @@ -29,16 +31,24 @@ ) from typing import TYPE_CHECKING + if TYPE_CHECKING: from typing import Optional, Union - from ...common._typing import Device, Dtype, Array, NestedSequence, SupportsBufferProtocol + from ...common._typing import ( + Device, + Dtype, + Array, + NestedSequence, + SupportsBufferProtocol, + ) import dask.array as da isdtype = get_xp(np)(_aliases.isdtype) unstack = get_xp(da)(_aliases.unstack) + # da.astype doesn't respect copy=True def astype( x: Array, @@ -46,7 +56,7 @@ def astype( /, *, copy: bool = True, - device: Optional[Device] = None + device: Optional[Device] = None, ) -> Array: """ Array API compatibility wrapper for astype(). @@ -61,8 +71,10 @@ def astype( x = x.astype(dtype) return x.copy() if copy else x + # Common aliases + # This arange func is modified from the common one to # not pass stop/step as keyword arguments, which will cause # an error with dask @@ -189,6 +201,7 @@ def asarray( concatenate as concat, ) + # dask.array.clip does not work unless all three arguments are provided. # Furthermore, the masking workaround in common._aliases.clip cannot work with # dask (meaning uint64 promoting to float64 is going to just be unfixed for @@ -205,8 +218,10 @@ def clip( See the corresponding documentation in the array library and/or the array API specification for more details. """ + def _isscalar(a): return isinstance(a, (int, float, type(None))) + min_shape = () if _isscalar(min) else min.shape max_shape = () if _isscalar(max) else max.shape @@ -228,12 +243,99 @@ def _isscalar(a): return astype(da.minimum(da.maximum(x, min), max), x.dtype) -# exclude these from all since dask.array has no sorting functions -_da_unsupported = ['sort', 'argsort'] -_common_aliases = [alias for alias in _aliases.__all__ if alias not in _da_unsupported] +def _ensure_single_chunk(x: Array, axis: int) -> tuple[Array, Callable[[Array], Array]]: + """ + Make sure that Array is not broken into multiple chunks along axis. + + Returns + ------- + x : Array + The input Array with a single chunk along axis. + restore : Callable[Array, Array] + function to apply to the output to rechunk it back into reasonable chunks + """ + if axis < 0: + axis += x.ndim + if x.numblocks[axis] < 2: + return x, lambda x: x + + # Break chunks on other axes in an attempt to keep chunk size low + x = x.rechunk({i: -1 if i == axis else "auto" for i in range(x.ndim)}) + + # Rather than reconstructing the original chunks, which can be a + # very expensive affair, just break down oversized chunks without + # incurring in any transfers over the network. + # This has the downside of a risk of overchunking if the array is + # then used in operations against other arrays that match the + # original chunking pattern. + return x, lambda x: x.rechunk() + + +def sort( + x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True +) -> Array: + """ + Array API compatibility layer around the lack of sort() in Dask. + + Warnings + -------- + This function temporarily rechunks the array along `axis` to a single chunk. + This can be extremely inefficient and can lead to out-of-memory errors. + + See the corresponding documentation in the array library and/or the array API + specification for more details. + """ + x, restore = _ensure_single_chunk(x, axis) + + meta_xp = array_namespace(x._meta) + x = da.map_blocks( + meta_xp.sort, + x, + axis=axis, + meta=x._meta, + dtype=x.dtype, + descending=descending, + stable=stable, + ) + + return restore(x) -__all__ = _common_aliases + ['__array_namespace_info__', 'asarray', 'astype', 'acos', + +def argsort( + x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True +) -> Array: + """ + Array API compatibility layer around the lack of argsort() in Dask. + + See the corresponding documentation in the array library and/or the array API + specification for more details. + + Warnings + -------- + This function temporarily rechunks the array along `axis` into a single chunk. + This can be extremely inefficient and can lead to out-of-memory errors. + """ + x, restore = _ensure_single_chunk(x, axis) + + meta_xp = array_namespace(x._meta) + dtype = meta_xp.argsort(x._meta).dtype + meta = meta_xp.astype(x._meta, dtype) + x = da.map_blocks( + meta_xp.argsort, + x, + axis=axis, + meta=meta, + dtype=dtype, + descending=descending, + stable=stable, + ) + + return restore(x) + + +__all__ = _aliases.__all__ + [ + '__array_namespace_info__', 'asarray', 'astype', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_right_shift', 'concat', 'pow', 'iinfo', 'finfo', 'can_cast', @@ -242,4 +344,4 @@ def _isscalar(a): 'complex64', 'complex128', 'iinfo', 'finfo', 'can_cast', 'result_type'] -_all_ignore = ["get_xp", "da", "np"] +_all_ignore = ["Callable", "array_namespace", "get_xp", "da", "np"] diff --git a/dask-xfails.txt b/dask-xfails.txt index 1631ea12..353c6f1e 100644 --- a/dask-xfails.txt +++ b/dask-xfails.txt @@ -23,17 +23,13 @@ array_api_tests/test_array_object.py::test_setitem_masking # Various indexing errors array_api_tests/test_array_object.py::test_getitem_masking -# asarray(copy=False) is not yet implemented -# copied from numpy xfails, TODO: should this pass with dask? -array_api_tests/test_creation_functions.py::test_asarray_arrays - # zero division error, and typeerror: tuple indices must be integers or slices not tuple array_api_tests/test_creation_functions.py::test_eye # finfo(float32).eps returns float32 but should return float array_api_tests/test_data_type_functions.py::test_finfo[float32] -# out[-1]=dask.aray but should be some floating number +# out[-1]=dask.array but should be some floating number # (I think the test is not forcing the op to be computed?) array_api_tests/test_creation_functions.py::test_linspace @@ -48,15 +44,7 @@ array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i > 0 array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity] array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity] -# No sorting in dask -array_api_tests/test_has_names.py::test_has_names[sorting-argsort] -array_api_tests/test_has_names.py::test_has_names[sorting-sort] -array_api_tests/test_sorting_functions.py::test_argsort -array_api_tests/test_sorting_functions.py::test_sort -array_api_tests/test_signatures.py::test_func_signature[argsort] -array_api_tests/test_signatures.py::test_func_signature[sort] - -# Array methods and attributes not already on np.ndarray cannot be wrapped +# Array methods and attributes not already on da.Array cannot be wrapped array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__] array_api_tests/test_has_names.py::test_has_names[array_method-to_device] array_api_tests/test_has_names.py::test_has_names[array_attribute-device] @@ -76,6 +64,7 @@ array_api_tests/test_set_functions.py::test_unique_values # fails for ndim > 2 array_api_tests/test_linalg.py::test_svdvals array_api_tests/test_linalg.py::test_cholesky + # dtype mismatch got uint64, but should be uint8, NPY_PROMOTION_STATE=weak doesn't help :( array_api_tests/test_linalg.py::test_tensordot @@ -105,6 +94,8 @@ array_api_tests/test_linalg.py::test_cross array_api_tests/test_linalg.py::test_det array_api_tests/test_linalg.py::test_eigh array_api_tests/test_linalg.py::test_eigvalsh +array_api_tests/test_linalg.py::test_matrix_norm +array_api_tests/test_linalg.py::test_matrix_rank array_api_tests/test_linalg.py::test_pinv array_api_tests/test_linalg.py::test_slogdet array_api_tests/test_has_names.py::test_has_names[linalg-cross] @@ -115,9 +106,6 @@ array_api_tests/test_has_names.py::test_has_names[linalg-matrix_power] array_api_tests/test_has_names.py::test_has_names[linalg-pinv] array_api_tests/test_has_names.py::test_has_names[linalg-slogdet] -array_api_tests/test_linalg.py::test_matrix_norm -array_api_tests/test_linalg.py::test_matrix_rank - # missing mode kw # https://github.com/dask/dask/issues/10388 array_api_tests/test_linalg.py::test_qr diff --git a/tests/test_dask.py b/tests/test_dask.py index 2983e696..be2b1e39 100644 --- a/tests/test_dask.py +++ b/tests/test_dask.py @@ -1,5 +1,6 @@ from contextlib import contextmanager +import array_api_strict import dask import numpy as np import pytest @@ -20,9 +21,10 @@ def assert_no_compute(): Context manager that raises if at any point inside it anything calls compute() or persist(), e.g. as it can be triggered implicitly by __bool__, __array__, etc. """ + def get(dsk, *args, **kwargs): raise AssertionError("Called compute() or persist()") - + with dask.config.set(scheduler=get): yield @@ -40,6 +42,7 @@ def test_assert_no_compute(): # Test no_compute for functions that use generic _aliases with xp=np + def test_unary_ops_no_compute(xp): with assert_no_compute(): a = xp.asarray([1.5, -1.5]) @@ -59,6 +62,7 @@ def test_matmul_tensordot_no_compute(xp): # Test no_compute for functions that are fully bespoke for dask + def test_asarray_no_compute(xp): with assert_no_compute(): a = xp.arange(10) @@ -88,6 +92,14 @@ def test_clip_no_compute(xp): xp.clip(a, 1, 8) +@pytest.mark.parametrize("chunks", (5, 10)) +def test_sort_argsort_nocompute(xp, chunks): + with assert_no_compute(): + a = xp.arange(10, chunks=chunks) + xp.sort(a) + xp.argsort(a) + + def test_generators_are_lazy(xp): """ Test that generator functions are fully lazy, e.g. that @@ -106,3 +118,62 @@ def test_generators_are_lazy(xp): xp.ones_like(a) xp.empty_like(a) xp.full_like(a, fill_value=123) + + +@pytest.mark.parametrize("axis", [0, 1]) +@pytest.mark.parametrize("func", ["sort", "argsort"]) +def test_sort_argsort_chunks(xp, func, axis): + """Test that sort and argsort are functionally correct when + the array is chunked along the sort axis, e.g. the sort is + not just local to each chunk. + """ + a = da.random.random((10, 10), chunks=(5, 5)) + actual = getattr(xp, func)(a, axis=axis) + expect = getattr(np, func)(a.compute(), axis=axis) + np.testing.assert_array_equal(actual, expect) + + +@pytest.mark.parametrize( + "shape,chunks", + [ + # 3 GiB; 128 MiB per chunk; must rechunk before sorting. + # Sort chunks can be 128 MiB each; no need for final rechunk. + ((20_000, 20_000), "auto"), + # 3 GiB; 128 MiB per chunk; must rechunk before sorting. + # Must sort on two 1.5 GiB chunks; benefits from final rechunk. + ((2, 2**30 * 3 // 16), "auto"), + # 3 GiB; 1.5 GiB per chunk; no need to rechunk before sorting. + # Surely the user must know what they're doing, so don't + # perform the final rechunk. + ((2, 2**30 * 3 // 16), (1, -1)), + ], +) +@pytest.mark.parametrize("func", ["sort", "argsort"]) +def test_sort_argsort_chunk_size(xp, func, shape, chunks): + """ + Test that sort and argsort produce reasonably-sized chunks + in the output array, even if they had to go through a singular + huge one to perform the operation. + """ + a = da.random.random(shape, chunks=chunks) + b = getattr(xp, func)(a) + max_chunk_size = max(b.chunks[0]) * max(b.chunks[1]) * b.dtype.itemsize + assert ( + max_chunk_size <= 128 * 1024 * 1024 # 128 MiB + or b.chunks == a.chunks + ) + + +@pytest.mark.parametrize("func", ["sort", "argsort"]) +def test_sort_argsort_meta(xp, func): + """Test meta-namespace other than numpy""" + typ = type(array_api_strict.asarray(0)) + a = da.random.random(10) + b = a.map_blocks(array_api_strict.asarray) + assert isinstance(b._meta, typ) + c = getattr(xp, func)(b) + assert isinstance(c._meta, typ) + d = c.compute() + # Note: np.sort(array_api_strict.asarray(0)) would return a numpy array + assert isinstance(d, typ) + np.testing.assert_array_equal(d, getattr(np, func)(a.compute()))