Skip to content

[ONNX] Add eliminate_nop_cast pass #3376

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
79 changes: 79 additions & 0 deletions nncf/onnx/graph/passes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) 2025 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import onnx

from nncf.onnx.graph.onnx_helper import get_children
from nncf.onnx.graph.onnx_helper import get_children_node_mapping


def eliminate_nop_cast(model: onnx.ModelProto) -> onnx.ModelProto:
"""
Inspects the provided ONNX model to identify and remove any 'No-op' (no-operation)
cast nodes, which are operations that do not change the data type of their input.

:param model: The ONNX model to be processed.
:return: The ONNX model with the redundant cast nodes removed.
"""
tensor_name_to_info = {
tensor.name: tensor
for tensor in (*model.graph.value_info, *model.graph.input, *model.graph.output, *model.graph.initializer)
}
redundant_cast_nodes = []
for node in model.graph.node:
if node.op_type == "Cast":
to_attr = None
for attr in node.attribute:
if attr.name == "to":
to_attr = onnx.helper.get_attribute_value(attr)

if to_attr is None:
continue

inp = node.input[0]
info = tensor_name_to_info[inp]
if info.type.tensor_type.elem_type == to_attr:
redundant_cast_nodes.append(node)

value_infos = {i.name: i for i in model.graph.value_info}
input_name_to_nodes_map = get_children_node_mapping(model)

for cast_node in redundant_cast_nodes:
# Unlink Cast node from the graph
children = get_children(cast_node, input_name_to_nodes_map)
for child in children:
for i, input_name in enumerate(child.input):
if input_name == cast_node.output[0]:
child.input[i] = cast_node.input[0]

# Remove Cast node from the graph
model.graph.value_info.remove(value_infos[cast_node.output[0]])
model.graph.node.remove(cast_node)

return model


def apply_preprocess_passes(model: onnx.ModelProto) -> onnx.ModelProto:
"""
Preprocesses the provided ONNX model for quantization.

This method performs the following steps:
1. Infers shapes in the model.
2. Removes redundant 'No-op' cast nodes from the model.

:param model: The ONNX model to be preprocessed.
:return: A preprocessed ONNX model, ready for quantization.
"""
preprocessed_model = onnx.shape_inference.infer_shapes(model)
# The `eliminate_nop_cast` pass should be applied after onnx.shape_inference.infer_shapes() call.
# Otherwise, not all no-op Cast nodes will be found.
preprocessed_model = eliminate_nop_cast(preprocessed_model)
return preprocessed_model
2 changes: 2 additions & 0 deletions nncf/onnx/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from nncf.onnx.graph.model_metadata import remove_metadata
from nncf.onnx.graph.model_metadata import set_metadata
from nncf.onnx.graph.nncf_graph_builder import GraphConverter
from nncf.onnx.graph.passes import apply_preprocess_passes
from nncf.onnx.quantization.backend_parameters import get_external_data_dir
from nncf.parameters import BackupMode
from nncf.parameters import CompressionFormat
Expand Down Expand Up @@ -155,6 +156,7 @@ def quantize_impl(
external_data_dir = check_external_data_location(model, external_data_dir)
if external_data_dir:
set_metadata(model, MetadataKey.EXTERNAL_DATA_DIR, external_data_dir)
model = apply_preprocess_passes(model)

quantization_algorithm = PostTrainingQuantization(
preset=preset,
Expand Down
19 changes: 19 additions & 0 deletions tests/onnx/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1897,3 +1897,22 @@ def build_matmul_model() -> onnx.ModelProto:
graph = onnx.helper.make_graph([matmul], "matmul-model", [X], [A], [W_initializer])
model = onnx.helper.make_model(graph)
return model


def build_matmul_model_with_nop_cast() -> onnx.ModelProto:
"""
Builds an ONNX model that contains a MatMul operation with a no-op Cast applied to the input.
"""
X = onnx.helper.make_tensor_value_info("X", onnx.TensorProto.FLOAT, [2, 3])
A = onnx.helper.make_tensor_value_info("A", onnx.TensorProto.FLOAT, [2, 2])
cast = onnx.helper.make_node("Cast", inputs=["X"], outputs=["X_cast"], name="cast", to=onnx.TensorProto.FLOAT)
matmul = onnx.helper.make_node("MatMul", inputs=["X_cast", "W"], outputs=["A"], name="matmul")

W_values = np.random.rand(3, 2).astype(np.float32)
W_initializer = onnx.helper.make_tensor(
name="W", data_type=onnx.TensorProto.FLOAT, dims=[3, 2], vals=W_values.tobytes(), raw=True
)

graph = onnx.helper.make_graph([cast, matmul], "matmul-model", [X], [A], [W_initializer])
model = onnx.helper.make_model(graph)
return model
23 changes: 23 additions & 0 deletions tests/onnx/test_passes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright (c) 2025 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nncf.onnx.graph.passes import apply_preprocess_passes
from tests.onnx.models import build_matmul_model_with_nop_cast


def test_apply_preprocess_passes():
model = build_matmul_model_with_nop_cast()
before_nodes = [node.name for node in model.graph.node]
preprocessed_model = apply_preprocess_passes(model)
after_nodes = [node.name for node in preprocessed_model.graph.node]

assert set(after_nodes) - set(before_nodes) == set()
assert set(before_nodes) - set(after_nodes) == set(["cast"])
Loading