From a4487103c519b14080cee2c9f8de54cdb0b9bf83 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Tue, 7 Jan 2025 17:23:29 +0200 Subject: [PATCH] torch: allow python scalars in result_type --- array_api_compat/torch/_aliases.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index df555054..f6ecb2f0 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -119,6 +119,9 @@ def _fix_promotion(x1, x2, only_scalar=True): x1 = x1.to(dtype) return x1, x2 + +_torch_dtype_and_py_scalars = (torch.dtype, bool, int, float, complex) + def result_type(*arrays_and_dtypes: Union[array, Dtype]) -> Dtype: if len(arrays_and_dtypes) == 0: raise TypeError("At least one array or dtype must be provided") @@ -140,8 +143,8 @@ def result_type(*arrays_and_dtypes: Union[array, Dtype]) -> Dtype: # This doesn't result_type(dtype, dtype) for non-array API dtypes # because torch.result_type only accepts tensors. This does however, allow # cross-kind promotion. - x = torch.tensor([], dtype=x) if isinstance(x, torch.dtype) else x - y = torch.tensor([], dtype=y) if isinstance(y, torch.dtype) else y + x = torch.tensor([], dtype=x) if isinstance(x, _torch_dtype_and_py_scalars) else x + y = torch.tensor([], dtype=y) if isinstance(y, _torch_dtype_and_py_scalars) else y return torch.result_type(x, y) def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool: