-
Notifications
You must be signed in to change notification settings - Fork 259
[ONNX] Add eliminate_nop_cast pass #3376
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
alexsu52
merged 15 commits into
openvinotoolkit:develop
from
andrey-churkin:ac/eliminate_nop_cast
May 2, 2025
Merged
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
21625a3
Add eliminate_nop_cast pass
andrey-churkin 5419716
Update cspell_dict.txt
andrey-churkin e4dcd5d
Update tests/onnx/requirements.txt
andrey-churkin 78aa580
Update reference graphs
andrey-churkin 3f7039a
minor fixes
andrey-churkin 226a734
Use onnx.shape_inference.infer_shapes_path() method
andrey-churkin 08a75e4
update
andrey-churkin 70ea5da
update
andrey-churkin ea63b52
update
andrey-churkin 7b8dcb2
minor update
andrey-churkin ac7b3c2
Fix tests
andrey-churkin c1564f0
minor
andrey-churkin 17f658e
minor
andrey-churkin e147504
revert
andrey-churkin b743745
Add tests
andrey-churkin File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
# Copyright (c) 2025 Intel Corporation | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import onnx | ||
|
||
from nncf.onnx.graph.onnx_helper import get_children | ||
from nncf.onnx.graph.onnx_helper import get_children_node_mapping | ||
|
||
|
||
def eliminate_nop_cast(model: onnx.ModelProto) -> onnx.ModelProto: | ||
""" | ||
Inspects the provided ONNX model to identify and remove any 'No-op' (no-operation) | ||
cast nodes, which are operations that do not change the data type of their input. | ||
|
||
:param model: The ONNX model to be processed. | ||
:return: The ONNX model with the redundant cast nodes removed. | ||
""" | ||
tensor_name_to_info = { | ||
tensor.name: tensor | ||
for tensor in (*model.graph.value_info, *model.graph.input, *model.graph.output, *model.graph.initializer) | ||
} | ||
redundant_cast_nodes = [] | ||
for node in model.graph.node: | ||
if node.op_type == "Cast": | ||
to_attr = None | ||
for attr in node.attribute: | ||
if attr.name == "to": | ||
to_attr = onnx.helper.get_attribute_value(attr) | ||
|
||
if to_attr is None: | ||
continue | ||
|
||
inp = node.input[0] | ||
info = tensor_name_to_info[inp] | ||
if info.type.tensor_type.elem_type == to_attr: | ||
redundant_cast_nodes.append(node) | ||
|
||
value_infos = {i.name: i for i in model.graph.value_info} | ||
input_name_to_nodes_map = get_children_node_mapping(model) | ||
|
||
for cast_node in redundant_cast_nodes: | ||
# Unlink Cast node from the graph | ||
children = get_children(cast_node, input_name_to_nodes_map) | ||
for child in children: | ||
for i, input_name in enumerate(child.input): | ||
if input_name == cast_node.output[0]: | ||
child.input[i] = cast_node.input[0] | ||
|
||
# Remove Cast node from the graph | ||
model.graph.value_info.remove(value_infos[cast_node.output[0]]) | ||
model.graph.node.remove(cast_node) | ||
|
||
return model | ||
|
||
|
||
def apply_preprocess_passes(model: onnx.ModelProto) -> onnx.ModelProto: | ||
""" | ||
Preprocesses the provided ONNX model for quantization. | ||
|
||
This method performs the following steps: | ||
1. Infers shapes in the model. | ||
2. Removes redundant 'No-op' cast nodes from the model. | ||
|
||
:param model: The ONNX model to be preprocessed. | ||
:return: A preprocessed ONNX model, ready for quantization. | ||
""" | ||
preprocessed_model = onnx.shape_inference.infer_shapes(model) | ||
# The `eliminate_nop_cast` pass should be applied after onnx.shape_inference.infer_shapes() call. | ||
# Otherwise, not all no-op Cast nodes will be found. | ||
preprocessed_model = eliminate_nop_cast(preprocessed_model) | ||
return preprocessed_model |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# Copyright (c) 2025 Intel Corporation | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from nncf.onnx.graph.passes import apply_preprocess_passes | ||
from tests.onnx.models import build_matmul_model_with_nop_cast | ||
|
||
|
||
def test_apply_preprocess_passes(): | ||
model = build_matmul_model_with_nop_cast() | ||
before_nodes = [node.name for node in model.graph.node] | ||
preprocessed_model = apply_preprocess_passes(model) | ||
after_nodes = [node.name for node in preprocessed_model.graph.node] | ||
|
||
assert set(after_nodes) - set(before_nodes) == set() | ||
assert set(before_nodes) - set(after_nodes) == set(["cast"]) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.