Skip to content

Commit 3a20910

Browse files
authored
[NvTensorRT RTX] Add Bfloat16 (#24743)
### Description TRT supports Bfloat 16 and ORT does as well. In addition the `setup.py` was missing a copy for NVTRT EP and TRT EP can only be built against the packaged parser with TRT RTX.
1 parent 9dad9af commit 3a20910

28 files changed

+144
-23
lines changed

cmake/onnxruntime_providers_nv.cmake

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
23
# Licensed under the MIT License.
34
find_package(CUDAToolkit REQUIRED 12.8)
45
enable_language(CUDA)
@@ -9,6 +10,9 @@
910
if (onnxruntime_NV_PLACEHOLDER_BUILDER)
1011
add_definitions(-DORT_NV_PLACEHOLDER_BUILDER)
1112
endif()
13+
if (NOT onnxruntime_USE_TENSORRT_BUILTIN_PARSER)
14+
message(FATAL_ERROR "TensorRT RTX can not be used with the open source parser.")
15+
endif ()
1216
set(BUILD_LIBRARY_ONLY 1)
1317
add_definitions("-DONNX_ML=1")
1418
add_definitions("-DONNX_NAMESPACE=onnx")

include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ struct OrtTensorRTProviderOptionsV2 {
2121
int trt_min_subgraph_size{1}; // minimum size of TensorRT subgraphs
2222
size_t trt_max_workspace_size{0}; // maximum workspace size for TensorRT. Default is 0 means max device memory size
2323
int trt_fp16_enable{0}; // enable TensorRT FP16 precision. Default 0 = false, nonzero = true
24+
int trt_bf16_enable{0}; // enable TensorRT BF16 precision. Default 0 = false, nonzero = true
2425
int trt_int8_enable{0}; // enable TensorRT INT8 precision. Default 0 = false, nonzero = true
2526
const char* trt_int8_calibration_table_name{nullptr}; // TensorRT INT8 calibration table name.
2627
int trt_int8_use_native_calibration_table{0}; // use native TensorRT generated calibration table. Default 0 = false, nonzero = true

onnxruntime/core/providers/nv_tensorrt_rtx/nv_allocator.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
23
// Licensed under the MIT License.
34

45
#include "nv_allocator.h"

onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -745,6 +745,7 @@ Status BindContextInput(Ort::KernelContext& ctx,
745745
switch (tensor_type) {
746746
CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, float)
747747
CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t)
748+
CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16, uint16_t)
748749
CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool)
749750
CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t)
750751
CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t)
@@ -831,6 +832,7 @@ Status BindContextOutput(Ort::KernelContext& ctx,
831832
switch (output_type) {
832833
CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, float)
833834
CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t)
835+
CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16, uint16_t)
834836
CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool)
835837
CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t)
836838
CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t)
@@ -894,6 +896,7 @@ Status BindKernelOutput(Ort::KernelContext& ctx,
894896
switch (output_type) {
895897
CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, float)
896898
CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t)
899+
CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16, uint16_t)
897900
CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool)
898901
CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t)
899902
CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t)

onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
23
// Licensed under the MIT License.
34

45
#pragma once

onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
23
// Licensed under the MIT License.
34

45
#include <unordered_set>

onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
23
// Licensed under the MIT License.
34

45
#pragma once

onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_helper.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
23
// Licensed under the MIT License.
34

45
#include "core/providers/shared_library/provider_api.h"

onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
23
// Licensed under the MIT License.
34

45
#include "core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h"

onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
23
// Licensed under the MIT License.
34

45
#pragma once

0 commit comments

Comments
 (0)