@@ -1314,9 +1314,23 @@ XLATensorPtr cross(const XLATensorPtr& input, const XLATensorPtr& other,
1314
1314
1315
1315
std::tuple<XLATensorPtr, XLATensorPtr> cummax (const XLATensorPtr& input,
1316
1316
int64_t dim) {
1317
- torch::lazy::NodePtr node = torch_xla::MakeNode<CumMax>(
1318
- input->GetIrValue (), torch::lazy::GetCanonicalDimensionIndex (
1319
- dim, input->shape ().get ().rank ()));
1317
+ xla::Shape shape = input->shape ().get ();
1318
+ int64_t canonical_dim =
1319
+ torch::lazy::GetCanonicalDimensionIndex (dim, shape.rank ());
1320
+
1321
+ if (shape.dimensions (canonical_dim) == 0 ) {
1322
+ // Handle edge-case where the size of `dim` is 0.
1323
+ // The current lowering crashes, setting the padding to -1.
1324
+ absl::Span<const int64_t > dimensions = shape.dimensions ();
1325
+ at::IntArrayRef shape_ (dimensions.data (), dimensions.size ());
1326
+ at::Tensor val =
1327
+ at::empty (shape_, at::TensorOptions ().dtype (input->dtype ()));
1328
+ at::Tensor idx = at::empty (shape_, at::TensorOptions ().dtype (at::kLong ));
1329
+ return std::make_tuple (input->Create (val, input->GetDevice ()),
1330
+ input->Create (idx, input->GetDevice ()));
1331
+ }
1332
+ torch::lazy::NodePtr node =
1333
+ torch_xla::MakeNode<CumMax>(input->GetIrValue (), canonical_dim);
1320
1334
XLATensorPtr t_value = input->CreateFrom (torch::lazy::Value (node, 0 ),
1321
1335
/* delay_eager_executation=*/ true );
1322
1336
XLATensorPtr t_index =
0 commit comments