Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ std::optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
std::optional<Value> tosaCastTensorToType(PatternRewriter &rewriter, Value src,
TensorType destType);

// Ensure TOSA argmax input is f32 by inserting a tosa.cast when needed.
Value ensureF32Input(PatternRewriter &rewriter, Operation *op, Value input);

// Creates a TOSA operation and performs shape inference on the individual
// op. This allows shape inference during the framework to TOSA lowering.
template <typename TosaOp, typename... Args>
Expand Down
10 changes: 7 additions & 3 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1504,6 +1504,7 @@ LogicalResult ConvertAtenOp<AtenArgmaxOp>::matchAndRewriteImpl(
// Create a single instance of tosa.argmax.
// Multiple dims require chained construct.
auto buildArgmax = [&](int64_t reduceDim, Value input) -> Value {
input = tosa::ensureF32Input(rewriter, op.getOperation(), input);
auto inputTy = cast<RankedTensorType>(input.getType());
auto inputShape = makeShapeTorchCompatible(inputTy.getShape());
SmallVector<int64_t> outputShapeArr = {};
Expand All @@ -1523,7 +1524,7 @@ LogicalResult ConvertAtenOp<AtenArgmaxOp>::matchAndRewriteImpl(
makeShapeLLVMCompatible(ArrayRef<int64_t>(outputShapeArr)),
rewriter.getI32Type());
auto reduceDimAttr =
rewriter.getIntegerAttr(rewriter.getI64Type(), reduceDim);
rewriter.getIntegerAttr(rewriter.getI32Type(), reduceDim);

// Use default NaN Propagation mode "PROPAGATE" for tosa.argmax
return tosa::ArgMaxOp::create(
Expand Down Expand Up @@ -4696,22 +4697,25 @@ class ConvertAtenMinMaxDimOp : public TorchToTosaOpConversionPattern<AtenOpT> {
if constexpr (std::is_same<AtenOpT, AtenMinDimOp>()) {
Value negateOp =
tosa::NegateOp::create(rewriter, op->getLoc(), selfType, self);
Value argInput =
tosa::ensureF32Input(rewriter, op.getOperation(), negateOp);

// Use default NaN Propagation mode "PROPAGATE" for tosa.argmax
argMaxOp = tosa::ArgMaxOp::create(
rewriter, op->getLoc(),
RankedTensorType::get(makeShapeLLVMCompatible(prunedShape),
indicesElemType),
negateOp, dimAttr, /*nan_mode=*/
argInput, dimAttr, /*nan_mode=*/
tosa::NanPropagationModeAttr::get(
rewriter.getContext(), tosa::NanPropagationMode::PROPAGATE));
} else {
Value argInput = tosa::ensureF32Input(rewriter, op.getOperation(), self);
// Use default NaN Propagation mode "PROPAGATE" for tosa.argmax
argMaxOp = tosa::ArgMaxOp::create(
rewriter, op->getLoc(),
RankedTensorType::get(makeShapeLLVMCompatible(prunedShape),
indicesElemType),
self, dimAttr, /*nan_mode=*/
argInput, dimAttr, /*nan_mode=*/
tosa::NanPropagationModeAttr::get(
rewriter.getContext(), tosa::NanPropagationMode::PROPAGATE));
}
Expand Down
10 changes: 10 additions & 0 deletions lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,16 @@ std::optional<Value> tosaCastTensorToType(PatternRewriter &rewriter, Value src,
return tosa::CastOp::create(rewriter, op->getLoc(), castedSrcType, src);
}

Value ensureF32Input(PatternRewriter &rewriter, Operation *op, Value input) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we rename ensureF32Input?

The current name implies "always convert to f32", but the function actually leaves floats and i8 unchanged and only casts some integer types. Something like legalizeArgMaxInputType or normalizeArgMaxInputForTosa may be clearer.

auto inputTy = cast<RankedTensorType>(input.getType());
auto elemTy = inputTy.getElementType();
if (!(elemTy.isInteger(32) || elemTy.isInteger(64)))
return input;
auto castTy =
RankedTensorType::get(inputTy.getShape(), rewriter.getF32Type());
return tosa::CastOp::create(rewriter, op->getLoc(), castTy, input);
}

// Template instantiation
template std::optional<Value>
getConstTensor<bool>(PatternRewriter &, Operation *, ArrayRef<bool> vec,
Expand Down
23 changes: 23 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -1642,6 +1642,29 @@ def ArgmaxModule_basic(module, tu: TestUtils):
# ==============================================================================


class ArgmaxInt32Module(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1], torch.int32, True),
]
)
def forward(self, a):
return torch.ops.aten.argmax(a)


@register_test_case(module_factory=lambda: ArgmaxInt32Module())
def ArgmaxInt32Module_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=-100, high=100, dtype=torch.int32))


# ==============================================================================


class ArgmaxKeepdimModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
22 changes: 22 additions & 0 deletions test/Conversion/TorchToTosa/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1076,6 +1076,28 @@ func.func @torch.aten.max.dim$basic(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x1xf3

// -----

// CHECK-LABEL: func.func @torch.aten.argmax$int32(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<3x2x3xi32>) -> tensor<3x2xi64> {
// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<3x2x3xi32> -> !torch.vtensor<[3,2,3],si32>
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[3,2,3],si32> -> tensor<3x2x3xi32>
// CHECK: %[[VAL_3:.*]] = torch.constant.bool false
// CHECK: %[[VAL_4:.*]] = torch.constant.int 2
// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_2]] : (tensor<3x2x3xi32>) -> tensor<3x2x3xf32>
// CHECK: %[[VAL_6:.*]] = tosa.argmax %[[VAL_5]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2xi32>
// CHECK: %[[VAL_7:.*]] = arith.extsi %[[VAL_6]] : tensor<3x2xi32> to tensor<3x2xi64>
// CHECK: return %{{.*}} : tensor<3x2xi64>
// CHECK: }
func.func @torch.aten.argmax$int32(%arg0: tensor<3x2x3xi32>) -> tensor<3x2xi64> {
%0 = torch_c.from_builtin_tensor %arg0 : tensor<3x2x3xi32> -> !torch.vtensor<[3,2,3],si32>
%false = torch.constant.bool false
%int2 = torch.constant.int 2
%1 = torch.aten.argmax %0, %int2, %false : !torch.vtensor<[3,2,3],si32>, !torch.int, !torch.bool -> !torch.vtensor<[3,2],si64>
%2 = torch_c.to_builtin_tensor %1 : !torch.vtensor<[3,2],si64> -> tensor<3x2xi64>
return %2 : tensor<3x2xi64>
}

// -----

// CHECK-LABEL: @torch.vtensor.literal_si64$basic(
// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{values = dense<-1> : tensor<1x512xi64>}> : () -> tensor<1x512xi64>
// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<1x512xi64> -> !torch.vtensor<[1,512],si64>
Expand Down
Loading