Skip to content

notebooks/openvino/vision_language_quantization.ipynb failed in Step 5c #1488

@yangxazhou

Description

@yangxazhou

Hi all, can you help me with my bug when run vision_language_quantization.ipynb with model 'Qwen/Qwen2.5-VL-3B-Instruct'?
Error:
RuntimeError: Exception from src/inference/src/cpp/infer_request.cpp:74:
Exception from src/inference/src/cpp/infer_request.cpp:66:
Check 'shape.compatible(ov::PartialShape(tensor->get_shape()))' failed at src/plugins/intel_cpu/src/infer_request.cpp:389:
Can't set the input tensor with index: 0, because the model input (shape=[?,?]) and the tensor (shape=(1.3.224.224)) are incompatible

code:

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from huggingface_hub import create_repo, upload_file
from optimum_benchmark import (
Benchmark,
BenchmarkConfig,
BenchmarkReport,
InferenceConfig,
OpenVINOConfig,
ProcessConfig,
PyTorchConfig,
)
from optimum_benchmark.logging_utils import setup_logging
from openvino.runtime import Core

setup_logging(level="INFO", prefix="MAIN-PROCESS")

if name == 'main':
launcher_config = ProcessConfig()
scenario_config = InferenceConfig(
memory=True,
latency=True,
generate_kwargs={"max_new_tokens": 16, "min_new_tokens": 16},
input_shapes={
"batch_size": 1,
"sequence_length": 16,
"num_channels": 3,
"num_images": 1,
"height": 224,
"width": 224,
},
)
model_id = 'Qwen/Qwen2.5-VL-3B-Instruct'
configs = {
# "pytorch": PyTorchConfig(device="cpu", model=model_id, no_weights=True),
"openvino": OpenVINOConfig(device="cpu", model=model_id, no_weights=True, task='image-text-to-text'),
# "openvino": OpenVINOConfig(device="cpu", model=model_path, no_weights=False, task='image-text-to-text'),
}

for config_name, backend_config in configs.items():
    benchmark_config = BenchmarkConfig(
        name=f"{config_name}",
        launcher=launcher_config,
        scenario=scenario_config,
        backend=backend_config,
    )
    benchmark_report = Benchmark.launch(benchmark_config)
    benchmark_report.save_json(f"{config_name}_report.json")
    benchmark_config.save_json(f"{config_name}_config.json")

reports = {}
for config_name in configs.keys():
    reports[config_name] = BenchmarkReport.from_json(f"{config_name}_report.json")

# Plotting results
_, ax = plt.subplots()
ax.boxplot(
    [reports[config_name].prefill.latency.values for config_name in reports.keys()],
    tick_labels=reports.keys(),
    showfliers=False,
)
plt.xticks(rotation=10)
ax.set_ylabel("Latency (s)")
ax.set_xlabel("Configurations")
ax.set_title("Prefill Latencies")
plt.savefig("prefill_latencies_boxplot.png")

_, ax = plt.subplots()
ax.bar(
    list(reports.keys()),
    [reports[config_name].decode.throughput.value for config_name in reports.keys()],
    color=["C0", "C1", "C2", "C3", "C4", "C5"],
)
plt.xticks(rotation=10)
ax.set_xlabel("Configurations")
ax.set_title("Decoding Throughput")
ax.set_ylabel("Throughput (tokens/s)")
plt.savefig("decode_throughput_barplot.png")

Detailed error log:
[MAIN-PROCESS][2025-10-23 08:53:51,028][process][INFO] - Allocated process launcher
[ISOLATED-PROCESS][2025-10-23 08:53:53,187][openvino][INFO] - Allocating openvino backend
[ISOLATED-PROCESS][2025-10-23 08:53:53,187][openvino][INFO] - + Seeding backend with 42
[ISOLATED-PROCESS][2025-10-23 08:53:53,188][openvino][INFO] - + Benchmarking a Transformers model
/home/intel/miniforge3/lib/python3.12/site-packages/torch/onnx/_internal/registration.py:168: OnnxExporterWarning: Symbolic function 'aten::scaled_dot_product_attention' already registered for opset 14. Replacing the existing function with new function. This is unexpected. Please report it on https://github.com/pytorch/pytorch/issues.
warnings.warn(
[ISOLATED-PROCESS][2025-10-23 08:54:04,581][openvino][INFO] - + Using OVModel class OVModelForVisualCausalLM
[ISOLATED-PROCESS][2025-10-23 08:54:04,582][inference][INFO] - Allocating inference scenario
[ISOLATED-PROCESS][2025-10-23 08:54:04,583][inference][INFO] - + Updating Text Generation kwargs with default values
[ISOLATED-PROCESS][2025-10-23 08:54:04,583][inference][INFO] - + Initializing Text Generation targets list
[ISOLATED-PROCESS][2025-10-23 08:54:04,583][inference][INFO] - + Initializing Latency tracker
[ISOLATED-PROCESS][2025-10-23 08:54:04,583][latency][INFO] - + Tracking latency using CPU performance counter
[ISOLATED-PROCESS][2025-10-23 08:54:04,583][inference][INFO] - + Initializing Per-Token Latency tracker
[ISOLATED-PROCESS][2025-10-23 08:54:04,583][latency][INFO] - + Tracking latency using CPU performance counter
[ISOLATED-PROCESS][2025-10-23 08:54:04,583][inference][INFO] - + Initializing Memory tracker
[ISOLATED-PROCESS][2025-10-23 08:54:04,583][memory][INFO] - + Tracking RAM memory of process 1228189
[ISOLATED-PROCESS][2025-10-23 08:54:04,583][inference][INFO] - + Generating inputs for task image-text-to-text
[ISOLATED-PROCESS][2025-10-23 08:54:04,584][inference][INFO] - + Running model loading tracking
[ISOLATED-PROCESS][2025-10-23 08:54:06,570][openvino][INFO] - + Creating backend temporary directory
[ISOLATED-PROCESS][2025-10-23 08:54:06,570][openvino][INFO] - + Creating no weights OVModel
[ISOLATED-PROCESS][2025-10-23 08:54:07,297][openvino][INFO] - + Loading no weights OVModel
/home/intel/miniforge3/lib/python3.12/site-packages/transformers/cache_utils.py:135: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if not self.is_initialized or self.keys.numel() == 0:
/home/intel/miniforge3/lib/python3.12/site-packages/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py:852: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if position_ids.ndim == 3 and position_ids.shape[0] == 4:
/home/intel/miniforge3/lib/python3.12/site-packages/transformers/masking_utils.py:207: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if (padding_length := kv_length + kv_offset - attention_mask.shape[-1]) > 0:
/home/intel/miniforge3/lib/python3.12/site-packages/optimum/exporters/openvino/model_patcher.py:332: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
torch.tensor(0.0, device=mask.device, dtype=dtype),
/home/intel/miniforge3/lib/python3.12/site-packages/optimum/exporters/openvino/model_patcher.py:333: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
torch.tensor(torch.finfo(torch.float16).min, device=mask.device, dtype=dtype),
/home/intel/miniforge3/lib/python3.12/site-packages/transformers/integrations/sdpa_attention.py:76: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
is_causal = query.shape[2] > 1 and attention_mask is None and getattr(module, "is_causal", True)
INFO:nncf:Statistics of the bitwidth distribution:
┍━━━━━━━━━━━━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┑
│ Weight compression mode │ % all parameters (layers) │ % ratio-defining parameters (layers) │
┝━━━━━━━━━━━━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┥
│ int8_asym │ 100% (253 / 253) │ 100% (253 / 253) │
┕━━━━━━━━━━━━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┙
Applying Weight Compression ━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% • 0:00:07 • 0:00:0000:0100:01
[ISOLATED-PROCESS][2025-10-23 08:55:37,707][inference][INFO] - + Preparing inputs for backend openvino
[ISOLATED-PROCESS][2025-10-23 08:55:37,707][inference][INFO] - + Warming up backend for Text Generation
[ISOLATED-PROCESS][2025-10-23 08:55:37,711][process][ERROR] - + Sending traceback string to main process
[MAIN-PROCESS][2025-10-23 08:55:37,714][process][ERROR] - + Received traceback from isolated process
[ISOLATED-PROCESS][2025-10-23 08:55:37,714][process][INFO] - + Sending traceback string directly (3909 bytes)
[ISOLATED-PROCESS][2025-10-23 08:55:37,726][process][INFO] - + Exiting isolated process

ChildProcessError Traceback (most recent call last)
Cell In[5], line 53
46 for config_name, backend_config in configs.items():
47 benchmark_config = BenchmarkConfig(
48 name=f"{config_name}",
49 launcher=launcher_config,
50 scenario=scenario_config,
51 backend=backend_config,
52 )
---> 53 benchmark_report = Benchmark.launch(benchmark_config)
54 benchmark_report.save_json(f"{config_name}_report.json")
55 benchmark_config.save_json(f"{config_name}_config.json")

File ~/miniforge3/lib/python3.12/site-packages/optimum_benchmark/benchmark/base.py:51, in Benchmark.launch(config)
48 launcher: Launcher = launcher_factory(launcher_config)
50 # Launch the benchmark using the launcher
---> 51 report = launcher.launch(worker=Benchmark.run, worker_args=[config])
53 if config.log_report:
54 report.log()

File ~/miniforge3/lib/python3.12/site-packages/optimum_benchmark/launchers/process/launcher.py:72, in ProcessLauncher.launch(self, worker, worker_args)
70 if isinstance(response, str):
71 self.logger.error("\t+ Received traceback from isolated process")
---> 72 raise ChildProcessError(response)
73 elif isinstance(response, dict):
74 self.logger.info("\t+ Received report dictionary from isolated process")

ChildProcessError: Traceback (most recent call last):
File "/home/intel/miniforge3/lib/python3.12/site-packages/optimum_benchmark/launchers/process/launcher.py", line 107, in target
report = worker(*worker_args)
^^^^^^^^^^^^^^^^^^^^
File "/home/intel/miniforge3/lib/python3.12/site-packages/optimum_benchmark/benchmark/base.py", line 78, in run
report = scenario.run(backend)
^^^^^^^^^^^^^^^^^^^^^
File "/home/intel/miniforge3/lib/python3.12/site-packages/optimum_benchmark/scenarios/inference/scenario.py", line 140, in run
self.warmup_text_generation()
File "/home/intel/miniforge3/lib/python3.12/site-packages/optimum_benchmark/scenarios/inference/scenario.py", line 204, in warmup_text_generation
self.backend.generate(self.inputs, self.config.generate_kwargs)
File "/home/intel/miniforge3/lib/python3.12/site-packages/optimum_benchmark/backends/openvino/backend.py", line 132, in generate
return self.pretrained_model.generate(**inputs, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/intel/miniforge3/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/intel/miniforge3/lib/python3.12/site-packages/transformers/generation/utils.py", line 2567, in generate
result = decoding_method(
^^^^^^^^^^^^^^^^
File "/home/intel/miniforge3/lib/python3.12/site-packages/transformers/generation/utils.py", line 2787, in _sample
outputs = self(**model_inputs, return_dict=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/intel/miniforge3/lib/python3.12/site-packages/optimum/modeling_base.py", line 111, in call
return self.forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/intel/miniforge3/lib/python3.12/site-packages/optimum/intel/openvino/modeling_visual_language.py", line 826, in forward
inputs_embeds, attention_mask, position_ids = self.get_multimodal_embeddings(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/intel/miniforge3/lib/python3.12/site-packages/optimum/intel/openvino/modeling_visual_language.py", line 3247, in get_multimodal_embeddings
image_embeds = torch.from_numpy(self.get_vision_embeddings(pixel_values, image_grid_thw))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/intel/miniforge3/lib/python3.12/site-packages/optimum/intel/openvino/modeling_visual_language.py", line 3301, in get_vision_embeddings
hidden_states = self.vision_embeddings(pixel_values)[0]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/intel/miniforge3/lib/python3.12/site-packages/optimum/intel/openvino/modeling_base.py", line 869, in call
return self.forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/intel/miniforge3/lib/python3.12/site-packages/optimum/intel/openvino/modeling_visual_language.py", line 282, in forward
result = self.request(inputs)
^^^^^^^^^^^^^^^^^^^^
File "/home/intel/miniforge3/lib/python3.12/site-packages/openvino/_ov_api.py", line 440, in call
return self._infer_request.infer(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/intel/miniforge3/lib/python3.12/site-packages/openvino/_ov_api.py", line 184, in infer
return OVDict(super().infer(_data_dispatch(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Exception from src/inference/src/cpp/infer_request.cpp:74:
Exception from src/inference/src/cpp/infer_request.cpp:66:
Check 'shape.compatible(ov::PartialShape(tensor->get_shape()))' failed at src/plugins/intel_cpu/src/infer_request.cpp:389:
Can't set the input tensor with index: 0, because the model input (shape=[?,?]) and the tensor (shape=(1.3.224.224)) are incompatible

Python3.12.11 ENV:
Package Version


about-time 4.2.1
accelerate 1.11.0
aiohappyeyeballs 2.6.1
aiohttp 3.13.1
aiosignal 1.4.0
alive-progress 3.3.0
antlr4-python3-runtime 4.9.3
anyio 4.11.0
attrs 25.4.0
autograd 1.8.0
certifi 2025.10.5
charset-normalizer 3.4.4
cma 4.4.0
colorlog 6.10.1
contourpy 1.3.3
cycler 0.12.1
datasets 4.2.0
Deprecated 1.2.18
dill 0.4.0
docopt 0.6.2
filelock 3.20.0
flatten-dict 0.4.2
fonttools 4.60.1
frozenlist 1.8.0
fsspec 2025.9.0
graphemeu 0.7.2
h11 0.16.0
hf-xet 1.1.10
httpcore 1.0.9
httpx 0.28.1
huggingface-hub 0.35.3
hydra-core 1.3.2
idna 3.11
Jinja2 3.1.6
joblib 1.5.2
jsonschema 4.25.1
jsonschema-specifications 2025.9.1
kiwisolver 1.4.9
markdown-it-py 4.0.0
MarkupSafe 3.0.3
matplotlib 3.10.7
mdurl 0.1.2
ml_dtypes 0.5.3
mpmath 1.3.0
multidict 6.7.0
multiprocess 0.70.16
natsort 8.4.0
networkx 3.4.2
ninja 1.13.0
nncf 2.18.0
num2words 0.5.14
numpy 2.2.6
nvidia-cublas-cu12 12.8.4.1
nvidia-cuda-cupti-cu12 12.8.90
nvidia-cuda-nvrtc-cu12 12.8.93
nvidia-cuda-runtime-cu12 12.8.90
nvidia-cudnn-cu12 9.10.2.21
nvidia-cufft-cu12 11.3.3.83
nvidia-cufile-cu12 1.13.1.3
nvidia-curand-cu12 10.3.9.90
nvidia-cusolver-cu12 11.7.3.90
nvidia-cusparse-cu12 12.5.8.93
nvidia-cusparselt-cu12 0.7.1
nvidia-ml-py 13.580.82
nvidia-nccl-cu12 2.27.5
nvidia-nvjitlink-cu12 12.8.93
nvidia-nvshmem-cu12 3.3.20
nvidia-nvtx-cu12 12.8.90
omegaconf 2.3.0
onnx 1.19.1
openvino 2025.3.0
openvino-telemetry 2025.2.0
openvino-tokenizers 2025.3.0.0
optimum 1.27.0
optimum-benchmark 0.7.0.dev0
optimum-intel 1.25.2
packaging 25.0
pandas 2.3.3
pillow 12.0.0
pip 25.0.1
propcache 0.4.1
protobuf 6.33.0
psutil 7.1.1
pyarrow 21.0.0
pydot 3.0.4
Pygments 2.19.2
pymoo 0.6.1.5
pyparsing 3.2.5
python-dateutil 2.9.0.post0
pytz 2025.2
PyYAML 6.0.3
referencing 0.37.0
regex 2025.10.23
requests 2.32.5
rich 14.2.0
rpds-py 0.28.0
safetensors 0.6.2
scikit-learn 1.7.2
scipy 1.16.2
setuptools 80.9.0
six 1.17.0
sniffio 1.3.1
sympy 1.14.0
tabulate 0.9.0
threadpoolctl 3.6.0
tokenizers 0.21.4
torch 2.9.0
torchvision 0.24.0
tqdm 4.67.1
transformers 4.53.3
triton 3.5.0
typing_extensions 4.15.0
tzdata 2025.2
urllib3 2.5.0
wrapt 1.17.3
xxhash 3.6.0
yarl 1.22.0

Sys env:
OS VERSION="25.04 (Plucky Puffin)"
Mem: 64G
CPU: i7-13700

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions