Skip to content

Commit 76053b7

Browse files
Fix tests
1 parent 35aa8af commit 76053b7

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

nncf/onnx/graph/nncf_graph_builder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ def convert_onnx_dtype_to_nncf_dtype(onnx_dtype: int) -> Dtype:
337337
return Dtype.FLOAT if onnx_dtype == int(onnx.TensorProto.FLOAT) else Dtype.INTEGER
338338

339339
@staticmethod
340-
def create_nncf_graph(onnx_model: onnx.ModelProto) -> NNCFGraph:
340+
def create_nncf_graph(onnx_model: onnx.ModelProto, infer_shapes: bool = True) -> NNCFGraph:
341341
"""
342342
Creates NNCFGraph from 'onnx_model'.
343343
Initially, ONNXGraph is built. All nodes from onnx_model which have valid metatype are added to NNCFGraph.
@@ -347,6 +347,8 @@ def create_nncf_graph(onnx_model: onnx.ModelProto) -> NNCFGraph:
347347
:return: NNCFGraph.
348348
"""
349349
onnx_model = GraphConverter._replace_empty_node_name(onnx_model)
350+
if infer_shapes:
351+
onnx_model = onnx.shape_inference.infer_shapes(onnx_model)
350352
edge_info_mapping = get_edge_info_mapping(onnx_model)
351353
children_node_mapping = get_children_node_mapping(onnx_model)
352354
parents_node_mapping = get_parents_node_mapping(onnx_model)

nncf/onnx/quantization/quantize_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def quantize_impl(
126126
)
127127

128128
model = quantize_pre_process(model)
129-
graph = GraphConverter.create_nncf_graph(model)
129+
graph = GraphConverter.create_nncf_graph(model, infer_shapes=False)
130130
warning_model_no_batchwise_support(graph, advanced_parameters, model_type, OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS)
131131
quantized_model = quantization_algorithm.apply(model, graph, dataset=calibration_dataset)
132132

0 commit comments

Comments
 (0)