diff --git a/test/test_operations.py b/test/test_operations.py index 20365cb4f1e0..6893bf093a0a 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2370,6 +2370,20 @@ def foo(x: torch.Tensor) -> torch.Tensor: self.assertEqual(out.dtype, out_xla.dtype) self.assertEqual(out.cpu(), out_xla.cpu(), prec=1e-4) + def test_cummax_0_sized_dimension(self): + # Test cummax on dim=2 (a 0-sized dimension). + # + # Make sure we are not crashing, here. Instead, we should return a tuple of + # empty tensors, just like PyTorch. + + dim = 2 + a = torch.rand(5, 5, 0, 5) + + expected = torch.cummax(a, dim) + actual = torch.cummax(a.to(xm.xla_device()), dim) + + self.assertEqual(actual, expected) + class MNISTComparator(nn.Module): diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index b0bcda9e0967..b43b7c6f5f8b 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -1314,9 +1314,23 @@ XLATensorPtr cross(const XLATensorPtr& input, const XLATensorPtr& other, std::tuple cummax(const XLATensorPtr& input, int64_t dim) { - torch::lazy::NodePtr node = torch_xla::MakeNode( - input->GetIrValue(), torch::lazy::GetCanonicalDimensionIndex( - dim, input->shape().get().rank())); + xla::Shape shape = input->shape().get(); + int64_t canonical_dim = + torch::lazy::GetCanonicalDimensionIndex(dim, shape.rank()); + + if (shape.dimensions(canonical_dim) == 0) { + // Handle edge-case where the size of `dim` is 0. + // The current lowering crashes, setting the padding to -1. + absl::Span dimensions = shape.dimensions(); + at::IntArrayRef shape_(dimensions.data(), dimensions.size()); + at::Tensor val = + at::empty(shape_, at::TensorOptions().dtype(input->dtype())); + at::Tensor idx = at::empty(shape_, at::TensorOptions().dtype(at::kLong)); + return std::make_tuple(input->Create(val, input->GetDevice()), + input->Create(idx, input->GetDevice())); + } + torch::lazy::NodePtr node = + torch_xla::MakeNode(input->GetIrValue(), canonical_dim); XLATensorPtr t_value = input->CreateFrom(torch::lazy::Value(node, 0), /*delay_eager_executation=*/true); XLATensorPtr t_index =