Skip to content

onnx.onnx_cpp2py_export.checker.ValidationError when call quantize_static() in onnxruntime==1.20.1 #23268

@dzk9528

Description

@dzk9528

Describe the issue

When I try to quantize model with larger weight size in onnxruntime 1.20.1, following error appeared:

WARNING:root:Please consider to run pre-processing before quantization. Refer to example: https://github.com/microsoft/onnxruntime-inference-examples/blob/main/quantization/image_classification/cpu/ReadMe.md 
Traceback (most recent call last):
  File "/home/engineer/tetramem/ml-experimental/quantize/ort_bug.py", line 75, in <module>
    onnxruntime.quantization.quantize(float_model, "test_quantized_model.onnx", quant_config=quant_config)
  File "/opt/tetramem/python_env/lib/python3.10/site-packages/onnxruntime/quantization/quantize.py", line 878, in quantize
    quantize_static(
  File "/opt/tetramem/python_env/lib/python3.10/site-packages/onnxruntime/quantization/quantize.py", line 693, in quantize_static
    calibrator = create_calibrator(
  File "/opt/tetramem/python_env/lib/python3.10/site-packages/onnxruntime/quantization/calibrate.py", line 1186, in create_calibrator
    calibrator = MinMaxCalibrater(
  File "/opt/tetramem/python_env/lib/python3.10/site-packages/onnxruntime/quantization/calibrate.py", line 321, in __init__
    super().__init__(
  File "/opt/tetramem/python_env/lib/python3.10/site-packages/onnxruntime/quantization/calibrate.py", line 208, in __init__
    self.model = load_model_with_shape_infer(model_path)
  File "/opt/tetramem/python_env/lib/python3.10/site-packages/onnxruntime/quantization/quant_utils.py", line 983, in load_model_with_shape_infer
    model = onnx.load(inferred_model_path.as_posix())
  File "/opt/tetramem/python_env/lib/python3.10/site-packages/onnx/__init__.py", line 216, in load_model
    load_external_data_for_model(model, base_dir)
  File "/opt/tetramem/python_env/lib/python3.10/site-packages/onnx/external_data_helper.py", line 64, in load_external_data_for_model
    load_external_data_for_tensor(tensor, base_dir)
  File "/opt/tetramem/python_env/lib/python3.10/site-packages/onnx/external_data_helper.py", line 42, in load_external_data_for_tensor
    external_data_file_path = c_checker._resolve_external_data_location(  # type: ignore[attr-defined]
onnx.onnx_cpp2py_export.checker.ValidationError: Data of TensorProto ( tensor name: weights) should be stored in /tmp/ort.quant.i4tq_bcf/5f6cc3dc-cc94-11ef-a03a-c87f54034d33, but it doesn't exist or is not accessible.

To reproduce

  • Python 3.10
  • onnx==1.16.2
  • Test code
import typing

import numpy as np
import onnx
import onnxruntime
import onnxruntime.quantization


class RandomCalibrationDataGenerator(onnxruntime.quantization.CalibrationDataReader):
    """Generates pseudo-random calibration data for testing."""

    def __init__(self, seed: int, name: str, shape: typing.Sequence[int], length: int):
        self.rng = np.random.default_rng(seed=seed)
        self.name = name
        self.shape = shape
        self.length = length
        self.counter = 0

    def get_next(self):
        """See base class."""
        if self.counter >= self.length:
            return None

        array = self.rng.normal(size=self.shape)
        datum = {self.name: array.astype(np.float32)}

        self.counter += 1
        return datum


rng = np.random.default_rng(seed=54321)
calibration_data_reader = RandomCalibrationDataGenerator(
    seed=54322,
    name="input",
    shape=[1, 128],
    length=5,
)

weight_array = rng.normal(size=(128, 256))
weight_array = weight_array.astype(np.float32)
weight_proto = onnx.numpy_helper.from_array(weight_array, name="weights")

node = onnx.helper.make_node(
    op_type="MatMul",
    inputs=["input", "weights"],
    outputs=["output"],
    name="dense",
)

input_info = onnx.helper.make_tensor_value_info(
    "input", onnx.TensorProto.FLOAT, ["batch", 128]
)
output_info = onnx.helper.make_tensor_value_info(
    "output", onnx.TensorProto.FLOAT, ["batch", 256]
)

graph = onnx.helper.make_graph(
    nodes=[node],
    initializer=[weight_proto],
    inputs=[input_info],
    outputs=[output_info],
    name="matmul_graph",
)
float_model = onnx.helper.make_model(graph)


quant_config = onnxruntime.quantization.StaticQuantConfig(
    calibration_data_reader,
    quant_format=onnxruntime.quantization.QuantFormat.QDQ,
    weight_type=onnxruntime.quantization.QuantType.QInt8,
    per_channel=True,
    extra_options=None,
)

onnxruntime.quantization.quantize(float_model, "test_quantized_model.onnx", quant_config=quant_config)

Urgency

This is a urgent request and it is very close to our overall model quantization software product development.

Platform

Linux

OS Version

Ubuntu 22.04

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.20.1

ONNX Runtime API

Python

Architecture

X64

Execution Provider

Default CPU

Execution Provider Library Version

No response

Metadata

Metadata

Assignees

Labels

quantizationissues related to quantization

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions