Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ std::optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
std::optional<Value> tosaCastTensorToType(PatternRewriter &rewriter, Value src,
TensorType destType);

// Ensure TOSA argmax input is f32 by inserting a tosa.cast when needed.
Value legalizeArgMaxInputType(PatternRewriter &rewriter, Operation *op,
Value input);

// Creates a TOSA operation and performs shape inference on the individual
// op. This allows shape inference during the framework to TOSA lowering.
template <typename TosaOp, typename... Args>
Expand Down
11 changes: 8 additions & 3 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1504,6 +1504,7 @@ LogicalResult ConvertAtenOp<AtenArgmaxOp>::matchAndRewriteImpl(
// Create a single instance of tosa.argmax.
// Multiple dims require chained construct.
auto buildArgmax = [&](int64_t reduceDim, Value input) -> Value {
input = tosa::legalizeArgMaxInputType(rewriter, op.getOperation(), input);
auto inputTy = cast<RankedTensorType>(input.getType());
auto inputShape = makeShapeTorchCompatible(inputTy.getShape());
SmallVector<int64_t> outputShapeArr = {};
Expand All @@ -1523,7 +1524,7 @@ LogicalResult ConvertAtenOp<AtenArgmaxOp>::matchAndRewriteImpl(
makeShapeLLVMCompatible(ArrayRef<int64_t>(outputShapeArr)),
rewriter.getI32Type());
auto reduceDimAttr =
rewriter.getIntegerAttr(rewriter.getI64Type(), reduceDim);
rewriter.getIntegerAttr(rewriter.getI32Type(), reduceDim);

// Use default NaN Propagation mode "PROPAGATE" for tosa.argmax
return tosa::ArgMaxOp::create(
Expand Down Expand Up @@ -4696,22 +4697,26 @@ class ConvertAtenMinMaxDimOp : public TorchToTosaOpConversionPattern<AtenOpT> {
if constexpr (std::is_same<AtenOpT, AtenMinDimOp>()) {
Value negateOp =
tosa::NegateOp::create(rewriter, op->getLoc(), selfType, self);
Value argInput =
tosa::legalizeArgMaxInputType(rewriter, op.getOperation(), negateOp);

// Use default NaN Propagation mode "PROPAGATE" for tosa.argmax
argMaxOp = tosa::ArgMaxOp::create(
rewriter, op->getLoc(),
RankedTensorType::get(makeShapeLLVMCompatible(prunedShape),
indicesElemType),
negateOp, dimAttr, /*nan_mode=*/
argInput, dimAttr, /*nan_mode=*/
tosa::NanPropagationModeAttr::get(
rewriter.getContext(), tosa::NanPropagationMode::PROPAGATE));
} else {
Value argInput =
tosa::legalizeArgMaxInputType(rewriter, op.getOperation(), self);
// Use default NaN Propagation mode "PROPAGATE" for tosa.argmax
argMaxOp = tosa::ArgMaxOp::create(
rewriter, op->getLoc(),
RankedTensorType::get(makeShapeLLVMCompatible(prunedShape),
indicesElemType),
self, dimAttr, /*nan_mode=*/
argInput, dimAttr, /*nan_mode=*/
tosa::NanPropagationModeAttr::get(
rewriter.getContext(), tosa::NanPropagationMode::PROPAGATE));
}
Expand Down
14 changes: 14 additions & 0 deletions lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,20 @@ std::optional<Value> tosaCastTensorToType(PatternRewriter &rewriter, Value src,
return tosa::CastOp::create(rewriter, op->getLoc(), castedSrcType, src);
}

Value legalizeArgMaxInputType(PatternRewriter &rewriter, Operation *op,
Value input) {
auto inputTy = cast<RankedTensorType>(input.getType());
auto elemTy = inputTy.getElementType();
// Keep i8 as-is (supported by TOSA pro_int argmax). Cast other integer
// types to f32, including i1 (handled via i1->i8->f32).
if (!elemTy.isInteger() || elemTy.isInteger(8))
return input;
auto castTy =
RankedTensorType::get(inputTy.getShape(), rewriter.getF32Type());
auto casted = tosa::tosaCastTensorToType(rewriter, input, castTy);
return casted ? *casted : input;
}

// Template instantiation
template std::optional<Value>
getConstTensor<bool>(PatternRewriter &, Operation *, ArrayRef<bool> vec,
Expand Down
23 changes: 23 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -1642,6 +1642,29 @@ def ArgmaxModule_basic(module, tu: TestUtils):
# ==============================================================================


class ArgmaxInt32Module(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1], torch.int32, True),
]
)
def forward(self, a):
return torch.ops.aten.argmax(a)


@register_test_case(module_factory=lambda: ArgmaxInt32Module())
def ArgmaxInt32Module_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=-100, high=100, dtype=torch.int32))


# ==============================================================================


class ArgmaxKeepdimModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
45 changes: 45 additions & 0 deletions test/Conversion/TorchToTosa/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1076,6 +1076,51 @@ func.func @torch.aten.max.dim$basic(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x1xf3

// -----

// CHECK-LABEL: func.func @torch.aten.argmax$int32(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<3x2x3xi32>) -> tensor<3x2xi64> {
// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<3x2x3xi32> -> !torch.vtensor<[3,2,3],si32>
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[3,2,3],si32> -> tensor<3x2x3xi32>
// CHECK: %[[VAL_3:.*]] = torch.constant.bool false
// CHECK: %[[VAL_4:.*]] = torch.constant.int 2
// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_2]] : (tensor<3x2x3xi32>) -> tensor<3x2x3xf32>
// CHECK: %[[VAL_6:.*]] = tosa.argmax %[[VAL_5]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2xi32>
// CHECK: %[[VAL_7:.*]] = arith.extsi %[[VAL_6]] : tensor<3x2xi32> to tensor<3x2xi64>
// CHECK: return %{{.*}} : tensor<3x2xi64>
// CHECK: }
func.func @torch.aten.argmax$int32(%arg0: tensor<3x2x3xi32>) -> tensor<3x2xi64> {
%0 = torch_c.from_builtin_tensor %arg0 : tensor<3x2x3xi32> -> !torch.vtensor<[3,2,3],si32>
%false = torch.constant.bool false
%int2 = torch.constant.int 2
%1 = torch.aten.argmax %0, %int2, %false : !torch.vtensor<[3,2,3],si32>, !torch.int, !torch.bool -> !torch.vtensor<[3,2],si64>
%2 = torch_c.to_builtin_tensor %1 : !torch.vtensor<[3,2],si64> -> tensor<3x2xi64>
return %2 : tensor<3x2xi64>
}

// -----

// CHECK-LABEL: func.func @torch.aten.argmax$i1(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<3x2x3xi1>) -> tensor<3x2xi64> {
// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<3x2x3xi1> -> !torch.vtensor<[3,2,3],i1>
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[3,2,3],i1> -> tensor<3x2x3xi1>
// CHECK: %[[VAL_3:.*]] = torch.constant.bool false
// CHECK: %[[VAL_4:.*]] = torch.constant.int 2
// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_2]] : (tensor<3x2x3xi1>) -> tensor<3x2x3xi8>
// CHECK: %[[VAL_6:.*]] = tosa.cast %[[VAL_5]] : (tensor<3x2x3xi8>) -> tensor<3x2x3xf32>
// CHECK: %[[VAL_7:.*]] = tosa.argmax %[[VAL_6]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2xi32>
// CHECK: %[[VAL_8:.*]] = arith.extsi %[[VAL_7]] : tensor<3x2xi32> to tensor<3x2xi64>
// CHECK: return %{{.*}} : tensor<3x2xi64>
// CHECK: }
func.func @torch.aten.argmax$i1(%arg0: tensor<3x2x3xi1>) -> tensor<3x2xi64> {
%0 = torch_c.from_builtin_tensor %arg0 : tensor<3x2x3xi1> -> !torch.vtensor<[3,2,3],i1>
%false = torch.constant.bool false
%int2 = torch.constant.int 2
%1 = torch.aten.argmax %0, %int2, %false : !torch.vtensor<[3,2,3],i1>, !torch.int, !torch.bool -> !torch.vtensor<[3,2],si64>
%2 = torch_c.to_builtin_tensor %1 : !torch.vtensor<[3,2],si64> -> tensor<3x2xi64>
return %2 : tensor<3x2xi64>
}

// -----

// CHECK-LABEL: @torch.vtensor.literal_si64$basic(
// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{values = dense<-1> : tensor<1x512xi64>}> : () -> tensor<1x512xi64>
// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<1x512xi64> -> !torch.vtensor<[1,512],si64>
Expand Down
Loading