@@ -240,6 +240,8 @@ def get_qdq_config(
240240 keep_removable_activations : bool = False ,
241241 min_real_range : float | None = None ,
242242 tensor_quant_overrides : dict [str , list [dict [str , Any ]]] | None = None ,
243+ calibration_providers : list [str ] | None = None ,
244+ op_types_to_quantize : list [str ] | None = None ,
243245 nodes_to_exclude : list [str ] | Callable [[onnx .ModelProto , onnx .NodeProto ], bool ] | None = None ,
244246 extra_options : dict | None = None ,
245247) -> StaticQuantConfig :
@@ -294,6 +296,10 @@ def get_qdq_config(
294296 'convert["recv_nodes"] = Set : Set of node names that consume the converted activation,
295297 other nodes get the original type. If not specified,
296298 assume all consumer nodes get the converted type.
299+ calibration_providers: Execution providers to run the session during calibration. Default is None which uses
300+ [ "CPUExecutionProvider" ].
301+ op_types_to_quantize: List of operator types to quantize. If None, all operators other than Cast, DequantizeLinear,
302+ and QuantizeLinear are quantized.
297303 nodes_to_exclude: List of nodes names to exclude from quantization. Alternatively, can provide a function that
298304 accepts an onnx.ModelProto and onnx.NodeProto as arguments and returns true if the give onnx.NodeProto
299305 should be excluded from quantization.
@@ -324,17 +330,20 @@ def get_qdq_config(
324330 if onnx .external_data_helper .uses_external_data (initializer ):
325331 model_has_external_data = True
326332
327- final_nodes_to_exclude = []
328- if nodes_to_exclude is not None and isinstance (nodes_to_exclude , list ):
329- final_nodes_to_exclude .extend (nodes_to_exclude )
333+ op_types_to_quantize_set = set (op_types_to_quantize ) if op_types_to_quantize else None
334+ nodes_to_exclude_set = set (nodes_to_exclude ) if isinstance (nodes_to_exclude , list ) else set ()
330335
331336 # Iterate through nodes to get all operator types in the model and
332337 # call user's function to filter out nodes from quantization.
333338 for node in model .graph .node :
334- op_types .add (node .op_type )
335- if nodes_to_exclude is not None and callable (nodes_to_exclude ):
336- if nodes_to_exclude (model , node ):
337- final_nodes_to_exclude .append (node .name )
339+ if op_types_to_quantize_set and node .op_type not in op_types_to_quantize_set :
340+ continue
341+ if node .name in nodes_to_exclude_set :
342+ continue
343+ if callable (nodes_to_exclude ) and nodes_to_exclude (model , node ):
344+ nodes_to_exclude_set .add (node .name )
345+ else :
346+ op_types .add (node .op_type )
338347
339348 final_extra_options = {
340349 "MinimumRealRange" : min_real_range ,
@@ -378,11 +387,14 @@ def get_qdq_config(
378387 quant_format = QuantFormat .QDQ ,
379388 activation_type = activation_type ,
380389 weight_type = weight_type ,
381- op_types_to_quantize = list (op_types .difference (op_types_to_exclude )),
382- nodes_to_exclude = final_nodes_to_exclude ,
390+ op_types_to_quantize = (
391+ op_types_to_quantize if op_types_to_quantize else list (op_types .difference (op_types_to_exclude ))
392+ ),
393+ nodes_to_exclude = list (nodes_to_exclude_set ),
383394 per_channel = per_channel ,
384395 reduce_range = reduce_range ,
385396 use_external_data_format = (model_has_external_data or model .ByteSize () >= MODEL_SIZE_THRESHOLD ),
397+ calibration_providers = calibration_providers ,
386398 extra_options = final_extra_options ,
387399 )
388400
@@ -442,7 +454,7 @@ def check_static_quant_arguments(quant_format: QuantFormat, activation_type: Qua
442454 if activation_type != QuantType .QFLOAT8E4M3FN and weight_type == QuantType .QFLOAT8E4M3FN :
443455 raise ValueError (
444456 f"ONNXRuntime quantization doesn't support data format: activation_type={ activation_type } "
445- f "!=QuantType.QFLOAT8E4M3FN, weight_type=QuantType.QFLOAT8E4M3FN."
457+ "!=QuantType.QFLOAT8E4M3FN, weight_type=QuantType.QFLOAT8E4M3FN."
446458 )
447459
448460 if activation_type == QuantType .QFLOAT8E4M3FN and weight_type != QuantType .QFLOAT8E4M3FN :
0 commit comments