-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Description
from this model
Expected behavior
ONNX Min supports NumPy-style broadcasting across inputs. For this repro:
relu_1has shape(30,)atanhas shape(1,)Min(relu_1, atan)should broadcast and produce shape(30,)
Actual behavior
ONNX Runtime runs the model successfully and produces a valid output.
[ort] OK
- out shape= (30,) dtype= float32 min/max= (1.4074832, 1.4074832)
TVM Relax ONNX frontend fails during import (from_onnx) with:
Concat expects the input tensors to have the same shape on every dimension except the one indicated by the input axis. However, the input contains tensors whose shapes on dimension 1 is T.int64(30) and T.int64(1)
Environment
Operating System:Ubuntu 22.04.4 LTS
TVM version:0.23.0dev
pytorch version:2.9.1
ort version:1.23.2
onnx version: 1.20.0
python:3.11.14
Steps to reproduce
This bug can be reproduced by the following code with the model in the attachment. As shown in the code, the model can be executed by onnxruntime. However, TVM crashes when importing the model.
from __future__ import annotations
import argparse
import pickle
import sys
from pathlib import Path
import numpy as np
import onnx
def _load_oracle_inputs(path: Path) -> dict[str, np.ndarray]:
obj = pickle.loads(path.read_bytes())
inp = obj.get("input")
if not isinstance(inp, dict):
raise ValueError(f"oracle.pkl missing 'input' dict: keys={list(obj.keys())}")
return {k: np.array(v) for k, v in inp.items()}
def main() -> int:
ap = argparse.ArgumentParser()
ap.add_argument("--model", type=Path, default=Path("model.onnx"))
ap.add_argument("--oracle", type=Path, default=Path("oracle.pkl"))
args = ap.parse_args()
model_path = args.model.resolve()
oracle_path = args.oracle.resolve()
onnx_model = onnx.load(model_path.as_posix())
inputs = _load_oracle_inputs(oracle_path)
print("[model]", model_path)
print("[oracle]", oracle_path)
print("[inputs]", {k: v.shape for k, v in inputs.items()})
# ORT: expected OK
try:
import onnxruntime as ort # type: ignore
ort_session = ort.InferenceSession(onnx_model.SerializeToString(), providers=["CPUExecutionProvider"])
ort_out = ort_session.run(None, inputs)
out_names = [o.name for o in ort_session.get_outputs()]
print("[ort] OK")
for name, val in zip(out_names, ort_out):
arr = np.array(val)
print(" -", name, "shape=", arr.shape, "dtype=", arr.dtype, "min/max=", (arr.min(), arr.max()))
except Exception as e:
print("[ort] FAILED:", type(e).__name__, e)
return 1
# TVM: expected FAIL during from_onnx
try:
import tvm # type: ignore
from tvm import relax # type: ignore
from tvm.relax.frontend.onnx import from_onnx # type: ignore
shape_dict = {k: v.shape for k, v in inputs.items()}
print("[tvm] shape_dict:", shape_dict)
tvm_mod = from_onnx(onnx_model, shape_dict=shape_dict)
if isinstance(tvm_mod, (list, tuple)):
tvm_mod = tvm_mod[0]
# If it unexpectedly imports, try build+run.
tgt = tvm.target.Target("llvm")
pipeline = relax.pipeline.get_default_pipeline(tgt)
with tvm.transform.PassContext(opt_level=3):
ex = relax.build(tvm_mod, target=tgt, relax_pipeline=pipeline)
vm = relax.VirtualMachine(ex, tvm.cpu())
vm.set_input("main", **inputs)
vm.invoke_stateful("main")
out = vm.get_outputs("main")
print("[tvm] UNEXPECTED: succeeded. output type:", type(out))
except Exception as e:
print("[tvm] EXPECTED FAIL:", type(e).__name__, str(e).splitlines()[0])
return 0
return 0
if __name__ == "__main__":
raise SystemExit(main())
Triage
- needs-triage