-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Description
Expected behavior
Per ONNX Round operator spec:
import onnx
from onnx import defs
print('onnx.__version__ =', onnx.__version__)
# Print available versions for Round
schemas = defs.get_all_schemas_with_history()
round_s = [s for s in schemas if s.name=='Round']
print('Round schema versions:', sorted({s.since_version for s in round_s}))
for v in sorted({s.since_version for s in round_s}):
s = defs.get_schema('Round', v)
print('\n=== Round since_version', v, '===' )
print('domain:', s.domain)
print('doc:')
print(s.doc)
onnx.version = 1.17.0
Round schema versions: [11, 22]=== Round since_version 11 ===
domain:
doc:Round takes one input Tensor and rounds the values, element-wise, meaning
it finds the nearest integer for each value.
In case of halves, the rule is to round them to the nearest even integer.
If input x is integral, +0, -0, NaN, or infinite, x itself is returned.
The output tensor has the same shape and type as the input.Examples:
round([0.9]) = [1.0] round([2.5]) = [2.0] round([2.3]) = [2.0] round([1.5]) = [2.0] round([-4.5]) = [-4.0]=== Round since_version 22 ===
domain:
doc:Round takes one input Tensor and rounds the values, element-wise, meaning
it finds the nearest integer for each value.
In case of halves, the rule is to round them to the nearest even integer.
If input x is integral, +0, -0, NaN, or infinite, x itself is returned.
The output tensor has the same shape and type as the input.Examples:
round([0.9]) = [1.0] round([2.5]) = [2.0] round([2.3]) = [2.0] round([1.5]) = [2.0] round([-4.5]) = [-4.0]
Therefore, for this repro (where sigmoid(0)=0.5):
Round(0.5) == 0(nearest-even / ties-to-even)
Actual behavior
For the following model,
With TVM (Relax, LLVM target) for this repro:
Round(0.5) == 1
Environment
Operating System:Ubuntu 22.04.4 LTS
TVM version:0.23.0dev
pytorch version:2.9.1
ort version:1.23.2
onnx version: 1.20.0
openvino: 2025.4.0
python:3.11.14
Steps to reproduce
build a model
from __future__ import annotations
import argparse
from pathlib import Path
import onnx
from onnx import TensorProto, helper
def make_model(path="round_sigmoid.onnx"):
# x = 0.0 (double scalar)
const0 = helper.make_node(
"Constant",
inputs=[],
outputs=["x"],
value=helper.make_tensor("c0", TensorProto.DOUBLE, dims=[], vals=[0.0]),
)
sig = helper.make_node("Sigmoid", inputs=["x"], outputs=["s"])
rnd = helper.make_node("Round", inputs=["s"], outputs=["y"])
y_info = helper.make_tensor_value_info("y", TensorProto.DOUBLE, [])
graph = helper.make_graph([const0, sig, rnd], "g", inputs=[], outputs=[y_info])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 18)])
onnx.save(model, path)
print("saved:", path)
def main() -> int:
ap = argparse.ArgumentParser(description="Minimal repro for Round(sigmoid(0)) tie-breaking.")
ap.add_argument("--out", type=Path, default=Path("round_sigmoid_half.onnx"), help="Where to save the ONNX model.")
args = ap.parse_args()
out_path = args.out.resolve()
out_path.parent.mkdir(parents=True, exist_ok=True)
make_model(out_path.as_posix())
return 0
if __name__ == "__main__":
raise SystemExit(main())
Comparison results
from __future__ import annotations
import argparse
import sys
from pathlib import Path
from typing import Any, Optional
import numpy as np
import onnx
def _fmt(x: Any) -> str:
if isinstance(x, np.ndarray):
return f"ndarray(shape={x.shape}, dtype={x.dtype}, value={x})"
return repr(x)
def _run_torch_reference() -> Optional[np.ndarray]:
"""Reference for this specific model: Round(sigmoid(0.0))."""
try:
import torch # type: ignore
except Exception as e:
print("[torch] not available:", e)
return None
x = torch.tensor(0.0, dtype=torch.float64)
y = torch.round(torch.sigmoid(x))
out = np.array(y.item(), dtype=np.float64)
print("[torch] torch.__version__ =", getattr(torch, "__version__", None))
print("[torch] y =", _fmt(out))
return out
def _run_ort(model_bytes: bytes) -> Optional[np.ndarray]:
try:
import onnxruntime as ort # type: ignore
except Exception as e:
print("[ort] not available:", e)
return None
sess_opts = ort.SessionOptions()
sess_opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
sess = ort.InferenceSession(model_bytes, sess_options=sess_opts, providers=["CPUExecutionProvider"])
outs = sess.run(None, {}) # no inputs
out = np.array(outs[0])
print("[ort] onnxruntime.__version__ =", getattr(ort, "__version__", None))
print("[ort] y =", _fmt(out))
return out
def _run_openvino(model_path: Path) -> Optional[np.ndarray]:
try:
import openvino as ov # type: ignore
except Exception as e:
print("[ov] not available:", e)
return None
core = ov.Core()
model = core.read_model(model_path.as_posix())
compiled = core.compile_model(model, "CPU")
req = compiled.create_infer_request()
raw = req.infer({}) # no inputs
out_port = compiled.outputs[0]
out = np.array(raw[out_port])
print("[ov] openvino.__version__ =", getattr(ov, "__version__", None))
print("[ov] y =", _fmt(out))
return out
def _ensure_repo_tvm_python_on_syspath() -> None:
repo_root = Path(__file__).resolve().parents[3]
tvm_python = repo_root / "tvm" / "python"
if tvm_python.exists():
sys.path.insert(0, tvm_python.as_posix())
def _tvm_to_numpy(x: Any) -> np.ndarray:
if hasattr(x, "numpy"):
return x.numpy()
if isinstance(x, (int, float, bool, np.generic)):
return np.array(x)
return np.array(x)
def _run_tvm(model_path: Path) -> Optional[np.ndarray]:
_ensure_repo_tvm_python_on_syspath()
try:
import tvm # type: ignore
from tvm import relax # type: ignore
from tvm.relax.frontend.onnx import from_onnx # type: ignore
except Exception as e:
print("[tvm] not available:", e)
return None
onnx_model = onnx.load(model_path.as_posix())
mod = from_onnx(onnx_model, shape_dict={})
if isinstance(mod, (list, tuple)):
mod = mod[0]
tgt = tvm.target.Target("llvm")
pipeline = relax.pipeline.get_default_pipeline(tgt)
with tvm.transform.PassContext(opt_level=3):
ex = relax.build(mod, target=tgt, relax_pipeline=pipeline)
vm = relax.VirtualMachine(ex, tvm.cpu())
# no inputs
try:
out = vm["main"]()
except Exception:
vm.set_input("main")
vm.invoke_stateful("main")
out = vm.get_outputs("main")
out_np = _tvm_to_numpy(out)
print("[tvm] tvm.__file__ =", getattr(tvm, "__file__", None))
print("[tvm] tvm.__version__ =", getattr(tvm, "__version__", None))
print("[tvm] y =", _fmt(out_np))
return out_np
def _eq(a: Optional[np.ndarray], b: Optional[np.ndarray]) -> Optional[bool]:
if a is None or b is None:
return None
try:
return bool(np.array_equal(a, b))
except Exception:
return None
def main() -> int:
ap = argparse.ArgumentParser(description="Run ONNX model across runtimes and print outputs.")
ap.add_argument("--model", type=Path, required=True, help="Path to ONNX model (must have no inputs).")
ap.add_argument(
"--no-torch",
action="store_true",
help="Skip torch reference printing (useful if torch not installed).",
)
args = ap.parse_args()
model_path = args.model.resolve()
if not model_path.exists():
raise FileNotFoundError(model_path)
model_bytes = model_path.read_bytes()
y_torch = None if args.no_torch else _run_torch_reference()
y_ort = _run_ort(model_bytes)
y_tvm = _run_tvm(model_path)
y_ov = _run_openvino(model_path)
print("torch :", y_torch)
print("tvm:", y_tvm)
print("ov :", y_ov)
print("ort :", y_ort)
return 0
if __name__ == "__main__":
raise SystemExit(main())
Triage
Please refer to the list of label tags here to find the relevant tags and add them below in a bullet format (example below).
- needs-triage