diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index 66e405ca..861b0bd0 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -1,7 +1,6 @@ from __future__ import annotations from ...common import _aliases -from ...common._helpers import _check_device from ..._internal import get_xp @@ -40,19 +39,25 @@ isdtype = get_xp(np)(_aliases.isdtype) unstack = get_xp(da)(_aliases.unstack) +# da.astype doesn't respect copy=True def astype( x: Array, dtype: Dtype, /, *, copy: bool = True, - device: Device | None = None + device: Optional[Device] = None ) -> Array: + """ + Array API compatibility wrapper for astype(). + + See the corresponding documentation in the array library and/or the array API + specification for more details. + """ # TODO: respect device keyword? + if not copy and dtype == x.dtype: return x - # dask astype doesn't respect copy=True, - # so call copy manually afterwards x = x.astype(dtype) return x.copy() if copy else x @@ -61,20 +66,24 @@ def astype( # This arange func is modified from the common one to # not pass stop/step as keyword arguments, which will cause # an error with dask - -# TODO: delete the xp stuff, it shouldn't be necessary -def _dask_arange( +def arange( start: Union[int, float], /, stop: Optional[Union[int, float]] = None, step: Union[int, float] = 1, *, - xp, dtype: Optional[Dtype] = None, device: Optional[Device] = None, **kwargs, ) -> Array: - _check_device(xp, device) + """ + Array API compatibility wrapper for arange(). + + See the corresponding documentation in the array library and/or the array API + specification for more details. + """ + # TODO: respect device keyword? + args = [start] if stop is not None: args.append(stop) @@ -83,13 +92,12 @@ def _dask_arange( # prepend the default value for start which is 0 args.insert(0, 0) args.append(step) - return xp.arange(*args, dtype=dtype, **kwargs) -arange = get_xp(da)(_dask_arange) -eye = get_xp(da)(_aliases.eye) + return da.arange(*args, dtype=dtype, **kwargs) + -linspace = get_xp(da)(_aliases.linspace) eye = get_xp(da)(_aliases.eye) +linspace = get_xp(da)(_aliases.linspace) UniqueAllResult = get_xp(da)(_aliases.UniqueAllResult) UniqueCountsResult = get_xp(da)(_aliases.UniqueCountsResult) UniqueInverseResult = get_xp(da)(_aliases.UniqueInverseResult) @@ -112,7 +120,6 @@ def _dask_arange( reshape = get_xp(da)(_aliases.reshape) matrix_transpose = get_xp(da)(_aliases.matrix_transpose) vecdot = get_xp(da)(_aliases.vecdot) - nonzero = get_xp(da)(_aliases.nonzero) ceil = get_xp(np)(_aliases.ceil) floor = get_xp(np)(_aliases.floor) @@ -121,6 +128,7 @@ def _dask_arange( tensordot = get_xp(np)(_aliases.tensordot) sign = get_xp(np)(_aliases.sign) + # asarray also adds the copy keyword, which is not present in numpy 1.0. def asarray( obj: Union[ @@ -135,7 +143,7 @@ def asarray( *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, - copy: "Optional[Union[bool, np._CopyMode]]" = None, + copy: Optional[Union[bool, np._CopyMode]] = None, **kwargs, ) -> Array: """ @@ -144,6 +152,8 @@ def asarray( See the corresponding documentation in the array library and/or the array API specification for more details. """ + # TODO: respect device keyword? + if isinstance(obj, da.Array): if dtype is not None and dtype != obj.dtype: if copy is False: @@ -183,38 +193,40 @@ def asarray( # Furthermore, the masking workaround in common._aliases.clip cannot work with # dask (meaning uint64 promoting to float64 is going to just be unfixed for # now). -@get_xp(da) def clip( x: Array, /, min: Optional[Union[int, float, Array]] = None, max: Optional[Union[int, float, Array]] = None, - *, - xp, ) -> Array: + """ + Array API compatibility wrapper for clip(). + + See the corresponding documentation in the array library and/or the array API + specification for more details. + """ def _isscalar(a): return isinstance(a, (int, float, type(None))) min_shape = () if _isscalar(min) else min.shape max_shape = () if _isscalar(max) else max.shape # TODO: This won't handle dask unknown shapes - import numpy as np result_shape = np.broadcast_shapes(x.shape, min_shape, max_shape) if min is not None: - min = xp.broadcast_to(xp.asarray(min), result_shape) + min = da.broadcast_to(da.asarray(min), result_shape) if max is not None: - max = xp.broadcast_to(xp.asarray(max), result_shape) + max = da.broadcast_to(da.asarray(max), result_shape) if min is None and max is None: - return xp.positive(x) + return da.positive(x) if min is None: - return astype(xp.minimum(x, max), x.dtype) + return astype(da.minimum(x, max), x.dtype) if max is None: - return astype(xp.maximum(x, min), x.dtype) + return astype(da.maximum(x, min), x.dtype) - return astype(xp.minimum(xp.maximum(x, min), max), x.dtype) + return astype(da.minimum(da.maximum(x, min), max), x.dtype) # exclude these from all since dask.array has no sorting functions _da_unsupported = ['sort', 'argsort']