Skip to content

Commit 3f5f753

Browse files
committed
Adding CIG context creation in OrtFactory
- Added Init and Deinit API which are to be called from application before calling any interop or ORT APIs
1 parent 99b06df commit 3f5f753

File tree

14 files changed

+472
-4
lines changed

14 files changed

+472
-4
lines changed

cmake/CMakeLists.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,15 @@ option(onnxruntime_USE_AZURE "Build with azure inferencing support" OFF)
247247
option(onnxruntime_USE_LOCK_FREE_QUEUE "Build with lock-free task queue for threadpool." OFF)
248248
option(onnxruntime_FORCE_GENERIC_ALGORITHMS "Disable optimized arch-specific algorithms. Use only for testing and debugging generic algorithms." OFF)
249249

250+
# DX interop feature option
251+
option(onnxruntime_USE_DX_INTEROP "Build with the DX Interop feature for graphics API synchronization." OFF)
252+
253+
if (onnxruntime_USE_DX_INTEROP)
254+
add_compile_definitions(USE_DX_INTEROP=1)
255+
else()
256+
add_compile_definitions(USE_DX_INTEROP=0)
257+
endif()
258+
250259
option(onnxruntime_USE_TENSORRT_INTERFACE "Build ONNXRuntime shared lib which is compatible with TensorRT EP interface" OFF)
251260
option(onnxruntime_USE_NV_INTERFACE "Build ONNXRuntime shared lib which is compatible with NV EP interface" OFF)
252261
option(onnxruntime_USE_CUDA_INTERFACE "Build ONNXRuntime shared lib which is compatible with Cuda EP interface" OFF)

cmake/onnxruntime_providers_nv.cmake

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,9 +146,9 @@ endif ()
146146
target_link_libraries(onnxruntime_providers_nv_tensorrt_rtx PRIVATE Eigen3::Eigen onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface Eigen3::Eigen)
147147
add_dependencies(onnxruntime_providers_nv_tensorrt_rtx onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES})
148148
if (onnxruntime_USE_TENSORRT_BUILTIN_PARSER)
149-
target_link_libraries(onnxruntime_providers_nv_tensorrt_rtx PRIVATE ${trt_link_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface ${ABSEIL_LIBS} PUBLIC CUDA::cudart)
149+
target_link_libraries(onnxruntime_providers_nv_tensorrt_rtx PRIVATE ${trt_link_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface ${ABSEIL_LIBS} PUBLIC CUDA::cudart CUDA::cuda_driver)
150150
else()
151-
target_link_libraries(onnxruntime_providers_nv_tensorrt_rtx PRIVATE ${onnxparser_link_libs} ${trt_link_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers ${ABSEIL_LIBS} PUBLIC CUDA::cudart)
151+
target_link_libraries(onnxruntime_providers_nv_tensorrt_rtx PRIVATE ${onnxparser_link_libs} ${trt_link_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers ${ABSEIL_LIBS} PUBLIC CUDA::cudart CUDA::cuda_driver)
152152
endif()
153153
target_include_directories(onnxruntime_providers_nv_tensorrt_rtx PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${TENSORRT_RTX_INCLUDE_DIR} ${onnx_tensorrt_SOURCE_DIR}
154154
PUBLIC ${CUDAToolkit_INCLUDE_DIRS})

include/onnxruntime/core/session/onnxruntime_c_api.h

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1020,6 +1020,74 @@ typedef struct OrtExternalSemaphoreDescriptor {
10201020
void* native_handle; /**< Platform-specific handle (e.g., Windows HANDLE) */
10211021
} OrtExternalSemaphoreDescriptor;
10221022

1023+
/** \brief Graphics API type for interop configuration.
1024+
*
1025+
* Specifies the graphics API used for GPU interop with the execution provider.
1026+
* This enables synchronization between graphics workloads (e.g., rendering, compute shaders)
1027+
* and ONNX Runtime inference.
1028+
*
1029+
* \since Version 1.25.
1030+
*/
1031+
typedef enum OrtGraphicsApi {
1032+
ORT_GRAPHICS_API_NONE = 0, /**< No graphics interop (default) */
1033+
ORT_GRAPHICS_API_D3D12 = 1, /**< Direct3D 12 interop */
1034+
ORT_GRAPHICS_API_VULKAN = 2, /**< Vulkan interop */
1035+
} OrtGraphicsApi;
1036+
1037+
/** \brief Configuration for initializing graphics interop on an EP factory.
1038+
*
1039+
* This structure contains all parameters needed to set up graphics interop between
1040+
* ONNX Runtime and an external graphics API (D3D12, Vulkan). The factory stores this
1041+
* configuration and uses it when creating synchronization streams.
1042+
*
1043+
* Design rationale (following Scott McKay's suggestions):
1044+
* - Single init function with all required params to avoid multiple init signatures
1045+
* - Factory stores the context and uses it in stream creation
1046+
* - Supports extensibility via additional_options for future requirements
1047+
*
1048+
* Example usage for D3D12:
1049+
* \code
1050+
* OrtGraphicsInteropConfig config = {0};
1051+
* config.version = ORT_API_VERSION;
1052+
* config.graphics_api = ORT_GRAPHICS_API_D3D12;
1053+
* config.command_queue = my_d3d12_command_queue; // ID3D12CommandQueue*
1054+
* config.device = my_d3d12_device; // ID3D12Device* (optional)
1055+
* status = ep_factory->InitGraphicsInterop(ep_factory, ep_device, &config);
1056+
* \endcode
1057+
*
1058+
* \note The version field must be set to ORT_API_VERSION.
1059+
* This ensures forward compatibility as fields may be added in future versions.
1060+
*
1061+
* \since Version 1.25.
1062+
*/
1063+
typedef struct OrtGraphicsInteropConfig {
1064+
uint32_t version; /**< Must be ORT_API_VERSION */
1065+
OrtGraphicsApi graphics_api; /**< The graphics API to use for interop */
1066+
1067+
/** \brief Command queue/submission queue for graphics workloads.
1068+
*
1069+
* For D3D12: ID3D12CommandQueue*
1070+
* For Vulkan: VkQueue (cast to void*)
1071+
*
1072+
* The factory stores this and uses it for synchronization with inference streams.
1073+
*/
1074+
void* command_queue;
1075+
1076+
/** \brief Graphics device handle (optional, may be inferred from command_queue).
1077+
*
1078+
* For D3D12: ID3D12Device* (optional, can be obtained from command queue)
1079+
* For Vulkan: VkDevice (cast to void*)
1080+
*/
1081+
void* device;
1082+
1083+
/** \brief Additional API-specific options (optional).
1084+
*
1085+
* Can be used for future extensibility without changing the struct layout.
1086+
* For example, Vulkan-specific queue family index, or D3D12 fence sharing flags.
1087+
*/
1088+
const OrtKeyValuePairs* additional_options;
1089+
} OrtGraphicsInteropConfig;
1090+
10231091
/** \brief Descriptor for creating a tensor from imported external memory.
10241092
*
10251093
* \note The version field must be set to ORT_API_VERSION.
@@ -7242,6 +7310,38 @@ struct OrtApi {
72427310
* \since Version 1.25.
72437311
*/
72447312
ORT_API2_STATUS(RunOptionsDisableProfiling, _Inout_ OrtRunOptions* options);
7313+
7314+
/** \brief Initialize graphics interop for an execution provider device.
7315+
*
7316+
* This function enables D3D12/Vulkan interoperability by creating a CIG (CUDA Interop Graphics) context
7317+
* bound to the provided graphics command queue. Once initialized, any OrtSyncStream created for this
7318+
* ep_device via CreateSyncStreamForEpDevice will be created on the CIG context, enabling efficient
7319+
* GPU-side synchronization between ONNX Runtime inference and graphics workloads.
7320+
*
7321+
* This must be called BEFORE CreateSyncStreamForEpDevice for the same ep_device.
7322+
*
7323+
* \param[in] ep_device The OrtEpDevice to initialize graphics interop for.
7324+
* \param[in] config Configuration specifying the graphics API (D3D12/Vulkan) and required handles.
7325+
*
7326+
* \snippet{doc} snippets.dox OrtStatus Return Value
7327+
*
7328+
* \since Version 1.25.
7329+
*/
7330+
ORT_API2_STATUS(InitGraphicsInteropForEpDevice, _In_ const OrtEpDevice* ep_device,
7331+
_In_ const OrtGraphicsInteropConfig* config);
7332+
7333+
/** \brief Deinitialize graphics interop for an execution provider device.
7334+
*
7335+
* This function cleans up the CIG context that was created by InitGraphicsInteropForEpDevice.
7336+
* Should be called when graphics interop is no longer needed for the ep_device.
7337+
*
7338+
* \param[in] ep_device The OrtEpDevice to deinitialize graphics interop for.
7339+
*
7340+
* \snippet{doc} snippets.dox OrtStatus Return Value
7341+
*
7342+
* \since Version 1.25.
7343+
*/
7344+
ORT_API2_STATUS(DeinitGraphicsInteropForEpDevice, _In_ const OrtEpDevice* ep_device);
72457345
};
72467346

72477347
/*

include/onnxruntime/core/session/onnxruntime_ep_c_api.h

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2105,6 +2105,56 @@ struct OrtEpFactory {
21052105
*/
21062106
ORT_API2_STATUS(GetCustomOpDomains, _In_ OrtEpFactory* this_ptr,
21072107
_Out_writes_all_(num_domains) OrtCustomOpDomain** domains, _In_ size_t num_domains);
2108+
2109+
/** \brief Initialize graphics interop for the EP factory.
2110+
*
2111+
* This function sets up graphics interop context that enables synchronization between
2112+
* external graphics API workloads (D3D12, Vulkan) and ONNX Runtime inference.
2113+
*
2114+
* The factory stores the graphics context configuration and uses it when creating
2115+
* synchronization streams via CreateSyncStreamForDevice. This approach (suggested by
2116+
* Scott McKay) is more graceful than passing the command queue directly during stream creation.
2117+
*
2118+
* For CUDA-based EPs (like NvTensorRTRTX), this sets up CUDA Interop Graphics (CIG) context
2119+
* using cuCtxCreate_v4 or equivalent APIs.
2120+
*
2121+
* Key design points:
2122+
* - Single init function with all required params (avoids multiple init signatures)
2123+
* - Factory stores context and uses it in stream creation
2124+
* - Paired with DeinitGraphicsInterop for cleanup
2125+
*
2126+
* \param[in] this_ptr The OrtEpFactory instance.
2127+
* \param[in] ep_device The OrtEpDevice to initialize graphics interop for.
2128+
* \param[in] config Configuration specifying the graphics API and required handles.
2129+
*
2130+
* \snippet{doc} snippets.dox OrtStatus Return Value
2131+
*
2132+
* \note Implementation of this function is optional.
2133+
* EPs that don't support graphics interop should set this to nullptr or return ORT_NOT_IMPLEMENTED.
2134+
*
2135+
* \since Version 1.25.
2136+
*/
2137+
ORT_API2_STATUS(InitGraphicsInterop, _In_ OrtEpFactory* this_ptr,
2138+
_In_ const OrtEpDevice* ep_device,
2139+
_In_ const OrtGraphicsInteropConfig* config);
2140+
2141+
/** \brief Deinitialize graphics interop for the EP factory.
2142+
*
2143+
* This function cleans up any graphics interop context that was set up by InitGraphicsInterop.
2144+
* Should be called when graphics interop is no longer needed.
2145+
*
2146+
* \param[in] this_ptr The OrtEpFactory instance.
2147+
* \param[in] ep_device The OrtEpDevice to deinitialize graphics interop for.
2148+
*
2149+
* \snippet{doc} snippets.dox OrtStatus Return Value
2150+
*
2151+
* \note Implementation of this function is optional.
2152+
* EPs that don't support graphics interop should set this to nullptr or return ORT_NOT_IMPLEMENTED.
2153+
*
2154+
* \since Version 1.25.
2155+
*/
2156+
ORT_API2_STATUS(DeinitGraphicsInterop, _In_ OrtEpFactory* this_ptr,
2157+
_In_ const OrtEpDevice* ep_device);
21082158
};
21092159

21102160
#ifdef __cplusplus

0 commit comments

Comments
 (0)