Skip to content

Commit 9f95908

Browse files
authored
[QNN EP] Fix Clip op with min or max from QDQ (#26601)
## Motivation QDQ node group selection logic currently navigate `Clip` op to `UnaryNodeGroupSelector`. This isn't properly handling the use case where `Clip` op has `min/max` provided from Q/DQ ops (still constant initializers). <img width="255" height="378" alt="image-2025-11-18-11-49-19-156" src="https://github.com/user-attachments/assets/ec6250ee-68f3-40fa-8f60-93b1a400d5a0" /> ## Changes: - Implement custom NodeGroupSelector so that `Clip` op is properly tagged for backend to consume. - Fix QNN EP `Clip` min/max parsing and perform de-quantize when needed. - Unit tests for both changes.
1 parent 07724d5 commit 9f95908

File tree

6 files changed

+407
-48
lines changed

6 files changed

+407
-48
lines changed

onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,41 @@ bool UnaryNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node&
245245
return true;
246246
}
247247

248+
bool ClipNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, const Node* redundant_clip_node,
249+
const std::vector<const Node*>& dq_nodes,
250+
const std::vector<const Node*>& q_nodes) const {
251+
// Clip can have 1, 2, or 3 DQ inputs:
252+
// - 1 DQ: only data input is quantized
253+
// - 2 DQ: data and min or max are quantized
254+
// - 3 DQ: data, min, and max are all quantized
255+
const size_t num_dq_nodes = dq_nodes.size();
256+
if (num_dq_nodes < 1 || num_dq_nodes > 3) {
257+
return false;
258+
}
259+
260+
if (!CheckQDQNodes(graph_viewer, node, redundant_clip_node, dq_nodes, q_nodes, static_cast<int>(num_dq_nodes))) {
261+
return false;
262+
}
263+
264+
int32_t dt_input = dq_nodes[0]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
265+
int32_t dt_output = q_nodes[0]->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
266+
267+
if (dt_input != dt_output) {
268+
return false;
269+
}
270+
271+
// 16-bit int types must be explicitly allowed.
272+
if (!allow_16bit_ && Is16BitIntType(dt_input)) {
273+
return false;
274+
}
275+
276+
if (!allow_4bit_ && Is4BitIntType(dt_input)) {
277+
return false;
278+
}
279+
280+
return true;
281+
}
282+
248283
bool BinaryNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, const Node* redundant_clip_node,
249284
const std::vector<const Node*>& dq_nodes,
250285
const std::vector<const Node*>& q_nodes) const {

onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,20 @@ class UnaryNodeGroupSelector : public NodeGroupSelector {
9292
bool allow_4bit_;
9393
};
9494

95+
class ClipNodeGroupSelector : public NodeGroupSelector {
96+
public:
97+
explicit ClipNodeGroupSelector(bool allow_16bit = true, bool allow_4bit = true)
98+
: allow_16bit_(allow_16bit), allow_4bit_(allow_4bit) {}
99+
100+
private:
101+
bool Check(const GraphViewer& graph_viewer, const Node& node, const Node* redundant_clip_node,
102+
const std::vector<const Node*>& dq_nodes,
103+
const std::vector<const Node*>& q_nodes) const override;
104+
105+
bool allow_16bit_;
106+
bool allow_4bit_;
107+
};
108+
95109
// 2 DQ nodes providing input -> node -> Q
96110
class BinaryNodeGroupSelector : public NodeGroupSelector {
97111
public:

onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,11 @@ static const OpVersionsAndSelector::OpVersionsMap GetUnaryOpVersionsMap() {
8787
{"Neg", {}},
8888
{"DepthToSpace", {}},
8989
{"SpaceToDepth", {}},
90-
{"Clip", {}},
9190
{"LpNormalization", {}}};
9291
}
92+
static const OpVersionsAndSelector::OpVersionsMap GetClipOpVersionsMap() {
93+
return {{"Clip", {}}};
94+
}
9395
static const OpVersionsAndSelector::OpVersionsMap GetBinaryOpVersionsMap() {
9496
return {{"Add", {}},
9597
{"Div", {}},
@@ -168,19 +170,26 @@ void RegisterMiscSelectors(Selectors& qdq_selectors) {
168170
}
169171

170172
void RegisterDropDQSelectors(Selectors& qdq_selectors) {
171-
/* register selectors for ops that have a sigle DQ -> node */
173+
/* register selectors for ops that have a single DQ -> node */
172174
std::unique_ptr<NodeGroupSelector> selector = std::make_unique<DropDQNodeGroupSelector>();
173175
qdq_selectors.RegisterSelector(GetDropDQOpVersionsMap(),
174176
std::move(selector));
175177
}
176178

177179
void RegisterUnarySelectors(Selectors& qdq_selectors) {
178-
/* regsiter selectors for unary ops */
180+
/* register selectors for unary ops */
179181
std::unique_ptr<NodeGroupSelector> selector = std::make_unique<UnaryNodeGroupSelector>();
180182
qdq_selectors.RegisterSelector(GetUnaryOpVersionsMap(),
181183
std::move(selector));
182184
}
183185

186+
void RegisterClipSelector(Selectors& qdq_selectors) {
187+
/* register selector for Clip op */
188+
std::unique_ptr<NodeGroupSelector> selector = std::make_unique<ClipNodeGroupSelector>();
189+
qdq_selectors.RegisterSelector(GetClipOpVersionsMap(),
190+
std::move(selector));
191+
}
192+
184193
void RegisterBinarySelectors(Selectors& qdq_selectors) {
185194
/* register selectors for binary ops */
186195
std::unique_ptr<NodeGroupSelector> selector = std::make_unique<BinaryNodeGroupSelector>();
@@ -305,6 +314,7 @@ void SelectorManager::CreateSelectors() {
305314
RegisterMiscSelectors(qdq_selectors_);
306315
RegisterDropDQSelectors(qdq_selectors_);
307316
RegisterUnarySelectors(qdq_selectors_);
317+
RegisterClipSelector(qdq_selectors_);
308318
RegisterBinarySelectors(qdq_selectors_);
309319
RegisterVariadicSelectors(qdq_selectors_);
310320
RegisterSplitSelector(qdq_selectors_);

onnxruntime/core/providers/qnn/builder/opbuilder/clip_op_builder.cc

Lines changed: 101 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class ClipOpBuilder : public BaseOpBuilder {
3030
bool do_op_validation) const override ORT_MUST_USE_RESULT;
3131

3232
private:
33-
Status ExplictOpCheck(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit) const;
33+
Status ExplicitOpCheck(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit) const;
3434
};
3535

3636
static Status ProcessClipMinMax(QnnModelWrapper& qnn_model_wrapper,
@@ -41,56 +41,112 @@ static Status ProcessClipMinMax(QnnModelWrapper& qnn_model_wrapper,
4141
ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(input, input_info));
4242
assert(input_info.is_initializer); // Checked by ExplicitOpCheck().
4343
ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*input_info.initializer_tensor, val_bytes));
44-
switch (input_info.qnn_data_type) {
45-
case QNN_DATATYPE_INT_8: {
46-
float_value = static_cast<float>(*reinterpret_cast<int8_t*>(val_bytes.data()));
47-
break;
48-
}
49-
case QNN_DATATYPE_INT_16: {
50-
float_value = static_cast<float>(*reinterpret_cast<int16_t*>(val_bytes.data()));
51-
break;
52-
}
53-
case QNN_DATATYPE_INT_32: {
54-
float_value = static_cast<float>(*reinterpret_cast<int32_t*>(val_bytes.data()));
55-
break;
56-
}
57-
case QNN_DATATYPE_INT_64: {
58-
float_value = static_cast<float>(*reinterpret_cast<int64_t*>(val_bytes.data()));
59-
break;
60-
}
61-
case QNN_DATATYPE_UINT_8: {
62-
float_value = static_cast<float>(*val_bytes.data());
63-
break;
64-
}
65-
case QNN_DATATYPE_UINT_16: {
66-
float_value = static_cast<float>(*reinterpret_cast<uint16_t*>(val_bytes.data()));
67-
break;
68-
}
69-
case QNN_DATATYPE_UINT_32: {
70-
float_value = static_cast<float>(*reinterpret_cast<uint32_t*>(val_bytes.data()));
71-
break;
72-
}
73-
case QNN_DATATYPE_UINT_64: {
74-
float_value = static_cast<float>(*reinterpret_cast<uint64_t*>(val_bytes.data()));
75-
break;
76-
}
77-
case QNN_DATATYPE_FLOAT_16: {
78-
MLFloat16 fp16_value = *reinterpret_cast<const MLFloat16*>(val_bytes.data());
79-
float_value = fp16_value.ToFloat();
80-
break;
44+
45+
// If the input is quantized, we need to dequantize it
46+
if (input.quant_param.has_value()) {
47+
ORT_RETURN_IF_NOT(input_info.quant_param.IsPerTensor(),
48+
"Clip's min/max must use per-tensor quantization");
49+
const Qnn_QuantizeParams_t& quant_param = input_info.quant_param.Get();
50+
51+
switch (input_info.qnn_data_type) {
52+
case QNN_DATATYPE_SFIXED_POINT_8: {
53+
int8_t quantized_value = *reinterpret_cast<int8_t*>(val_bytes.data());
54+
float_value = static_cast<float>(utils::Dequantize(quant_param.scaleOffsetEncoding.offset,
55+
quant_param.scaleOffsetEncoding.scale,
56+
static_cast<double>(quantized_value)));
57+
break;
58+
}
59+
case QNN_DATATYPE_SFIXED_POINT_16: {
60+
int16_t quantized_value = *reinterpret_cast<int16_t*>(val_bytes.data());
61+
float_value = static_cast<float>(utils::Dequantize(quant_param.scaleOffsetEncoding.offset,
62+
quant_param.scaleOffsetEncoding.scale,
63+
static_cast<double>(quantized_value)));
64+
break;
65+
}
66+
case QNN_DATATYPE_SFIXED_POINT_32: {
67+
int32_t quantized_value = *reinterpret_cast<int32_t*>(val_bytes.data());
68+
float_value = static_cast<float>(utils::Dequantize(quant_param.scaleOffsetEncoding.offset,
69+
quant_param.scaleOffsetEncoding.scale,
70+
static_cast<double>(quantized_value)));
71+
break;
72+
}
73+
case QNN_DATATYPE_UFIXED_POINT_8: {
74+
uint8_t quantized_value = *val_bytes.data();
75+
float_value = static_cast<float>(utils::Dequantize(quant_param.scaleOffsetEncoding.offset,
76+
quant_param.scaleOffsetEncoding.scale,
77+
static_cast<double>(quantized_value)));
78+
break;
79+
}
80+
case QNN_DATATYPE_UFIXED_POINT_16: {
81+
uint16_t quantized_value = *reinterpret_cast<uint16_t*>(val_bytes.data());
82+
float_value = static_cast<float>(utils::Dequantize(quant_param.scaleOffsetEncoding.offset,
83+
quant_param.scaleOffsetEncoding.scale,
84+
static_cast<double>(quantized_value)));
85+
break;
86+
}
87+
case QNN_DATATYPE_UFIXED_POINT_32: {
88+
uint32_t quantized_value = *reinterpret_cast<uint32_t*>(val_bytes.data());
89+
float_value = static_cast<float>(utils::Dequantize(quant_param.scaleOffsetEncoding.offset,
90+
quant_param.scaleOffsetEncoding.scale,
91+
static_cast<double>(quantized_value)));
92+
break;
93+
}
94+
default:
95+
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Quantized min/max input data type not supported.");
8196
}
82-
case QNN_DATATYPE_FLOAT_32: {
83-
float_value = *reinterpret_cast<const float*>(val_bytes.data());
84-
break;
97+
} else {
98+
// Non-quantized input, just cast to float
99+
switch (input_info.qnn_data_type) {
100+
case QNN_DATATYPE_INT_8: {
101+
float_value = static_cast<float>(*reinterpret_cast<int8_t*>(val_bytes.data()));
102+
break;
103+
}
104+
case QNN_DATATYPE_INT_16: {
105+
float_value = static_cast<float>(*reinterpret_cast<int16_t*>(val_bytes.data()));
106+
break;
107+
}
108+
case QNN_DATATYPE_INT_32: {
109+
float_value = static_cast<float>(*reinterpret_cast<int32_t*>(val_bytes.data()));
110+
break;
111+
}
112+
case QNN_DATATYPE_INT_64: {
113+
float_value = static_cast<float>(*reinterpret_cast<int64_t*>(val_bytes.data()));
114+
break;
115+
}
116+
case QNN_DATATYPE_UINT_8: {
117+
float_value = static_cast<float>(*val_bytes.data());
118+
break;
119+
}
120+
case QNN_DATATYPE_UINT_16: {
121+
float_value = static_cast<float>(*reinterpret_cast<uint16_t*>(val_bytes.data()));
122+
break;
123+
}
124+
case QNN_DATATYPE_UINT_32: {
125+
float_value = static_cast<float>(*reinterpret_cast<uint32_t*>(val_bytes.data()));
126+
break;
127+
}
128+
case QNN_DATATYPE_UINT_64: {
129+
float_value = static_cast<float>(*reinterpret_cast<uint64_t*>(val_bytes.data()));
130+
break;
131+
}
132+
case QNN_DATATYPE_FLOAT_16: {
133+
MLFloat16 fp16_value = *reinterpret_cast<const MLFloat16*>(val_bytes.data());
134+
float_value = fp16_value.ToFloat();
135+
break;
136+
}
137+
case QNN_DATATYPE_FLOAT_32: {
138+
float_value = *reinterpret_cast<const float*>(val_bytes.data());
139+
break;
140+
}
141+
default:
142+
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Non-quantized min/max input data type not supported.");
85143
}
86-
default:
87-
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "min/max input data type not supported.");
88144
}
89145

90146
return Status::OK();
91147
}
92148

93-
Status ClipOpBuilder::ExplictOpCheck(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit) const {
149+
Status ClipOpBuilder::ExplicitOpCheck(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit) const {
94150
if (node_unit.Inputs().size() > 1) {
95151
const auto& min_input_name = node_unit.Inputs()[1].node_arg.Name();
96152
if (!min_input_name.empty() && !qnn_model_wrapper.IsConstantInput(min_input_name)) {
@@ -112,7 +168,7 @@ Status ClipOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
112168
std::vector<std::string>& input_names,
113169
bool do_op_validation) const {
114170
if (do_op_validation) {
115-
ORT_RETURN_IF_ERROR(ExplictOpCheck(qnn_model_wrapper, node_unit));
171+
ORT_RETURN_IF_ERROR(ExplicitOpCheck(qnn_model_wrapper, node_unit));
116172
}
117173

118174
return ProcessInput(qnn_model_wrapper, node_unit.Inputs()[0], logger, input_names);

0 commit comments

Comments
 (0)