Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/Dialect/ONNX/Transforms/ResultNamesUpdater.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (C) 2022 - 2025 Advanced Micro Devices, Inc. All rights reserved.

#include <deque>
#include <iterator>
#include <memory>
#include <unordered_set>

Expand Down Expand Up @@ -57,13 +58,22 @@ void inferTensorNames(ValueRange replOperands) {
} while (workList.size() > 0 && wlen > workList.size());
}

bool hasNameAndManyUses(Value value) {
auto numUses = std::distance(value.use_begin(), value.use_end());
return TensorName(value) && numUses > 1;
}

} // namespace

void ResultNamesUpdater::notifyOperationReplaced(
Operation *op, Operation *replacement) {
if (!op->hasAttrOfType<ArrayAttr>("ResultNames"))
return;

// If replacements have existing name and many uses, don't update ResultNames
if (llvm::any_of(replacement->getResults(), hasNameAndManyUses))
return;

// First, copy the ResultNames attribute for the last value
auto resultNamesArray = op->getAttrOfType<ArrayAttr>("ResultNames");
replacement->setAttr("ResultNames", resultNamesArray);
Expand All @@ -82,6 +92,10 @@ void ResultNamesUpdater::notifyOperationReplaced(
replSingleOp && replSingleOp->getResults() == replacement)
return notifyOperationReplaced(op, replSingleOp);

// If replacements have existing name and many uses, don't update ResultNames
if (llvm::any_of(replacement, hasNameAndManyUses))
return;

// First, copy the ResultNames attribute for the last value
auto resultNamesArray = op->getAttrOfType<ArrayAttr>("ResultNames");
MLIRContext *ctx = op->getContext();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -881,6 +881,11 @@ struct PushTransposeThroughSCast
if (!perm)
return failure();

if (llvm::any_of(op->getUsers(),
[](Operation *op) { return isa<ONNXDequantizeLinearOp>(op); }))
return rewriter.notifyMatchFailure(
op, "Not pushing through boundary scast");

auto outputType = mlir::cast<RankedTensorType>(op.getType());

// The new scast takes the transpose's input directly, so its output must
Expand Down
41 changes: 31 additions & 10 deletions test/mlir/onnx/onnx_resultnames_prop.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func.func @canonicalize(%arg0: tensor<f32>) -> tensor<f32> {
// CHECK: "onnx.Add"(%arg0, %0)
// CHECK-SAME: ResultNames = ["add0"]

func.func @qdq_canonicalize(%arg0: tensor<1x128xf32>) -> tensor<1x1x128xf32> {
func.func @remove_qdq_around_ops(%arg0: tensor<1x128xf32>) -> tensor<1x1x128xf32> {
%0 = onnx.Constant {ResultNames = ["scale"]} dense<1.000000e+00> : tensor<f32>
%1 = onnx.Constant {ResultNames = ["zp"]} dense<128> : tensor<ui8>
%2 = onnx.Constant {ResultNames = ["shape"]} dense<[1, 1, 128]> : tensor<3xi64>
Expand All @@ -50,7 +50,7 @@ func.func @qdq_canonicalize(%arg0: tensor<1x128xf32>) -> tensor<1x1x128xf32> {
return %7 : tensor<1x1x128xf32>
}

// CHECK-LABEL: @qdq_canonicalize
// CHECK-LABEL: @remove_qdq_around_ops
// CHECK: onnx.QuantizeLinear
// CHECK-SAME: ResultNames = ["q0"]
// CHECK-NOT: onnx.DequantizeLinear
Expand All @@ -59,13 +59,34 @@ func.func @qdq_canonicalize(%arg0: tensor<1x128xf32>) -> tensor<1x1x128xf32> {
// CHECK-NOT: onnx.QuantizeLinear
// CHECK: onnx.DequantizeLinear

func.func @complex_names(%arg0: tensor<f32>) -> tensor<f32> {
%0 = onnx.Constant {ResultNames = ["const0"]} dense<2.000000e+00> : tensor<f32>
%1 = "onnx.Add"(%0, %arg0) {ResultNames = [["add0", "with", "array", [1, 2, 3, 4]]]} : (tensor<f32>, tensor<f32>) -> tensor<f32>
return %1 : tensor<f32>
func.func @remove_dq_q_single_use(%arg0: tensor<1x128xf32>) -> tensor<1x128xf32> {
%0 = onnx.Constant {ResultNames = ["scale"]} dense<1.000000e+00> : tensor<f32>
%1 = onnx.Constant {ResultNames = ["zp"]} dense<128> : tensor<ui8>
%3 = "onnx.QuantizeLinear"(%arg0, %0, %1) {ResultNames = ["q0"], axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x128xf32>, tensor<f32>, tensor<ui8>) -> tensor<1x128xui8>
%4 = "onnx.DequantizeLinear"(%3, %0, %1) {ResultNames = ["dq0"], axis = 1 : si64, block_size = 0 : si64} : (tensor<1x128xui8>, tensor<f32>, tensor<ui8>) -> tensor<1x128xf32>
%6 = "onnx.QuantizeLinear"(%4, %0, %1) {ResultNames = ["q1"], axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x128xf32>, tensor<f32>, tensor<ui8>) -> tensor<1x128xui8>
%7 = "onnx.DequantizeLinear"(%6, %0, %1) {ResultNames = ["dq1"], axis = 1 : si64, block_size = 0 : si64} : (tensor<1x128xui8>, tensor<f32>, tensor<ui8>) -> tensor<1x128xf32>
return %7 : tensor<1x128xf32>
}

// CHECK-LABEL: @complex_names
// CHECK: "onnx.Add"(%arg0, %0)
// CHECK-SAME: ResultNames = [
// CHECK-SAME: ["add0", "with", "array", [1, 2, 3, 4]]]
// CHECK-LABEL: @remove_dq_q_single_use
// CHECK: onnx.QuantizeLinear
// CHECK-SAME: ResultNames = ["q1"]
// CHECK: onnx.DequantizeLinear
// CHECK-SAME: ResultNames = ["dq1"]

func.func @remove_dq_q_multi_use(%arg0: tensor<1x128xf32>) -> (tensor<1x128xui8>, tensor<1x128xf32>) {
%0 = onnx.Constant {ResultNames = ["scale"]} dense<1.000000e+00> : tensor<f32>
%1 = onnx.Constant {ResultNames = ["zp"]} dense<128> : tensor<ui8>
%3 = "onnx.QuantizeLinear"(%arg0, %0, %1) {ResultNames = ["q0"], axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x128xf32>, tensor<f32>, tensor<ui8>) -> tensor<1x128xui8>
%4 = "onnx.DequantizeLinear"(%3, %0, %1) {ResultNames = ["dq0"], axis = 1 : si64, block_size = 0 : si64} : (tensor<1x128xui8>, tensor<f32>, tensor<ui8>) -> tensor<1x128xf32>
%6 = "onnx.QuantizeLinear"(%4, %0, %1) {ResultNames = ["q1"], axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x128xf32>, tensor<f32>, tensor<ui8>) -> tensor<1x128xui8>
%7 = "onnx.DequantizeLinear"(%6, %0, %1) {ResultNames = ["dq1"], axis = 1 : si64, block_size = 0 : si64} : (tensor<1x128xui8>, tensor<f32>, tensor<ui8>) -> tensor<1x128xf32>
return %3, %7 : tensor<1x128xui8>, tensor<1x128xf32>
}

// CHECK-LABEL: @remove_dq_q_multi_use
// CHECK: onnx.QuantizeLinear
// CHECK-SAME: ResultNames = ["q0"]
// CHECK: onnx.DequantizeLinear
// CHECK-SAME: ResultNames = ["dq1"]
34 changes: 34 additions & 0 deletions test/mlir/onnx/xmc/onnx_transpose_optimization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1286,6 +1286,40 @@ func.func @test_push_transpose_through_scast_i8(%arg0: tensor<1x32x7x7x!quant.un

// -----

// Test: scast feeding DequantizeLinear is a boundary - should NOT push transpose through
// CHECK-LABEL: func @test_no_push_transpose_through_boundary_scast
func.func @test_no_push_transpose_through_boundary_scast(
%arg0: tensor<1x3x4x4x!quant.uniform<i8:f32, 0.05:0>>,
%scale: tensor<f32>,
%zp: tensor<i8>) -> tensor<1x4x4x3xf32> {
// CHECK: "onnx.Transpose"(%arg0) {perm = [0, 2, 3, 1]}
// CHECK: quant.scast
// CHECK: "onnx.DequantizeLinear"
%0 = "onnx.Transpose"(%arg0) {perm = [0, 2, 3, 1]} : (tensor<1x3x4x4x!quant.uniform<i8:f32, 0.05:0>>) -> tensor<1x4x4x3x!quant.uniform<i8:f32, 0.05:0>>
%1 = quant.scast %0 : tensor<1x4x4x3x!quant.uniform<i8:f32, 0.05:0>> to tensor<1x4x4x3xi8>
%2 = "onnx.DequantizeLinear"(%1, %scale, %zp) : (tensor<1x4x4x3xi8>, tensor<f32>, tensor<i8>) -> tensor<1x4x4x3xf32>
return %2 : tensor<1x4x4x3xf32>
}

// -----

// Test: scast with mixed users including DequantizeLinear - should NOT push transpose
// CHECK-LABEL: func @test_no_push_transpose_through_boundary_scast_mixed_users
func.func @test_no_push_transpose_through_boundary_scast_mixed_users(
%arg0: tensor<1x3x4x4x!quant.uniform<i8:f32, 0.05:0>>,
%scale: tensor<f32>,
%zp: tensor<i8>) -> (tensor<1x4x4x3xf32>, tensor<1x4x4x3xi8>) {
// CHECK: "onnx.Transpose"(%arg0) {perm = [0, 2, 3, 1]}
// CHECK: quant.scast
// CHECK: "onnx.DequantizeLinear"
%0 = "onnx.Transpose"(%arg0) {perm = [0, 2, 3, 1]} : (tensor<1x3x4x4x!quant.uniform<i8:f32, 0.05:0>>) -> tensor<1x4x4x3x!quant.uniform<i8:f32, 0.05:0>>
%1 = quant.scast %0 : tensor<1x4x4x3x!quant.uniform<i8:f32, 0.05:0>> to tensor<1x4x4x3xi8>
%2 = "onnx.DequantizeLinear"(%1, %scale, %zp) : (tensor<1x4x4x3xi8>, tensor<f32>, tensor<i8>) -> tensor<1x4x4x3xf32>
return %2, %1 : tensor<1x4x4x3xf32>, tensor<1x4x4x3xi8>
}

// -----

// Test: scast without transpose input should NOT be modified
// CHECK-LABEL: func @test_no_push_scast_without_transpose
func.func @test_no_push_scast_without_transpose(%arg0: tensor<1x3x4x4x!quant.uniform<i8:f32, 0.05:0>>) -> tensor<1x3x4x4xi8> {
Expand Down
Loading