Skip to content

Commit f5742c4

Browse files
committed
Add bf16 e2e test and update utils func
Signed-off-by: Cathal Corbett <cathal.corbett@arm.com> Change-Id: I5e240a4dbaaee76d11e5895c82d93b718aabeff2
1 parent cd8bb69 commit f5742c4

File tree

4 files changed

+39
-21
lines changed

4 files changed

+39
-21
lines changed

include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ std::optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
6464
std::optional<Value> tosaCastTensorToType(PatternRewriter &rewriter, Value src,
6565
TensorType destType);
6666

67+
// Ensure TOSA argmax input is f32 by inserting a tosa.cast when needed.
68+
Value ensureF32Input(PatternRewriter &rewriter, Operation *op, Value input);
69+
6770
// Creates a TOSA operation and performs shape inference on the individual
6871
// op. This allows shape inference during the framework to TOSA lowering.
6972
template <typename TosaOp, typename... Args>

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1501,19 +1501,10 @@ LogicalResult ConvertAtenOp<AtenArgmaxOp>::matchAndRewriteImpl(
15011501
getTypeConverter()->convertType(op.getResult().getType()));
15021502
auto outputETy = resultTy.getElementType();
15031503

1504-
auto ensureF32Input = [&](Value input) -> Value {
1505-
auto inputTy = cast<RankedTensorType>(input.getType());
1506-
if (inputTy.getElementType().isF32())
1507-
return input;
1508-
auto castTy =
1509-
RankedTensorType::get(inputTy.getShape(), rewriter.getF32Type());
1510-
return tosa::CastOp::create(rewriter, op->getLoc(), castTy, input);
1511-
};
1512-
15131504
// Create a single instance of tosa.argmax.
15141505
// Multiple dims require chained construct.
15151506
auto buildArgmax = [&](int64_t reduceDim, Value input) -> Value {
1516-
input = ensureF32Input(input);
1507+
input = tosa::ensureF32Input(rewriter, op.getOperation(), input);
15171508
auto inputTy = cast<RankedTensorType>(input.getType());
15181509
auto inputShape = makeShapeTorchCompatible(inputTy.getShape());
15191510
SmallVector<int64_t> outputShapeArr = {};
@@ -4702,20 +4693,12 @@ class ConvertAtenMinMaxDimOp : public TorchToTosaOpConversionPattern<AtenOpT> {
47024693

47034694
// To handle ReduceMinDim indices, we apply ArgMaxOp on the negate
47044695
// of the input tensor, which will return indices of input's min values
4705-
auto ensureF32Input = [&](Value input) -> Value {
4706-
auto inputTy = cast<RankedTensorType>(input.getType());
4707-
if (inputTy.getElementType().isF32())
4708-
return input;
4709-
auto castTy =
4710-
RankedTensorType::get(inputTy.getShape(), rewriter.getF32Type());
4711-
return tosa::CastOp::create(rewriter, op->getLoc(), castTy, input);
4712-
};
4713-
47144696
Value argMaxOp;
47154697
if constexpr (std::is_same<AtenOpT, AtenMinDimOp>()) {
47164698
Value negateOp =
47174699
tosa::NegateOp::create(rewriter, op->getLoc(), selfType, self);
4718-
Value argInput = ensureF32Input(negateOp);
4700+
Value argInput =
4701+
tosa::ensureF32Input(rewriter, op.getOperation(), negateOp);
47194702

47204703
// Use default NaN Propagation mode "PROPAGATE" for tosa.argmax
47214704
argMaxOp = tosa::ArgMaxOp::create(
@@ -4726,7 +4709,7 @@ class ConvertAtenMinMaxDimOp : public TorchToTosaOpConversionPattern<AtenOpT> {
47264709
tosa::NanPropagationModeAttr::get(
47274710
rewriter.getContext(), tosa::NanPropagationMode::PROPAGATE));
47284711
} else {
4729-
Value argInput = ensureF32Input(self);
4712+
Value argInput = tosa::ensureF32Input(rewriter, op.getOperation(), self);
47304713
// Use default NaN Propagation mode "PROPAGATE" for tosa.argmax
47314714
argMaxOp = tosa::ArgMaxOp::create(
47324715
rewriter, op->getLoc(),

lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,15 @@ std::optional<Value> tosaCastTensorToType(PatternRewriter &rewriter, Value src,
381381
return tosa::CastOp::create(rewriter, op->getLoc(), castedSrcType, src);
382382
}
383383

384+
Value ensureF32Input(PatternRewriter &rewriter, Operation *op, Value input) {
385+
auto inputTy = cast<RankedTensorType>(input.getType());
386+
if (inputTy.getElementType().isF32())
387+
return input;
388+
auto castTy =
389+
RankedTensorType::get(inputTy.getShape(), rewriter.getF32Type());
390+
return tosa::CastOp::create(rewriter, op->getLoc(), castTy, input);
391+
}
392+
384393
// Template instantiation
385394
template std::optional<Value>
386395
getConstTensor<bool>(PatternRewriter &, Operation *, ArrayRef<bool> vec,

projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1642,6 +1642,29 @@ def ArgmaxModule_basic(module, tu: TestUtils):
16421642
# ==============================================================================
16431643

16441644

1645+
class ArgmaxBFloat16Module(torch.nn.Module):
1646+
def __init__(self):
1647+
super().__init__()
1648+
1649+
@export
1650+
@annotate_args(
1651+
[
1652+
None,
1653+
([-1, -1], torch.bfloat16, True),
1654+
]
1655+
)
1656+
def forward(self, a):
1657+
return torch.ops.aten.argmax(a)
1658+
1659+
1660+
@register_test_case(module_factory=lambda: ArgmaxBFloat16Module())
1661+
def ArgmaxBFloat16Module_basic(module, tu: TestUtils):
1662+
module.forward(tu.rand(3, 4).to(torch.bfloat16))
1663+
1664+
1665+
# ==============================================================================
1666+
1667+
16451668
class ArgmaxKeepdimModule(torch.nn.Module):
16461669
def __init__(self):
16471670
super().__init__()

0 commit comments

Comments
 (0)