From 02d5854ee902016e28d77486da688bb1654cabd4 Mon Sep 17 00:00:00 2001 From: Ilango Rajagopal Date: Fri, 27 Mar 2026 11:31:55 -0600 Subject: [PATCH 1/2] Transpose not to be moved through boundary scast --- .../xmc/ONNXTransposeOptimizationPass.cpp | 5 +++ .../onnx/xmc/onnx_transpose_optimization.mlir | 34 +++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/src/Dialect/ONNX/Transforms/xmc/ONNXTransposeOptimizationPass.cpp b/src/Dialect/ONNX/Transforms/xmc/ONNXTransposeOptimizationPass.cpp index 13bae45c807..e85edfd34e2 100644 --- a/src/Dialect/ONNX/Transforms/xmc/ONNXTransposeOptimizationPass.cpp +++ b/src/Dialect/ONNX/Transforms/xmc/ONNXTransposeOptimizationPass.cpp @@ -881,6 +881,11 @@ struct PushTransposeThroughSCast if (!perm) return failure(); + if (llvm::any_of(op->getUsers(), + [](Operation *op) { return isa(op); })) + return rewriter.notifyMatchFailure( + op, "Not pushing through boundary scast"); + auto outputType = mlir::cast(op.getType()); // The new scast takes the transpose's input directly, so its output must diff --git a/test/mlir/onnx/xmc/onnx_transpose_optimization.mlir b/test/mlir/onnx/xmc/onnx_transpose_optimization.mlir index 565c23f2508..034b07d1fdc 100644 --- a/test/mlir/onnx/xmc/onnx_transpose_optimization.mlir +++ b/test/mlir/onnx/xmc/onnx_transpose_optimization.mlir @@ -1286,6 +1286,40 @@ func.func @test_push_transpose_through_scast_i8(%arg0: tensor<1x32x7x7x!quant.un // ----- +// Test: scast feeding DequantizeLinear is a boundary - should NOT push transpose through +// CHECK-LABEL: func @test_no_push_transpose_through_boundary_scast +func.func @test_no_push_transpose_through_boundary_scast( + %arg0: tensor<1x3x4x4x!quant.uniform>, + %scale: tensor, + %zp: tensor) -> tensor<1x4x4x3xf32> { + // CHECK: "onnx.Transpose"(%arg0) {perm = [0, 2, 3, 1]} + // CHECK: quant.scast + // CHECK: "onnx.DequantizeLinear" + %0 = "onnx.Transpose"(%arg0) {perm = [0, 2, 3, 1]} : (tensor<1x3x4x4x!quant.uniform>) -> tensor<1x4x4x3x!quant.uniform> + %1 = quant.scast %0 : tensor<1x4x4x3x!quant.uniform> to tensor<1x4x4x3xi8> + %2 = "onnx.DequantizeLinear"(%1, %scale, %zp) : (tensor<1x4x4x3xi8>, tensor, tensor) -> tensor<1x4x4x3xf32> + return %2 : tensor<1x4x4x3xf32> +} + +// ----- + +// Test: scast with mixed users including DequantizeLinear - should NOT push transpose +// CHECK-LABEL: func @test_no_push_transpose_through_boundary_scast_mixed_users +func.func @test_no_push_transpose_through_boundary_scast_mixed_users( + %arg0: tensor<1x3x4x4x!quant.uniform>, + %scale: tensor, + %zp: tensor) -> (tensor<1x4x4x3xf32>, tensor<1x4x4x3xi8>) { + // CHECK: "onnx.Transpose"(%arg0) {perm = [0, 2, 3, 1]} + // CHECK: quant.scast + // CHECK: "onnx.DequantizeLinear" + %0 = "onnx.Transpose"(%arg0) {perm = [0, 2, 3, 1]} : (tensor<1x3x4x4x!quant.uniform>) -> tensor<1x4x4x3x!quant.uniform> + %1 = quant.scast %0 : tensor<1x4x4x3x!quant.uniform> to tensor<1x4x4x3xi8> + %2 = "onnx.DequantizeLinear"(%1, %scale, %zp) : (tensor<1x4x4x3xi8>, tensor, tensor) -> tensor<1x4x4x3xf32> + return %2, %1 : tensor<1x4x4x3xf32>, tensor<1x4x4x3xi8> +} + +// ----- + // Test: scast without transpose input should NOT be modified // CHECK-LABEL: func @test_no_push_scast_without_transpose func.func @test_no_push_scast_without_transpose(%arg0: tensor<1x3x4x4x!quant.uniform>) -> tensor<1x3x4x4xi8> { From 949e6679eeeb234654b999cac519cd1c3d5b512e Mon Sep 17 00:00:00 2001 From: Ilango Rajagopal Date: Mon, 4 May 2026 02:05:15 -0600 Subject: [PATCH 2/2] Replacements with name & many uses, no tname updat --- .../ONNX/Transforms/ResultNamesUpdater.cpp | 14 +++++++ test/mlir/onnx/onnx_resultnames_prop.mlir | 41 ++++++++++++++----- 2 files changed, 45 insertions(+), 10 deletions(-) diff --git a/src/Dialect/ONNX/Transforms/ResultNamesUpdater.cpp b/src/Dialect/ONNX/Transforms/ResultNamesUpdater.cpp index 89afcc427d6..a71741ab19b 100644 --- a/src/Dialect/ONNX/Transforms/ResultNamesUpdater.cpp +++ b/src/Dialect/ONNX/Transforms/ResultNamesUpdater.cpp @@ -1,6 +1,7 @@ // Copyright (C) 2022 - 2025 Advanced Micro Devices, Inc. All rights reserved. #include +#include #include #include @@ -57,6 +58,11 @@ void inferTensorNames(ValueRange replOperands) { } while (workList.size() > 0 && wlen > workList.size()); } +bool hasNameAndManyUses(Value value) { + auto numUses = std::distance(value.use_begin(), value.use_end()); + return TensorName(value) && numUses > 1; +} + } // namespace void ResultNamesUpdater::notifyOperationReplaced( @@ -64,6 +70,10 @@ void ResultNamesUpdater::notifyOperationReplaced( if (!op->hasAttrOfType("ResultNames")) return; + // If replacements have existing name and many uses, don't update ResultNames + if (llvm::any_of(replacement->getResults(), hasNameAndManyUses)) + return; + // First, copy the ResultNames attribute for the last value auto resultNamesArray = op->getAttrOfType("ResultNames"); replacement->setAttr("ResultNames", resultNamesArray); @@ -82,6 +92,10 @@ void ResultNamesUpdater::notifyOperationReplaced( replSingleOp && replSingleOp->getResults() == replacement) return notifyOperationReplaced(op, replSingleOp); + // If replacements have existing name and many uses, don't update ResultNames + if (llvm::any_of(replacement, hasNameAndManyUses)) + return; + // First, copy the ResultNames attribute for the last value auto resultNamesArray = op->getAttrOfType("ResultNames"); MLIRContext *ctx = op->getContext(); diff --git a/test/mlir/onnx/onnx_resultnames_prop.mlir b/test/mlir/onnx/onnx_resultnames_prop.mlir index a83a4b8a1e6..98113892a02 100644 --- a/test/mlir/onnx/onnx_resultnames_prop.mlir +++ b/test/mlir/onnx/onnx_resultnames_prop.mlir @@ -38,7 +38,7 @@ func.func @canonicalize(%arg0: tensor) -> tensor { // CHECK: "onnx.Add"(%arg0, %0) // CHECK-SAME: ResultNames = ["add0"] -func.func @qdq_canonicalize(%arg0: tensor<1x128xf32>) -> tensor<1x1x128xf32> { +func.func @remove_qdq_around_ops(%arg0: tensor<1x128xf32>) -> tensor<1x1x128xf32> { %0 = onnx.Constant {ResultNames = ["scale"]} dense<1.000000e+00> : tensor %1 = onnx.Constant {ResultNames = ["zp"]} dense<128> : tensor %2 = onnx.Constant {ResultNames = ["shape"]} dense<[1, 1, 128]> : tensor<3xi64> @@ -50,7 +50,7 @@ func.func @qdq_canonicalize(%arg0: tensor<1x128xf32>) -> tensor<1x1x128xf32> { return %7 : tensor<1x1x128xf32> } -// CHECK-LABEL: @qdq_canonicalize +// CHECK-LABEL: @remove_qdq_around_ops // CHECK: onnx.QuantizeLinear // CHECK-SAME: ResultNames = ["q0"] // CHECK-NOT: onnx.DequantizeLinear @@ -59,13 +59,34 @@ func.func @qdq_canonicalize(%arg0: tensor<1x128xf32>) -> tensor<1x1x128xf32> { // CHECK-NOT: onnx.QuantizeLinear // CHECK: onnx.DequantizeLinear -func.func @complex_names(%arg0: tensor) -> tensor { - %0 = onnx.Constant {ResultNames = ["const0"]} dense<2.000000e+00> : tensor - %1 = "onnx.Add"(%0, %arg0) {ResultNames = [["add0", "with", "array", [1, 2, 3, 4]]]} : (tensor, tensor) -> tensor - return %1 : tensor +func.func @remove_dq_q_single_use(%arg0: tensor<1x128xf32>) -> tensor<1x128xf32> { + %0 = onnx.Constant {ResultNames = ["scale"]} dense<1.000000e+00> : tensor + %1 = onnx.Constant {ResultNames = ["zp"]} dense<128> : tensor + %3 = "onnx.QuantizeLinear"(%arg0, %0, %1) {ResultNames = ["q0"], axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x128xf32>, tensor, tensor) -> tensor<1x128xui8> + %4 = "onnx.DequantizeLinear"(%3, %0, %1) {ResultNames = ["dq0"], axis = 1 : si64, block_size = 0 : si64} : (tensor<1x128xui8>, tensor, tensor) -> tensor<1x128xf32> + %6 = "onnx.QuantizeLinear"(%4, %0, %1) {ResultNames = ["q1"], axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x128xf32>, tensor, tensor) -> tensor<1x128xui8> + %7 = "onnx.DequantizeLinear"(%6, %0, %1) {ResultNames = ["dq1"], axis = 1 : si64, block_size = 0 : si64} : (tensor<1x128xui8>, tensor, tensor) -> tensor<1x128xf32> + return %7 : tensor<1x128xf32> } -// CHECK-LABEL: @complex_names -// CHECK: "onnx.Add"(%arg0, %0) -// CHECK-SAME: ResultNames = [ -// CHECK-SAME: ["add0", "with", "array", [1, 2, 3, 4]]] +// CHECK-LABEL: @remove_dq_q_single_use +// CHECK: onnx.QuantizeLinear +// CHECK-SAME: ResultNames = ["q1"] +// CHECK: onnx.DequantizeLinear +// CHECK-SAME: ResultNames = ["dq1"] + +func.func @remove_dq_q_multi_use(%arg0: tensor<1x128xf32>) -> (tensor<1x128xui8>, tensor<1x128xf32>) { + %0 = onnx.Constant {ResultNames = ["scale"]} dense<1.000000e+00> : tensor + %1 = onnx.Constant {ResultNames = ["zp"]} dense<128> : tensor + %3 = "onnx.QuantizeLinear"(%arg0, %0, %1) {ResultNames = ["q0"], axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x128xf32>, tensor, tensor) -> tensor<1x128xui8> + %4 = "onnx.DequantizeLinear"(%3, %0, %1) {ResultNames = ["dq0"], axis = 1 : si64, block_size = 0 : si64} : (tensor<1x128xui8>, tensor, tensor) -> tensor<1x128xf32> + %6 = "onnx.QuantizeLinear"(%4, %0, %1) {ResultNames = ["q1"], axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x128xf32>, tensor, tensor) -> tensor<1x128xui8> + %7 = "onnx.DequantizeLinear"(%6, %0, %1) {ResultNames = ["dq1"], axis = 1 : si64, block_size = 0 : si64} : (tensor<1x128xui8>, tensor, tensor) -> tensor<1x128xf32> + return %3, %7 : tensor<1x128xui8>, tensor<1x128xf32> +} + +// CHECK-LABEL: @remove_dq_q_multi_use +// CHECK: onnx.QuantizeLinear +// CHECK-SAME: ResultNames = ["q0"] +// CHECK: onnx.DequantizeLinear +// CHECK-SAME: ResultNames = ["dq1"]