Skip to content

Commit

Permalink
torch: allow python scalars in result_type
Browse files Browse the repository at this point in the history
  • Loading branch information
ev-br committed Jan 7, 2025
1 parent bfe3fcc commit a448710
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions array_api_compat/torch/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ def _fix_promotion(x1, x2, only_scalar=True):
x1 = x1.to(dtype)
return x1, x2


_torch_dtype_and_py_scalars = (torch.dtype, bool, int, float, complex)

def result_type(*arrays_and_dtypes: Union[array, Dtype]) -> Dtype:
if len(arrays_and_dtypes) == 0:
raise TypeError("At least one array or dtype must be provided")
Expand All @@ -140,8 +143,8 @@ def result_type(*arrays_and_dtypes: Union[array, Dtype]) -> Dtype:
# This doesn't result_type(dtype, dtype) for non-array API dtypes
# because torch.result_type only accepts tensors. This does however, allow
# cross-kind promotion.
x = torch.tensor([], dtype=x) if isinstance(x, torch.dtype) else x
y = torch.tensor([], dtype=y) if isinstance(y, torch.dtype) else y
x = torch.tensor([], dtype=x) if isinstance(x, _torch_dtype_and_py_scalars) else x
y = torch.tensor([], dtype=y) if isinstance(y, _torch_dtype_and_py_scalars) else y
return torch.result_type(x, y)

def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
Expand Down

0 comments on commit a448710

Please sign in to comment.