diff --git a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp index 95435dd5805b..314f90dfa807 100644 --- a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp +++ b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp @@ -61,12 +61,37 @@ Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op, Value input, Value indices, int64_t axis, size_t dimSizeIndexBits) { auto loc = op->getLoc(); + + auto indicesRankTy = dyn_cast(indices.getType()); + auto indicesShape = indicesRankTy.getShape(); + auto inputRankTy = dyn_cast(input.getType()); + + // Check if indices tensor is empty (has any dimension with size 0) + bool isEmpty = + llvm::any_of(indicesShape, [](int64_t dim) { return dim == 0; }); + + if (isEmpty) { + // Special case: StableHLO doesn't support gather operations on empty + // tensors. Return an empty tensor with the correct output shape. + auto inputShape = inputRankTy.getShape(); + SmallVector outputShape(inputShape.begin(), + inputShape.begin() + axis); + outputShape.insert(outputShape.end(), indicesShape.begin(), + indicesShape.end()); + outputShape.insert(outputShape.end(), inputShape.begin() + axis + 1, + inputShape.end()); + + auto outputTy = + RankedTensorType::get(outputShape, inputRankTy.getElementType()); + auto emptyAttr = cast(rewriter.getZeroAttr(outputTy)); + return stablehlo::ConstantOp::create(rewriter, loc, emptyAttr); + } + Type intType = rewriter.getIntegerType(dimSizeIndexBits); Value one = arith::ConstantOp::create(rewriter, loc, rewriter.getIntegerAttr(intType, 1)); // sliceSizes - auto inputRankTy = dyn_cast(input.getType()); auto inputRank = inputRankTy.getRank(); SmallVector sliceSizes; sliceSizes.reserve(inputRank); @@ -88,7 +113,6 @@ Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op, for (int64_t r = 0; r < axis; ++r) { offsetDims.push_back(r); } - auto indicesRankTy = dyn_cast(indices.getType()); auto indicesRank = indicesRankTy.getRank(); for (int64_t r = axis + 1; r < inputRank; ++r) { offsetDims.push_back(r + indicesRank - 1); @@ -112,7 +136,6 @@ Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op, // outputShape = input.shape[:axis] + indices.shape + // input.shape[axis + 1:] auto inputShape = inputRankTy.getShape(); - auto indicesShape = indicesRankTy.getShape(); SmallVector outputShape(inputShape.begin(), inputShape.begin() + axis); outputShape.insert(outputShape.end(), indicesShape.begin(), diff --git a/test/Conversion/TorchToStablehlo/gather.mlir b/test/Conversion/TorchToStablehlo/gather.mlir index 14581bcc658c..d3d1539f4e08 100644 --- a/test/Conversion/TorchToStablehlo/gather.mlir +++ b/test/Conversion/TorchToStablehlo/gather.mlir @@ -63,3 +63,51 @@ func.func @torch.aten.embedding$rank_two_indices(%weight: !torch.vtensor<[?,?],f %ret = torch.aten.embedding %weight, %indices, %int-1, %false, %false : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,1], si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[?,1,?],f32> return %ret: !torch.vtensor<[?,1,?],f32> } + + +// ----- + +// CHECK-LABEL: func.func @torch.aten.index_select$empty_indices( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[4,8],f32>, %[[ARG1:.*]]: !torch.vtensor<[0],si64>) -> !torch.vtensor<[0,8],f32> { +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[CST:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<0x8xf32> +// CHECK: %[[CONVERT:.*]] = stablehlo.convert %[[CST]] : tensor<0x8xf32> +// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[CONVERT]] : tensor<0x8xf32> -> !torch.vtensor<[0,8],f32> +// CHECK: return %[[RESULT]] : !torch.vtensor<[0,8],f32> +func.func @torch.aten.index_select$empty_indices(%arg0: !torch.vtensor<[4,8],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[0,8],f32> { + %int0 = torch.constant.int 0 + %0 = torch.aten.index_select %arg0, %int0, %arg1 : !torch.vtensor<[4,8],f32>, !torch.int, !torch.vtensor<[0],si64> -> !torch.vtensor<[0,8],f32> + return %0 : !torch.vtensor<[0,8],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.embedding$empty_indices( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[10,8],f32>, %[[ARG1:.*]]: !torch.vtensor<[0],si64>) -> !torch.vtensor<[0,8],f32> { +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[INT_NEG1:.*]] = torch.constant.int -1 +// CHECK: %[[CST:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<0x8xf32> +// CHECK: %[[CONVERT:.*]] = stablehlo.convert %[[CST]] : tensor<0x8xf32> +// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[CONVERT]] : tensor<0x8xf32> -> !torch.vtensor<[0,8],f32> +// CHECK: return %[[RESULT]] : !torch.vtensor<[0,8],f32> +func.func @torch.aten.embedding$empty_indices(%weight: !torch.vtensor<[10,8],f32>, %indices: !torch.vtensor<[0], si64>) -> !torch.vtensor<[0,8],f32> { + %false = torch.constant.bool false + %int-1 = torch.constant.int -1 + %ret = torch.aten.embedding %weight, %indices, %int-1, %false, %false : !torch.vtensor<[10,8],f32>, !torch.vtensor<[0], si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[0,8],f32> + return %ret: !torch.vtensor<[0,8],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.index_select$empty_indices_dim1( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[4,8],f32>, %[[ARG1:.*]]: !torch.vtensor<[0],si64>) -> !torch.vtensor<[4,0],f32> { +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[CST:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<4x0xf32> +// CHECK: %[[CONVERT:.*]] = stablehlo.convert %[[CST]] : tensor<4x0xf32> +// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[CONVERT]] : tensor<4x0xf32> -> !torch.vtensor<[4,0],f32> +// CHECK: return %[[RESULT]] : !torch.vtensor<[4,0],f32> +func.func @torch.aten.index_select$empty_indices_dim1(%arg0: !torch.vtensor<[4,8],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[4,0],f32> { + %int1 = torch.constant.int 1 + %0 = torch.aten.index_select %arg0, %int1, %arg1 : !torch.vtensor<[4,8],f32>, !torch.int, !torch.vtensor<[0],si64> -> !torch.vtensor<[4,0],f32> + return %0 : !torch.vtensor<[4,0],f32> +}