Skip to content

Commit fa04951

Browse files
minor fixes
1 parent c673688 commit fa04951

File tree

2 files changed

+22
-5
lines changed

2 files changed

+22
-5
lines changed

nncf/onnx/graph/nncf_graph_builder.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,23 @@ def convert_onnx_dtype_to_nncf_dtype(onnx_dtype: int) -> Dtype:
338338
return Dtype.FLOAT if onnx_dtype == int(onnx.TensorProto.FLOAT) else Dtype.INTEGER
339339

340340
@staticmethod
341-
def create_nncf_graph(onnx_model: onnx.ModelProto) -> NNCFGraph:
341+
def preprocess_model(model: onnx.ModelProto) -> onnx.ModelProto:
342+
"""
343+
Applies the following transformations to the input model:
344+
- Replace empty node names
345+
- Infer shapes
346+
- Eliminate nop casts
347+
348+
:param model: Input model.
349+
:return: Preprocessed model.
350+
"""
351+
preprocessed_model = GraphConverter._replace_empty_node_name(model)
352+
preprocessed_model = onnx.shape_inference.infer_shapes(preprocessed_model)
353+
preprocessed_model = onnxoptimizer.optimize(preprocessed_model, ["eliminate_nop_cast"])
354+
return preprocessed_model
355+
356+
@staticmethod
357+
def create_nncf_graph(onnx_model: onnx.ModelProto, preprocess_model: bool = True) -> NNCFGraph:
342358
"""
343359
Creates NNCFGraph from 'onnx_model'.
344360
Initially, ONNXGraph is built. All nodes from onnx_model which have valid metatype are added to NNCFGraph.
@@ -347,9 +363,9 @@ def create_nncf_graph(onnx_model: onnx.ModelProto) -> NNCFGraph:
347363
:param onnx_model: ONNX model.
348364
:return: NNCFGraph.
349365
"""
350-
onnx_model = GraphConverter._replace_empty_node_name(onnx_model)
351-
onnx_model = onnx.shape_inference.infer_shapes(onnx_model)
352-
onnx_model = onnxoptimizer.optimize(onnx_model, ["eliminate_nop_cast"])
366+
if preprocess_model:
367+
onnx_model = GraphConverter.preprocess_model(onnx_model)
368+
353369
edge_info_mapping = get_edge_info_mapping(onnx_model)
354370
children_node_mapping = get_children_node_mapping(onnx_model)
355371
parents_node_mapping = get_parents_node_mapping(onnx_model)

nncf/onnx/quantization/quantize_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ def quantize_impl(
8181
advanced_parameters=advanced_parameters,
8282
)
8383

84-
graph = GraphConverter.create_nncf_graph(model)
84+
model = GraphConverter.preprocess_model(model)
85+
graph = GraphConverter.create_nncf_graph(model, preprocess_model=False)
8586
warning_model_no_batchwise_support(graph, advanced_parameters, model_type, OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS)
8687
quantized_model = quantization_algorithm.apply(model, graph, dataset=calibration_dataset)
8788

0 commit comments

Comments
 (0)