Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: astype: add device kwarg #240

Merged
merged 2 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions array_api_compat/common/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,11 +233,6 @@ def unique_values(x: ndarray, /, xp) -> ndarray:
**kwargs,
)

def astype(x: ndarray, dtype: Dtype, /, *, copy: bool = True) -> ndarray:
if not copy and dtype == x.dtype:
return x
return x.astype(dtype=dtype, copy=copy)

# These functions have different keyword argument names

def std(
Expand Down Expand Up @@ -549,7 +544,7 @@ def sign(x: ndarray, /, xp, **kwargs) -> ndarray:
'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
'astype', 'std', 'var', 'cumulative_sum', 'clip', 'permute_dims',
'std', 'var', 'cumulative_sum', 'clip', 'permute_dims',
'reshape', 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc',
'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype',
'unstack', 'sign']
22 changes: 18 additions & 4 deletions array_api_compat/cupy/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import cupy as cp

from ..common import _aliases
from ..common import _aliases, _helpers
from .._internal import get_xp

from ._info import __array_namespace_info__
Expand Down Expand Up @@ -46,7 +46,6 @@
unique_counts = get_xp(cp)(_aliases.unique_counts)
unique_inverse = get_xp(cp)(_aliases.unique_inverse)
unique_values = get_xp(cp)(_aliases.unique_values)
astype = _aliases.astype
std = get_xp(cp)(_aliases.std)
var = get_xp(cp)(_aliases.var)
cumulative_sum = get_xp(cp)(_aliases.cumulative_sum)
Expand Down Expand Up @@ -110,6 +109,21 @@ def asarray(

return cp.array(obj, dtype=dtype, **kwargs)


def astype(
x: ndarray,
dtype: Dtype,
/,
*,
copy: bool = True,
device: Optional[Device] = None,
) -> ndarray:
if device is None:
return x.astype(dtype=dtype, copy=copy)
out = _helpers.to_device(x.astype(dtype=dtype, copy=False), device)
return out.copy() if copy and out is x else out


# These functions are completely new here. If the library already has them
# (i.e., numpy 2.0), use the library version instead of our wrapper.
if hasattr(cp, 'vecdot'):
Expand All @@ -127,10 +141,10 @@ def asarray(
else:
unstack = get_xp(cp)(_aliases.unstack)

__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'bool',
__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'astype',
'acos', 'acosh', 'asin', 'asinh', 'atan',
'atan2', 'atanh', 'bitwise_left_shift',
'bitwise_invert', 'bitwise_right_shift',
'concat', 'pow', 'sign']
'bool', 'concat', 'pow', 'sign']

_all_ignore = ['cp', 'get_xp']
2 changes: 1 addition & 1 deletion array_api_compat/dask/array/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def _isscalar(a):

_common_aliases = [alias for alias in _aliases.__all__ if alias not in _da_unsupported]

__all__ = _common_aliases + ['__array_namespace_info__', 'asarray', 'acos',
__all__ = _common_aliases + ['__array_namespace_info__', 'asarray', 'astype', 'acos',
'acosh', 'asin', 'asinh', 'atan', 'atan2',
'atanh', 'bitwise_left_shift', 'bitwise_invert',
'bitwise_right_shift', 'concat', 'pow', 'iinfo', 'finfo', 'can_cast',
Expand Down
17 changes: 14 additions & 3 deletions array_api_compat/numpy/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
unique_counts = get_xp(np)(_aliases.unique_counts)
unique_inverse = get_xp(np)(_aliases.unique_inverse)
unique_values = get_xp(np)(_aliases.unique_values)
astype = _aliases.astype
std = get_xp(np)(_aliases.std)
var = get_xp(np)(_aliases.var)
cumulative_sum = get_xp(np)(_aliases.cumulative_sum)
Expand Down Expand Up @@ -115,6 +114,18 @@ def asarray(

return np.array(obj, copy=copy, dtype=dtype, **kwargs)


def astype(
x: ndarray,
dtype: Dtype,
/,
*,
copy: bool = True,
device: Optional[Device] = None,
) -> ndarray:
return x.astype(dtype=dtype, copy=copy)


# These functions are completely new here. If the library already has them
# (i.e., numpy 2.0), use the library version instead of our wrapper.
if hasattr(np, 'vecdot'):
Expand All @@ -132,10 +143,10 @@ def asarray(
else:
unstack = get_xp(np)(_aliases.unstack)

__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'bool',
__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'astype',
'acos', 'acosh', 'asin', 'asinh', 'atan',
'atan2', 'atanh', 'bitwise_left_shift',
'bitwise_invert', 'bitwise_right_shift',
'concat', 'pow']
'bool', 'concat', 'pow']

_all_ignore = ['np', 'get_xp']
15 changes: 13 additions & 2 deletions array_api_compat/torch/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,8 +613,19 @@ def triu(x: array, /, *, k: int = 0) -> array:
def expand_dims(x: array, /, *, axis: int = 0) -> array:
return torch.unsqueeze(x, axis)

def astype(x: array, dtype: Dtype, /, *, copy: bool = True) -> array:
return x.to(dtype, copy=copy)

def astype(
x: array,
dtype: Dtype,
/,
*,
copy: bool = True,
device: Optional[Device] = None,
) -> array:
if device is not None:
return x.to(device, dtype=dtype, copy=copy)
return x.to(dtype=dtype, copy=copy)


def broadcast_arrays(*arrays: array) -> List[array]:
shape = torch.broadcast_shapes(*[a.shape for a in arrays])
Expand Down
1 change: 0 additions & 1 deletion cupy-xfails.txt
Original file line number Diff line number Diff line change
Expand Up @@ -181,5 +181,4 @@ array_api_tests/test_fft.py::test_irfftn
# cupy.ndaray cannot be specified as `repeats` argument.
array_api_tests/test_manipulation_functions.py::test_repeat
array_api_tests/test_signatures.py::test_func_signature[from_dlpack]
array_api_tests/test_signatures.py::test_func_signature[astype]
array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__]
1 change: 0 additions & 1 deletion dask-xfails.txt
Original file line number Diff line number Diff line change
Expand Up @@ -154,4 +154,3 @@ array_api_tests/test_statistical_functions.py::test_prod
# 2023.12 support
array_api_tests/test_manipulation_functions.py::test_repeat
array_api_tests/test_searching_functions.py::test_searchsorted
array_api_tests/test_signatures.py::test_func_signature[astype]
1 change: 0 additions & 1 deletion numpy-1-21-xfails.txt
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,6 @@ array_api_tests/test_special_cases.py::test_iop[__iadd__(x1_i is -0 and x2_i is
# 2023.12 support
array_api_tests/test_searching_functions.py::test_searchsorted
array_api_tests/test_signatures.py::test_func_signature[from_dlpack]
array_api_tests/test_signatures.py::test_func_signature[astype]
array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__]
# uint64 repeats not supported
array_api_tests/test_manipulation_functions.py::test_repeat
1 change: 0 additions & 1 deletion numpy-1-26-xfails.txt
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ array_api_tests/test_statistical_functions.py::test_prod
# 2023.12 support
array_api_tests/test_searching_functions.py::test_searchsorted
array_api_tests/test_signatures.py::test_func_signature[from_dlpack]
array_api_tests/test_signatures.py::test_func_signature[astype]
array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__]
# uint64 repeats not supported
array_api_tests/test_manipulation_functions.py::test_repeat
1 change: 0 additions & 1 deletion numpy-dev-xfails.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ array_api_tests/test_signatures.py::test_extension_func_signature[linalg.vecdot]

# 2023.12 support
# Argument 'device' missing from signature
array_api_tests/test_signatures.py::test_func_signature[astype]
array_api_tests/test_signatures.py::test_func_signature[from_dlpack]
array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__]
# uint64 repeats not supported
Expand Down
1 change: 0 additions & 1 deletion numpy-xfails.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ array_api_tests/test_signatures.py::test_extension_func_signature[linalg.vecdot]
# 2023.12 support
array_api_tests/test_searching_functions.py::test_searchsorted
array_api_tests/test_signatures.py::test_func_signature[from_dlpack]
array_api_tests/test_signatures.py::test_func_signature[astype]
array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__]
# uint64 repeats not supported
array_api_tests/test_manipulation_functions.py::test_repeat
2 changes: 0 additions & 2 deletions torch-xfails.txt
Original file line number Diff line number Diff line change
Expand Up @@ -202,5 +202,3 @@ array_api_tests/test_signatures.py::test_func_signature[repeat]
array_api_tests/test_signatures.py::test_func_signature[from_dlpack]
# Argument 'max_version' missing from signature
array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__]
# Argument 'device' missing from signature
array_api_tests/test_signatures.py::test_func_signature[astype]
Loading