Skip to content

Commit 1c53790

Browse files
authored
cummax: fix 0-sized dimension reduction. (#8653)
1 parent 3578940 commit 1c53790

File tree

2 files changed

+31
-3
lines changed

2 files changed

+31
-3
lines changed

test/test_operations.py

+14
Original file line numberDiff line numberDiff line change
@@ -2370,6 +2370,20 @@ def foo(x: torch.Tensor) -> torch.Tensor:
23702370
self.assertEqual(out.dtype, out_xla.dtype)
23712371
self.assertEqual(out.cpu(), out_xla.cpu(), prec=1e-4)
23722372

2373+
def test_cummax_0_sized_dimension(self):
2374+
# Test cummax on dim=2 (a 0-sized dimension).
2375+
#
2376+
# Make sure we are not crashing, here. Instead, we should return a tuple of
2377+
# empty tensors, just like PyTorch.
2378+
2379+
dim = 2
2380+
a = torch.rand(5, 5, 0, 5)
2381+
2382+
expected = torch.cummax(a, dim)
2383+
actual = torch.cummax(a.to(xm.xla_device()), dim)
2384+
2385+
self.assertEqual(actual, expected)
2386+
23732387

23742388
class MNISTComparator(nn.Module):
23752389

torch_xla/csrc/tensor_methods.cpp

+17-3
Original file line numberDiff line numberDiff line change
@@ -1314,9 +1314,23 @@ XLATensorPtr cross(const XLATensorPtr& input, const XLATensorPtr& other,
13141314

13151315
std::tuple<XLATensorPtr, XLATensorPtr> cummax(const XLATensorPtr& input,
13161316
int64_t dim) {
1317-
torch::lazy::NodePtr node = torch_xla::MakeNode<CumMax>(
1318-
input->GetIrValue(), torch::lazy::GetCanonicalDimensionIndex(
1319-
dim, input->shape().get().rank()));
1317+
xla::Shape shape = input->shape().get();
1318+
int64_t canonical_dim =
1319+
torch::lazy::GetCanonicalDimensionIndex(dim, shape.rank());
1320+
1321+
if (shape.dimensions(canonical_dim) == 0) {
1322+
// Handle edge-case where the size of `dim` is 0.
1323+
// The current lowering crashes, setting the padding to -1.
1324+
absl::Span<const int64_t> dimensions = shape.dimensions();
1325+
at::IntArrayRef shape_(dimensions.data(), dimensions.size());
1326+
at::Tensor val =
1327+
at::empty(shape_, at::TensorOptions().dtype(input->dtype()));
1328+
at::Tensor idx = at::empty(shape_, at::TensorOptions().dtype(at::kLong));
1329+
return std::make_tuple(input->Create(val, input->GetDevice()),
1330+
input->Create(idx, input->GetDevice()));
1331+
}
1332+
torch::lazy::NodePtr node =
1333+
torch_xla::MakeNode<CumMax>(input->GetIrValue(), canonical_dim);
13201334
XLATensorPtr t_value = input->CreateFrom(torch::lazy::Value(node, 0),
13211335
/*delay_eager_executation=*/true);
13221336
XLATensorPtr t_index =

0 commit comments

Comments
 (0)