Skip to content

Commit a8ff3f3

Browse files
authored
[TRT RTX EP] Add support for D3D12 external resource import (#26948)
This PR adds a test for CiG inference to demonstrate how usage for it should look like. It is important to not call `cudaSetDevice` in that flow since it will create a new context. @nieubank I am not sure why there was a `cudaSetDevice` on each import call 🤔 Is this done to enable importing semaphores of e.g. GPU:1 to a session running on GPU:0 ? Context management is unreliable with the current CUDA runtime and CUDA driver API mixing.
1 parent 06fe9a4 commit a8ff3f3

File tree

8 files changed

+1762
-33
lines changed

8 files changed

+1762
-33
lines changed

cmake/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1441,7 +1441,7 @@ get_property(onnxruntime_GENERATOR_IS_MULTI_CONFIG GLOBAL PROPERTY GENERATOR_IS_
14411441
if (onnxruntime_USE_CUDA)
14421442
set(CMAKE_CUDA_STANDARD 17)
14431443
if(onnxruntime_CUDA_HOME)
1444-
file(TO_CMAKE_PATH CUDAToolkit_ROOT ${onnxruntime_CUDA_HOME})
1444+
file(TO_CMAKE_PATH ${onnxruntime_CUDA_HOME} CUDAToolkit_ROOT)
14451445
endif()
14461446
find_package(CUDAToolkit REQUIRED)
14471447

cmake/onnxruntime_providers_nv.cmake

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
# Copyright (c) Microsoft Corporation. All rights reserved.
22
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
33
# Licensed under the MIT License.
4-
find_package(CUDAToolkit REQUIRED 12.8)
4+
if(onnxruntime_CUDA_HOME)
5+
file(TO_CMAKE_PATH ${onnxruntime_CUDA_HOME} CUDAToolkit_ROOT)
6+
endif()
7+
find_package(CUDAToolkit REQUIRED)
58
enable_language(CUDA)
69
if(onnxruntime_DISABLE_CONTRIB_OPS)
710
message( FATAL_ERROR "To compile TensorRT execution provider contrib ops have to be enabled to dump an engine using com.microsoft:EPContext node." )
@@ -146,9 +149,9 @@ endif ()
146149
target_link_libraries(onnxruntime_providers_nv_tensorrt_rtx PRIVATE Eigen3::Eigen onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface Eigen3::Eigen)
147150
add_dependencies(onnxruntime_providers_nv_tensorrt_rtx onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES})
148151
if (onnxruntime_USE_TENSORRT_BUILTIN_PARSER)
149-
target_link_libraries(onnxruntime_providers_nv_tensorrt_rtx PRIVATE ${trt_link_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface ${ABSEIL_LIBS} PUBLIC CUDA::cudart)
152+
target_link_libraries(onnxruntime_providers_nv_tensorrt_rtx PRIVATE ${trt_link_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface ${ABSEIL_LIBS} PUBLIC CUDA::cudart CUDA::cuda_driver)
150153
else()
151-
target_link_libraries(onnxruntime_providers_nv_tensorrt_rtx PRIVATE ${onnxparser_link_libs} ${trt_link_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers ${ABSEIL_LIBS} PUBLIC CUDA::cudart)
154+
target_link_libraries(onnxruntime_providers_nv_tensorrt_rtx PRIVATE ${onnxparser_link_libs} ${trt_link_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers ${ABSEIL_LIBS} PUBLIC CUDA::cudart CUDA::cuda_driver)
152155
endif()
153156
target_include_directories(onnxruntime_providers_nv_tensorrt_rtx PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${TENSORRT_RTX_INCLUDE_DIR} ${onnx_tensorrt_SOURCE_DIR}
154157
PUBLIC ${CUDAToolkit_INCLUDE_DIRS})

onnxruntime/core/providers/cuda/shared_inc/cuda_call.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ std::conditional_t<THRW, void, Status> CudaCall(
1616
const char* file, const int line);
1717

1818
#define CUDA_CALL(expr) (::onnxruntime::CudaCall<cudaError, false>((expr), #expr, "CUDA", cudaSuccess, "", __FILE__, __LINE__))
19+
#define CU_CALL(expr) (::onnxruntime::CudaCall<CUresult, false>((expr), #expr, "CUDA", CUDA_SUCCESS, "", __FILE__, __LINE__))
1920
#define CUBLAS_CALL(expr) (::onnxruntime::CudaCall<cublasStatus_t, false>((expr), #expr, "CUBLAS", CUBLAS_STATUS_SUCCESS, "", __FILE__, __LINE__))
2021

2122
#define CUSPARSE_CALL(expr) (::onnxruntime::CudaCall<cusparseStatus_t, false>((expr), #expr, "CUSPARSE", CUSPARSE_STATUS_SUCCESS, "", __FILE__, __LINE__))
@@ -26,6 +27,7 @@ std::conditional_t<THRW, void, Status> CudaCall(
2627
#define CUFFT_CALL(expr) (::onnxruntime::CudaCall<cufftResult, false>((expr), #expr, "CUFFT", CUFFT_SUCCESS, "", __FILE__, __LINE__))
2728

2829
#define CUDA_CALL_THROW(expr) (::onnxruntime::CudaCall<cudaError, true>((expr), #expr, "CUDA", cudaSuccess, "", __FILE__, __LINE__))
30+
#define CU_CALL_THROW(expr) (::onnxruntime::CudaCall<CUresult, true>((expr), #expr, "CUDA", CUDA_SUCCESS, "", __FILE__, __LINE__))
2931
#define CUBLAS_CALL_THROW(expr) (::onnxruntime::CudaCall<cublasStatus_t, true>((expr), #expr, "CUBLAS", CUBLAS_STATUS_SUCCESS, "", __FILE__, __LINE__))
3032

3133
#define CUSPARSE_CALL_THROW(expr) (::onnxruntime::CudaCall<cusparseStatus_t, true>((expr), #expr, "CUSPARSE", CUSPARSE_STATUS_SUCCESS, "", __FILE__, __LINE__))

onnxruntime/core/providers/nv_tensorrt_rtx/nv_cuda_call.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,13 @@ const char* CudaErrString<cudaError_t>(cudaError_t x) {
3131
return cudaGetErrorString(x);
3232
}
3333

34+
template <>
35+
const char* CudaErrString<CUresult>(CUresult x) {
36+
const char* errorStr = NULL;
37+
cuGetErrorString(x, &errorStr);
38+
return errorStr;
39+
}
40+
3441
#ifndef USE_CUDA_MINIMAL
3542
template <>
3643
const char* CudaErrString<cublasStatus_t>(cublasStatus_t e) {
@@ -141,5 +148,7 @@ std::conditional_t<THRW, void, Status> CudaCall(
141148

142149
template Status CudaCall<cudaError, false>(cudaError retCode, const char* exprString, const char* libName, cudaError successCode, const char* msg, const char* file, const int line);
143150
template void CudaCall<cudaError, true>(cudaError retCode, const char* exprString, const char* libName, cudaError successCode, const char* msg, const char* file, const int line);
151+
template Status CudaCall<CUresult, false>(CUresult retCode, const char* exprString, const char* libName, CUresult successCode, const char* msg, const char* file, const int line);
152+
template void CudaCall<CUresult, true>(CUresult retCode, const char* exprString, const char* libName, CUresult successCode, const char* msg, const char* file, const int line);
144153

145154
} // namespace onnxruntime

onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -959,8 +959,6 @@ NvExecutionProvider::NvExecutionProvider(const NvExecutionProviderInfo& info)
959959
device_id_(info.device_id) {
960960
InitProviderOrtApi();
961961

962-
// TODO(maximlianm) remove this since we should be able to compile an AOT context file without GPU
963-
964962
if (!info.has_user_compute_stream) {
965963
// If the app is passing in a compute stream, it already has initialized cuda and created a context.
966964
// Calling cudaSetDevice() will set the default context in the current thread

onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_utils.h

Lines changed: 49 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
33
// Licensed under the MIT License.
44

5+
#pragma once
6+
57
#include <fstream>
68
#include <unordered_map>
79
#include <string>
@@ -10,10 +12,11 @@
1012
#include <iostream>
1113
#include <filesystem>
1214
#include "flatbuffers/idl.h"
13-
#include <NvInferVersion.h>
15+
#include "nv_includes.h"
1416
#include "core/providers/cuda/cuda_pch.h"
1517
#include "core/common/path_string.h"
1618
#include "core/framework/murmurhash3.h"
19+
#include "core/providers/cuda/shared_inc/cuda_call.h"
1720

1821
namespace fs = std::filesystem;
1922

@@ -31,7 +34,7 @@ namespace onnxruntime {
3134
* }
3235
*
3336
*/
34-
int GetNumProfiles(std::unordered_map<std::string, std::vector<std::vector<int64_t>>>& profile_shapes) {
37+
static int GetNumProfiles(std::unordered_map<std::string, std::vector<std::vector<int64_t>>>& profile_shapes) {
3538
int num_profile = 0;
3639
for (auto it = profile_shapes.begin(); it != profile_shapes.end(); it++) {
3740
num_profile = static_cast<int>(it->second.size());
@@ -52,7 +55,7 @@ int GetNumProfiles(std::unordered_map<std::string, std::vector<std::vector<int64
5255
*
5356
* [Deprecated] Use SerializeProfileV2
5457
*/
55-
void SerializeProfile(const std::string& file_name, std::unordered_map<std::string, std::unordered_map<size_t, std::pair<int64_t, int64_t>>>& shape_ranges) {
58+
static void SerializeProfile(const std::string& file_name, std::unordered_map<std::string, std::unordered_map<size_t, std::pair<int64_t, int64_t>>>& shape_ranges) {
5659
// Serialize profile
5760
flexbuffers::Builder builder;
5861
auto profile_start = builder.StartMap();
@@ -78,7 +81,7 @@ void SerializeProfile(const std::string& file_name, std::unordered_map<std::stri
7881

7982
// Deserialize engine profile
8083
// [Deprecated] Use DeserializeProfileV2
81-
std::unordered_map<std::string, std::unordered_map<size_t, std::pair<int64_t, int64_t>>> DeserializeProfile(std::ifstream& infile) {
84+
static std::unordered_map<std::string, std::unordered_map<size_t, std::pair<int64_t, int64_t>>> DeserializeProfile(std::ifstream& infile) {
8285
// Load flexbuffer
8386
infile.seekg(0, std::ios::end);
8487
size_t length = infile.tellg();
@@ -153,7 +156,7 @@ std::unordered_map<std::string, std::unordered_map<size_t, std::pair<int64_t, in
153156
* }
154157
*
155158
*/
156-
void SerializeProfileV2(const std::string& file_name, std::unordered_map<std::string, std::unordered_map<size_t, std::vector<std::vector<int64_t>>>>& shape_ranges) {
159+
static void SerializeProfileV2(const std::string& file_name, std::unordered_map<std::string, std::unordered_map<size_t, std::vector<std::vector<int64_t>>>>& shape_ranges) {
157160
LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] In SerializeProfileV2()";
158161
// Serialize profile
159162
flexbuffers::Builder builder;
@@ -233,7 +236,7 @@ void SerializeProfileV2(const std::string& file_name, std::unordered_map<std::st
233236
* }
234237
* }
235238
*/
236-
std::unordered_map<std::string, std::unordered_map<size_t, std::vector<std::vector<int64_t>>>> DeserializeProfileV2(std::ifstream& infile) {
239+
static std::unordered_map<std::string, std::unordered_map<size_t, std::vector<std::vector<int64_t>>>> DeserializeProfileV2(std::ifstream& infile) {
237240
LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] In DeserializeProfileV2()";
238241
// Load flexbuffer
239242
infile.seekg(0, std::ios::end);
@@ -278,10 +281,10 @@ std::unordered_map<std::string, std::unordered_map<size_t, std::vector<std::vect
278281
* Return false meaning no need to rebuild engine if everything is same.
279282
* Otherwise return true and engine needs to be rebuilt.
280283
*/
281-
bool CompareProfiles(const std::string& file_name,
282-
std::unordered_map<std::string, std::vector<std::vector<int64_t>>>& profile_min_shapes,
283-
std::unordered_map<std::string, std::vector<std::vector<int64_t>>>& profile_max_shapes,
284-
std::unordered_map<std::string, std::vector<std::vector<int64_t>>>& profile_opt_shapes) {
284+
static bool CompareProfiles(const std::string& file_name,
285+
std::unordered_map<std::string, std::vector<std::vector<int64_t>>>& profile_min_shapes,
286+
std::unordered_map<std::string, std::vector<std::vector<int64_t>>>& profile_max_shapes,
287+
std::unordered_map<std::string, std::vector<std::vector<int64_t>>>& profile_opt_shapes) {
285288
std::ifstream profile_file(file_name, std::ios::binary | std::ios::in);
286289
if (!profile_file) {
287290
LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] " << file_name << " doesn't exist.";
@@ -372,7 +375,7 @@ bool CompareProfiles(const std::string& file_name,
372375
* Get cache by name
373376
*
374377
*/
375-
std::string GetCachePath(const std::string& root, const std::string& name) {
378+
static std::string GetCachePath(const std::string& root, const std::string& name) {
376379
if (root.empty()) {
377380
return name;
378381
} else {
@@ -386,7 +389,7 @@ std::string GetCachePath(const std::string& root, const std::string& name) {
386389
* Get compute capability
387390
*
388391
*/
389-
std::string GetComputeCapability(const cudaDeviceProp& prop) {
392+
static std::string GetComputeCapability(const cudaDeviceProp& prop) {
390393
const std::string compute_capability = std::to_string(prop.major * 10 + prop.minor);
391394
return compute_capability;
392395
}
@@ -397,7 +400,7 @@ std::string GetComputeCapability(const cudaDeviceProp& prop) {
397400
* \param root root path of the cache
398401
* \param file_extension It could be ".engine", ".profile" or ".timing"
399402
*/
400-
std::vector<fs::path> GetCachesByType(const std::string& root, std::string file_extension) {
403+
static std::vector<fs::path> GetCachesByType(const std::string& root, std::string file_extension) {
401404
std::vector<fs::path> cache_files;
402405
for (const auto& entry : fs::directory_iterator(root)) {
403406
if (fs::path(file_extension) == fs::path(entry).extension()) {
@@ -407,15 +410,15 @@ std::vector<fs::path> GetCachesByType(const std::string& root, std::string file_
407410
return cache_files;
408411
}
409412

410-
bool IsCacheExistedByType(const std::string& root, std::string file_extension) {
413+
static bool IsCacheExistedByType(const std::string& root, std::string file_extension) {
411414
auto cache_files = GetCachesByType(root, file_extension);
412415
if (cache_files.size() == 0) {
413416
return false;
414417
}
415418
return true;
416419
}
417420

418-
void RemoveCachesByType(const std::string& root, std::string file_extension) {
421+
static void RemoveCachesByType(const std::string& root, std::string file_extension) {
419422
auto cache_files = GetCachesByType(root, file_extension);
420423
for (const auto& entry : cache_files) {
421424
fs::remove(entry);
@@ -431,7 +434,7 @@ void RemoveCachesByType(const std::string& root, std::string file_extension) {
431434
* compiled kernels, so the name must be unique and deterministic across models and sessions.
432435
* </remarks>
433436
*/
434-
HashValue TRTGenerateId(const GraphViewer& graph_viewer, std::string trt_version, std::string cuda_version) {
437+
static HashValue TRTGenerateId(const GraphViewer& graph_viewer, std::string trt_version, std::string cuda_version) {
435438
HashValue model_hash = 0;
436439

437440
// find the top level graph
@@ -507,9 +510,9 @@ HashValue TRTGenerateId(const GraphViewer& graph_viewer, std::string trt_version
507510
return model_hash;
508511
}
509512

510-
bool ValidateProfileShapes(std::unordered_map<std::string, std::vector<std::vector<int64_t>>>& profile_min_shapes,
511-
std::unordered_map<std::string, std::vector<std::vector<int64_t>>>& profile_max_shapes,
512-
std::unordered_map<std::string, std::vector<std::vector<int64_t>>>& profile_opt_shapes) {
513+
static bool ValidateProfileShapes(std::unordered_map<std::string, std::vector<std::vector<int64_t>>>& profile_min_shapes,
514+
std::unordered_map<std::string, std::vector<std::vector<int64_t>>>& profile_max_shapes,
515+
std::unordered_map<std::string, std::vector<std::vector<int64_t>>>& profile_opt_shapes) {
513516
if (profile_min_shapes.empty() && profile_max_shapes.empty() && profile_opt_shapes.empty()) {
514517
return true;
515518
}
@@ -552,7 +555,7 @@ bool ValidateProfileShapes(std::unordered_map<std::string, std::vector<std::vect
552555
*
553556
* Return true if string can be successfully parsed or false if string has wrong format.
554557
*/
555-
bool MakeInputNameShapePair(std::string pair_string, std::pair<std::string, std::vector<int64_t>>& pair) {
558+
static bool MakeInputNameShapePair(std::string pair_string, std::pair<std::string, std::vector<int64_t>>& pair) {
556559
if (pair_string.empty()) {
557560
return true;
558561
}
@@ -595,7 +598,7 @@ bool MakeInputNameShapePair(std::string pair_string, std::pair<std::string, std:
595598
*
596599
* Return true if string can be successfully parsed or false if string has wrong format.
597600
*/
598-
bool ParseProfileShapes(std::string profile_shapes_string, std::unordered_map<std::string, std::vector<std::vector<int64_t>>>& profile_shapes) {
601+
static bool ParseProfileShapes(std::string profile_shapes_string, std::unordered_map<std::string, std::vector<std::vector<int64_t>>>& profile_shapes) {
599602
if (profile_shapes_string.empty()) {
600603
return true;
601604
}
@@ -628,7 +631,7 @@ bool ParseProfileShapes(std::string profile_shapes_string, std::unordered_map<st
628631
return true;
629632
}
630633

631-
std::vector<std::string> split(const std::string& str, char delimiter) {
634+
static std::vector<std::string> split(const std::string& str, char delimiter) {
632635
std::vector<std::string> tokens;
633636
std::string token;
634637
std::istringstream tokenStream(str);
@@ -638,7 +641,7 @@ std::vector<std::string> split(const std::string& str, char delimiter) {
638641
return tokens;
639642
}
640643

641-
std::string join(const std::vector<std::string>& vec, const std::string& delimiter) {
644+
static std::string join(const std::vector<std::string>& vec, const std::string& delimiter) {
642645
std::string result;
643646
for (size_t i = 0; i < vec.size(); ++i) {
644647
result += vec[i];
@@ -657,7 +660,7 @@ std::string join(const std::vector<std::string>& vec, const std::string& delimit
657660
* This func will generate the suffix "2068723788287043730_189_fp16"
658661
*
659662
*/
660-
std::string GetCacheSuffix(const std::string& fused_node_name, const std::string& trt_node_name_with_precision) {
663+
static std::string GetCacheSuffix(const std::string& fused_node_name, const std::string& trt_node_name_with_precision) {
661664
std::vector<std::string> split_fused_node_name = split(fused_node_name, '_');
662665
if (split_fused_node_name.size() >= 3) {
663666
// Get index of model hash from fused_node_name
@@ -697,4 +700,26 @@ static bool checkTrtTensorIsDynamic(nvinfer1::ITensor* tensor) {
697700
return checkTrtDimIsDynamic(tensor->getDimensions());
698701
}
699702
}
703+
704+
struct ScopedContext {
705+
explicit ScopedContext(int device_id) {
706+
CUcontext cu_context = 0;
707+
CU_CALL_THROW(cuCtxGetCurrent(&cu_context));
708+
if (!cu_context) {
709+
// cuCtxGetCurrent succeeded but returned nullptr, which indicates that no CUDA context
710+
// is currently set for this thread. This implicates that there is not user created context.
711+
// We use runtime API to initialize a context for the specified device.
712+
CUDA_CALL_THROW(cudaSetDevice(device_id));
713+
CU_CALL_THROW(cuCtxGetCurrent(&cu_context));
714+
}
715+
CU_CALL_THROW(cuCtxPushCurrent(cu_context));
716+
}
717+
718+
ScopedContext(const ScopedContext&) = delete;
719+
720+
~ScopedContext() {
721+
// Destructor must not throw. Perform a best-effort pop of the current context.
722+
cuCtxPopCurrent(nullptr);
723+
}
724+
};
700725
} // namespace onnxruntime

0 commit comments

Comments
 (0)