Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// Modifications (c) Copyright 2026 Advanced Micro Devices, Inc. or its
// affiliates
//
//===----------------------------------------------------------------------===//
//
// Function declarations for TOSA numerical support functions and quantization
Expand All @@ -26,6 +29,12 @@ namespace tosa {
// Utility functions to support quantization handling in Tosa.
//===----------------------------------------------------------------------===//

/// Returns true if the type is a quantized type.
bool isQuantizedType(Type type);

/// Returns true if the value has a quantized type.
bool hasQuantizedType(Value value);

/// From a scale value, computes multiplier and shift values
/// for 16 or 32-bit scale widths.
void computeMultiplierAndShift(double scale, int32_t &multiplier,
Expand Down
13 changes: 13 additions & 0 deletions mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1820,6 +1820,10 @@ OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {

OpFoldResult tosa::LogOp::fold(FoldAdaptor adaptor) {
auto input = getInput1();

if (hasQuantizedType(input))
return {};

// Element-wise log(exp(x)) = x
if (auto op = input.getDefiningOp<tosa::ExpOp>()) {
return op.getInput1();
Expand All @@ -1830,6 +1834,10 @@ OpFoldResult tosa::LogOp::fold(FoldAdaptor adaptor) {

OpFoldResult tosa::ExpOp::fold(FoldAdaptor adaptor) {
auto input = getInput1();

if (hasQuantizedType(input))
return {};

// Element-wise exp(log(x)) = x
if (auto op = input.getDefiningOp<tosa::LogOp>()) {
return op.getInput1();
Expand All @@ -1840,6 +1848,10 @@ OpFoldResult tosa::ExpOp::fold(FoldAdaptor adaptor) {

OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
auto input = getInput1();

if (hasQuantizedType(input))
return {};

// Element-wise negate(negate(x)) = x
if (auto op = input.getDefiningOp<tosa::NegateOp>()) {
return op.getInput1();
Expand All @@ -1850,6 +1862,7 @@ OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {

OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
auto input = getInput1();

// Element-wise abs(abs(x)) = abs(x)
if (auto op = input.getDefiningOp<tosa::AbsOp>()) {
return input;
Expand Down
11 changes: 11 additions & 0 deletions mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// Modifications (c) Copyright 2026 Advanced Micro Devices, Inc. or its
// affiliates
//
//===----------------------------------------------------------------------===//
//
// This file contains TOSA numerical support functions and quantization
Expand All @@ -16,6 +19,14 @@
using namespace mlir;
using namespace mlir::tosa;

bool mlir::tosa::isQuantizedType(Type type) {
return isa<quant::UniformQuantizedType>(mlir::getElementTypeOrSelf(type));
}

bool mlir::tosa::hasQuantizedType(Value value) {
return isQuantizedType(value.getType());
}

/// From a scale value, generates multiplier and shift values where
/// mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that
/// multiplier = mantissa*2^shift for 16-bit scaling.
Expand Down
62 changes: 62 additions & 0 deletions mlir/test/Dialect/Tosa/canonicalize-quant.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Modifications (c) Copyright 2026 Advanced Micro Devices, Inc. or its
// affiliates

// RUN: mlir-opt --split-input-file -canonicalize="test-convergence" %s | FileCheck %s

// CHECK-LABEL: @negate_negate_quant_nofold
// CHECK-NEXT: tosa.negate
// CHECK-NEXT: tosa.negate

// The output of negate should be clipped to the range of the storage type.
// However, the canonicalization pass illegally removes the intermediate clip.
// Thus, negate(negate(x)) = x is not valid when x carries a quant type.
// A simple counter-example is neg(neg(-128)) = neg(127) = -127 != -128
func.func @negate_negate_quant_nofold(%arg0: tensor<4x!quant.uniform<i8:f32, 0.05>>) -> tensor<4x!quant.uniform<i8:f32, 0.05>> {
%0 = tosa.negate %arg0 : (tensor<4x!quant.uniform<i8:f32, 0.05>>) -> tensor<4x!quant.uniform<i8:f32, 0.05>>
%1 = tosa.negate %0 : (tensor<4x!quant.uniform<i8:f32, 0.05>>) -> tensor<4x!quant.uniform<i8:f32, 0.05>>
return %1 : tensor<4x!quant.uniform<i8:f32, 0.05>>
}

// -----

// CHECK-LABEL: @exp_log_quant_nofold
// CHECK-NEXT: tosa.exp
// CHECK-NEXT: tosa.log
func.func @exp_log_quant_nofold(%arg0: tensor<4x!quant.uniform<i8:f32, 0.05>>) -> tensor<4x!quant.uniform<i8:f32, 0.05>> {
%0 = tosa.exp %arg0 : (tensor<4x!quant.uniform<i8:f32, 0.05>>) -> tensor<4x!quant.uniform<i8:f32, 0.05>>
%1 = tosa.log %0 : (tensor<4x!quant.uniform<i8:f32, 0.05>>) -> tensor<4x!quant.uniform<i8:f32, 0.05>>
return %1 : tensor<4x!quant.uniform<i8:f32, 0.05>>
}

// -----

// CHECK-LABEL: @log_exp_quant_nofold
// CHECK-NEXT: tosa.log
// CHECK-NEXT: tosa.exp
func.func @log_exp_quant_nofold(%arg0: tensor<4x!quant.uniform<i8:f32, 0.05>>) -> tensor<4x!quant.uniform<i8:f32, 0.05>> {
%0 = tosa.log %arg0 : (tensor<4x!quant.uniform<i8:f32, 0.05>>) -> tensor<4x!quant.uniform<i8:f32, 0.05>>
%1 = tosa.exp %0 : (tensor<4x!quant.uniform<i8:f32, 0.05>>) -> tensor<4x!quant.uniform<i8:f32, 0.05>>
return %1 : tensor<4x!quant.uniform<i8:f32, 0.05>>
}

// -----

// CHECK-LABEL: @min_to_clamp_quant
// CHECK: tosa.clamp %arg0 {max_fp = 6.000000e+00 : f32, max_int = 6 : i64,
// CHECK-SAME: min_fp = -3.40282347E+38 : f32, min_int = -2147483648 : i64}
func.func @min_to_clamp_quant(%arg0: tensor<4x!quant.uniform<i8:f32, 0.05>>) -> tensor<4x!quant.uniform<i8:f32, 0.05>> {
%0 = "tosa.const"() <{value = dense<6> : tensor<1xi8>}> : () -> tensor<1x!quant.uniform<i8:f32, 0.05>>
%1 = tosa.minimum %arg0, %0 : (tensor<4x!quant.uniform<i8:f32, 0.05>>, tensor<1x!quant.uniform<i8:f32, 0.05>>) -> tensor<4x!quant.uniform<i8:f32, 0.05>>
return %1 : tensor<4x!quant.uniform<i8:f32, 0.05>>
}

// -----

// CHECK-LABEL: @max_to_clamp_quant
// CHECK: tosa.clamp %arg0 {max_fp = 3.40282347E+38 : f32, max_int = 9223372036854775807 : i64,
// CHECK-SAME: min_fp = -6.000000e+00 : f32, min_int = -6 : i64}
func.func @max_to_clamp_quant(%arg0: tensor<4x!quant.uniform<i8:f32, 0.05>>) -> tensor<4x!quant.uniform<i8:f32, 0.05>> {
%0 = "tosa.const"() <{value = dense<-6> : tensor<1xi8>}> : () -> tensor<1x!quant.uniform<i8:f32, 0.05>>
%1 = tosa.maximum %arg0, %0 : (tensor<4x!quant.uniform<i8:f32, 0.05>>, tensor<1x!quant.uniform<i8:f32, 0.05>>) -> tensor<4x!quant.uniform<i8:f32, 0.05>>
return %1 : tensor<4x!quant.uniform<i8:f32, 0.05>>
}