Skip to content

Commit daf9565

Browse files
authored
Quant tool: Consistent get_qdq_config and get_qnn_qdq_config behavior (#23856)
1 parent 0a6b05f commit daf9565

File tree

3 files changed

+81
-13
lines changed

3 files changed

+81
-13
lines changed

onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -204,9 +204,9 @@ def get_qnn_qdq_config(
204204
calibrate_method=calibrate_method,
205205
activation_type=activation_type,
206206
weight_type=weight_type,
207-
op_types_to_quantize=op_types_to_quantize
208-
if op_types_to_quantize
209-
else list(op_types.difference(OP_TYPES_TO_EXCLUDE)),
207+
op_types_to_quantize=(
208+
op_types_to_quantize if op_types_to_quantize else list(op_types.difference(OP_TYPES_TO_EXCLUDE))
209+
),
210210
nodes_to_exclude=nodes_to_exclude,
211211
per_channel=per_channel,
212212
use_external_data_format=(model_has_external_data or model.ByteSize() >= MODEL_SIZE_THRESHOLD),

onnxruntime/python/tools/quantization/quantize.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,8 @@ def get_qdq_config(
240240
keep_removable_activations: bool = False,
241241
min_real_range: float | None = None,
242242
tensor_quant_overrides: dict[str, list[dict[str, Any]]] | None = None,
243+
calibration_providers: list[str] | None = None,
244+
op_types_to_quantize: list[str] | None = None,
243245
nodes_to_exclude: list[str] | Callable[[onnx.ModelProto, onnx.NodeProto], bool] | None = None,
244246
extra_options: dict | None = None,
245247
) -> StaticQuantConfig:
@@ -294,6 +296,10 @@ def get_qdq_config(
294296
'convert["recv_nodes"] = Set : Set of node names that consume the converted activation,
295297
other nodes get the original type. If not specified,
296298
assume all consumer nodes get the converted type.
299+
calibration_providers: Execution providers to run the session during calibration. Default is None which uses
300+
[ "CPUExecutionProvider" ].
301+
op_types_to_quantize: List of operator types to quantize. If None, all operators other than Cast, DequantizeLinear,
302+
and QuantizeLinear are quantized.
297303
nodes_to_exclude: List of nodes names to exclude from quantization. Alternatively, can provide a function that
298304
accepts an onnx.ModelProto and onnx.NodeProto as arguments and returns true if the give onnx.NodeProto
299305
should be excluded from quantization.
@@ -324,17 +330,20 @@ def get_qdq_config(
324330
if onnx.external_data_helper.uses_external_data(initializer):
325331
model_has_external_data = True
326332

327-
final_nodes_to_exclude = []
328-
if nodes_to_exclude is not None and isinstance(nodes_to_exclude, list):
329-
final_nodes_to_exclude.extend(nodes_to_exclude)
333+
op_types_to_quantize_set = set(op_types_to_quantize) if op_types_to_quantize else None
334+
nodes_to_exclude_set = set(nodes_to_exclude) if isinstance(nodes_to_exclude, list) else set()
330335

331336
# Iterate through nodes to get all operator types in the model and
332337
# call user's function to filter out nodes from quantization.
333338
for node in model.graph.node:
334-
op_types.add(node.op_type)
335-
if nodes_to_exclude is not None and callable(nodes_to_exclude):
336-
if nodes_to_exclude(model, node):
337-
final_nodes_to_exclude.append(node.name)
339+
if op_types_to_quantize_set and node.op_type not in op_types_to_quantize_set:
340+
continue
341+
if node.name in nodes_to_exclude_set:
342+
continue
343+
if callable(nodes_to_exclude) and nodes_to_exclude(model, node):
344+
nodes_to_exclude_set.add(node.name)
345+
else:
346+
op_types.add(node.op_type)
338347

339348
final_extra_options = {
340349
"MinimumRealRange": min_real_range,
@@ -378,11 +387,14 @@ def get_qdq_config(
378387
quant_format=QuantFormat.QDQ,
379388
activation_type=activation_type,
380389
weight_type=weight_type,
381-
op_types_to_quantize=list(op_types.difference(op_types_to_exclude)),
382-
nodes_to_exclude=final_nodes_to_exclude,
390+
op_types_to_quantize=(
391+
op_types_to_quantize if op_types_to_quantize else list(op_types.difference(op_types_to_exclude))
392+
),
393+
nodes_to_exclude=list(nodes_to_exclude_set),
383394
per_channel=per_channel,
384395
reduce_range=reduce_range,
385396
use_external_data_format=(model_has_external_data or model.ByteSize() >= MODEL_SIZE_THRESHOLD),
397+
calibration_providers=calibration_providers,
386398
extra_options=final_extra_options,
387399
)
388400

@@ -442,7 +454,7 @@ def check_static_quant_arguments(quant_format: QuantFormat, activation_type: Qua
442454
if activation_type != QuantType.QFLOAT8E4M3FN and weight_type == QuantType.QFLOAT8E4M3FN:
443455
raise ValueError(
444456
f"ONNXRuntime quantization doesn't support data format: activation_type={activation_type} "
445-
f"!=QuantType.QFLOAT8E4M3FN, weight_type=QuantType.QFLOAT8E4M3FN."
457+
"!=QuantType.QFLOAT8E4M3FN, weight_type=QuantType.QFLOAT8E4M3FN."
446458
)
447459

448460
if activation_type == QuantType.QFLOAT8E4M3FN and weight_type != QuantType.QFLOAT8E4M3FN:

onnxruntime/test/python/quantization/test_get_qdq_config.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,62 @@ def should_exclude_node_(model: onnx.ModelProto, node: onnx.NodeProto) -> bool:
156156
self.assertTrue(bool(expected_excluded_nodes))
157157
self.assertEqual(set(qdq_config.nodes_to_exclude), expected_excluded_nodes)
158158

159+
def test_op_types_to_quantize(self):
160+
"""
161+
Test that get_qdq_config() returns a config that sets the op_types_to_quantize arg.
162+
"""
163+
shape = [1, 8, 8]
164+
tensor_type = onnx.TensorProto.FLOAT
165+
np_dtype = onnx.helper.tensor_dtype_to_np_dtype(tensor_type)
166+
weight = onnx.numpy_helper.from_array(np.ones(shape, dtype=np_dtype), "weight")
167+
float_model = self.build_add_model(shape, tensor_type, weight)
168+
169+
input_data_list = [
170+
{"input_0": np.ones(shape, dtype=np_dtype) * np.array(-2, dtype=np_dtype)},
171+
{"input_0": np.ones(shape, dtype=np_dtype) * np.array(2, dtype=np_dtype)},
172+
]
173+
data_reader = TestDataFeeds(input_data_list)
174+
175+
# No op_types_to_quantize arg means all ops are quantized.
176+
qdq_config = get_qdq_config(float_model, data_reader, op_types_to_quantize=None)
177+
self.assertEqual(set(qdq_config.op_types_to_quantize), {"Add"})
178+
179+
# specify custom op_types_to_quantize arg.
180+
qdq_config = get_qdq_config(float_model, data_reader, op_types_to_quantize=["Mul"])
181+
self.assertEqual(set(qdq_config.op_types_to_quantize), {"Mul"})
182+
183+
# exclude op_type indirectly by specifying nodes_to_exclude arg.
184+
qdq_config = get_qdq_config(
185+
float_model,
186+
data_reader,
187+
nodes_to_exclude=[node.name for node in float_model.graph.node if node.op_type == "Add"],
188+
)
189+
self.assertEqual(set(qdq_config.op_types_to_quantize), set())
190+
191+
def test_calibration_providers(self):
192+
"""
193+
Test that get_qdq_config() returns a config that sets the calibration providers arg.
194+
"""
195+
196+
shape = [1, 8, 8]
197+
tensor_type = onnx.TensorProto.FLOAT
198+
np_dtype = onnx.helper.tensor_dtype_to_np_dtype(tensor_type)
199+
weight = onnx.numpy_helper.from_array(np.ones(shape, dtype=np_dtype), "weight")
200+
float_model = self.build_add_model(shape, tensor_type, weight)
201+
202+
input_data_list = [
203+
{"input_0": np.ones(shape, dtype=np_dtype) * np.array(-2, dtype=np_dtype)},
204+
{"input_0": np.ones(shape, dtype=np_dtype) * np.array(2, dtype=np_dtype)},
205+
]
206+
data_reader = TestDataFeeds(input_data_list)
207+
208+
qdq_config = get_qdq_config(
209+
float_model,
210+
data_reader,
211+
calibration_providers=["CPUExecutionProvider"],
212+
)
213+
self.assertEqual(qdq_config.calibration_providers, ["CPUExecutionProvider"])
214+
159215
def test_external_data(self):
160216
"""
161217
Test that get_qdq_config() returns a config that enables external data

0 commit comments

Comments
 (0)