@@ -338,7 +338,23 @@ def convert_onnx_dtype_to_nncf_dtype(onnx_dtype: int) -> Dtype:
338338 return Dtype .FLOAT if onnx_dtype == int (onnx .TensorProto .FLOAT ) else Dtype .INTEGER
339339
340340 @staticmethod
341- def create_nncf_graph (onnx_model : onnx .ModelProto ) -> NNCFGraph :
341+ def preprocess_model (model : onnx .ModelProto ) -> onnx .ModelProto :
342+ """
343+ Applies the following transformations to the input model:
344+ - Replace empty node names
345+ - Infer shapes
346+ - Eliminate nop casts
347+
348+ :param model: Input model.
349+ :return: Preprocessed model.
350+ """
351+ preprocessed_model = GraphConverter ._replace_empty_node_name (model )
352+ preprocessed_model = onnx .shape_inference .infer_shapes (preprocessed_model )
353+ preprocessed_model = onnxoptimizer .optimize (preprocessed_model , ["eliminate_nop_cast" ])
354+ return preprocessed_model
355+
356+ @staticmethod
357+ def create_nncf_graph (onnx_model : onnx .ModelProto , preprocess_model : bool = True ) -> NNCFGraph :
342358 """
343359 Creates NNCFGraph from 'onnx_model'.
344360 Initially, ONNXGraph is built. All nodes from onnx_model which have valid metatype are added to NNCFGraph.
@@ -347,9 +363,9 @@ def create_nncf_graph(onnx_model: onnx.ModelProto) -> NNCFGraph:
347363 :param onnx_model: ONNX model.
348364 :return: NNCFGraph.
349365 """
350- onnx_model = GraphConverter . _replace_empty_node_name ( onnx_model )
351- onnx_model = onnx . shape_inference . infer_shapes (onnx_model )
352- onnx_model = onnxoptimizer . optimize ( onnx_model , [ "eliminate_nop_cast" ])
366+ if preprocess_model :
367+ onnx_model = GraphConverter . preprocess_model (onnx_model )
368+
353369 edge_info_mapping = get_edge_info_mapping (onnx_model )
354370 children_node_mapping = get_children_node_mapping (onnx_model )
355371 parents_node_mapping = get_parents_node_mapping (onnx_model )
0 commit comments