Skip to content

Commit cd8bb69

Browse files
committed
[TorchToTosa] Cast argmax inputs to f32 and add bf16 test
Ensure tosa.argmax receives f32 inputs by inserting tosa.cast in argmax lowering (and min/max-dim argmax paths), and fix the axis attribute type. Add a bf16 argmax conversion test in basic.mlir to validate the cast+extsi sequence. Signed-off-by: Cathal Corbett <cathal.corbett@arm.com> Change-Id: I5f8847cc7400152c20001905dfacc92af0c1583c
1 parent 7f1d4b2 commit cd8bb69

File tree

2 files changed

+46
-3
lines changed

2 files changed

+46
-3
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1501,9 +1501,19 @@ LogicalResult ConvertAtenOp<AtenArgmaxOp>::matchAndRewriteImpl(
15011501
getTypeConverter()->convertType(op.getResult().getType()));
15021502
auto outputETy = resultTy.getElementType();
15031503

1504+
auto ensureF32Input = [&](Value input) -> Value {
1505+
auto inputTy = cast<RankedTensorType>(input.getType());
1506+
if (inputTy.getElementType().isF32())
1507+
return input;
1508+
auto castTy =
1509+
RankedTensorType::get(inputTy.getShape(), rewriter.getF32Type());
1510+
return tosa::CastOp::create(rewriter, op->getLoc(), castTy, input);
1511+
};
1512+
15041513
// Create a single instance of tosa.argmax.
15051514
// Multiple dims require chained construct.
15061515
auto buildArgmax = [&](int64_t reduceDim, Value input) -> Value {
1516+
input = ensureF32Input(input);
15071517
auto inputTy = cast<RankedTensorType>(input.getType());
15081518
auto inputShape = makeShapeTorchCompatible(inputTy.getShape());
15091519
SmallVector<int64_t> outputShapeArr = {};
@@ -1523,7 +1533,7 @@ LogicalResult ConvertAtenOp<AtenArgmaxOp>::matchAndRewriteImpl(
15231533
makeShapeLLVMCompatible(ArrayRef<int64_t>(outputShapeArr)),
15241534
rewriter.getI32Type());
15251535
auto reduceDimAttr =
1526-
rewriter.getIntegerAttr(rewriter.getI64Type(), reduceDim);
1536+
rewriter.getIntegerAttr(rewriter.getI32Type(), reduceDim);
15271537

15281538
// Use default NaN Propagation mode "PROPAGATE" for tosa.argmax
15291539
return tosa::ArgMaxOp::create(
@@ -4692,26 +4702,37 @@ class ConvertAtenMinMaxDimOp : public TorchToTosaOpConversionPattern<AtenOpT> {
46924702

46934703
// To handle ReduceMinDim indices, we apply ArgMaxOp on the negate
46944704
// of the input tensor, which will return indices of input's min values
4705+
auto ensureF32Input = [&](Value input) -> Value {
4706+
auto inputTy = cast<RankedTensorType>(input.getType());
4707+
if (inputTy.getElementType().isF32())
4708+
return input;
4709+
auto castTy =
4710+
RankedTensorType::get(inputTy.getShape(), rewriter.getF32Type());
4711+
return tosa::CastOp::create(rewriter, op->getLoc(), castTy, input);
4712+
};
4713+
46954714
Value argMaxOp;
46964715
if constexpr (std::is_same<AtenOpT, AtenMinDimOp>()) {
46974716
Value negateOp =
46984717
tosa::NegateOp::create(rewriter, op->getLoc(), selfType, self);
4718+
Value argInput = ensureF32Input(negateOp);
46994719

47004720
// Use default NaN Propagation mode "PROPAGATE" for tosa.argmax
47014721
argMaxOp = tosa::ArgMaxOp::create(
47024722
rewriter, op->getLoc(),
47034723
RankedTensorType::get(makeShapeLLVMCompatible(prunedShape),
47044724
indicesElemType),
4705-
negateOp, dimAttr, /*nan_mode=*/
4725+
argInput, dimAttr, /*nan_mode=*/
47064726
tosa::NanPropagationModeAttr::get(
47074727
rewriter.getContext(), tosa::NanPropagationMode::PROPAGATE));
47084728
} else {
4729+
Value argInput = ensureF32Input(self);
47094730
// Use default NaN Propagation mode "PROPAGATE" for tosa.argmax
47104731
argMaxOp = tosa::ArgMaxOp::create(
47114732
rewriter, op->getLoc(),
47124733
RankedTensorType::get(makeShapeLLVMCompatible(prunedShape),
47134734
indicesElemType),
4714-
self, dimAttr, /*nan_mode=*/
4735+
argInput, dimAttr, /*nan_mode=*/
47154736
tosa::NanPropagationModeAttr::get(
47164737
rewriter.getContext(), tosa::NanPropagationMode::PROPAGATE));
47174738
}

test/Conversion/TorchToTosa/basic.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1076,6 +1076,28 @@ func.func @torch.aten.max.dim$basic(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x1xf3
10761076

10771077
// -----
10781078

1079+
// CHECK-LABEL: func.func @torch.aten.argmax$bf16(
1080+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<3x2x3xbf16>) -> tensor<3x2xi64> {
1081+
// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<3x2x3xbf16> -> !torch.vtensor<[3,2,3],bf16>
1082+
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[3,2,3],bf16> -> tensor<3x2x3xbf16>
1083+
// CHECK: %[[VAL_3:.*]] = torch.constant.bool false
1084+
// CHECK: %[[VAL_4:.*]] = torch.constant.int 2
1085+
// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_2]] : (tensor<3x2x3xbf16>) -> tensor<3x2x3xf32>
1086+
// CHECK: %[[VAL_6:.*]] = tosa.argmax %[[VAL_5]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2xi32>
1087+
// CHECK: %[[VAL_7:.*]] = arith.extsi %[[VAL_6]] : tensor<3x2xi32> to tensor<3x2xi64>
1088+
// CHECK: return %{{.*}} : tensor<3x2xi64>
1089+
// CHECK: }
1090+
func.func @torch.aten.argmax$bf16(%arg0: tensor<3x2x3xbf16>) -> tensor<3x2xi64> {
1091+
%0 = torch_c.from_builtin_tensor %arg0 : tensor<3x2x3xbf16> -> !torch.vtensor<[3,2,3],bf16>
1092+
%false = torch.constant.bool false
1093+
%int2 = torch.constant.int 2
1094+
%1 = torch.aten.argmax %0, %int2, %false : !torch.vtensor<[3,2,3],bf16>, !torch.int, !torch.bool -> !torch.vtensor<[3,2],si64>
1095+
%2 = torch_c.to_builtin_tensor %1 : !torch.vtensor<[3,2],si64> -> tensor<3x2xi64>
1096+
return %2 : tensor<3x2xi64>
1097+
}
1098+
1099+
// -----
1100+
10791101
// CHECK-LABEL: @torch.vtensor.literal_si64$basic(
10801102
// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{values = dense<-1> : tensor<1x512xi64>}> : () -> tensor<1x512xi64>
10811103
// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<1x512xi64> -> !torch.vtensor<[1,512],si64>

0 commit comments

Comments
 (0)