From 2db504379214c1af0dec4a3bcd19dbf9c5904dd1 Mon Sep 17 00:00:00 2001 From: Tarik Rosin Date: Fri, 6 Mar 2026 02:44:44 -0700 Subject: [PATCH] [MLIR][TOSA] guard against illegal rw with quant types --- .../mlir/Dialect/Tosa/Utils/QuantUtils.h | 9 +++ .../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 13 ++++ mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp | 11 ++++ .../test/Dialect/Tosa/canonicalize-quant.mlir | 62 +++++++++++++++++++ 4 files changed, 95 insertions(+) create mode 100644 mlir/test/Dialect/Tosa/canonicalize-quant.mlir diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h index 5e80745777b3..8876d8645663 100644 --- a/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h +++ b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h @@ -4,6 +4,9 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // +// Modifications (c) Copyright 2026 Advanced Micro Devices, Inc. or its +// affiliates +// //===----------------------------------------------------------------------===// // // Function declarations for TOSA numerical support functions and quantization @@ -26,6 +29,12 @@ namespace tosa { // Utility functions to support quantization handling in Tosa. //===----------------------------------------------------------------------===// +/// Returns true if the type is a quantized type. +bool isQuantizedType(Type type); + +/// Returns true if the value has a quantized type. +bool hasQuantizedType(Value value); + /// From a scale value, computes multiplier and shift values /// for 16 or 32-bit scale widths. void computeMultiplierAndShift(double scale, int32_t &multiplier, diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 061ea1232864..461316b61d11 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -1820,6 +1820,10 @@ OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) { OpFoldResult tosa::LogOp::fold(FoldAdaptor adaptor) { auto input = getInput1(); + + if (hasQuantizedType(input)) + return {}; + // Element-wise log(exp(x)) = x if (auto op = input.getDefiningOp()) { return op.getInput1(); @@ -1830,6 +1834,10 @@ OpFoldResult tosa::LogOp::fold(FoldAdaptor adaptor) { OpFoldResult tosa::ExpOp::fold(FoldAdaptor adaptor) { auto input = getInput1(); + + if (hasQuantizedType(input)) + return {}; + // Element-wise exp(log(x)) = x if (auto op = input.getDefiningOp()) { return op.getInput1(); @@ -1840,6 +1848,10 @@ OpFoldResult tosa::ExpOp::fold(FoldAdaptor adaptor) { OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) { auto input = getInput1(); + + if (hasQuantizedType(input)) + return {}; + // Element-wise negate(negate(x)) = x if (auto op = input.getDefiningOp()) { return op.getInput1(); @@ -1850,6 +1862,7 @@ OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) { OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) { auto input = getInput1(); + // Element-wise abs(abs(x)) = abs(x) if (auto op = input.getDefiningOp()) { return input; diff --git a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp index 5c546f59cde4..633cbd9ebd1e 100644 --- a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp +++ b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp @@ -4,6 +4,9 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // +// Modifications (c) Copyright 2026 Advanced Micro Devices, Inc. or its +// affiliates +// //===----------------------------------------------------------------------===// // // This file contains TOSA numerical support functions and quantization @@ -16,6 +19,14 @@ using namespace mlir; using namespace mlir::tosa; +bool mlir::tosa::isQuantizedType(Type type) { + return isa(mlir::getElementTypeOrSelf(type)); +} + +bool mlir::tosa::hasQuantizedType(Value value) { + return isQuantizedType(value.getType()); +} + /// From a scale value, generates multiplier and shift values where /// mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that /// multiplier = mantissa*2^shift for 16-bit scaling. diff --git a/mlir/test/Dialect/Tosa/canonicalize-quant.mlir b/mlir/test/Dialect/Tosa/canonicalize-quant.mlir new file mode 100644 index 000000000000..2255bf858fc2 --- /dev/null +++ b/mlir/test/Dialect/Tosa/canonicalize-quant.mlir @@ -0,0 +1,62 @@ +// Modifications (c) Copyright 2026 Advanced Micro Devices, Inc. or its +// affiliates + +// RUN: mlir-opt --split-input-file -canonicalize="test-convergence" %s | FileCheck %s + +// CHECK-LABEL: @negate_negate_quant_nofold +// CHECK-NEXT: tosa.negate +// CHECK-NEXT: tosa.negate + +// The output of negate should be clipped to the range of the storage type. +// However, the canonicalization pass illegally removes the intermediate clip. +// Thus, negate(negate(x)) = x is not valid when x carries a quant type. +// A simple counter-example is neg(neg(-128)) = neg(127) = -127 != -128 +func.func @negate_negate_quant_nofold(%arg0: tensor<4x!quant.uniform>) -> tensor<4x!quant.uniform> { + %0 = tosa.negate %arg0 : (tensor<4x!quant.uniform>) -> tensor<4x!quant.uniform> + %1 = tosa.negate %0 : (tensor<4x!quant.uniform>) -> tensor<4x!quant.uniform> + return %1 : tensor<4x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: @exp_log_quant_nofold +// CHECK-NEXT: tosa.exp +// CHECK-NEXT: tosa.log +func.func @exp_log_quant_nofold(%arg0: tensor<4x!quant.uniform>) -> tensor<4x!quant.uniform> { + %0 = tosa.exp %arg0 : (tensor<4x!quant.uniform>) -> tensor<4x!quant.uniform> + %1 = tosa.log %0 : (tensor<4x!quant.uniform>) -> tensor<4x!quant.uniform> + return %1 : tensor<4x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: @log_exp_quant_nofold +// CHECK-NEXT: tosa.log +// CHECK-NEXT: tosa.exp +func.func @log_exp_quant_nofold(%arg0: tensor<4x!quant.uniform>) -> tensor<4x!quant.uniform> { + %0 = tosa.log %arg0 : (tensor<4x!quant.uniform>) -> tensor<4x!quant.uniform> + %1 = tosa.exp %0 : (tensor<4x!quant.uniform>) -> tensor<4x!quant.uniform> + return %1 : tensor<4x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: @min_to_clamp_quant +// CHECK: tosa.clamp %arg0 {max_fp = 6.000000e+00 : f32, max_int = 6 : i64, +// CHECK-SAME: min_fp = -3.40282347E+38 : f32, min_int = -2147483648 : i64} +func.func @min_to_clamp_quant(%arg0: tensor<4x!quant.uniform>) -> tensor<4x!quant.uniform> { + %0 = "tosa.const"() <{value = dense<6> : tensor<1xi8>}> : () -> tensor<1x!quant.uniform> + %1 = tosa.minimum %arg0, %0 : (tensor<4x!quant.uniform>, tensor<1x!quant.uniform>) -> tensor<4x!quant.uniform> + return %1 : tensor<4x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: @max_to_clamp_quant +// CHECK: tosa.clamp %arg0 {max_fp = 3.40282347E+38 : f32, max_int = 9223372036854775807 : i64, +// CHECK-SAME: min_fp = -6.000000e+00 : f32, min_int = -6 : i64} +func.func @max_to_clamp_quant(%arg0: tensor<4x!quant.uniform>) -> tensor<4x!quant.uniform> { + %0 = "tosa.const"() <{value = dense<-6> : tensor<1xi8>}> : () -> tensor<1x!quant.uniform> + %1 = tosa.maximum %arg0, %0 : (tensor<4x!quant.uniform>, tensor<1x!quant.uniform>) -> tensor<4x!quant.uniform> + return %1 : tensor<4x!quant.uniform> +}