From def6b46c7e062711a1e42c8bc62573f0eac28826 Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Mon, 4 Mar 2024 18:11:15 -0500 Subject: [PATCH 1/8] Fix up asarray --- array_api_compat/common/_aliases.py | 37 ++++++++++++++--------------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 8792aa2e..f998481c 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -325,31 +325,30 @@ def _asarray( else: COPY_FALSE = (False,) COPY_TRUE = (True,) - if copy in COPY_FALSE: + if copy in COPY_FALSE and namespace != "dask.array": # copy=False is not yet implemented in xp.asarray raise NotImplementedError("copy=False is not yet implemented") - if (hasattr(xp, "ndarray") and isinstance(obj, xp.ndarray)) or hasattr(obj, "__array__"): - #print('hit me') + if (hasattr(xp, "ndarray") and isinstance(obj, xp.ndarray)): if dtype is not None and obj.dtype != dtype: copy = True - #print(copy) if copy in COPY_TRUE: - copy_kwargs = {} - if namespace != "dask.array": - copy_kwargs["copy"] = True - else: - # No copy kw in dask.asarray so we go thorugh np.asarray first - # (like dask also does) but copy after - if dtype is None: - # Same dtype copy is no-op in dask - #print("in here?") - return obj.copy() - import numpy as np - #print(obj) - obj = np.asarray(obj).copy() - #print(obj) - return xp.array(obj, dtype=dtype, **copy_kwargs) + return xp.array(obj, copy=True, dtype=dtype) return obj + elif namespace == "dask.array": + if copy in COPY_TRUE: + if dtype is None: + return obj.copy() + # Go through numpy, since dask copy is no-op by default + import numpy as np + obj = np.array(obj, dtype=dtype, copy=True) + return xp.array(obj, dtype=dtype) + else: + import dask.array as da + import numpy as np + if not isinstance(obj, da.Array): + obj = np.asarray(obj, dtype=dtype) + return da.from_array(obj) + return obj return xp.asarray(obj, dtype=dtype, **kwargs) From ec4f628fdac6d30e778836e8fca081f079f97bae Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Mon, 11 Mar 2024 18:25:40 -0400 Subject: [PATCH 2/8] last linalg fixes --- array_api_compat/dask/array/linalg.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/array_api_compat/dask/array/linalg.py b/array_api_compat/dask/array/linalg.py index 03f16e89..60637382 100644 --- a/array_api_compat/dask/array/linalg.py +++ b/array_api_compat/dask/array/linalg.py @@ -17,9 +17,9 @@ if TYPE_CHECKING: from ...common._typing import Array -# cupy.linalg doesn't have __all__. If it is added, replace this with +# dask.array.linalg doesn't have __all__. If it is added, replace this with # -# from cupy.linalg import __all__ as linalg_all +# from dask.array.linalg import __all__ as linalg_all _n = {} exec('from dask.array.linalg import *', _n) del _n['__builtins__'] @@ -32,7 +32,15 @@ QRResult = _linalg.QRResult SlogdetResult = _linalg.SlogdetResult SVDResult = _linalg.SVDResult -qr = get_xp(da)(_linalg.qr) +# TODO: use the QR wrapper once dask +# supports the mode keyword on QR +# https://github.com/dask/dask/issues/10388 +#qr = get_xp(da)(_linalg.qr) +def qr(x: Array, mode: Literal['reduced', 'complete'] = 'reduced', + **kwargs) -> QRResult: + if mode != "reduced": + raise ValueError("dask arrays only support using mode='reduced'") + return QRResult(*da.linalg.qr(x, **kwargs)) cholesky = get_xp(da)(_linalg.cholesky) matrix_rank = get_xp(da)(_linalg.matrix_rank) matrix_norm = get_xp(da)(_linalg.matrix_norm) @@ -44,7 +52,7 @@ def svd(x: Array, full_matrices: bool = True, **kwargs) -> SVDResult: if full_matrices: raise ValueError("full_matrics=True is not supported by dask.") - return da.linalg.svd(x, **kwargs) + return da.linalg.svd(x, coerce_signs=False, **kwargs) def svdvals(x: Array) -> Array: # TODO: can't avoid computing U or V for dask From 9df9aa834acc4ad23737e4b973504f33ca70f558 Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Mon, 11 Mar 2024 18:27:11 -0400 Subject: [PATCH 3/8] add regression test for asarray Co-Authored-By: Isaac Virshup --- tests/test_common.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/test_common.py b/tests/test_common.py index 66076bfe..98f454e6 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -60,3 +60,16 @@ def test_to_device_host(library): # here is that we can test portably after calling # to_device(x, "cpu") to return to host assert_allclose(x, expected) + + +@pytest.mark.parametrize("target_library,func", is_functions.items()) +@pytest.mark.parametrize("source_library", is_functions.keys()) +def test_asarray(source_library, target_library, func): + src_lib = import_(source_library, wrapper=True) + tgt_lib = import_(target_library, wrapper=True) + is_tgt_type = globals()[func] + + a = src_lib.asarray([1, 2, 3]) + b = tgt_lib.asarray(a) + + assert is_tgt_type(b), f"Expected {b} to be a {tgt_lib.ndarray}, but was {type(b)}" From effc041e8f3cb15877e22edc933405750c70a66d Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Mon, 11 Mar 2024 18:30:27 -0400 Subject: [PATCH 4/8] fix tests? --- array_api_compat/dask/array/linalg.py | 1 + 1 file changed, 1 insertion(+) diff --git a/array_api_compat/dask/array/linalg.py b/array_api_compat/dask/array/linalg.py index 60637382..7f5b2c6e 100644 --- a/array_api_compat/dask/array/linalg.py +++ b/array_api_compat/dask/array/linalg.py @@ -16,6 +16,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from ...common._typing import Array + from typing import Literal # dask.array.linalg doesn't have __all__. If it is added, replace this with # From 826cde40f9bb3867e230c496a2bb050ac924d3ff Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Mon, 11 Mar 2024 21:35:52 -0400 Subject: [PATCH 5/8] fix tests --- dask-xfails.txt | 3 +++ tests/test_common.py | 9 ++++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/dask-xfails.txt b/dask-xfails.txt index ecde5420..0d74ecbb 100644 --- a/dask-xfails.txt +++ b/dask-xfails.txt @@ -37,6 +37,9 @@ array_api_tests/test_data_type_functions.py::test_finfo[float32] # (I think the test is not forcing the op to be computed?) array_api_tests/test_creation_functions.py::test_linspace +# out.shape=(2,) but should be (1,) +array_api_tests/test_indexing_functions.py::test_take + # out=-0, but should be +0 array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0] array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0] diff --git a/tests/test_common.py b/tests/test_common.py index 98f454e6..22b98d83 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -64,7 +64,14 @@ def test_to_device_host(library): @pytest.mark.parametrize("target_library,func", is_functions.items()) @pytest.mark.parametrize("source_library", is_functions.keys()) -def test_asarray(source_library, target_library, func): +def test_asarray(source_library, target_library, func, request): + if source_library == "dask.array" and target_library == "torch": + # Allow rest of test to execute instead of immediately xfailing + # xref https://github.com/pandas-dev/pandas/issues/38902 + + # TODO: remove xfail once + # https://github.com/dask/dask/issues/8260 is resolved + request.node.add_marker(pytest.mark.xfail(reason="Bug in dask raising error on conversion")) src_lib = import_(source_library, wrapper=True) tgt_lib = import_(target_library, wrapper=True) is_tgt_type = globals()[func] From c63ebe6582f8fb838e985f339b248fe626c7c25d Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 18 Mar 2024 14:43:32 -0600 Subject: [PATCH 6/8] Update some xfails --- numpy-1-21-xfails.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/numpy-1-21-xfails.txt b/numpy-1-21-xfails.txt index 2a564a3a..b1556dc0 100644 --- a/numpy-1-21-xfails.txt +++ b/numpy-1-21-xfails.txt @@ -121,8 +121,8 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[bi array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__truediv__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__truediv__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_divide[divide(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_equal[equal(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_equal[__eq__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_equal[equal(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x, s)] @@ -130,12 +130,14 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[f array_api_tests/test_operators_and_elementwise_functions.py::test_greater[__gt__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_greater[greater(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_less[__lt__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_less[less(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_less_equal[less_equal(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_logaddexp array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__imul__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__mul__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__mul__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[multiply(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_not_equal[not_equal(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__ipow__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__pow__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__pow__(x1, x2)] From 8d42075864bc9d7c3bd63fbe8826f8956616068c Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 18 Mar 2024 14:45:01 -0600 Subject: [PATCH 7/8] XFAIL test_prod for NumPy --- numpy-1-21-xfails.txt | 1 + numpy-dev-xfails.txt | 1 + numpy-xfails.txt | 1 + 3 files changed, 3 insertions(+) diff --git a/numpy-1-21-xfails.txt b/numpy-1-21-xfails.txt index b1556dc0..aa6de9f8 100644 --- a/numpy-1-21-xfails.txt +++ b/numpy-1-21-xfails.txt @@ -89,6 +89,7 @@ array_api_tests/test_set_functions.py::test_unique_values # The test suite is incorrectly checking sums that have loss of significance # (https://github.com/data-apis/array-api-tests/issues/168) array_api_tests/test_statistical_functions.py::test_sum +array_api_tests/test_statistical_functions.py::test_prod # NumPy 1.21 doesn't support NPY_PROMOTION_STATE=weak, so many tests fail with # type promotion issues diff --git a/numpy-dev-xfails.txt b/numpy-dev-xfails.txt index 8d291d01..51ff34ad 100644 --- a/numpy-dev-xfails.txt +++ b/numpy-dev-xfails.txt @@ -42,3 +42,4 @@ array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices # The test suite is incorrectly checking sums that have loss of significance # (https://github.com/data-apis/array-api-tests/issues/168) array_api_tests/test_statistical_functions.py::test_sum +array_api_tests/test_statistical_functions.py::test_prod diff --git a/numpy-xfails.txt b/numpy-xfails.txt index e44d7035..40c6cbc4 100644 --- a/numpy-xfails.txt +++ b/numpy-xfails.txt @@ -44,3 +44,4 @@ array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices # The test suite is incorrectly checking sums that have loss of significance # (https://github.com/data-apis/array-api-tests/issues/168) array_api_tests/test_statistical_functions.py::test_sum +array_api_tests/test_statistical_functions.py::test_prod From a3e2c0df1d7e72727f881233adfbaa79b1dd2bfb Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 18 Mar 2024 22:09:25 -0600 Subject: [PATCH 8/8] Add xfail --- numpy-1-21-xfails.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/numpy-1-21-xfails.txt b/numpy-1-21-xfails.txt index aa6de9f8..fe53f452 100644 --- a/numpy-1-21-xfails.txt +++ b/numpy-1-21-xfails.txt @@ -138,6 +138,7 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__imu array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__mul__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__mul__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[multiply(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_not_equal[__ne__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_not_equal[not_equal(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__ipow__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__pow__(x, s)]