diff --git a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h index 1d850757a9ed..044020bf8ee0 100644 --- a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h +++ b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h @@ -119,6 +119,9 @@ FailureOr getZeroPointValue(PatternRewriter &rewriter, Operation *op, // Check if a shaped type has any dimension with size 0. bool typeHasZeroDim(ShapedType type); +// Check if a type is i1 or a shaped type with i1 element type. +bool isI1Type(Type type); + // Compute scale/offset/border parameters for TOSA resize on one dimension. void computeResizeParams(int inputSize, int outputSize, bool alignCorners, tosa::ResizeMode mode, int &scaleN, int &scaleD, diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index b34e0bf5ca61..8c11e5ffb8a8 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -748,6 +748,27 @@ class ConvertAtenCompareOp : public TorchToTosaOpConversionPattern { OpConversionPattern::getTypeConverter()->convertType( op.getType())); if (isBitwiseOp) { + // TOSA bitwise ops do not support i1. Use logical ops for bool tensors. + if (tosa::isI1Type(resultTy)) { + auto lhsBool = + tosa::tosaCastTensorToType(rewriter, lhs, resultTy).value(); + auto rhsBool = + tosa::tosaCastTensorToType(rewriter, rhsTensor, resultTy).value(); + if constexpr (std::is_same() || + std::is_same()) { + rewriter.replaceOpWithNewOp(op, resultTy, lhsBool, + rhsBool); + return success(); + } else if constexpr (std::is_same()) { + rewriter.replaceOpWithNewOp(op, resultTy, lhsBool, + rhsBool); + return success(); + } else if constexpr (std::is_same()) { + rewriter.replaceOpWithNewOp(op, resultTy, lhsBool, + rhsBool); + return success(); + } + } lhs = tosa::tosaCastTensorToType(rewriter, lhs, resultTy).value(); rhsTensor = tosa::tosaCastTensorToType(rewriter, rhsTensor, resultTy).value(); diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp index c05f6ee82209..7ab6d7b7e6a6 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp @@ -586,6 +586,14 @@ bool typeHasZeroDim(ShapedType type) { return llvm::any_of(outShape, [](int64_t dim) { return dim == 0; }); } +bool isI1Type(Type type) { + if (auto shapedTy = dyn_cast(type)) + type = shapedTy.getElementType(); + if (auto intTy = dyn_cast(type)) + return intTy.getWidth() == 1; + return false; +} + void computeResizeParams(int inputSize, int outputSize, bool alignCorners, tosa::ResizeMode mode, int &scaleN, int &scaleD, int &offset, int &border) { diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index 40b6b1b873af..bb302a3c5907 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -4899,6 +4899,33 @@ def ElementwiseBitwiseAndModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseBitwiseAndBoolModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.bool, True), + ([-1, -1], torch.bool, True), + ] + ) + def forward(self, x, y): + return torch.bitwise_and(x, y) + + +@register_test_case(module_factory=lambda: ElementwiseBitwiseAndBoolModule()) +def ElementwiseBitwiseAndBoolModule_basic(module, tu: TestUtils): + module.forward( + tu.randint(3, 4, low=0, high=2).to(torch.bool), + tu.randint(3, 4, low=0, high=2).to(torch.bool), + ) + + +# ============================================================================== + + class ElementwiseBitwiseAndStaticShapeModule(torch.nn.Module): def __init__(self): super().__init__() @@ -4953,6 +4980,33 @@ def ElementwiseBitwiseOrModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseBitwiseOrBoolModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.bool, True), + ([-1, -1], torch.bool, True), + ] + ) + def forward(self, x, y): + return torch.bitwise_or(x, y) + + +@register_test_case(module_factory=lambda: ElementwiseBitwiseOrBoolModule()) +def ElementwiseBitwiseOrBoolModule_basic(module, tu: TestUtils): + module.forward( + tu.randint(3, 4, low=0, high=2).to(torch.bool), + tu.randint(3, 4, low=0, high=2).to(torch.bool), + ) + + +# ============================================================================== + + class ElementwiseBitwiseOrStaticShapeModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index f95347563fae..1f0731f5bed3 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -895,6 +895,23 @@ func.func @torch.aten.bitwise_and.Tensor$basic(%arg0: !torch.vtensor<[?,?],si32> // ----- +// CHECK-LABEL: func.func @torch.aten.bitwise_and.Tensor$bool( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],i1>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> { +// CHECK-DAG: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],i1> -> tensor +// CHECK-DAG: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],i1> -> tensor +// CHECK-NOT: tosa.bitwise_and +// CHECK: %[[VAL_4:.*]] = tosa.logical_and %[[VAL_2]], %[[VAL_3]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],i1> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1> +// CHECK: } +func.func @torch.aten.bitwise_and.Tensor$bool(%arg0: !torch.vtensor<[?,?],i1>, %arg1: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> { + %0 = torch.aten.bitwise_and.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],i1>, !torch.vtensor<[?,?],i1> -> !torch.vtensor<[?,?],i1> + return %0 : !torch.vtensor<[?,?],i1> +} + +// ----- + // CHECK-LABEL: func.func @torch.aten.log2$basic( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor diff --git a/test/Conversion/TorchToTosa/torch-backend-to-tosa-backend-pipeline.mlir b/test/Conversion/TorchToTosa/torch-backend-to-tosa-backend-pipeline.mlir index d50a2c830d7c..ae1c887aee61 100644 --- a/test/Conversion/TorchToTosa/torch-backend-to-tosa-backend-pipeline.mlir +++ b/test/Conversion/TorchToTosa/torch-backend-to-tosa-backend-pipeline.mlir @@ -91,6 +91,17 @@ func.func @torch.aten.bitwise_and.Tensor$mixed_type(%arg0: !torch.vtensor<[?,?], // ----- +// CHECK-LABEL: torch.aten.bitwise_xor.Tensor$bool +// CHECK-SAME: %[[VAL_0:.*]]: tensor, +// CHECK-SAME: %[[VAL_1:.*]]: tensor +// CHECK: %[[VAL_2:.*]] = tosa.logical_xor %[[VAL_0]], %[[VAL_1]] : (tensor, tensor) -> tensor +func.func @torch.aten.bitwise_xor.Tensor$bool(%arg0: !torch.vtensor<[?,?],i1>, %arg1: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> { + %0 = torch.aten.bitwise_xor.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],i1>, !torch.vtensor<[?,?],i1> -> !torch.vtensor<[?,?],i1> + return %0 : !torch.vtensor<[?,?],i1> +} + +// ----- + // CHECK-LABEL: func.func @torch.aten.div.Tensor$mixed_type_fp( // CHECK-SAME: %[[VAL_0:.*]]: tensor, // CHECK-SAME: %[[VAL_1:.*]]: tensor) -> tensor {