Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
//====------ ConvertONNXToTOSA.cpp - ONNX dialects to TOSA lowering -------===//
//
// Copyright (c) 2022 Arm Limited.
// Copyright (c) 2022-2025 Advanced Micro Devices, Inc.
// Copyright (c) 2022-2026 Advanced Micro Devices, Inc.
//
// =============================================================================
//
Expand Down Expand Up @@ -174,11 +174,11 @@ void FrontendToTosaLoweringPass::runOnOperation() {

// We use the type converter to legalize types before any conversion patterns
// are executed. This ensures that we do not need to trigger separate
// conversion failures. Quantized types are not supported right now.
// conversion failures. Only per-tensor quantization is supported right now.
TypeConverter typeConverter;
typeConverter.addConversion([](Type type) -> std::optional<Type> {
if (isTOSAInt(type) || isa<FloatType>(type) || isa<NoneType>(type) ||
isTOSABool(type))
isTOSABool(type) || isTOSAQuantizedInt(type))
return type;
return std::nullopt;
});
Expand Down
16 changes: 14 additions & 2 deletions src/Conversion/ONNXToTOSA/Math/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

//===---------------- Elementwise.cpp - Elementwise Op --------------------===//
//
// Copyright (c) 2022 Advanced Micro Devices, Inc.
// Copyright (c) 2022-2026 Advanced Micro Devices, Inc.
//
// =============================================================================
//
Expand All @@ -31,6 +31,18 @@ struct TOSADialectOp<ONNXNegOp> {
using Op = mlir::tosa::NegateOp;
};

struct IsIntOrFloatOrQuantizedInt {
static LogicalResult checkType(
ConversionPatternRewriter &rewriter, Type scalarType, Operation *op) {
if (!isa<FloatType>(scalarType) && !isTOSAInt(scalarType) &&
!isTOSAQuantizedInt(scalarType)) {
return rewriter.notifyMatchFailure(op,
"this operation only supports int, float, or quantized int types");
}
return success();
}
};

struct IsIntOrFloat {
static LogicalResult checkType(
ConversionPatternRewriter &rewriter, Type scalarType, Operation *op) {
Expand Down Expand Up @@ -724,7 +736,7 @@ static void populateLoweringONNXElementwiseUnaryTemplateOpToTOSAPattern(
ONNXElementwiseUnaryOpLoweringToTOSA<ONNXNotOp, mlir::tosa::LogicalNotOp,
IsBool, IsBool>,
ONNXElementwiseUnaryOpLoweringToTOSA<ONNXAbsOp, mlir::tosa::AbsOp,
IsIntOrFloat, IsIntOrFloat>,
IsIntOrFloatOrQuantizedInt, IsIntOrFloatOrQuantizedInt>,
ONNXElementwiseUnaryOpLoweringToTOSA<ONNXErfOp, mlir::tosa::ErfOp,
IsFloat, IsFloat>,
ONNXElementwiseUnaryOpLoweringToTOSA<ONNXSinOp, mlir::tosa::SinOp,
Expand Down
7 changes: 6 additions & 1 deletion src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
//====------ ONNXToTOSACommon.hpp - ONNX dialects to TOSA lowering --------===//
//
// Copyright 2020-2024 The TensorFlow Authors. All Rights Reserved.
// Copyright (c) 2022-2024 Advanced Micro Devices, Inc.
// Copyright (c) 2022-2026 Advanced Micro Devices, Inc.
//
// =============================================================================
//
Expand Down Expand Up @@ -104,6 +104,11 @@ inline bool isTOSAFloat(mlir::Type type) {
type);
}

inline bool isTOSAQuantizedInt(mlir::Type type) {
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please move the implementation to the .cpp file?

For a small function like this it does not make a big difference, but generally having implementations in .cpp files helps to keep compile time and recompilations down

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
inline bool isTOSAQuantizedInt(mlir::Type type) {
inline bool isTOSAUniformQuantizedInt(mlir::Type type) {

auto quantizedType = mlir::dyn_cast<mlir::quant::UniformQuantizedType>(type);
return quantizedType && isTOSAInt(quantizedType.getStorageType());
}

//===----------------------------------------------------------------------===//
// This is to get a TOSA operation of a given type for a specific operation.
//===----------------------------------------------------------------------===//
Expand Down
25 changes: 25 additions & 0 deletions test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
// Copyright (c) 2026 Advanced Micro Devices, Inc.

// RUN: onnx-mlir-opt --shape-inference --convert-onnx-to-tosa -cse %s -split-input-file | FileCheck %s

// -----
Expand Down Expand Up @@ -786,6 +788,29 @@ func.func @test_abs_f64(%arg0: tensor<3xf64>) -> tensor<3xf64> {
// CHECK: return {{.*}}: tensor<3xf64>
}

func.func @test_abs_qi8(%arg0: tensor<3x!quant.uniform<i8:f32, 1.0>>) -> tensor<3x!quant.uniform<i8:f32, 1.0>> {
%0 = "onnx.Abs"(%arg0) : (tensor<3x!quant.uniform<i8:f32, 1.0>>) -> tensor<3x!quant.uniform<i8:f32, 1.0>>
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we get a test where the input and output quantization parameters are not the same?

return %0 : tensor<3x!quant.uniform<i8:f32, 1.0>>
// CHECK-LABEL: func @test_abs_qi8
// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.abs [[PARAM_0_]] : (tensor<3x!quant.uniform<i8:f32, 1.000000e+00>>) -> tensor<3x!quant.uniform<i8:f32, 1.000000e+00>>
// CHECK-NEXT: return [[VAR_0_]] : tensor<3x!quant.uniform<i8:f32, 1.000000e+00>>
// CHECK-NEXT: }
}

// Per-channel quantized types use quant.uniform<i8:f32:1 {s0, s1, ...}>,
// which is not currently handled by the ONNX-to-TOSA elementwise conversion.
// Only per-tensor uniform quantized types are supported.
// This test checks that the conversion does not fail but keeps the original op.

func.func @test_abs_qi8_channel(%arg0: tensor<3x4x!quant.uniform<i8:f32:1, {1.0, 2.0, 3.0, 4.0}>>) -> tensor<3x4x!quant.uniform<i8:f32:1, {1.0, 2.0, 3.0, 4.0}>> {
%0 = "onnx.Abs"(%arg0) : (tensor<3x4x!quant.uniform<i8:f32:1, {1.0, 2.0, 3.0, 4.0}>>) -> tensor<3x4x!quant.uniform<i8:f32:1, {1.0, 2.0, 3.0, 4.0}>>
return %0 : tensor<3x4x!quant.uniform<i8:f32:1, {1.0, 2.0, 3.0, 4.0}>>

// CHECK-LABEL: func @test_abs_qi8_channel
// CHECK: "onnx.Abs"
// CHECK: return
}

// -----

func.func @test_erf_f32(%arg0: tensor<3xf32>) -> tensor<3xf32> {
Expand Down