diff --git a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h index 1d850757a9ed..543859873931 100644 --- a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h +++ b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h @@ -64,6 +64,11 @@ std::optional getConstTensor(PatternRewriter &rewriter, Operation *op, std::optional tosaCastTensorToType(PatternRewriter &rewriter, Value src, TensorType destType); +// Create a tosa.gather op. Casts i1 inputs to i8 internally if needed. +std::optional createGatherOp(PatternRewriter &rewriter, Location loc, + RankedTensorType resultType, Value input, + Value indices); + // Creates a TOSA operation and performs shape inference on the individual // op. This allows shape inference during the framework to TOSA lowering. template @@ -119,6 +124,9 @@ FailureOr getZeroPointValue(PatternRewriter &rewriter, Operation *op, // Check if a shaped type has any dimension with size 0. bool typeHasZeroDim(ShapedType type); +// Check if a type is i1 or a shaped type with i1 element type. +bool isI1Type(Type type); + // Compute scale/offset/border parameters for TOSA resize on one dimension. void computeResizeParams(int inputSize, int outputSize, bool alignCorners, tosa::ResizeMode mode, int &scaleN, int &scaleD, diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index b34e0bf5ca61..367762088193 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -4545,14 +4545,17 @@ LogicalResult ConvertAtenOp::matchAndRewriteImpl( .value(); SmallVector intermediateOutShape = {1, numIndices, weightShape[1]}; - auto gatherOp = tosa::GatherOp::create( - rewriter, op->getLoc(), - RankedTensorType::get(makeShapeLLVMCompatible(intermediateOutShape), - weightType.getElementType()), - reshapedWeight, castIndices); + auto gatherElemTy = weightType.getElementType(); + auto gatherTy = RankedTensorType::get( + makeShapeLLVMCompatible(intermediateOutShape), gatherElemTy); + auto gatherResult = tosa::createGatherOp(rewriter, op->getLoc(), gatherTy, + reshapedWeight, castIndices); + if (!gatherResult) + return rewriter.notifyMatchFailure( + op, "expected ranked tensor input for gather"); rewriter.replaceOpWithNewOp( - op, outType, gatherOp, + op, outType, *gatherResult, tosa::getTosaConstShape(rewriter, op->getLoc(), makeShapeTorchCompatible(outType.getShape()))); @@ -4868,9 +4871,11 @@ LogicalResult ConvertAtenOp::matchAndRewriteImpl( // Duplicate the 1-D index vector across the batch dimension so that we can // use a single tosa.gather to materialize the strided slice. auto gatherTy = RankedTensorType::get({N, W, C}, elemTy); - Value gathered = - tosa::GatherOp::create(rewriter, loc, gatherTy, reshaped, idxNW) - .getResult(); + auto gathered = + tosa::createGatherOp(rewriter, loc, gatherTy, reshaped, idxNW); + if (!gathered) + return rewriter.notifyMatchFailure( + op, "expected ranked tensor input for gather"); SmallVector outShape = inputShape; outShape[dim] = W; @@ -4879,7 +4884,7 @@ LogicalResult ConvertAtenOp::matchAndRewriteImpl( // Restore the original rank with the newly strided dimension size. Value result = - tosa::ReshapeOp::create(rewriter, loc, convertedResultTy, gathered, + tosa::ReshapeOp::create(rewriter, loc, convertedResultTy, *gathered, tosa::getTosaConstShape(rewriter, loc, outShape)) .getResult(); diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp index a33a990758aa..5f0eae8146bb 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp @@ -422,18 +422,16 @@ std::optional convertGatherNdOp(PatternRewriter &rewriter, Operation *op, // Now the gather op itself // %9 = "tosa.gather"(%2, %7) : (tensor<1x12x1xf32>, tensor<1x8xi32>) -> // tensor<1x8x1xf32> - auto tosaGatherOp = tosa::CreateOpAndInfer( - rewriter, op->getLoc(), - GetTypeFromTensorShape(tosaGatherResultShape, - resultType.getElementType()), - tosaValuesReshapeOp.getResult(), tosaIndicesReshapeOp.getResult()); + auto gatherTy = GetTypeFromTensorShape(tosaGatherResultShape, + resultType.getElementType()); + auto gatherResult = tosa::createGatherOp(rewriter, op->getLoc(), gatherTy, + tosaValuesReshapeOp.getResult(), + tosaIndicesReshapeOp.getResult()); + if (!gatherResult) + return std::nullopt; - // Finally, reshape back to the original output shape of [Indices, - // ParamChannels]. %10 = "tosa.reshape"(%9) {new_shape = [1, 4, 2]} : - // (tensor<1x8x1xf32>) -> tensor<1x4x2xf32> %11 = torch_c.from_builtin_tensor - // %10 : tensor<1x4x2xf32> -> !torch.vtensor<[1,4,2],f32> return tosa::CreateOpAndInfer( - rewriter, op->getLoc(), resultType, tosaGatherOp.getResult(), + rewriter, op->getLoc(), resultType, *gatherResult, tosa::getTosaConstShape(rewriter, op->getLoc(), resultType.getShape())) .getResult(); diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp index c05f6ee82209..68f10ecaad27 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp @@ -382,6 +382,28 @@ std::optional tosaCastTensorToType(PatternRewriter &rewriter, Value src, return tosa::CastOp::create(rewriter, op->getLoc(), castedSrcType, src); } +// Create a tosa.gather op. Casts i1 inputs to i8 internally if needed. +std::optional createGatherOp(PatternRewriter &rewriter, Location loc, + RankedTensorType resultType, Value input, + Value indices) { + if (tosa::isI1Type(resultType)) { + auto i8Ty = rewriter.getI8Type(); + auto inputTy = dyn_cast(input.getType()); + if (!inputTy) + return std::nullopt; + auto inputI8Ty = inputTy.clone(i8Ty); + auto inputI8 = + tosa::tosaCastTensorToType(rewriter, input, inputI8Ty).value(); + auto gatherI8Ty = resultType.clone(i8Ty); + auto gatheredI8 = + tosa::GatherOp::create(rewriter, loc, gatherI8Ty, inputI8, indices); + return tosa::tosaCastTensorToType(rewriter, gatheredI8, resultType).value(); + } + + return tosa::GatherOp::create(rewriter, loc, resultType, input, indices) + .getResult(); +} + // Template instantiation template std::optional getConstTensor(PatternRewriter &, Operation *, ArrayRef vec, @@ -586,6 +608,14 @@ bool typeHasZeroDim(ShapedType type) { return llvm::any_of(outShape, [](int64_t dim) { return dim == 0; }); } +bool isI1Type(Type type) { + if (auto shapedTy = dyn_cast(type)) + type = shapedTy.getElementType(); + if (auto intTy = dyn_cast(type)) + return intTy.getWidth() == 1; + return false; +} + void computeResizeParams(int inputSize, int outputSize, bool alignCorners, tosa::ResizeMode mode, int &scaleN, int &scaleD, int &offset, int &border) { diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index b8fa177bb93d..d16401086b89 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -1792,6 +1792,33 @@ def GatherModule_basic(module, tu: TestUtils): # ============================================================================== +class GatherBoolModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.bool, True), + ([-1, -1, -1], torch.int64, True), + ] + ) + def forward(self, tensor, indices): + return torch.gather(tensor, 2, indices) + + +@register_test_case(module_factory=lambda: GatherBoolModule()) +def GatherBoolModule_basic(module, tu: TestUtils): + module.forward( + tu.randint(2, 3, 4, high=2).to(torch.bool), + torch.tensor([[[1, 2, 3], [1, 2, 3]]]), + ) + + +# ============================================================================== + + class GatherNegativeDimModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index f95347563fae..313f02b7fda3 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -1349,6 +1349,47 @@ func.func @torch.aten.gather(%arg0: !torch.vtensor<[1,4,3],f32>, %arg1: !torch.v return %0 : !torch.vtensor<[1,4,2],f32> } +// ----- +// CHECK-LABEL: func.func @torch.aten.gather$bool( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,4,3],i1>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1,4,2],si64>) -> !torch.vtensor<[1,4,2],i1> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1,4,2],si64> -> tensor<1x4x2xi64> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,4,3],i1> -> tensor<1x4x3xi1> +// CHECK: %[[VAL_4:.*]] = torch.constant.int -1 +// CHECK: %[[VAL_5:.*]] = torch.constant.bool false +// CHECK: %[[VAL_6:.*]] = tosa.cast %[[VAL_2]] : (tensor<1x4x2xi64>) -> tensor<1x4x2xi32> +// CHECK: %[[VAL_7:.*]] = tosa.const_shape {values = dense<[1, 4, 2, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_6]], %[[VAL_7]] : (tensor<1x4x2xi32>, !tosa.shape<4>) -> tensor<1x4x2x1xi32> +// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{values = dense<0> : tensor<1x4x2x1xi32>}> : () -> tensor<1x4x2x1xi32> +// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{values = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]]]]> : tensor<1x4x2x1xi32>}> : () -> tensor<1x4x2x1xi32> +// CHECK: %[[VAL_11:.*]] = tosa.concat %[[VAL_9]], %[[VAL_10]], %[[VAL_8]] {axis = 3 : i32} : (tensor<1x4x2x1xi32>, tensor<1x4x2x1xi32>, tensor<1x4x2x1xi32>) -> tensor<1x4x2x3xi32> +// CHECK: %[[VAL_12:.*]] = tosa.const_shape {values = dense<[1, 12, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_3]], %[[VAL_12]] : (tensor<1x4x3xi1>, !tosa.shape<3>) -> tensor<1x12x1xi1> +// CHECK: %[[VAL_14:.*]] = tosa.const_shape {values = dense<[8, 3]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_11]], %[[VAL_14]] : (tensor<1x4x2x3xi32>, !tosa.shape<2>) -> tensor<8x3xi32> +// CHECK: %[[VAL_16:.*]] = "tosa.const"() <{values = dense<[12, 3, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK: %[[VAL_17:.*]] = tosa.const_shape {values = dense<[1, 3]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_18:.*]] = tosa.reshape %[[VAL_16]], %[[VAL_17]] : (tensor<3xi32>, !tosa.shape<2>) -> tensor<1x3xi32> +// CHECK: %[[VAL_19:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_20:.*]] = tosa.mul %[[VAL_15]], %[[VAL_18]], %[[VAL_19]] : (tensor<8x3xi32>, tensor<1x3xi32>, tensor<1xi8>) -> tensor<8x3xi32> +// CHECK: %[[VAL_21:.*]] = tosa.reduce_sum %[[VAL_20]] {axis = 1 : i32} : (tensor<8x3xi32>) -> tensor<8x1xi32> +// CHECK: %[[VAL_22:.*]] = tosa.const_shape {values = dense<[1, 8]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_23:.*]] = tosa.reshape %[[VAL_21]], %[[VAL_22]] : (tensor<8x1xi32>, !tosa.shape<2>) -> tensor<1x8xi32> +// CHECK: %[[VAL_24:.*]] = tosa.cast %[[VAL_13]] : (tensor<1x12x1xi1>) -> tensor<1x12x1xi8> +// CHECK: %[[VAL_25:.*]] = tosa.gather %[[VAL_24]], %[[VAL_23]] : (tensor<1x12x1xi8>, tensor<1x8xi32>) -> tensor<1x8x1xi8> +// CHECK: %[[VAL_26:.*]] = tosa.cast %[[VAL_25]] : (tensor<1x8x1xi8>) -> tensor<1x8x1xi1> +// CHECK: %[[VAL_27:.*]] = tosa.const_shape {values = dense<[1, 4, 2]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_28:.*]] = tosa.reshape %[[VAL_26]], %[[VAL_27]] : (tensor<1x8x1xi1>, !tosa.shape<3>) -> tensor<1x4x2xi1> +// CHECK: %[[VAL_29:.*]] = torch_c.from_builtin_tensor %[[VAL_28]] : tensor<1x4x2xi1> -> !torch.vtensor<[1,4,2],i1> +// CHECK: return %[[VAL_29]] : !torch.vtensor<[1,4,2],i1> +// CHECK: } +func.func @torch.aten.gather$bool(%arg0: !torch.vtensor<[1,4,3],i1>, %arg1: !torch.vtensor<[1,4,2],si64>) -> !torch.vtensor<[1,4,2],i1> { + %int-1 = torch.constant.int -1 + %false = torch.constant.bool false + %0 = torch.aten.gather %arg0, %int-1, %arg1, %false : !torch.vtensor<[1,4,3],i1>, !torch.int, !torch.vtensor<[1,4,2],si64>, !torch.bool -> !torch.vtensor<[1,4,2],i1> + return %0 : !torch.vtensor<[1,4,2],i1> +} + // ----- // CHECK-LABEL: func.func @torch.aten.add$int( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,2],si32>, @@ -1422,6 +1463,27 @@ func.func @torch.aten.slice.negative_start(%arg0: !torch.vtensor<[4,65,256],f32> return %0 : !torch.vtensor<[4,16,256],f32> } +// ----- +// CHECK-LABEL: func.func @torch.aten.slice.bool_strided( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,64,1],i1>) -> !torch.vtensor<[1,32,1],i1> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,64,1],i1> -> tensor<1x64x1xi1> +// CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_1]] : (tensor<1x64x1xi1>) -> tensor<1x64x1xi8> +// CHECK: %[[VAL_3:.*]] = tosa.gather %[[VAL_2]], %{{.*}} : (tensor<1x64x1xi8>, tensor<1x32xi32>) -> tensor<1x32x1xi8> +// CHECK: %[[VAL_4:.*]] = tosa.cast %[[VAL_3]] : (tensor<1x32x1xi8>) -> tensor<1x32x1xi1> +// CHECK: %[[VAL_5:.*]] = tosa.const_shape {values = dense<[1, 32, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_4]], %[[VAL_5]] : (tensor<1x32x1xi1>, !tosa.shape<3>) -> tensor<1x32x1xi1> +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<1x32x1xi1> -> !torch.vtensor<[1,32,1],i1> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[1,32,1],i1> +// CHECK: } +func.func @torch.aten.slice.bool_strided(%arg0: !torch.vtensor<[1,64,1],i1>) -> !torch.vtensor<[1,32,1],i1> { + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %int64 = torch.constant.int 64 + %int2 = torch.constant.int 2 + %0 = torch.aten.slice.Tensor %arg0, %int1, %int0, %int64, %int2 : !torch.vtensor<[1,64,1],i1>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,32,1],i1> + return %0 : !torch.vtensor<[1,32,1],i1> +} + // ----- // CHECK-LABEL: func.func @torch.aten.clamp.min_none( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,128,128],si64>) -> !torch.vtensor<[1,1,128,128],si64> {