Skip to content

Commit

Permalink
Merge branch 'no-dependencies' of github.com:asmeurer/array-api-compa…
Browse files Browse the repository at this point in the history
…t into no-dependencies
  • Loading branch information
asmeurer committed Mar 20, 2024
2 parents ea96a84 + 554e77d commit 93b8a53
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 24 deletions.
37 changes: 18 additions & 19 deletions array_api_compat/common/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,31 +325,30 @@ def _asarray(
else:
COPY_FALSE = (False,)
COPY_TRUE = (True,)
if copy in COPY_FALSE:
if copy in COPY_FALSE and namespace != "dask.array":
# copy=False is not yet implemented in xp.asarray
raise NotImplementedError("copy=False is not yet implemented")
if (hasattr(xp, "ndarray") and isinstance(obj, xp.ndarray)) or hasattr(obj, "__array__"):
#print('hit me')
if (hasattr(xp, "ndarray") and isinstance(obj, xp.ndarray)):
if dtype is not None and obj.dtype != dtype:
copy = True
#print(copy)
if copy in COPY_TRUE:
copy_kwargs = {}
if namespace != "dask.array":
copy_kwargs["copy"] = True
else:
# No copy kw in dask.asarray so we go thorugh np.asarray first
# (like dask also does) but copy after
if dtype is None:
# Same dtype copy is no-op in dask
#print("in here?")
return obj.copy()
import numpy as np
#print(obj)
obj = np.asarray(obj).copy()
#print(obj)
return xp.array(obj, dtype=dtype, **copy_kwargs)
return xp.array(obj, copy=True, dtype=dtype)
return obj
elif namespace == "dask.array":
if copy in COPY_TRUE:
if dtype is None:
return obj.copy()
# Go through numpy, since dask copy is no-op by default
import numpy as np
obj = np.array(obj, dtype=dtype, copy=True)
return xp.array(obj, dtype=dtype)
else:
import dask.array as da
import numpy as np
if not isinstance(obj, da.Array):
obj = np.asarray(obj, dtype=dtype)
return da.from_array(obj)
return obj

return xp.asarray(obj, dtype=dtype, **kwargs)

Expand Down
17 changes: 13 additions & 4 deletions array_api_compat/dask/array/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from ...common._typing import Array
from typing import Literal

# cupy.linalg doesn't have __all__. If it is added, replace this with
# dask.array.linalg doesn't have __all__. If it is added, replace this with
#
# from cupy.linalg import __all__ as linalg_all
# from dask.array.linalg import __all__ as linalg_all
_n = {}
exec('from dask.array.linalg import *', _n)
del _n['__builtins__']
Expand All @@ -32,7 +33,15 @@
QRResult = _linalg.QRResult
SlogdetResult = _linalg.SlogdetResult
SVDResult = _linalg.SVDResult
qr = get_xp(da)(_linalg.qr)
# TODO: use the QR wrapper once dask
# supports the mode keyword on QR
# https://github.com/dask/dask/issues/10388
#qr = get_xp(da)(_linalg.qr)
def qr(x: Array, mode: Literal['reduced', 'complete'] = 'reduced',
**kwargs) -> QRResult:
if mode != "reduced":
raise ValueError("dask arrays only support using mode='reduced'")
return QRResult(*da.linalg.qr(x, **kwargs))
cholesky = get_xp(da)(_linalg.cholesky)
matrix_rank = get_xp(da)(_linalg.matrix_rank)
matrix_norm = get_xp(da)(_linalg.matrix_norm)
Expand All @@ -44,7 +53,7 @@
def svd(x: Array, full_matrices: bool = True, **kwargs) -> SVDResult:
if full_matrices:
raise ValueError("full_matrics=True is not supported by dask.")
return da.linalg.svd(x, **kwargs)
return da.linalg.svd(x, coerce_signs=False, **kwargs)

def svdvals(x: Array) -> Array:
# TODO: can't avoid computing U or V for dask
Expand Down
3 changes: 3 additions & 0 deletions dask-xfails.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ array_api_tests/test_data_type_functions.py::test_finfo[float32]
# (I think the test is not forcing the op to be computed?)
array_api_tests/test_creation_functions.py::test_linspace

# out.shape=(2,) but should be (1,)
array_api_tests/test_indexing_functions.py::test_take

# out=-0, but should be +0
array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0]
array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0]
Expand Down
6 changes: 5 additions & 1 deletion numpy-1-21-xfails.txt
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ array_api_tests/test_set_functions.py::test_unique_values
# The test suite is incorrectly checking sums that have loss of significance
# (https://github.com/data-apis/array-api-tests/issues/168)
array_api_tests/test_statistical_functions.py::test_sum
array_api_tests/test_statistical_functions.py::test_prod

# NumPy 1.21 doesn't support NPY_PROMOTION_STATE=weak, so many tests fail with
# type promotion issues
Expand Down Expand Up @@ -121,21 +122,24 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[bi
array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__truediv__(x, s)]
array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__truediv__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_divide[divide(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_equal[equal(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_equal[__eq__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_equal[equal(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x, s)]
array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x, s)]
array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[floor_divide(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_greater[__gt__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_greater[greater(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_less[__lt__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_less[less(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_less_equal[less_equal(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_logaddexp
array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__imul__(x, s)]
array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__mul__(x, s)]
array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__mul__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[multiply(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_not_equal[__ne__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_not_equal[not_equal(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__ipow__(x, s)]
array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__pow__(x, s)]
array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__pow__(x1, x2)]
Expand Down
1 change: 1 addition & 0 deletions numpy-dev-xfails.txt
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,4 @@ array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices
# The test suite is incorrectly checking sums that have loss of significance
# (https://github.com/data-apis/array-api-tests/issues/168)
array_api_tests/test_statistical_functions.py::test_sum
array_api_tests/test_statistical_functions.py::test_prod
1 change: 1 addition & 0 deletions numpy-xfails.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,4 @@ array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices
# The test suite is incorrectly checking sums that have loss of significance
# (https://github.com/data-apis/array-api-tests/issues/168)
array_api_tests/test_statistical_functions.py::test_sum
array_api_tests/test_statistical_functions.py::test_prod
20 changes: 20 additions & 0 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,23 @@ def test_to_device_host(library):
# here is that we can test portably after calling
# to_device(x, "cpu") to return to host
assert_allclose(x, expected)


@pytest.mark.parametrize("target_library,func", is_functions.items())
@pytest.mark.parametrize("source_library", is_functions.keys())
def test_asarray(source_library, target_library, func, request):
if source_library == "dask.array" and target_library == "torch":
# Allow rest of test to execute instead of immediately xfailing
# xref https://github.com/pandas-dev/pandas/issues/38902

# TODO: remove xfail once
# https://github.com/dask/dask/issues/8260 is resolved
request.node.add_marker(pytest.mark.xfail(reason="Bug in dask raising error on conversion"))
src_lib = import_(source_library, wrapper=True)
tgt_lib = import_(target_library, wrapper=True)
is_tgt_type = globals()[func]

a = src_lib.asarray([1, 2, 3])
b = tgt_lib.asarray(a)

assert is_tgt_type(b), f"Expected {b} to be a {tgt_lib.ndarray}, but was {type(b)}"

0 comments on commit 93b8a53

Please sign in to comment.