Skip to content

Commit

Permalink
Merge pull request #244 from crusaderky/torch_uint
Browse files Browse the repository at this point in the history
ENH: More uint types for torch
  • Loading branch information
ev-br authored Jan 31, 2025
2 parents 6e897a1 + c787bea commit 2eafb97
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 12 deletions.
6 changes: 6 additions & 0 deletions array_api_compat/torch/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@
torch.int32,
torch.int64,
}
try:
# torch >=2.3
_int_dtypes |= {torch.uint16, torch.uint32, torch.uint64}
except AttributeError:
pass


_array_api_dtypes = {
torch.bool,
Expand Down
12 changes: 0 additions & 12 deletions torch-xfails.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,12 @@ array_api_tests/test_array_object.py::test_getitem
array_api_tests/test_array_object.py::test_setitem
# Masking doesn't suport 0 dimensions in the mask
array_api_tests/test_array_object.py::test_getitem_masking
# torch doesn't have uint dtypes other than uint8
array_api_tests/test_array_object.py::test_scalar_casting[__int__(uint16)]
array_api_tests/test_array_object.py::test_scalar_casting[__int__(uint32)]
array_api_tests/test_array_object.py::test_scalar_casting[__int__(uint64)]
array_api_tests/test_array_object.py::test_scalar_casting[__index__(uint16)]
array_api_tests/test_array_object.py::test_scalar_casting[__index__(uint32)]
array_api_tests/test_array_object.py::test_scalar_casting[__index__(uint64)]

# Overflow error from large inputs
array_api_tests/test_creation_functions.py::test_arange
# pytorch linspace bug (should be fixed in torch 2.0)
array_api_tests/test_creation_functions.py::test_linspace

# torch doesn't have higher uint dtypes
array_api_tests/test_data_type_functions.py::test_iinfo[uint16]
array_api_tests/test_data_type_functions.py::test_iinfo[uint32]
array_api_tests/test_data_type_functions.py::test_iinfo[uint64]

# We cannot wrap the tensor object
array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__]
array_api_tests/test_has_names.py::test_has_names[array_method-to_device]
Expand Down

0 comments on commit 2eafb97

Please sign in to comment.