Skip to content

Commit 307b680

Browse files
committed
Change shape-mismatch error to warning in infer()
1 parent a696f38 commit 307b680

1 file changed

Lines changed: 10 additions & 5 deletions

File tree

runtime/onert/api/python/package/infer/session.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import List, Union, Tuple, Dict
22
import numpy as np
33
import time
4+
import warnings
45
from contextlib import contextmanager
56

67
from ..native.libnnfw_api_pybind import infer, tensorinfo
@@ -92,6 +93,7 @@ def infer(
9293
for idx, info in enumerate(original_infos):
9394
input_shape = inputs_array[idx].shape
9495
new_dims = []
96+
static_dim_changed = False
9597
# only the first `info.rank` entries matter
9698
for j, d in enumerate(info.dims[:info.rank]):
9799
if d == -1:
@@ -101,11 +103,14 @@ def infer(
101103
# static dim must match the provided array
102104
new_dims.append(d)
103105
else:
104-
raise ValueError(
105-
f"Input #{idx} dim {j} mismatch: "
106-
f"tensorinfo={d}, actual input shape={input_shape[j]}")
107-
# Preserve any trailing dims beyond rank
108-
# new_dims += list(info.dims[info.rank:])
106+
static_dim_changed = True
107+
108+
if static_dim_changed:
109+
warnings.warn(
110+
f"infer() called with input {idx}'s shape={input_shape}, "
111+
f"which differs from model’s expected shape={tuple(info.dims)}. "
112+
"Ensure this is intended.", UserWarning)
113+
109114
info.dims = new_dims
110115
fixed_infos.append(info)
111116

0 commit comments

Comments
 (0)