Skip to content

[ONNX] Add eliminate_nop_cast pass #3376

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: develop
Choose a base branch
from
5 changes: 2 additions & 3 deletions nncf/onnx/graph/model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,8 @@ class ONNXModelTransformer(ModelTransformer):
ZERO_POINT_NAME_PREFIX = "zero_point_"

def __init__(self, model: onnx.ModelProto):
inferred_model = onnx.shape_inference.infer_shapes(model)
super().__init__(inferred_model)
self.onnx_model_extractor = onnx.utils.Extractor(inferred_model)
super().__init__(model)
self.onnx_model_extractor = onnx.utils.Extractor(model)

def _get_target_edge(
self,
Expand Down
48 changes: 48 additions & 0 deletions nncf/onnx/graph/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXDequantizeLinearMetatype
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXQuantizeLinearMetatype
from nncf.onnx.graph.onnx_helper import get_children
from nncf.onnx.graph.onnx_helper import get_children_node_mapping
from nncf.onnx.graph.transformations.commands import ONNXQDQNodeRemovingCommand
from nncf.onnx.graph.transformations.commands import ONNXTargetPoint

Expand Down Expand Up @@ -51,3 +53,49 @@ def remove_fq_from_inputs(model: onnx.ModelProto, nncf_graph: NNCFGraph) -> onnx
nodes_queue.extend(nncf_graph.get_next_nodes(current_node))

return model_transformer.transform(transformation_layout)


def eliminate_nop_cast(model: onnx.ModelProto) -> onnx.ModelProto:
"""
Inspects the provided ONNX model to identify and remove any 'No-op' (no-operation)
cast nodes, which are operations that do not change the data type of their input.

:param model: The ONNX model to be processed.
:return: The ONNX model with the redundant cast nodes removed.
"""
tensor_name_to_info = {
tensor.name: tensor
for tensor in (*model.graph.value_info, *model.graph.input, *model.graph.output, *model.graph.initializer)
}
redundant_cast_nodes = []
for node in model.graph.node:
if node.op_type == "Cast":
to_attr = None
for attr in node.attribute:
if attr.name == "to":
to_attr = onnx.helper.get_attribute_value(attr)

if to_attr is None:
continue

inp = node.input[0]
info = tensor_name_to_info[inp]
if info.type.tensor_type.elem_type == to_attr:
redundant_cast_nodes.append(node)

value_infos = {i.name: i for i in model.graph.value_info}
input_name_to_nodes_map = get_children_node_mapping(model)

for cast_node in redundant_cast_nodes:
# Unlink Cast node from the graph
children = get_children(cast_node, input_name_to_nodes_map)
for child in children:
for i, input_name in enumerate(child.input):
if input_name == cast_node.output[0]:
child.input[i] = cast_node.input[0]

# Remove Cast node from the graph
model.graph.value_info.remove(value_infos[cast_node.output[0]])
model.graph.node.remove(cast_node)

return model
5 changes: 3 additions & 2 deletions nncf/onnx/graph/nncf_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def convert_onnx_dtype_to_nncf_dtype(onnx_dtype: int) -> Dtype:
return Dtype.FLOAT if onnx_dtype == int(onnx.TensorProto.FLOAT) else Dtype.INTEGER

@staticmethod
def create_nncf_graph(onnx_model: onnx.ModelProto) -> NNCFGraph:
def create_nncf_graph(onnx_model: onnx.ModelProto, infer_shapes: bool = True) -> NNCFGraph:
"""
Creates NNCFGraph from 'onnx_model'.
Initially, ONNXGraph is built. All nodes from onnx_model which have valid metatype are added to NNCFGraph.
Expand All @@ -347,7 +347,8 @@ def create_nncf_graph(onnx_model: onnx.ModelProto) -> NNCFGraph:
:return: NNCFGraph.
"""
onnx_model = GraphConverter._replace_empty_node_name(onnx_model)
onnx_model = onnx.shape_inference.infer_shapes(onnx_model)
if infer_shapes:
onnx_model = onnx.shape_inference.infer_shapes(onnx_model)
edge_info_mapping = get_edge_info_mapping(onnx_model)
children_node_mapping = get_children_node_mapping(onnx_model)
parents_node_mapping = get_parents_node_mapping(onnx_model)
Expand Down
47 changes: 46 additions & 1 deletion nncf/onnx/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import tempfile
from pathlib import Path
from typing import Any, Callable, Iterable, List, Optional, Tuple, TypeVar, Union

import onnx
Expand All @@ -18,6 +20,7 @@
from nncf.common.quantization.structs import QuantizationPreset
from nncf.data import Dataset
from nncf.onnx.graph.metatypes.groups import OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS
from nncf.onnx.graph.model_utils import eliminate_nop_cast
from nncf.onnx.graph.nncf_graph_builder import GraphConverter
from nncf.parameters import DropType
from nncf.parameters import ModelType
Expand All @@ -37,6 +40,47 @@
TTensor = TypeVar("TTensor")


def quantize_pre_process(model: onnx.ModelProto, save_as_external_data: bool = True):
"""
Preprocesses the provided ONNX model for quantization.

This method performs the following steps:
1. Infers shapes in the model.
2. Removes redundant 'No-op' cast nodes from the model.

:param model: The ONNX model to be preprocessed.
:param save_as_external_data: A boolean flag indicating whether to
save the model with external data. If `True`, external data is
saved separately; otherwise, the model is saved as a single file.
:return: A preprocessed ONNX model, ready for quantization.
"""
with tempfile.TemporaryDirectory(dir=tempfile.gettempdir()) as temp_dir:
temp_path = Path(temp_dir)
input_model_path = str(temp_path / "input_model.onnx")

if save_as_external_data:
onnx.save_model(
model,
input_model_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
location="model.data",
size_threshold=1024,
convert_attribute=False,
)
else:
onnx.save(model, input_model_path)
model = None

shape_inferred_model_path = str(temp_path / "shape_inferred_model.onnx")
onnx.shape_inference.infer_shapes_path(input_model_path, shape_inferred_model_path)

preprocessed_model = onnx.load(shape_inferred_model_path)
preprocessed_model = eliminate_nop_cast(preprocessed_model)

return preprocessed_model


def quantize_impl(
model: onnx.ModelProto,
calibration_dataset: Dataset,
Expand Down Expand Up @@ -81,7 +125,8 @@ def quantize_impl(
advanced_parameters=advanced_parameters,
)

graph = GraphConverter.create_nncf_graph(model)
model = quantize_pre_process(model)
graph = GraphConverter.create_nncf_graph(model, infer_shapes=False)
warning_model_no_batchwise_support(graph, advanced_parameters, model_type, OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS)
quantized_model = quantization_algorithm.apply(model, graph, dataset=calibration_dataset)

Expand Down
Loading
Loading