Skip to content

[Bug] ONNX Round tie-breaking mismatch on 0.5: TVM lowers to llvm.round (ties-away-from-zero) so Round(sigmoid(0))=1, while ONNX spec requires Round(0.5)=0 (ties-to-even) #18590

@dutZ1855

Description

@dutZ1855

Expected behavior

Per ONNX Round operator spec:

import onnx
from onnx import defs
print('onnx.__version__ =', onnx.__version__)
# Print available versions for Round
schemas = defs.get_all_schemas_with_history()
round_s = [s for s in schemas if s.name=='Round']
print('Round schema versions:', sorted({s.since_version for s in round_s}))
for v in sorted({s.since_version for s in round_s}):
    s = defs.get_schema('Round', v)
    print('\n=== Round since_version', v, '===' )
    print('domain:', s.domain)
    print('doc:')
    print(s.doc)

onnx.version = 1.17.0
Round schema versions: [11, 22]

=== Round since_version 11 ===
domain:
doc:

Round takes one input Tensor and rounds the values, element-wise, meaning
it finds the nearest integer for each value.
In case of halves, the rule is to round them to the nearest even integer.
If input x is integral, +0, -0, NaN, or infinite, x itself is returned.
The output tensor has the same shape and type as the input.

Examples:

round([0.9]) = [1.0]
round([2.5]) = [2.0]
round([2.3]) = [2.0]
round([1.5]) = [2.0]
round([-4.5]) = [-4.0]

=== Round since_version 22 ===
domain:
doc:

Round takes one input Tensor and rounds the values, element-wise, meaning
it finds the nearest integer for each value.
In case of halves, the rule is to round them to the nearest even integer.
If input x is integral, +0, -0, NaN, or infinite, x itself is returned.
The output tensor has the same shape and type as the input.

Examples:

round([0.9]) = [1.0]
round([2.5]) = [2.0]
round([2.3]) = [2.0]
round([1.5]) = [2.0]
round([-4.5]) = [-4.0]

Therefore, for this repro (where sigmoid(0)=0.5):

  • Round(0.5) == 0 (nearest-even / ties-to-even)

Actual behavior

For the following model,

Image

With TVM (Relax, LLVM target) for this repro:

  • Round(0.5) == 1

Environment

Operating System:Ubuntu 22.04.4 LTS
TVM version:0.23.0dev
pytorch version:2.9.1
ort version:1.23.2
onnx version: 1.20.0
openvino: 2025.4.0
python:3.11.14

Steps to reproduce

build a model

from __future__ import annotations

import argparse
from pathlib import Path

import onnx
from onnx import TensorProto, helper


def make_model(path="round_sigmoid.onnx"):
    # x = 0.0 (double scalar)
    const0 = helper.make_node(
        "Constant",
        inputs=[],
        outputs=["x"],
        value=helper.make_tensor("c0", TensorProto.DOUBLE, dims=[], vals=[0.0]),
    )
    sig = helper.make_node("Sigmoid", inputs=["x"], outputs=["s"])
    rnd = helper.make_node("Round", inputs=["s"], outputs=["y"])

    y_info = helper.make_tensor_value_info("y", TensorProto.DOUBLE, [])
    graph = helper.make_graph([const0, sig, rnd], "g", inputs=[], outputs=[y_info])
    model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 18)])
    onnx.save(model, path)
    print("saved:", path)

def main() -> int:
    ap = argparse.ArgumentParser(description="Minimal repro for Round(sigmoid(0)) tie-breaking.")
    ap.add_argument("--out", type=Path, default=Path("round_sigmoid_half.onnx"), help="Where to save the ONNX model.")
    args = ap.parse_args()

    out_path = args.out.resolve()
    out_path.parent.mkdir(parents=True, exist_ok=True)
    make_model(out_path.as_posix())
    return 0

if __name__ == "__main__":
    raise SystemExit(main())

Comparison results

from __future__ import annotations

import argparse
import sys
from pathlib import Path
from typing import Any, Optional

import numpy as np
import onnx


def _fmt(x: Any) -> str:
    if isinstance(x, np.ndarray):
        return f"ndarray(shape={x.shape}, dtype={x.dtype}, value={x})"
    return repr(x)


def _run_torch_reference() -> Optional[np.ndarray]:
    """Reference for this specific model: Round(sigmoid(0.0))."""
    try:
        import torch  # type: ignore
    except Exception as e:
        print("[torch] not available:", e)
        return None
    x = torch.tensor(0.0, dtype=torch.float64)
    y = torch.round(torch.sigmoid(x))
    out = np.array(y.item(), dtype=np.float64)
    print("[torch] torch.__version__ =", getattr(torch, "__version__", None))
    print("[torch] y =", _fmt(out))
    return out


def _run_ort(model_bytes: bytes) -> Optional[np.ndarray]:
    try:
        import onnxruntime as ort  # type: ignore
    except Exception as e:
        print("[ort] not available:", e)
        return None
    sess_opts = ort.SessionOptions()
    sess_opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
    sess = ort.InferenceSession(model_bytes, sess_options=sess_opts, providers=["CPUExecutionProvider"])
    outs = sess.run(None, {})  # no inputs
    out = np.array(outs[0])
    print("[ort] onnxruntime.__version__ =", getattr(ort, "__version__", None))
    print("[ort] y =", _fmt(out))
    return out


def _run_openvino(model_path: Path) -> Optional[np.ndarray]:
    try:
        import openvino as ov  # type: ignore
    except Exception as e:
        print("[ov] not available:", e)
        return None
    core = ov.Core()
    model = core.read_model(model_path.as_posix())
    compiled = core.compile_model(model, "CPU")
    req = compiled.create_infer_request()
    raw = req.infer({})  # no inputs
    out_port = compiled.outputs[0]
    out = np.array(raw[out_port])
    print("[ov] openvino.__version__ =", getattr(ov, "__version__", None))
    print("[ov] y =", _fmt(out))
    return out


def _ensure_repo_tvm_python_on_syspath() -> None:
    repo_root = Path(__file__).resolve().parents[3]
    tvm_python = repo_root / "tvm" / "python"
    if tvm_python.exists():
        sys.path.insert(0, tvm_python.as_posix())


def _tvm_to_numpy(x: Any) -> np.ndarray:
    if hasattr(x, "numpy"):
        return x.numpy()
    if isinstance(x, (int, float, bool, np.generic)):
        return np.array(x)
    return np.array(x)


def _run_tvm(model_path: Path) -> Optional[np.ndarray]:
    _ensure_repo_tvm_python_on_syspath()
    try:
        import tvm  # type: ignore
        from tvm import relax  # type: ignore
        from tvm.relax.frontend.onnx import from_onnx  # type: ignore
    except Exception as e:
        print("[tvm] not available:", e)
        return None

    onnx_model = onnx.load(model_path.as_posix())
    mod = from_onnx(onnx_model, shape_dict={})
    if isinstance(mod, (list, tuple)):
        mod = mod[0]

    tgt = tvm.target.Target("llvm")
    pipeline = relax.pipeline.get_default_pipeline(tgt)
    with tvm.transform.PassContext(opt_level=3):
        ex = relax.build(mod, target=tgt, relax_pipeline=pipeline)
    vm = relax.VirtualMachine(ex, tvm.cpu())

    # no inputs
    try:
        out = vm["main"]()
    except Exception:
        vm.set_input("main")
        vm.invoke_stateful("main")
        out = vm.get_outputs("main")

    out_np = _tvm_to_numpy(out)
    print("[tvm] tvm.__file__ =", getattr(tvm, "__file__", None))
    print("[tvm] tvm.__version__ =", getattr(tvm, "__version__", None))
    print("[tvm] y =", _fmt(out_np))
    return out_np


def _eq(a: Optional[np.ndarray], b: Optional[np.ndarray]) -> Optional[bool]:
    if a is None or b is None:
        return None
    try:
        return bool(np.array_equal(a, b))
    except Exception:
        return None


def main() -> int:
    ap = argparse.ArgumentParser(description="Run ONNX model across runtimes and print outputs.")
    ap.add_argument("--model", type=Path, required=True, help="Path to ONNX model (must have no inputs).")
    ap.add_argument(
        "--no-torch",
        action="store_true",
        help="Skip torch reference printing (useful if torch not installed).",
    )
    args = ap.parse_args()

    model_path = args.model.resolve()
    if not model_path.exists():
        raise FileNotFoundError(model_path)
    model_bytes = model_path.read_bytes()


    y_torch = None if args.no_torch else _run_torch_reference()
    y_ort = _run_ort(model_bytes)
    y_tvm = _run_tvm(model_path)
    y_ov = _run_openvino(model_path)

    print("torch :", y_torch)
    print("tvm:",  y_tvm)
    print("ov :",  y_ov)
    print("ort :", y_ort)
    return 0


if __name__ == "__main__":
    raise SystemExit(main())

Triage

Please refer to the list of label tags here to find the relevant tags and add them below in a bullet format (example below).

  • needs-triage

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions