Skip to content

Incorrect Clip NaN handling of TensorRT 10.16.1.11 when running ONNX Clip on GPU #4773

@ALinrunrun

Description

@ALinrunrun

Description

TensorRT appears to handle NaN values incorrectly for ONNX Clip.

For an input containing NaN, ONNX Runtime preserves the NaN value in the clipped output. TensorRT instead replaces the NaN input with the lower bound value.

This appears to be a TensorRT execution issue for ONNX Clip NaN propagation.

Environment

TensorRT Version: 10.16.1.11

NVIDIA GPU: N/A / not detected by nvidia-smi

NVIDIA Driver Version: N/A / nvidia-smi failed

CUDA Version: N/A / nvcc not found

CUDNN Version: N/A / torch.backends.cudnn.version() returned None

Operating System: Linux 6.17.0-20-generic x86_64, glibc 2.39

Python Version (if applicable): Python 3.11.15

Tensorflow Version (if applicable): N/A

PyTorch Version (if applicable): N/A

Baremetal or Container (if so, version): Baremetal / non-Docker environment (/proc/1/cgroup: 0::/init.scope)

Additional package versions:

ONNX Version: 1.21.0
ONNX Runtime Version: 1.25.1

Relevant Files

Model link: N/A

The ONNX model is generated inline by the minimal reproducible script below.

Steps To Reproduce

Commands or scripts:

import numpy as np
import onnx
import onnxruntime as ort
from onnx import helper, TensorProto
from _trt_helper import build_engine_from_onnx, run_engine

X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [5])
MN = helper.make_tensor_value_info("MN", TensorProto.FLOAT, [])
MX = helper.make_tensor_value_info("MX", TensorProto.FLOAT, [])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [5])

g = helper.make_graph(
    [helper.make_node("Clip", ["X", "MN", "MX"], ["Y"])],
    "g",
    [X, MN, MX],
    [Y],
)

m = helper.make_model(g, opset_imports=[helper.make_opsetid("", 18)])
m.ir_version = 10
ob = m.SerializeToString()

x = np.array([3.5, np.nan, -7.0, 12.0, 0.0], dtype=np.float32)
mn = np.array(-5.0, dtype=np.float32)
mx = np.array(8.0, dtype=np.float32)

ort_y = ort.InferenceSession(
    ob,
    providers=["CPUExecutionProvider"],
).run(["Y"], {"X": x, "MN": mn, "MX": mx})[0]

eng, _ = build_engine_from_onnx(ob)
trt_y = run_engine(
    eng,
    {"X": x, "MN": mn, "MX": mx},
    ["Y"],
    [(5,)],
    [np.float32],
)["Y"]

print("ORT:", ort_y.tolist())
print("TRT:", trt_y.tolist())

assert np.isnan(ort_y[1]) and not np.isnan(trt_y[1])

Have you tried the latest release?: Yes, reproduced with TensorRT 10.16.1.11.

Attach the captured .json and .bin files from TensorRT's API Capture tool if you're on an x86_64 Unix system Not attached. The issue is reproducible from the self-contained Python script above.

Can this model run on other frameworks? For example run ONNX model with ONNXRuntime (polygraphy run <model.onnx> --onnxrt): For example run ONNX model with ONNXRuntime (polygraphy run <model.onnx> --onnxrt):

Yes. ONNX Runtime runs the same model and preserves the NaN value.

Actual output:

ORT: [3.5, nan, -5.0, 8.0, 0.0]
TRT: [3.5, -5.0, -5.0, 8.0, 0.0]

TensorRT replaces the NaN input with the lower bound -5.0, while ONNX Runtime preserves NaN.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Module:ONNXIssues relating to ONNX usage and import

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions