Skip to content

Commit 17f658e

Browse files
minor
1 parent c1564f0 commit 17f658e

File tree

4 files changed

+9
-6
lines changed

4 files changed

+9
-6
lines changed

.ci/cspell_dict.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,9 +223,9 @@ logicalor
223223
logicalxor
224224
logit
225225
loglikelihoods
226+
lspec
226227
lstmsequence
227228
lstsq
228-
lspec
229229
lyalyushkin
230230
mapillary
231231
maskrcnn

nncf/onnx/graph/model_transformer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,9 @@ class ONNXModelTransformer(ModelTransformer):
4747
ZERO_POINT_NAME_PREFIX = "zero_point_"
4848

4949
def __init__(self, model: onnx.ModelProto):
50-
super().__init__(model)
51-
self.onnx_model_extractor = onnx.utils.Extractor(model)
50+
inferred_model = onnx.shape_inference.infer_shapes(model)
51+
super().__init__(inferred_model)
52+
self.onnx_model_extractor = onnx.utils.Extractor(inferred_model)
5253

5354
def _get_target_edge(
5455
self,

nncf/onnx/graph/passes.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def eliminate_nop_cast(model: onnx.ModelProto) -> onnx.ModelProto:
6161
return model
6262

6363

64-
def apply_preprocess_passes(model: onnx.ModelProto) -> None:
64+
def apply_preprocess_passes(model: onnx.ModelProto) -> onnx.ModelProto:
6565
"""
6666
Preprocesses the provided ONNX model for quantization.
6767
@@ -73,5 +73,7 @@ def apply_preprocess_passes(model: onnx.ModelProto) -> None:
7373
:return: A preprocessed ONNX model, ready for quantization.
7474
"""
7575
preprocessed_model = onnx.shape_inference.infer_shapes(model)
76+
# The `eliminate_nop_cast` pass should be applied after onnx.shape_inference.infer_shapes() call.
77+
# Otherwise, not all no-op Cast nodes will be found.
7678
preprocessed_model = eliminate_nop_cast(preprocessed_model)
7779
return preprocessed_model

nncf/onnx/quantization/quantize_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import sys
1313
from pathlib import Path
1414
from typing import Any, Callable, Iterable, Optional, TypeVar, Union
15-
import tempfile
1615

1716
import onnx
1817
from onnx.external_data_helper import ExternalDataInfo
@@ -29,8 +28,8 @@
2928
from nncf.onnx.graph.model_metadata import MetadataKey
3029
from nncf.onnx.graph.model_metadata import remove_metadata
3130
from nncf.onnx.graph.model_metadata import set_metadata
32-
from nncf.onnx.graph.model_utils import eliminate_nop_cast
3331
from nncf.onnx.graph.nncf_graph_builder import GraphConverter
32+
from nncf.onnx.graph.passes import apply_preprocess_passes
3433
from nncf.onnx.quantization.backend_parameters import get_external_data_dir
3534
from nncf.parameters import BackupMode
3635
from nncf.parameters import CompressionFormat
@@ -157,6 +156,7 @@ def quantize_impl(
157156
external_data_dir = check_external_data_location(model, external_data_dir)
158157
if external_data_dir:
159158
set_metadata(model, MetadataKey.EXTERNAL_DATA_DIR, external_data_dir)
159+
model = apply_preprocess_passes(model)
160160

161161
quantization_algorithm = PostTrainingQuantization(
162162
preset=preset,

0 commit comments

Comments
 (0)