Skip to content

Commit

Permalink
update promotion table and can_cast table
Browse files Browse the repository at this point in the history
  • Loading branch information
HydrogenSulfate committed Nov 26, 2024
1 parent 7118894 commit 85dc3ba
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 77 deletions.
3 changes: 1 addition & 2 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ def is_paddle_array(x):

import paddle

# TODO: Should we reject ndarray subclasses?
return paddle.is_tensor(x)

def is_ndonnx_array(x):
Expand Down Expand Up @@ -725,7 +724,7 @@ def device(x: Array, /) -> Device:
return "cpu"
elif "gpu" in raw_place_str:
return "gpu"
raise NotImplementedError(f"Unsupported device {raw_place_str}")
raise ValueError(f"Unsupported Paddle device: {x.place}")

return x.device

Expand Down
121 changes: 51 additions & 70 deletions array_api_compat/paddle/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,37 +42,18 @@
paddle.complex128,
}

# NOTE: Implicit promotion rules of Paddle is a bit strict than other frameworks,
# see details: https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/guides/advanced/auto_type_promotion_cn.html
_promotion_table = {
# bool
(paddle.bool, paddle.bool): paddle.bool,
# ints
(paddle.int8, paddle.int8): paddle.int8,
(paddle.int8, paddle.int16): paddle.int16,
(paddle.int8, paddle.int32): paddle.int32,
(paddle.int8, paddle.int64): paddle.int64,
(paddle.int16, paddle.int8): paddle.int16,
(paddle.int16, paddle.int16): paddle.int16,
(paddle.int16, paddle.int32): paddle.int32,
(paddle.int16, paddle.int64): paddle.int64,
(paddle.int32, paddle.int8): paddle.int32,
(paddle.int32, paddle.int16): paddle.int32,
(paddle.int32, paddle.int32): paddle.int32,
(paddle.int32, paddle.int64): paddle.int64,
(paddle.int64, paddle.int8): paddle.int64,
(paddle.int64, paddle.int16): paddle.int64,
(paddle.int64, paddle.int32): paddle.int64,
(paddle.int64, paddle.int64): paddle.int64,
# uints
(paddle.uint8, paddle.uint8): paddle.uint8,
# ints and uints (mixed sign)
(paddle.int8, paddle.uint8): paddle.int16,
(paddle.int16, paddle.uint8): paddle.int16,
(paddle.int32, paddle.uint8): paddle.int32,
(paddle.int64, paddle.uint8): paddle.int64,
(paddle.uint8, paddle.int8): paddle.int16,
(paddle.uint8, paddle.int16): paddle.int16,
(paddle.uint8, paddle.int32): paddle.int32,
(paddle.uint8, paddle.int64): paddle.int64,
# floats
(paddle.float32, paddle.float32): paddle.float32,
(paddle.float32, paddle.float64): paddle.float64,
Expand Down Expand Up @@ -158,12 +139,12 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
paddle.float64: True,
paddle.complex64: True,
paddle.complex128: True,
paddle.uint8: False,
paddle.int8: False,
paddle.int16: False,
paddle.int32: False,
paddle.int64: False,
paddle.bool: False,
paddle.uint8: True,
paddle.int8: True,
paddle.int16: True,
paddle.int32: True,
paddle.int64: True,
paddle.bool: True,
},
paddle.float16: {
paddle.bfloat16: True,
Expand All @@ -172,12 +153,12 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
paddle.float64: True,
paddle.complex64: True,
paddle.complex128: True,
paddle.uint8: False,
paddle.int8: False,
paddle.int16: False,
paddle.int32: False,
paddle.int64: False,
paddle.bool: False,
paddle.uint8: True,
paddle.int8: True,
paddle.int16: True,
paddle.int32: True,
paddle.int64: True,
paddle.bool: True,
},
paddle.float32: {
paddle.bfloat16: True,
Expand All @@ -186,12 +167,12 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
paddle.float64: True,
paddle.complex64: True,
paddle.complex128: True,
paddle.uint8: False,
paddle.int8: False,
paddle.int16: False,
paddle.int32: False,
paddle.int64: False,
paddle.bool: False,
paddle.uint8: True,
paddle.int8: True,
paddle.int16: True,
paddle.int32: True,
paddle.int64: True,
paddle.bool: True,
},
paddle.float64: {
paddle.bfloat16: True,
Expand All @@ -200,40 +181,40 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
paddle.float64: True,
paddle.complex64: True,
paddle.complex128: True,
paddle.uint8: False,
paddle.int8: False,
paddle.int16: False,
paddle.int32: False,
paddle.int64: False,
paddle.bool: False,
paddle.uint8: True,
paddle.int8: True,
paddle.int16: True,
paddle.int32: True,
paddle.int64: True,
paddle.bool: True,
},
paddle.complex64: {
paddle.bfloat16: False,
paddle.float16: False,
paddle.float32: False,
paddle.float64: False,
paddle.bfloat16: True,
paddle.float16: True,
paddle.float32: True,
paddle.float64: True,
paddle.complex64: True,
paddle.complex128: True,
paddle.uint8: False,
paddle.int8: False,
paddle.int16: False,
paddle.int32: False,
paddle.int64: False,
paddle.bool: False,
paddle.uint8: True,
paddle.int8: True,
paddle.int16: True,
paddle.int32: True,
paddle.int64: True,
paddle.bool: True,
},
paddle.complex128: {
paddle.bfloat16: False,
paddle.float16: False,
paddle.float32: False,
paddle.float64: False,
paddle.bfloat16: True,
paddle.float16: True,
paddle.float32: True,
paddle.float64: True,
paddle.complex64: True,
paddle.complex128: True,
paddle.uint8: False,
paddle.int8: False,
paddle.int16: False,
paddle.int32: False,
paddle.int64: False,
paddle.bool: False,
paddle.uint8: True,
paddle.int8: True,
paddle.int16: True,
paddle.int32: True,
paddle.int64: True,
paddle.bool: True,
},
paddle.uint8: {
paddle.bfloat16: True,
Expand All @@ -247,7 +228,7 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
paddle.int16: True,
paddle.int32: True,
paddle.int64: True,
paddle.bool: False,
paddle.bool: True,
},
paddle.int8: {
paddle.bfloat16: True,
Expand All @@ -261,7 +242,7 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
paddle.int16: True,
paddle.int32: True,
paddle.int64: True,
paddle.bool: False,
paddle.bool: True,
},
paddle.int16: {
paddle.bfloat16: True,
Expand All @@ -275,7 +256,7 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
paddle.int16: True,
paddle.int32: True,
paddle.int64: True,
paddle.bool: False,
paddle.bool: True,
},
paddle.int32: {
paddle.bfloat16: True,
Expand All @@ -289,7 +270,7 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
paddle.int16: True,
paddle.int32: True,
paddle.int64: True,
paddle.bool: False,
paddle.bool: True,
},
paddle.int64: {
paddle.bfloat16: True,
Expand All @@ -303,7 +284,7 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
paddle.int16: True,
paddle.int32: True,
paddle.int64: True,
paddle.bool: False,
paddle.bool: True,
},
paddle.bool: {
paddle.bfloat16: True,
Expand Down
2 changes: 1 addition & 1 deletion tests/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest

wrapped_libraries = ["numpy", "paddle"]
wrapped_libraries = ["numpy", "paddle", "torch"]
all_libraries = wrapped_libraries + []

# `sparse` added array API support as of Python 3.10.
Expand Down
4 changes: 2 additions & 2 deletions tests/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,5 @@ def test_all(library):
all_names = module.__all__

if set(dir_names) != set(all_names):
assert set(dir_names) - set(all_names) == set(), f"Some dir() names not included in __all__ for {mod_name}"
assert set(all_names) - set(dir_names) == set(), f"Some __all__ names not in dir() for {mod_name}"
assert set(dir_names) - set(all_names) == set(), f"Failed in library '{library}', some dir() names not included in __all__ for {mod_name}"
assert set(all_names) - set(dir_names) == set(), f"Failed in library '{library}', some __all__ names not in dir() for {mod_name}"
11 changes: 9 additions & 2 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
is_array_functions = {
'numpy': 'is_numpy_array',
# 'cupy': 'is_cupy_array',
# 'torch': 'is_torch_array',
'torch': 'is_torch_array',
# 'dask.array': 'is_dask_array',
# 'jax.numpy': 'is_jax_array',
# 'sparse': 'is_pydata_sparse_array',
Expand All @@ -27,7 +27,7 @@
is_namespace_functions = {
'numpy': 'is_numpy_namespace',
# 'cupy': 'is_cupy_namespace',
# 'torch': 'is_torch_namespace',
'torch': 'is_torch_namespace',
# 'dask.array': 'is_dask_namespace',
# 'jax.numpy': 'is_jax_namespace',
# 'sparse': 'is_pydata_sparse_namespace',
Expand Down Expand Up @@ -103,6 +103,13 @@ def test_asarray_cross_library(source_library, target_library, request):
if source_library == "cupy" and target_library != "cupy":
# cupy explicitly disallows implicit conversions to CPU
pytest.skip(reason="cupy does not support implicit conversion to CPU")
if source_library == "paddle" or target_library == "paddle":
pytest.skip(
reason=(
"paddle does not support implicit conversion from/to other framework "
"via 'asarray', dlpack is recommend now."
)
)
elif source_library == "sparse" and target_library != "sparse":
pytest.skip(reason="`sparse` does not allow implicit densification")
src_lib = import_(source_library, wrapper=True)
Expand Down

0 comments on commit 85dc3ba

Please sign in to comment.