Open
Description
OpenVINO Version
openvino-2025.2.0.dev20250408
Operating System
Other (Please specify in description)
Device used for inference
CPU
Framework
ONNX
Model used
pytorch
Issue description
I found that switching my model to openvino produced completely different results compared to pytorch and onnx,
After a long time of troubleshooting, the problem was located to the following example
i think the bug comes from inference(stage of model compile)
Netron Validation: Confirmed IR model structure is equal to only generate (3,3) shape randn tensor
Observed Behavior: Output contains unexpected zeros only when y is (?,?) or other two dimension shape.
when you change it to (3,) or (3,3,3) it works well
cpu:INTEL(R) XEON(R) PLATINUM 8581C
system: rocky9.5
Step-by-step reproduction
import torch
import torch.nn as nn
import numpy as np
import random
import onnxruntime as ort
import openvino as ov
import openvino.properties as properties
class BugModel(nn.Module):
def __init__(self):
super(BugModel, self).__init__()
def forward(self, x) -> torch.Tensor:
#you could try to change the size ,for example ,(3,) (3,3,3)
y=torch.randn(size=(2,2),dtype=x.dtype,device=x.device)
return torch.matmul(x,y)
x=torch.randn(size=(3,4,2),dtype=torch.float32)
bug_model = BugModel().eval()
print(bug_model(x))
seq_len = torch.export.Dim('seq_len', min=2, max=1600)
onnx_path='/root/model.onnx'
torch.onnx.export(bug_model,
(x), f=onnx_path,
input_names=['x'],
output_names=['out'],
optimize=True,
opset_version=22,
dynamo=True,
dynamic_shapes={# when close it, result will full of zero
'x': {0: seq_len},
}
)
ov_model=ov.convert_model(onnx_path,verbose=True)
ov.save_model(ov_model, '/root/model.xml')
config = {
properties.inference_num_threads: 60,
}
compiled_model = ov.Core().compile_model(
model=ov_model,
device_name='CPU',
config=config,
)
infer_request = compiled_model.create_infer_request()
infer_request.set_input_tensor(0, ov.Tensor(torch.randn(size=(3,4,2),dtype=torch.float32).numpy()))
infer_request.infer()
output_tensor = infer_request.get_output_tensor(0)
out = torch.from_numpy(output_tensor.data)
print('openvino -------------')
print(out.shape)
print(out)
Relevant log output
when use dynamic shape
openvino -------------
torch.Size([3, 4, 3])
tensor([[[ 0., 0., -2232.],
[ 0., 0., -3480.],
[ 0., 0., -8352.],
[ 0., 0., 9216.]],
[[ 0., 0., -43776.],
[ 0., 0., 17856.],
[ 0., 0., -2472.],
[ 0., 0., 21024.]],
[[ 0., 0., 18432.],
[ 0., 0., -4800.],
[ 0., 0., 55680.],
[ 0., 0., -6384.]]])
when use fixed shape
tensor([[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]],
[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]],
[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]]])
when the randn shape is (3,),or (3,3,3)
openvino -------------
torch.Size([3, 4])
tensor([[-0.2812, 1.0469, 1.2109, 0.4141],
[ 0.4043, 1.2344, -1.5469, -0.0835],
[-1.2812, 0.9570, 1.1562, 0.0413]])
Issue submission checklist
- I'm reporting an issue. It's not a question.
- I checked the problem with the documentation, FAQ, open issues, Stack Overflow, etc., and have not found a solution.
- There is reproducer code and related data files such as images, videos, models, etc.