Skip to content

Commit c72f787

Browse files
[ONNX] Add eliminate_nop_cast pass (#3376)
### Changes Apply the `eliminate_nop_cast` optimization pass before building the NNCF graph. This pass aims to eliminate no-op cast nodes in the graph. A no-op cast is a cast that doesn't change the value, such as casting a tensor to its own type. ### Reason for changes Statistics cannot be collected after the no-op Cast nodes because such nodes are removed from the ONNX inference graph during the session. ### Related tickets Ref: 164211
1 parent 549d3ed commit c72f787

File tree

4 files changed

+123
-0
lines changed

4 files changed

+123
-0
lines changed

nncf/onnx/graph/passes.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright (c) 2025 Intel Corporation
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import onnx
13+
14+
from nncf.onnx.graph.onnx_helper import get_children
15+
from nncf.onnx.graph.onnx_helper import get_children_node_mapping
16+
17+
18+
def eliminate_nop_cast(model: onnx.ModelProto) -> onnx.ModelProto:
19+
"""
20+
Inspects the provided ONNX model to identify and remove any 'No-op' (no-operation)
21+
cast nodes, which are operations that do not change the data type of their input.
22+
23+
:param model: The ONNX model to be processed.
24+
:return: The ONNX model with the redundant cast nodes removed.
25+
"""
26+
tensor_name_to_info = {
27+
tensor.name: tensor
28+
for tensor in (*model.graph.value_info, *model.graph.input, *model.graph.output, *model.graph.initializer)
29+
}
30+
redundant_cast_nodes = []
31+
for node in model.graph.node:
32+
if node.op_type == "Cast":
33+
to_attr = None
34+
for attr in node.attribute:
35+
if attr.name == "to":
36+
to_attr = onnx.helper.get_attribute_value(attr)
37+
38+
if to_attr is None:
39+
continue
40+
41+
inp = node.input[0]
42+
info = tensor_name_to_info[inp]
43+
if info.type.tensor_type.elem_type == to_attr:
44+
redundant_cast_nodes.append(node)
45+
46+
value_infos = {i.name: i for i in model.graph.value_info}
47+
input_name_to_nodes_map = get_children_node_mapping(model)
48+
49+
for cast_node in redundant_cast_nodes:
50+
# Unlink Cast node from the graph
51+
children = get_children(cast_node, input_name_to_nodes_map)
52+
for child in children:
53+
for i, input_name in enumerate(child.input):
54+
if input_name == cast_node.output[0]:
55+
child.input[i] = cast_node.input[0]
56+
57+
# Remove Cast node from the graph
58+
model.graph.value_info.remove(value_infos[cast_node.output[0]])
59+
model.graph.node.remove(cast_node)
60+
61+
return model
62+
63+
64+
def apply_preprocess_passes(model: onnx.ModelProto) -> onnx.ModelProto:
65+
"""
66+
Preprocesses the provided ONNX model for quantization.
67+
68+
This method performs the following steps:
69+
1. Infers shapes in the model.
70+
2. Removes redundant 'No-op' cast nodes from the model.
71+
72+
:param model: The ONNX model to be preprocessed.
73+
:return: A preprocessed ONNX model, ready for quantization.
74+
"""
75+
preprocessed_model = onnx.shape_inference.infer_shapes(model)
76+
# The `eliminate_nop_cast` pass should be applied after onnx.shape_inference.infer_shapes() call.
77+
# Otherwise, not all no-op Cast nodes will be found.
78+
preprocessed_model = eliminate_nop_cast(preprocessed_model)
79+
return preprocessed_model

nncf/onnx/quantization/quantize_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from nncf.onnx.graph.model_metadata import remove_metadata
3030
from nncf.onnx.graph.model_metadata import set_metadata
3131
from nncf.onnx.graph.nncf_graph_builder import GraphConverter
32+
from nncf.onnx.graph.passes import apply_preprocess_passes
3233
from nncf.onnx.quantization.backend_parameters import get_external_data_dir
3334
from nncf.parameters import BackupMode
3435
from nncf.parameters import CompressionFormat
@@ -155,6 +156,7 @@ def quantize_impl(
155156
external_data_dir = check_external_data_location(model, external_data_dir)
156157
if external_data_dir:
157158
set_metadata(model, MetadataKey.EXTERNAL_DATA_DIR, external_data_dir)
159+
model = apply_preprocess_passes(model)
158160

159161
quantization_algorithm = PostTrainingQuantization(
160162
preset=preset,

tests/onnx/models.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1897,3 +1897,22 @@ def build_matmul_model() -> onnx.ModelProto:
18971897
graph = onnx.helper.make_graph([matmul], "matmul-model", [X], [A], [W_initializer])
18981898
model = onnx.helper.make_model(graph)
18991899
return model
1900+
1901+
1902+
def build_matmul_model_with_nop_cast() -> onnx.ModelProto:
1903+
"""
1904+
Builds an ONNX model that contains a MatMul operation with a no-op Cast applied to the input.
1905+
"""
1906+
X = onnx.helper.make_tensor_value_info("X", onnx.TensorProto.FLOAT, [2, 3])
1907+
A = onnx.helper.make_tensor_value_info("A", onnx.TensorProto.FLOAT, [2, 2])
1908+
cast = onnx.helper.make_node("Cast", inputs=["X"], outputs=["X_cast"], name="cast", to=onnx.TensorProto.FLOAT)
1909+
matmul = onnx.helper.make_node("MatMul", inputs=["X_cast", "W"], outputs=["A"], name="matmul")
1910+
1911+
W_values = np.random.rand(3, 2).astype(np.float32)
1912+
W_initializer = onnx.helper.make_tensor(
1913+
name="W", data_type=onnx.TensorProto.FLOAT, dims=[3, 2], vals=W_values.tobytes(), raw=True
1914+
)
1915+
1916+
graph = onnx.helper.make_graph([cast, matmul], "matmul-model", [X], [A], [W_initializer])
1917+
model = onnx.helper.make_model(graph)
1918+
return model

tests/onnx/test_passes.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright (c) 2025 Intel Corporation
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from nncf.onnx.graph.passes import apply_preprocess_passes
13+
from tests.onnx.models import build_matmul_model_with_nop_cast
14+
15+
16+
def test_apply_preprocess_passes():
17+
model = build_matmul_model_with_nop_cast()
18+
before_nodes = [node.name for node in model.graph.node]
19+
preprocessed_model = apply_preprocess_passes(model)
20+
after_nodes = [node.name for node in preprocessed_model.graph.node]
21+
22+
assert set(after_nodes) - set(before_nodes) == set()
23+
assert set(before_nodes) - set(after_nodes) == set(["cast"])

0 commit comments

Comments
 (0)