diff --git a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h index 46803235eeac..4b16efec1ff5 100644 --- a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h +++ b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h @@ -64,6 +64,10 @@ std::optional getConstTensor(PatternRewriter &rewriter, Operation *op, std::optional tosaCastTensorToType(PatternRewriter &rewriter, Value src, TensorType destType); +// Ensure TOSA argmax input is f32 by inserting a tosa.cast when needed. +Value legalizeArgMaxInputType(PatternRewriter &rewriter, Operation *op, + Value input); + // Creates a TOSA operation and performs shape inference on the individual // op. This allows shape inference during the framework to TOSA lowering. template diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index e051e559ae38..f3647f0b156d 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -1504,6 +1504,7 @@ LogicalResult ConvertAtenOp::matchAndRewriteImpl( // Create a single instance of tosa.argmax. // Multiple dims require chained construct. auto buildArgmax = [&](int64_t reduceDim, Value input) -> Value { + input = tosa::legalizeArgMaxInputType(rewriter, op.getOperation(), input); auto inputTy = cast(input.getType()); auto inputShape = makeShapeTorchCompatible(inputTy.getShape()); SmallVector outputShapeArr = {}; @@ -1523,7 +1524,7 @@ LogicalResult ConvertAtenOp::matchAndRewriteImpl( makeShapeLLVMCompatible(ArrayRef(outputShapeArr)), rewriter.getI32Type()); auto reduceDimAttr = - rewriter.getIntegerAttr(rewriter.getI64Type(), reduceDim); + rewriter.getIntegerAttr(rewriter.getI32Type(), reduceDim); // Use default NaN Propagation mode "PROPAGATE" for tosa.argmax return tosa::ArgMaxOp::create( @@ -4696,22 +4697,26 @@ class ConvertAtenMinMaxDimOp : public TorchToTosaOpConversionPattern { if constexpr (std::is_same()) { Value negateOp = tosa::NegateOp::create(rewriter, op->getLoc(), selfType, self); + Value argInput = + tosa::legalizeArgMaxInputType(rewriter, op.getOperation(), negateOp); // Use default NaN Propagation mode "PROPAGATE" for tosa.argmax argMaxOp = tosa::ArgMaxOp::create( rewriter, op->getLoc(), RankedTensorType::get(makeShapeLLVMCompatible(prunedShape), indicesElemType), - negateOp, dimAttr, /*nan_mode=*/ + argInput, dimAttr, /*nan_mode=*/ tosa::NanPropagationModeAttr::get( rewriter.getContext(), tosa::NanPropagationMode::PROPAGATE)); } else { + Value argInput = + tosa::legalizeArgMaxInputType(rewriter, op.getOperation(), self); // Use default NaN Propagation mode "PROPAGATE" for tosa.argmax argMaxOp = tosa::ArgMaxOp::create( rewriter, op->getLoc(), RankedTensorType::get(makeShapeLLVMCompatible(prunedShape), indicesElemType), - self, dimAttr, /*nan_mode=*/ + argInput, dimAttr, /*nan_mode=*/ tosa::NanPropagationModeAttr::get( rewriter.getContext(), tosa::NanPropagationMode::PROPAGATE)); } diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp index e358ecf66513..46d6bb54472f 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp @@ -381,6 +381,20 @@ std::optional tosaCastTensorToType(PatternRewriter &rewriter, Value src, return tosa::CastOp::create(rewriter, op->getLoc(), castedSrcType, src); } +Value legalizeArgMaxInputType(PatternRewriter &rewriter, Operation *op, + Value input) { + auto inputTy = cast(input.getType()); + auto elemTy = inputTy.getElementType(); + // Keep i8 as-is (supported by TOSA pro_int argmax). Cast other integer + // types to f32, including i1 (handled via i1->i8->f32). + if (!elemTy.isInteger() || elemTy.isInteger(8)) + return input; + auto castTy = + RankedTensorType::get(inputTy.getShape(), rewriter.getF32Type()); + auto casted = tosa::tosaCastTensorToType(rewriter, input, castTy); + return casted ? *casted : input; +} + // Template instantiation template std::optional getConstTensor(PatternRewriter &, Operation *, ArrayRef vec, diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py index 2e4ba9c4ccfc..d500216ac401 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py @@ -1642,6 +1642,29 @@ def ArgmaxModule_basic(module, tu: TestUtils): # ============================================================================== +class ArgmaxInt32Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ] + ) + def forward(self, a): + return torch.ops.aten.argmax(a) + + +@register_test_case(module_factory=lambda: ArgmaxInt32Module()) +def ArgmaxInt32Module_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=-100, high=100, dtype=torch.int32)) + + +# ============================================================================== + + class ArgmaxKeepdimModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 0f113d77ec80..3ca014494af3 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -1076,6 +1076,51 @@ func.func @torch.aten.max.dim$basic(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x1xf3 // ----- +// CHECK-LABEL: func.func @torch.aten.argmax$int32( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<3x2x3xi32>) -> tensor<3x2xi64> { +// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<3x2x3xi32> -> !torch.vtensor<[3,2,3],si32> +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[3,2,3],si32> -> tensor<3x2x3xi32> +// CHECK: %[[VAL_3:.*]] = torch.constant.bool false +// CHECK: %[[VAL_4:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_2]] : (tensor<3x2x3xi32>) -> tensor<3x2x3xf32> +// CHECK: %[[VAL_6:.*]] = tosa.argmax %[[VAL_5]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2xi32> +// CHECK: %[[VAL_7:.*]] = arith.extsi %[[VAL_6]] : tensor<3x2xi32> to tensor<3x2xi64> +// CHECK: return %{{.*}} : tensor<3x2xi64> +// CHECK: } +func.func @torch.aten.argmax$int32(%arg0: tensor<3x2x3xi32>) -> tensor<3x2xi64> { + %0 = torch_c.from_builtin_tensor %arg0 : tensor<3x2x3xi32> -> !torch.vtensor<[3,2,3],si32> + %false = torch.constant.bool false + %int2 = torch.constant.int 2 + %1 = torch.aten.argmax %0, %int2, %false : !torch.vtensor<[3,2,3],si32>, !torch.int, !torch.bool -> !torch.vtensor<[3,2],si64> + %2 = torch_c.to_builtin_tensor %1 : !torch.vtensor<[3,2],si64> -> tensor<3x2xi64> + return %2 : tensor<3x2xi64> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.argmax$i1( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<3x2x3xi1>) -> tensor<3x2xi64> { +// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<3x2x3xi1> -> !torch.vtensor<[3,2,3],i1> +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[3,2,3],i1> -> tensor<3x2x3xi1> +// CHECK: %[[VAL_3:.*]] = torch.constant.bool false +// CHECK: %[[VAL_4:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_2]] : (tensor<3x2x3xi1>) -> tensor<3x2x3xi8> +// CHECK: %[[VAL_6:.*]] = tosa.cast %[[VAL_5]] : (tensor<3x2x3xi8>) -> tensor<3x2x3xf32> +// CHECK: %[[VAL_7:.*]] = tosa.argmax %[[VAL_6]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2xi32> +// CHECK: %[[VAL_8:.*]] = arith.extsi %[[VAL_7]] : tensor<3x2xi32> to tensor<3x2xi64> +// CHECK: return %{{.*}} : tensor<3x2xi64> +// CHECK: } +func.func @torch.aten.argmax$i1(%arg0: tensor<3x2x3xi1>) -> tensor<3x2xi64> { + %0 = torch_c.from_builtin_tensor %arg0 : tensor<3x2x3xi1> -> !torch.vtensor<[3,2,3],i1> + %false = torch.constant.bool false + %int2 = torch.constant.int 2 + %1 = torch.aten.argmax %0, %int2, %false : !torch.vtensor<[3,2,3],i1>, !torch.int, !torch.bool -> !torch.vtensor<[3,2],si64> + %2 = torch_c.to_builtin_tensor %1 : !torch.vtensor<[3,2],si64> -> tensor<3x2xi64> + return %2 : tensor<3x2xi64> +} + +// ----- + // CHECK-LABEL: @torch.vtensor.literal_si64$basic( // CHECK: %[[VAL_0:.*]] = "tosa.const"() <{values = dense<-1> : tensor<1x512xi64>}> : () -> tensor<1x512xi64> // CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<1x512xi64> -> !torch.vtensor<[1,512],si64>