Skip to content

Commit 2db5043

Browse files
committed
[MLIR][TOSA] guard against illegal rw with quant types
1 parent 0778a83 commit 2db5043

4 files changed

Lines changed: 95 additions & 0 deletions

File tree

mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
7+
// Modifications (c) Copyright 2026 Advanced Micro Devices, Inc. or its
8+
// affiliates
9+
//
710
//===----------------------------------------------------------------------===//
811
//
912
// Function declarations for TOSA numerical support functions and quantization
@@ -26,6 +29,12 @@ namespace tosa {
2629
// Utility functions to support quantization handling in Tosa.
2730
//===----------------------------------------------------------------------===//
2831

32+
/// Returns true if the type is a quantized type.
33+
bool isQuantizedType(Type type);
34+
35+
/// Returns true if the value has a quantized type.
36+
bool hasQuantizedType(Value value);
37+
2938
/// From a scale value, computes multiplier and shift values
3039
/// for 16 or 32-bit scale widths.
3140
void computeMultiplierAndShift(double scale, int32_t &multiplier,

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1820,6 +1820,10 @@ OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
18201820

18211821
OpFoldResult tosa::LogOp::fold(FoldAdaptor adaptor) {
18221822
auto input = getInput1();
1823+
1824+
if (hasQuantizedType(input))
1825+
return {};
1826+
18231827
// Element-wise log(exp(x)) = x
18241828
if (auto op = input.getDefiningOp<tosa::ExpOp>()) {
18251829
return op.getInput1();
@@ -1830,6 +1834,10 @@ OpFoldResult tosa::LogOp::fold(FoldAdaptor adaptor) {
18301834

18311835
OpFoldResult tosa::ExpOp::fold(FoldAdaptor adaptor) {
18321836
auto input = getInput1();
1837+
1838+
if (hasQuantizedType(input))
1839+
return {};
1840+
18331841
// Element-wise exp(log(x)) = x
18341842
if (auto op = input.getDefiningOp<tosa::LogOp>()) {
18351843
return op.getInput1();
@@ -1840,6 +1848,10 @@ OpFoldResult tosa::ExpOp::fold(FoldAdaptor adaptor) {
18401848

18411849
OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
18421850
auto input = getInput1();
1851+
1852+
if (hasQuantizedType(input))
1853+
return {};
1854+
18431855
// Element-wise negate(negate(x)) = x
18441856
if (auto op = input.getDefiningOp<tosa::NegateOp>()) {
18451857
return op.getInput1();
@@ -1850,6 +1862,7 @@ OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
18501862

18511863
OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
18521864
auto input = getInput1();
1865+
18531866
// Element-wise abs(abs(x)) = abs(x)
18541867
if (auto op = input.getDefiningOp<tosa::AbsOp>()) {
18551868
return input;

mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
7+
// Modifications (c) Copyright 2026 Advanced Micro Devices, Inc. or its
8+
// affiliates
9+
//
710
//===----------------------------------------------------------------------===//
811
//
912
// This file contains TOSA numerical support functions and quantization
@@ -16,6 +19,14 @@
1619
using namespace mlir;
1720
using namespace mlir::tosa;
1821

22+
bool mlir::tosa::isQuantizedType(Type type) {
23+
return isa<quant::UniformQuantizedType>(mlir::getElementTypeOrSelf(type));
24+
}
25+
26+
bool mlir::tosa::hasQuantizedType(Value value) {
27+
return isQuantizedType(value.getType());
28+
}
29+
1930
/// From a scale value, generates multiplier and shift values where
2031
/// mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that
2132
/// multiplier = mantissa*2^shift for 16-bit scaling.
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
// Modifications (c) Copyright 2026 Advanced Micro Devices, Inc. or its
2+
// affiliates
3+
4+
// RUN: mlir-opt --split-input-file -canonicalize="test-convergence" %s | FileCheck %s
5+
6+
// CHECK-LABEL: @negate_negate_quant_nofold
7+
// CHECK-NEXT: tosa.negate
8+
// CHECK-NEXT: tosa.negate
9+
10+
// The output of negate should be clipped to the range of the storage type.
11+
// However, the canonicalization pass illegally removes the intermediate clip.
12+
// Thus, negate(negate(x)) = x is not valid when x carries a quant type.
13+
// A simple counter-example is neg(neg(-128)) = neg(127) = -127 != -128
14+
func.func @negate_negate_quant_nofold(%arg0: tensor<4x!quant.uniform<i8:f32, 0.05>>) -> tensor<4x!quant.uniform<i8:f32, 0.05>> {
15+
%0 = tosa.negate %arg0 : (tensor<4x!quant.uniform<i8:f32, 0.05>>) -> tensor<4x!quant.uniform<i8:f32, 0.05>>
16+
%1 = tosa.negate %0 : (tensor<4x!quant.uniform<i8:f32, 0.05>>) -> tensor<4x!quant.uniform<i8:f32, 0.05>>
17+
return %1 : tensor<4x!quant.uniform<i8:f32, 0.05>>
18+
}
19+
20+
// -----
21+
22+
// CHECK-LABEL: @exp_log_quant_nofold
23+
// CHECK-NEXT: tosa.exp
24+
// CHECK-NEXT: tosa.log
25+
func.func @exp_log_quant_nofold(%arg0: tensor<4x!quant.uniform<i8:f32, 0.05>>) -> tensor<4x!quant.uniform<i8:f32, 0.05>> {
26+
%0 = tosa.exp %arg0 : (tensor<4x!quant.uniform<i8:f32, 0.05>>) -> tensor<4x!quant.uniform<i8:f32, 0.05>>
27+
%1 = tosa.log %0 : (tensor<4x!quant.uniform<i8:f32, 0.05>>) -> tensor<4x!quant.uniform<i8:f32, 0.05>>
28+
return %1 : tensor<4x!quant.uniform<i8:f32, 0.05>>
29+
}
30+
31+
// -----
32+
33+
// CHECK-LABEL: @log_exp_quant_nofold
34+
// CHECK-NEXT: tosa.log
35+
// CHECK-NEXT: tosa.exp
36+
func.func @log_exp_quant_nofold(%arg0: tensor<4x!quant.uniform<i8:f32, 0.05>>) -> tensor<4x!quant.uniform<i8:f32, 0.05>> {
37+
%0 = tosa.log %arg0 : (tensor<4x!quant.uniform<i8:f32, 0.05>>) -> tensor<4x!quant.uniform<i8:f32, 0.05>>
38+
%1 = tosa.exp %0 : (tensor<4x!quant.uniform<i8:f32, 0.05>>) -> tensor<4x!quant.uniform<i8:f32, 0.05>>
39+
return %1 : tensor<4x!quant.uniform<i8:f32, 0.05>>
40+
}
41+
42+
// -----
43+
44+
// CHECK-LABEL: @min_to_clamp_quant
45+
// CHECK: tosa.clamp %arg0 {max_fp = 6.000000e+00 : f32, max_int = 6 : i64,
46+
// CHECK-SAME: min_fp = -3.40282347E+38 : f32, min_int = -2147483648 : i64}
47+
func.func @min_to_clamp_quant(%arg0: tensor<4x!quant.uniform<i8:f32, 0.05>>) -> tensor<4x!quant.uniform<i8:f32, 0.05>> {
48+
%0 = "tosa.const"() <{value = dense<6> : tensor<1xi8>}> : () -> tensor<1x!quant.uniform<i8:f32, 0.05>>
49+
%1 = tosa.minimum %arg0, %0 : (tensor<4x!quant.uniform<i8:f32, 0.05>>, tensor<1x!quant.uniform<i8:f32, 0.05>>) -> tensor<4x!quant.uniform<i8:f32, 0.05>>
50+
return %1 : tensor<4x!quant.uniform<i8:f32, 0.05>>
51+
}
52+
53+
// -----
54+
55+
// CHECK-LABEL: @max_to_clamp_quant
56+
// CHECK: tosa.clamp %arg0 {max_fp = 3.40282347E+38 : f32, max_int = 9223372036854775807 : i64,
57+
// CHECK-SAME: min_fp = -6.000000e+00 : f32, min_int = -6 : i64}
58+
func.func @max_to_clamp_quant(%arg0: tensor<4x!quant.uniform<i8:f32, 0.05>>) -> tensor<4x!quant.uniform<i8:f32, 0.05>> {
59+
%0 = "tosa.const"() <{value = dense<-6> : tensor<1xi8>}> : () -> tensor<1x!quant.uniform<i8:f32, 0.05>>
60+
%1 = tosa.maximum %arg0, %0 : (tensor<4x!quant.uniform<i8:f32, 0.05>>, tensor<1x!quant.uniform<i8:f32, 0.05>>) -> tensor<4x!quant.uniform<i8:f32, 0.05>>
61+
return %1 : tensor<4x!quant.uniform<i8:f32, 0.05>>
62+
}

0 commit comments

Comments
 (0)