Skip to content

Commit

Permalink
Merge pull request #236 from crusaderky/dask_cosmetic
Browse files Browse the repository at this point in the history
MAINT: dask: cosmetic tweaks
  • Loading branch information
ev-br authored Jan 18, 2025
2 parents 44e1eb3 + 0e4d6d8 commit 8a79994
Showing 1 changed file with 37 additions and 25 deletions.
62 changes: 37 additions & 25 deletions array_api_compat/dask/array/_aliases.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

from ...common import _aliases
from ...common._helpers import _check_device

from ..._internal import get_xp

Expand Down Expand Up @@ -40,19 +39,25 @@
isdtype = get_xp(np)(_aliases.isdtype)
unstack = get_xp(da)(_aliases.unstack)

# da.astype doesn't respect copy=True
def astype(
x: Array,
dtype: Dtype,
/,
*,
copy: bool = True,
device: Device | None = None
device: Optional[Device] = None
) -> Array:
"""
Array API compatibility wrapper for astype().
See the corresponding documentation in the array library and/or the array API
specification for more details.
"""
# TODO: respect device keyword?

if not copy and dtype == x.dtype:
return x
# dask astype doesn't respect copy=True,
# so call copy manually afterwards
x = x.astype(dtype)
return x.copy() if copy else x

Expand All @@ -61,20 +66,24 @@ def astype(
# This arange func is modified from the common one to
# not pass stop/step as keyword arguments, which will cause
# an error with dask

# TODO: delete the xp stuff, it shouldn't be necessary
def _dask_arange(
def arange(
start: Union[int, float],
/,
stop: Optional[Union[int, float]] = None,
step: Union[int, float] = 1,
*,
xp,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
**kwargs,
) -> Array:
_check_device(xp, device)
"""
Array API compatibility wrapper for arange().
See the corresponding documentation in the array library and/or the array API
specification for more details.
"""
# TODO: respect device keyword?

args = [start]
if stop is not None:
args.append(stop)
Expand All @@ -83,13 +92,12 @@ def _dask_arange(
# prepend the default value for start which is 0
args.insert(0, 0)
args.append(step)
return xp.arange(*args, dtype=dtype, **kwargs)

arange = get_xp(da)(_dask_arange)
eye = get_xp(da)(_aliases.eye)
return da.arange(*args, dtype=dtype, **kwargs)


linspace = get_xp(da)(_aliases.linspace)
eye = get_xp(da)(_aliases.eye)
linspace = get_xp(da)(_aliases.linspace)
UniqueAllResult = get_xp(da)(_aliases.UniqueAllResult)
UniqueCountsResult = get_xp(da)(_aliases.UniqueCountsResult)
UniqueInverseResult = get_xp(da)(_aliases.UniqueInverseResult)
Expand All @@ -112,7 +120,6 @@ def _dask_arange(
reshape = get_xp(da)(_aliases.reshape)
matrix_transpose = get_xp(da)(_aliases.matrix_transpose)
vecdot = get_xp(da)(_aliases.vecdot)

nonzero = get_xp(da)(_aliases.nonzero)
ceil = get_xp(np)(_aliases.ceil)
floor = get_xp(np)(_aliases.floor)
Expand All @@ -121,6 +128,7 @@ def _dask_arange(
tensordot = get_xp(np)(_aliases.tensordot)
sign = get_xp(np)(_aliases.sign)


# asarray also adds the copy keyword, which is not present in numpy 1.0.
def asarray(
obj: Union[
Expand All @@ -135,7 +143,7 @@ def asarray(
*,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
copy: "Optional[Union[bool, np._CopyMode]]" = None,
copy: Optional[Union[bool, np._CopyMode]] = None,
**kwargs,
) -> Array:
"""
Expand All @@ -144,6 +152,8 @@ def asarray(
See the corresponding documentation in the array library and/or the array API
specification for more details.
"""
# TODO: respect device keyword?

if isinstance(obj, da.Array):
if dtype is not None and dtype != obj.dtype:
if copy is False:
Expand Down Expand Up @@ -183,38 +193,40 @@ def asarray(
# Furthermore, the masking workaround in common._aliases.clip cannot work with
# dask (meaning uint64 promoting to float64 is going to just be unfixed for
# now).
@get_xp(da)
def clip(
x: Array,
/,
min: Optional[Union[int, float, Array]] = None,
max: Optional[Union[int, float, Array]] = None,
*,
xp,
) -> Array:
"""
Array API compatibility wrapper for clip().
See the corresponding documentation in the array library and/or the array API
specification for more details.
"""
def _isscalar(a):
return isinstance(a, (int, float, type(None)))
min_shape = () if _isscalar(min) else min.shape
max_shape = () if _isscalar(max) else max.shape

# TODO: This won't handle dask unknown shapes
import numpy as np
result_shape = np.broadcast_shapes(x.shape, min_shape, max_shape)

if min is not None:
min = xp.broadcast_to(xp.asarray(min), result_shape)
min = da.broadcast_to(da.asarray(min), result_shape)
if max is not None:
max = xp.broadcast_to(xp.asarray(max), result_shape)
max = da.broadcast_to(da.asarray(max), result_shape)

if min is None and max is None:
return xp.positive(x)
return da.positive(x)

if min is None:
return astype(xp.minimum(x, max), x.dtype)
return astype(da.minimum(x, max), x.dtype)
if max is None:
return astype(xp.maximum(x, min), x.dtype)
return astype(da.maximum(x, min), x.dtype)

return astype(xp.minimum(xp.maximum(x, min), max), x.dtype)
return astype(da.minimum(da.maximum(x, min), max), x.dtype)

# exclude these from all since dask.array has no sorting functions
_da_unsupported = ['sort', 'argsort']
Expand Down

0 comments on commit 8a79994

Please sign in to comment.