Skip to content

[Bug] Importing ONNX Min with broadcasting fails #18592

@dutZ1855

Description

@dutZ1855

from this model

Image

Expected behavior

ONNX Min supports NumPy-style broadcasting across inputs. For this repro:

  • relu_1 has shape (30,)
  • atan has shape (1,)
  • Min(relu_1, atan) should broadcast and produce shape (30,)

ONNX Broadcasting

Actual behavior

ONNX Runtime runs the model successfully and produces a valid output.

[ort] OK
  - out shape= (30,) dtype= float32 min/max= (1.4074832, 1.4074832)

TVM Relax ONNX frontend fails during import (from_onnx) with:

Concat expects the input tensors to have the same shape on every dimension except the one indicated by the input axis. However, the input contains tensors whose shapes on dimension 1 is T.int64(30) and T.int64(1)

Environment

Operating System:Ubuntu 22.04.4 LTS
TVM version:0.23.0dev
pytorch version:2.9.1
ort version:1.23.2
onnx version: 1.20.0
python:3.11.14

Steps to reproduce

This bug can be reproduced by the following code with the model in the attachment. As shown in the code, the model can be executed by onnxruntime. However, TVM crashes when importing the model.

from __future__ import annotations

import argparse
import pickle
import sys
from pathlib import Path

import numpy as np
import onnx


def _load_oracle_inputs(path: Path) -> dict[str, np.ndarray]:
    obj = pickle.loads(path.read_bytes())
    inp = obj.get("input")
    if not isinstance(inp, dict):
        raise ValueError(f"oracle.pkl missing 'input' dict: keys={list(obj.keys())}")
    return {k: np.array(v) for k, v in inp.items()}


def main() -> int:
    ap = argparse.ArgumentParser()
    ap.add_argument("--model", type=Path, default=Path("model.onnx"))
    ap.add_argument("--oracle", type=Path, default=Path("oracle.pkl"))
    args = ap.parse_args()

    model_path = args.model.resolve()
    oracle_path = args.oracle.resolve()

    onnx_model = onnx.load(model_path.as_posix())
    inputs = _load_oracle_inputs(oracle_path)
    print("[model]", model_path)
    print("[oracle]", oracle_path)
    print("[inputs]", {k: v.shape for k, v in inputs.items()})

    # ORT: expected OK
    try:
        import onnxruntime as ort  # type: ignore

        ort_session = ort.InferenceSession(onnx_model.SerializeToString(), providers=["CPUExecutionProvider"])
        ort_out = ort_session.run(None, inputs)
        out_names = [o.name for o in ort_session.get_outputs()]
        print("[ort] OK")
        for name, val in zip(out_names, ort_out):
            arr = np.array(val)
            print("  -", name, "shape=", arr.shape, "dtype=", arr.dtype, "min/max=", (arr.min(), arr.max()))
    except Exception as e:
        print("[ort] FAILED:", type(e).__name__, e)
        return 1

    # TVM: expected FAIL during from_onnx
    try:
        import tvm  # type: ignore
        from tvm import relax  # type: ignore
        from tvm.relax.frontend.onnx import from_onnx  # type: ignore

        shape_dict = {k: v.shape for k, v in inputs.items()}
        print("[tvm] shape_dict:", shape_dict)
        tvm_mod = from_onnx(onnx_model, shape_dict=shape_dict)
        if isinstance(tvm_mod, (list, tuple)):
            tvm_mod = tvm_mod[0]
        # If it unexpectedly imports, try build+run.
        tgt = tvm.target.Target("llvm")
        pipeline = relax.pipeline.get_default_pipeline(tgt)
        with tvm.transform.PassContext(opt_level=3):
            ex = relax.build(tvm_mod, target=tgt, relax_pipeline=pipeline)
        vm = relax.VirtualMachine(ex, tvm.cpu())
        vm.set_input("main", **inputs)
        vm.invoke_stateful("main")
        out = vm.get_outputs("main")
        print("[tvm] UNEXPECTED: succeeded. output type:", type(out))
    except Exception as e:
        print("[tvm] EXPECTED FAIL:", type(e).__name__, str(e).splitlines()[0])
        return 0

    return 0


if __name__ == "__main__":
    raise SystemExit(main())

model.zip

Triage

  • needs-triage

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions