Skip to content

Commit c1564f0

Browse files
minor
1 parent ac7b3c2 commit c1564f0

File tree

4 files changed

+80
-92
lines changed

4 files changed

+80
-92
lines changed

nncf/onnx/graph/model_utils.py

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919
from nncf.common.graph.transformations.layout import TransformationLayout
2020
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXDequantizeLinearMetatype
2121
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXQuantizeLinearMetatype
22-
from nncf.onnx.graph.onnx_helper import get_children
23-
from nncf.onnx.graph.onnx_helper import get_children_node_mapping
2422
from nncf.onnx.graph.transformations.commands import ONNXQDQNodeRemovingCommand
2523
from nncf.onnx.graph.transformations.commands import ONNXTargetPoint
2624

@@ -54,49 +52,3 @@ def remove_fq_from_inputs(model: onnx.ModelProto, nncf_graph: NNCFGraph) -> onnx
5452
nodes_queue.extend(nncf_graph.get_next_nodes(current_node))
5553

5654
return model_transformer.transform(transformation_layout)
57-
58-
59-
def eliminate_nop_cast(model: onnx.ModelProto) -> onnx.ModelProto:
60-
"""
61-
Inspects the provided ONNX model to identify and remove any 'No-op' (no-operation)
62-
cast nodes, which are operations that do not change the data type of their input.
63-
64-
:param model: The ONNX model to be processed.
65-
:return: The ONNX model with the redundant cast nodes removed.
66-
"""
67-
tensor_name_to_info = {
68-
tensor.name: tensor
69-
for tensor in (*model.graph.value_info, *model.graph.input, *model.graph.output, *model.graph.initializer)
70-
}
71-
redundant_cast_nodes = []
72-
for node in model.graph.node:
73-
if node.op_type == "Cast":
74-
to_attr = None
75-
for attr in node.attribute:
76-
if attr.name == "to":
77-
to_attr = onnx.helper.get_attribute_value(attr)
78-
79-
if to_attr is None:
80-
continue
81-
82-
inp = node.input[0]
83-
info = tensor_name_to_info[inp]
84-
if info.type.tensor_type.elem_type == to_attr:
85-
redundant_cast_nodes.append(node)
86-
87-
value_infos = {i.name: i for i in model.graph.value_info}
88-
input_name_to_nodes_map = get_children_node_mapping(model)
89-
90-
for cast_node in redundant_cast_nodes:
91-
# Unlink Cast node from the graph
92-
children = get_children(cast_node, input_name_to_nodes_map)
93-
for child in children:
94-
for i, input_name in enumerate(child.input):
95-
if input_name == cast_node.output[0]:
96-
child.input[i] = cast_node.input[0]
97-
98-
# Remove Cast node from the graph
99-
model.graph.value_info.remove(value_infos[cast_node.output[0]])
100-
model.graph.node.remove(cast_node)
101-
102-
return model

nncf/onnx/graph/nncf_graph_builder.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ def convert_onnx_dtype_to_nncf_dtype(onnx_dtype: int) -> Dtype:
337337
return Dtype.FLOAT if onnx_dtype == int(onnx.TensorProto.FLOAT) else Dtype.INTEGER
338338

339339
@staticmethod
340-
def create_nncf_graph(onnx_model: onnx.ModelProto, infer_shapes: bool = True) -> NNCFGraph:
340+
def create_nncf_graph(onnx_model: onnx.ModelProto) -> NNCFGraph:
341341
"""
342342
Creates NNCFGraph from 'onnx_model'.
343343
Initially, ONNXGraph is built. All nodes from onnx_model which have valid metatype are added to NNCFGraph.
@@ -347,8 +347,7 @@ def create_nncf_graph(onnx_model: onnx.ModelProto, infer_shapes: bool = True) ->
347347
:return: NNCFGraph.
348348
"""
349349
onnx_model = GraphConverter._replace_empty_node_name(onnx_model)
350-
if infer_shapes:
351-
onnx_model = onnx.shape_inference.infer_shapes(onnx_model)
350+
onnx_model = onnx.shape_inference.infer_shapes(onnx_model)
352351
edge_info_mapping = get_edge_info_mapping(onnx_model)
353352
children_node_mapping = get_children_node_mapping(onnx_model)
354353
parents_node_mapping = get_parents_node_mapping(onnx_model)

nncf/onnx/graph/passes.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright (c) 2025 Intel Corporation
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import onnx
13+
14+
from nncf.onnx.graph.onnx_helper import get_children
15+
from nncf.onnx.graph.onnx_helper import get_children_node_mapping
16+
17+
18+
def eliminate_nop_cast(model: onnx.ModelProto) -> onnx.ModelProto:
19+
"""
20+
Inspects the provided ONNX model to identify and remove any 'No-op' (no-operation)
21+
cast nodes, which are operations that do not change the data type of their input.
22+
23+
:param model: The ONNX model to be processed.
24+
:return: The ONNX model with the redundant cast nodes removed.
25+
"""
26+
tensor_name_to_info = {
27+
tensor.name: tensor
28+
for tensor in (*model.graph.value_info, *model.graph.input, *model.graph.output, *model.graph.initializer)
29+
}
30+
redundant_cast_nodes = []
31+
for node in model.graph.node:
32+
if node.op_type == "Cast":
33+
to_attr = None
34+
for attr in node.attribute:
35+
if attr.name == "to":
36+
to_attr = onnx.helper.get_attribute_value(attr)
37+
38+
if to_attr is None:
39+
continue
40+
41+
inp = node.input[0]
42+
info = tensor_name_to_info[inp]
43+
if info.type.tensor_type.elem_type == to_attr:
44+
redundant_cast_nodes.append(node)
45+
46+
value_infos = {i.name: i for i in model.graph.value_info}
47+
input_name_to_nodes_map = get_children_node_mapping(model)
48+
49+
for cast_node in redundant_cast_nodes:
50+
# Unlink Cast node from the graph
51+
children = get_children(cast_node, input_name_to_nodes_map)
52+
for child in children:
53+
for i, input_name in enumerate(child.input):
54+
if input_name == cast_node.output[0]:
55+
child.input[i] = cast_node.input[0]
56+
57+
# Remove Cast node from the graph
58+
model.graph.value_info.remove(value_infos[cast_node.output[0]])
59+
model.graph.node.remove(cast_node)
60+
61+
return model
62+
63+
64+
def apply_preprocess_passes(model: onnx.ModelProto) -> None:
65+
"""
66+
Preprocesses the provided ONNX model for quantization.
67+
68+
This method performs the following steps:
69+
1. Infers shapes in the model.
70+
2. Removes redundant 'No-op' cast nodes from the model.
71+
72+
:param model: The ONNX model to be preprocessed.
73+
:return: A preprocessed ONNX model, ready for quantization.
74+
"""
75+
preprocessed_model = onnx.shape_inference.infer_shapes(model)
76+
preprocessed_model = eliminate_nop_cast(preprocessed_model)
77+
return preprocessed_model

nncf/onnx/quantization/quantize_model.py

Lines changed: 1 addition & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -116,45 +116,6 @@ def check_external_data_location(model: onnx.ModelProto, external_data_dir: Opti
116116

117117
# If len(data_path) == 0, it means there are no tensors that use external data.
118118
return str(external_data_dir) if data_paths else None
119-
def quantize_pre_process(model: onnx.ModelProto, save_as_external_data: bool = True):
120-
"""
121-
Preprocesses the provided ONNX model for quantization.
122-
123-
This method performs the following steps:
124-
1. Infers shapes in the model.
125-
2. Removes redundant 'No-op' cast nodes from the model.
126-
127-
:param model: The ONNX model to be preprocessed.
128-
:param save_as_external_data: A boolean flag indicating whether to
129-
save the model with external data. If `True`, external data is
130-
saved separately; otherwise, the model is saved as a single file.
131-
:return: A preprocessed ONNX model, ready for quantization.
132-
"""
133-
with tempfile.TemporaryDirectory(dir=tempfile.gettempdir()) as temp_dir:
134-
temp_path = Path(temp_dir)
135-
input_model_path = str(temp_path / "input_model.onnx")
136-
137-
if save_as_external_data:
138-
onnx.save_model(
139-
model,
140-
input_model_path,
141-
save_as_external_data=True,
142-
all_tensors_to_one_file=True,
143-
location="model.data",
144-
size_threshold=1024,
145-
convert_attribute=False,
146-
)
147-
else:
148-
onnx.save(model, input_model_path)
149-
model = None
150-
151-
shape_inferred_model_path = str(temp_path / "shape_inferred_model.onnx")
152-
onnx.shape_inference.infer_shapes_path(input_model_path, shape_inferred_model_path)
153-
154-
preprocessed_model = onnx.load(shape_inferred_model_path)
155-
preprocessed_model = eliminate_nop_cast(preprocessed_model)
156-
157-
return preprocessed_model
158119

159120

160121
def quantize_impl(
@@ -207,8 +168,7 @@ def quantize_impl(
207168
advanced_parameters=advanced_parameters,
208169
)
209170

210-
model = quantize_pre_process(model)
211-
graph = GraphConverter.create_nncf_graph(model, infer_shapes=False)
171+
graph = GraphConverter.create_nncf_graph(model)
212172
warning_model_no_batchwise_support(graph, advanced_parameters, model_type, OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS)
213173
quantized_model = quantization_algorithm.apply(model, graph, dataset=calibration_dataset)
214174

0 commit comments

Comments
 (0)