Skip to content

Commit 6b8ac8d

Browse files
committed
Fix ClipQuantFusion crash when Clip has multiple input edges
1 parent 4665804 commit 6b8ac8d

File tree

2 files changed

+36
-1
lines changed

2 files changed

+36
-1
lines changed

onnxruntime/core/optimizer/qdq_transformer/clip_quantizelinear.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ static bool GetQConstantLowerUpper(const Graph& graph, const Node& node, float&
8181
return true;
8282
}
8383

84-
bool ClipQuantFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& /*logger*/) const {
84+
bool ClipQuantFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const {
8585
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Clip", {1, 6, 11, 12, 13}) ||
8686
!graph_utils::IsSupportedProvider(node, {kCpuExecutionProvider}) ||
8787
!optimizer_utils::CheckOutputEdges(graph, node, 1)) {
@@ -95,6 +95,10 @@ bool ClipQuantFusion::SatisfyCondition(const Graph& graph, const Node& node, con
9595
return false;
9696
}
9797

98+
if (!graph_utils::CanRemoveNode(graph, node, logger)) {
99+
return false;
100+
}
101+
98102
return true;
99103
}
100104

onnxruntime/test/optimizer/qdq_transformer_test.cc

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3221,6 +3221,37 @@ TEST(QDQTransformerTests, ReluQuantFusion_Level2Only) {
32213221
test_case(TransformerLevel::Level3, 0); // Will not fuse Relu into QuantizeLinear due to zero-point != -128
32223222
}
32233223

3224+
// Test skip removing edge when min/max come from DequantizeLinear nodes instead of initializers).
3225+
TEST(QDQTransformerTests, ClipQuantFusion_MultipleInputEdges) {
3226+
auto build_test_case = [&](ModelTestBuilder& builder) {
3227+
// Clip's min coming from another DQ node (creating 2 input edges to Clip)
3228+
auto* input_arg = builder.MakeInput<uint8_t>({1, 2, 2, 2}, std::numeric_limits<uint8_t>::min(),
3229+
std::numeric_limits<uint8_t>::max());
3230+
auto* data_dq = builder.MakeIntermediate();
3231+
builder.AddDequantizeLinearNode<uint8_t>(input_arg, 0.04f, static_cast<uint8_t>(0), data_dq);
3232+
auto* min_q = builder.MakeScalarInitializer<uint8_t>(0);
3233+
auto* min_dq = builder.MakeIntermediate();
3234+
builder.AddDequantizeLinearNode<uint8_t>(min_q, 0.04f, static_cast<uint8_t>(0), min_dq);
3235+
auto* clip_output = builder.MakeIntermediate();
3236+
builder.AddNode("Clip", {data_dq, min_dq}, {clip_output});
3237+
auto* output_q = builder.MakeIntermediate();
3238+
builder.AddQuantizeLinearNode<uint8_t>(clip_output, 0.04f, static_cast<uint8_t>(0), output_q);
3239+
auto* output_arg = builder.MakeOutput();
3240+
builder.AddDequantizeLinearNode<uint8_t>(output_q, 0.04f, static_cast<uint8_t>(0), output_arg);
3241+
};
3242+
3243+
auto check_graph = [&](InferenceSessionWrapper& session) {
3244+
auto op_to_count = CountOpsInGraph(session.GetGraph());
3245+
// ClipQuantFusion should skip it due to CanRemoveNode check
3246+
EXPECT_EQ(op_to_count["Clip"], 1);
3247+
};
3248+
3249+
TransformerTester(build_test_case, check_graph,
3250+
TransformerLevel::Default,
3251+
TransformerLevel::Level2,
3252+
18); // opset
3253+
}
3254+
32243255
template <typename ScaleType, typename ZpType>
32253256
void TestWhereWithDqInput(bool is_dq_1,
32263257
bool is_dq_2,

0 commit comments

Comments
 (0)