From 3f5f75361135aac580bab9beec930736b5db1ae7 Mon Sep 17 00:00:00 2001 From: praneshgo Date: Fri, 6 Feb 2026 13:37:21 +0000 Subject: [PATCH 1/2] Adding CIG context creation in OrtFactory - Added Init and Deinit API which are to be called from application before calling any interop or ORT APIs --- cmake/CMakeLists.txt | 9 + cmake/onnxruntime_providers_nv.cmake | 4 +- .../core/session/onnxruntime_c_api.h | 100 +++++++++ .../core/session/onnxruntime_ep_c_api.h | 50 +++++ .../nv_tensorrt_rtx/nv_provider_factory.cc | 193 +++++++++++++++++- onnxruntime/core/session/onnxruntime_c_api.cc | 56 +++++ onnxruntime/core/session/ort_apis.h | 5 + .../session/plugin_ep/ep_factory_internal.cc | 2 + .../session/plugin_ep/ep_factory_internal.h | 9 + .../plugin_ep/ep_factory_internal_impl.h | 13 ++ .../plugin_ep/ep_factory_provider_bridge.h | 17 ++ .../plugin_ep/forward_to_factory_impl.h | 11 + tools/ci_build/build.py | 3 + tools/ci_build/build_args.py | 4 + 14 files changed, 472 insertions(+), 4 deletions(-) diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 971741137cf4f..453c6c8bb81d0 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -247,6 +247,15 @@ option(onnxruntime_USE_AZURE "Build with azure inferencing support" OFF) option(onnxruntime_USE_LOCK_FREE_QUEUE "Build with lock-free task queue for threadpool." OFF) option(onnxruntime_FORCE_GENERIC_ALGORITHMS "Disable optimized arch-specific algorithms. Use only for testing and debugging generic algorithms." OFF) +# DX interop feature option +option(onnxruntime_USE_DX_INTEROP "Build with the DX Interop feature for graphics API synchronization." OFF) + +if (onnxruntime_USE_DX_INTEROP) + add_compile_definitions(USE_DX_INTEROP=1) +else() + add_compile_definitions(USE_DX_INTEROP=0) +endif() + option(onnxruntime_USE_TENSORRT_INTERFACE "Build ONNXRuntime shared lib which is compatible with TensorRT EP interface" OFF) option(onnxruntime_USE_NV_INTERFACE "Build ONNXRuntime shared lib which is compatible with NV EP interface" OFF) option(onnxruntime_USE_CUDA_INTERFACE "Build ONNXRuntime shared lib which is compatible with Cuda EP interface" OFF) diff --git a/cmake/onnxruntime_providers_nv.cmake b/cmake/onnxruntime_providers_nv.cmake index e59463b6b91f1..5ec45a64e46bb 100644 --- a/cmake/onnxruntime_providers_nv.cmake +++ b/cmake/onnxruntime_providers_nv.cmake @@ -146,9 +146,9 @@ endif () target_link_libraries(onnxruntime_providers_nv_tensorrt_rtx PRIVATE Eigen3::Eigen onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface Eigen3::Eigen) add_dependencies(onnxruntime_providers_nv_tensorrt_rtx onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES}) if (onnxruntime_USE_TENSORRT_BUILTIN_PARSER) - target_link_libraries(onnxruntime_providers_nv_tensorrt_rtx PRIVATE ${trt_link_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface ${ABSEIL_LIBS} PUBLIC CUDA::cudart) + target_link_libraries(onnxruntime_providers_nv_tensorrt_rtx PRIVATE ${trt_link_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface ${ABSEIL_LIBS} PUBLIC CUDA::cudart CUDA::cuda_driver) else() - target_link_libraries(onnxruntime_providers_nv_tensorrt_rtx PRIVATE ${onnxparser_link_libs} ${trt_link_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers ${ABSEIL_LIBS} PUBLIC CUDA::cudart) + target_link_libraries(onnxruntime_providers_nv_tensorrt_rtx PRIVATE ${onnxparser_link_libs} ${trt_link_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers ${ABSEIL_LIBS} PUBLIC CUDA::cudart CUDA::cuda_driver) endif() target_include_directories(onnxruntime_providers_nv_tensorrt_rtx PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${TENSORRT_RTX_INCLUDE_DIR} ${onnx_tensorrt_SOURCE_DIR} PUBLIC ${CUDAToolkit_INCLUDE_DIRS}) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 221f3673f2027..83ca1fd8df869 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -1020,6 +1020,74 @@ typedef struct OrtExternalSemaphoreDescriptor { void* native_handle; /**< Platform-specific handle (e.g., Windows HANDLE) */ } OrtExternalSemaphoreDescriptor; +/** \brief Graphics API type for interop configuration. + * + * Specifies the graphics API used for GPU interop with the execution provider. + * This enables synchronization between graphics workloads (e.g., rendering, compute shaders) + * and ONNX Runtime inference. + * + * \since Version 1.25. + */ +typedef enum OrtGraphicsApi { + ORT_GRAPHICS_API_NONE = 0, /**< No graphics interop (default) */ + ORT_GRAPHICS_API_D3D12 = 1, /**< Direct3D 12 interop */ + ORT_GRAPHICS_API_VULKAN = 2, /**< Vulkan interop */ +} OrtGraphicsApi; + +/** \brief Configuration for initializing graphics interop on an EP factory. + * + * This structure contains all parameters needed to set up graphics interop between + * ONNX Runtime and an external graphics API (D3D12, Vulkan). The factory stores this + * configuration and uses it when creating synchronization streams. + * + * Design rationale (following Scott McKay's suggestions): + * - Single init function with all required params to avoid multiple init signatures + * - Factory stores the context and uses it in stream creation + * - Supports extensibility via additional_options for future requirements + * + * Example usage for D3D12: + * \code + * OrtGraphicsInteropConfig config = {0}; + * config.version = ORT_API_VERSION; + * config.graphics_api = ORT_GRAPHICS_API_D3D12; + * config.command_queue = my_d3d12_command_queue; // ID3D12CommandQueue* + * config.device = my_d3d12_device; // ID3D12Device* (optional) + * status = ep_factory->InitGraphicsInterop(ep_factory, ep_device, &config); + * \endcode + * + * \note The version field must be set to ORT_API_VERSION. + * This ensures forward compatibility as fields may be added in future versions. + * + * \since Version 1.25. + */ +typedef struct OrtGraphicsInteropConfig { + uint32_t version; /**< Must be ORT_API_VERSION */ + OrtGraphicsApi graphics_api; /**< The graphics API to use for interop */ + + /** \brief Command queue/submission queue for graphics workloads. + * + * For D3D12: ID3D12CommandQueue* + * For Vulkan: VkQueue (cast to void*) + * + * The factory stores this and uses it for synchronization with inference streams. + */ + void* command_queue; + + /** \brief Graphics device handle (optional, may be inferred from command_queue). + * + * For D3D12: ID3D12Device* (optional, can be obtained from command queue) + * For Vulkan: VkDevice (cast to void*) + */ + void* device; + + /** \brief Additional API-specific options (optional). + * + * Can be used for future extensibility without changing the struct layout. + * For example, Vulkan-specific queue family index, or D3D12 fence sharing flags. + */ + const OrtKeyValuePairs* additional_options; +} OrtGraphicsInteropConfig; + /** \brief Descriptor for creating a tensor from imported external memory. * * \note The version field must be set to ORT_API_VERSION. @@ -7242,6 +7310,38 @@ struct OrtApi { * \since Version 1.25. */ ORT_API2_STATUS(RunOptionsDisableProfiling, _Inout_ OrtRunOptions* options); + + /** \brief Initialize graphics interop for an execution provider device. + * + * This function enables D3D12/Vulkan interoperability by creating a CIG (CUDA Interop Graphics) context + * bound to the provided graphics command queue. Once initialized, any OrtSyncStream created for this + * ep_device via CreateSyncStreamForEpDevice will be created on the CIG context, enabling efficient + * GPU-side synchronization between ONNX Runtime inference and graphics workloads. + * + * This must be called BEFORE CreateSyncStreamForEpDevice for the same ep_device. + * + * \param[in] ep_device The OrtEpDevice to initialize graphics interop for. + * \param[in] config Configuration specifying the graphics API (D3D12/Vulkan) and required handles. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.25. + */ + ORT_API2_STATUS(InitGraphicsInteropForEpDevice, _In_ const OrtEpDevice* ep_device, + _In_ const OrtGraphicsInteropConfig* config); + + /** \brief Deinitialize graphics interop for an execution provider device. + * + * This function cleans up the CIG context that was created by InitGraphicsInteropForEpDevice. + * Should be called when graphics interop is no longer needed for the ep_device. + * + * \param[in] ep_device The OrtEpDevice to deinitialize graphics interop for. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.25. + */ + ORT_API2_STATUS(DeinitGraphicsInteropForEpDevice, _In_ const OrtEpDevice* ep_device); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index b888d0d609e55..8b4bbe7b6a493 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -2105,6 +2105,56 @@ struct OrtEpFactory { */ ORT_API2_STATUS(GetCustomOpDomains, _In_ OrtEpFactory* this_ptr, _Out_writes_all_(num_domains) OrtCustomOpDomain** domains, _In_ size_t num_domains); + + /** \brief Initialize graphics interop for the EP factory. + * + * This function sets up graphics interop context that enables synchronization between + * external graphics API workloads (D3D12, Vulkan) and ONNX Runtime inference. + * + * The factory stores the graphics context configuration and uses it when creating + * synchronization streams via CreateSyncStreamForDevice. This approach (suggested by + * Scott McKay) is more graceful than passing the command queue directly during stream creation. + * + * For CUDA-based EPs (like NvTensorRTRTX), this sets up CUDA Interop Graphics (CIG) context + * using cuCtxCreate_v4 or equivalent APIs. + * + * Key design points: + * - Single init function with all required params (avoids multiple init signatures) + * - Factory stores context and uses it in stream creation + * - Paired with DeinitGraphicsInterop for cleanup + * + * \param[in] this_ptr The OrtEpFactory instance. + * \param[in] ep_device The OrtEpDevice to initialize graphics interop for. + * \param[in] config Configuration specifying the graphics API and required handles. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \note Implementation of this function is optional. + * EPs that don't support graphics interop should set this to nullptr or return ORT_NOT_IMPLEMENTED. + * + * \since Version 1.25. + */ + ORT_API2_STATUS(InitGraphicsInterop, _In_ OrtEpFactory* this_ptr, + _In_ const OrtEpDevice* ep_device, + _In_ const OrtGraphicsInteropConfig* config); + + /** \brief Deinitialize graphics interop for the EP factory. + * + * This function cleans up any graphics interop context that was set up by InitGraphicsInterop. + * Should be called when graphics interop is no longer needed. + * + * \param[in] this_ptr The OrtEpFactory instance. + * \param[in] ep_device The OrtEpDevice to deinitialize graphics interop for. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \note Implementation of this function is optional. + * EPs that don't support graphics interop should set this to nullptr or return ORT_NOT_IMPLEMENTED. + * + * \since Version 1.25. + */ + ORT_API2_STATUS(DeinitGraphicsInterop, _In_ OrtEpFactory* this_ptr, + _In_ const OrtEpDevice* ep_device); }; #ifdef __cplusplus diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc index d1e449eb58870..c1e1f169ffa9f 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc @@ -4,6 +4,7 @@ #include #include +#include #include "core/providers/shared_library/provider_api.h" #include "core/framework/provider_options.h" @@ -13,6 +14,11 @@ #include "core/providers/cuda/shared_inc/cuda_call.h" #include "core/providers/cuda/cuda_stream_handle.h" +// D3D12 headers for graphics interop on Windows +#if defined(_WIN32) && USE_DX_INTEROP +#include +#endif + #include "onnx_ctx_model_helper.h" #include "nv_provider_factory.h" #include "nv_execution_provider.h" @@ -548,6 +554,11 @@ struct NvTensorRtRtxEpFactory : OrtEpFactory { IsStreamAware = IsStreamAwareImpl; CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; ValidateCompiledModelCompatibilityInfo = ValidateCompiledModelCompatibilityInfoImpl; + + // Graphics interop support (added in ORT 1.25) + InitGraphicsInterop = InitGraphicsInteropImpl; + DeinitGraphicsInterop = DeinitGraphicsInteropImpl; + ort_version_supported = ORT_API_VERSION; // Set to the ORT version we were compiled with. } @@ -750,8 +761,28 @@ struct NvTensorRtRtxEpFactory : OrtEpFactory { auto device_id = factory.ep_api.MemoryDevice_GetDeviceId(memory_device); cudaStream_t stream = nullptr; - CUDA_RETURN_IF_ERROR(cudaSetDevice(device_id)); - CUDA_RETURN_IF_ERROR(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); + + // Check if we have a CIG context for this device (set up via InitGraphicsInterop) + CUcontext cig_context = factory.GetCigContext(device_id); + if (cig_context != nullptr) { + // We have a CIG context - make it current and create stream on it + // This enables CUDA-Graphics synchronization for this stream + CUresult cu_result = cuCtxSetCurrent(cig_context); + if (cu_result != CUDA_SUCCESS) { + const char* error_str = nullptr; + cuGetErrorString(cu_result, &error_str); + std::string error_msg = "[NvTensorRTRTX EP] Failed to set CIG context current: "; + error_msg += error_str ? error_str : "unknown error"; + return onnxruntime::CreateStatus(ORT_FAIL, error_msg.c_str()); + } + + // Create stream on the CIG context + CUDA_RETURN_IF_ERROR(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); + } else { + // No CIG context - use default behavior + CUDA_RETURN_IF_ERROR(cudaSetDevice(device_id)); + CUDA_RETURN_IF_ERROR(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); + } const OrtDevice* ort_device = static_cast(memory_device); @@ -876,6 +907,161 @@ struct NvTensorRtRtxEpFactory : OrtEpFactory { } } + // Initialize graphics interop for a device. Creates a CIG (CUDA Interop Graphics) context + // bound to the provided graphics command queue/device. + // This follows Scott's suggestion to pass in everything required so we don't end up with multiple + // init function signatures, and the factory stores the queue to utilize in stream creation. + static OrtStatus* ORT_API_CALL InitGraphicsInteropImpl(OrtEpFactory* this_ptr, + const OrtEpDevice* ep_device, + const OrtGraphicsInteropConfig* config) noexcept { + if (ep_device == nullptr) { + return onnxruntime::CreateStatus(ORT_INVALID_ARGUMENT, + "[NvTensorRTRTX EP] InitGraphicsInterop: ep_device is null"); + } + if (config == nullptr) { + return onnxruntime::CreateStatus(ORT_INVALID_ARGUMENT, + "[NvTensorRTRTX EP] InitGraphicsInterop: config is null"); + } + + auto& factory = *static_cast(this_ptr); + + // Extract device_id from OrtEpDevice options + const OrtKeyValuePairs* ep_options = factory.ort_api.EpDevice_EpOptions(ep_device); + const char* device_id_str = factory.ort_api.GetKeyValue(ep_options, "device_id"); + if (device_id_str == nullptr) { + return onnxruntime::CreateStatus(ORT_FAIL, + "[NvTensorRTRTX EP] InitGraphicsInterop: device_id not found in ep_device options"); + } + int32_t device_id = std::stoi(device_id_str); + + // Validate graphics API + if (config->graphics_api == ORT_GRAPHICS_API_NONE) { + return onnxruntime::CreateStatus(ORT_INVALID_ARGUMENT, + "[NvTensorRTRTX EP] InitGraphicsInterop: graphics_api cannot be NONE"); + } + + // Initialize CUDA driver API + CUresult cu_result = cuInit(0); + if (cu_result != CUDA_SUCCESS) { + const char* error_str = nullptr; + cuGetErrorString(cu_result, &error_str); + std::string error_msg = "[NvTensorRTRTX EP] Failed to initialize CUDA driver API: "; + error_msg += error_str ? error_str : "unknown error"; + return onnxruntime::CreateStatus(ORT_FAIL, error_msg.c_str()); + } + + // Get CUDA device properties to retrieve LUID + cudaDeviceProp cuda_prop; + cudaError_t cuda_err = cudaGetDeviceProperties(&cuda_prop, device_id); + if (cuda_err != cudaSuccess) { + std::string error_msg = "[NvTensorRTRTX EP] Failed to get CUDA device properties: "; + error_msg += cudaGetErrorString(cuda_err); + return onnxruntime::CreateStatus(ORT_FAIL, error_msg.c_str()); + } + + // Create CIG context based on graphics API type + CUcontext cig_context = nullptr; + + if (config->graphics_api == ORT_GRAPHICS_API_D3D12) { +#if defined(_WIN32) && USE_DX_INTEROP + // Validate required parameters + if (config->device == nullptr) { + return onnxruntime::CreateStatus(ORT_INVALID_ARGUMENT, + "[NvTensorRTRTX EP] InitGraphicsInterop: D3D12 device is null"); + } + if (config->command_queue == nullptr) { + return onnxruntime::CreateStatus(ORT_INVALID_ARGUMENT, + "[NvTensorRTRTX EP] InitGraphicsInterop: D3D12 command queue is null"); + } + + // Get LUID from CUDA device + if (cuda_prop.luidDeviceNodeMask == 0) { + return onnxruntime::CreateStatus(ORT_FAIL, + "[NvTensorRTRTX EP] CUDA device does not have a valid LUID"); + } + uint64_t cuda_luid = *reinterpret_cast(cuda_prop.luid); + + // Get LUID from D3D12 device and compare + ID3D12Device* d3d12_device = reinterpret_cast(config->device); + LUID d3d12_luid = d3d12_device->GetAdapterLuid(); + uint64_t d3d12_luid_64 = (static_cast(d3d12_luid.HighPart) << 32) | d3d12_luid.LowPart; + + if (d3d12_luid_64 != cuda_luid) { + return onnxruntime::CreateStatus(ORT_FAIL, + "[NvTensorRTRTX EP] D3D12 device LUID does not match CUDA device LUID"); + } + + // Create CIG context bound to D3D12 command queue + ID3D12CommandQueue* d3d12_queue = reinterpret_cast(config->command_queue); + + CUctxCigParam cig_param = {CIG_DATA_TYPE_D3D12_COMMAND_QUEUE, d3d12_queue}; + CUctxCreateParams ctx_params = {nullptr, 0, &cig_param}; + + cu_result = cuCtxCreate_v4(&cig_context, &ctx_params, 0, device_id); + if (cu_result != CUDA_SUCCESS) { + const char* error_str = nullptr; + cuGetErrorString(cu_result, &error_str); + std::string error_msg = "[NvTensorRTRTX EP] Failed to create CIG context for D3D12: "; + error_msg += error_str ? error_str : "unknown error"; + return onnxruntime::CreateStatus(ORT_FAIL, error_msg.c_str()); + } +#else + return onnxruntime::CreateStatus(ORT_NOT_IMPLEMENTED, + "[NvTensorRTRTX EP] D3D12 CIG context creation not supported on this platform"); +#endif + } else if (config->graphics_api == ORT_GRAPHICS_API_VULKAN) { + // TODO: Add Vulkan CIG context support if needed + return onnxruntime::CreateStatus(ORT_NOT_IMPLEMENTED, + "[NvTensorRTRTX EP] Vulkan CIG context not yet implemented"); + } else { + return onnxruntime::CreateStatus(ORT_INVALID_ARGUMENT, + "[NvTensorRTRTX EP] Unsupported graphics API for CIG context"); + } + + // Store the CIG context for this device + factory.cig_contexts_[device_id] = cig_context; + + return nullptr; + } + + // Deinitialize graphics interop for a device. Cleans up CIG context and stored config. + static OrtStatus* ORT_API_CALL DeinitGraphicsInteropImpl(OrtEpFactory* this_ptr, + const OrtEpDevice* ep_device) noexcept { + if (ep_device == nullptr) { + return onnxruntime::CreateStatus(ORT_INVALID_ARGUMENT, + "[NvTensorRTRTX EP] DeinitGraphicsInterop: ep_device is null"); + } + + auto& factory = *static_cast(this_ptr); + + // Extract device_id from OrtEpDevice options + const OrtKeyValuePairs* ep_options = factory.ort_api.EpDevice_EpOptions(ep_device); + const char* device_id_str = factory.ort_api.GetKeyValue(ep_options, "device_id"); + if (device_id_str == nullptr) { + return onnxruntime::CreateStatus(ORT_FAIL, + "[NvTensorRTRTX EP] DeinitGraphicsInterop: device_id not found in ep_device options"); + } + int32_t device_id = std::stoi(device_id_str); + + auto it = factory.cig_contexts_.find(device_id); + if (it != factory.cig_contexts_.end()) { + // The CIG context may still be in use by TensorRT resources. The CUDA driver + // will clean up the context when the process exits. Just remove from our map. + factory.cig_contexts_.erase(it); + } + + return nullptr; + } + + // Get the CIG CUDA context for a device (for use in stream creation) + CUcontext GetCigContext(int32_t device_id) const { + auto it = cig_contexts_.find(device_id); + if (it != cig_contexts_.end()) { + return it->second; + } + return nullptr; + } + OrtStatus* CreateMemoryInfoForDevices(int num_devices) { gpu_memory_infos.reserve(num_devices); host_accessible_memory_infos.reserve(num_devices); @@ -920,6 +1106,9 @@ struct NvTensorRtRtxEpFactory : OrtEpFactory { // we use a shared instance for the OrtDataTransferImpl instead of creating a new one on every call to NvTrtRtxDataTransferImpl data_transfer_impl; + // CIG contexts per device (keyed by device_id) + std::unordered_map cig_contexts_; + NvTensorRtRtxEpFactory(const NvTensorRtRtxEpFactory&) = delete; NvTensorRtRtxEpFactory& operator=(const NvTensorRtRtxEpFactory&) = delete; diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 7a027c8eafb81..15c14e1e610f3 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -3837,6 +3837,47 @@ ORT_API(void, OrtApis::ReleaseSyncStream, _Frees_ptr_opt_ OrtSyncStream* ort_str std::unique_ptr ep_stream(reinterpret_cast(stream)); } +ORT_API_STATUS_IMPL(OrtApis::InitGraphicsInteropForEpDevice, _In_ const OrtEpDevice* ep_device, + _In_ const OrtGraphicsInteropConfig* config) { + API_IMPL_BEGIN + if (ep_device == nullptr || config == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "ep_device and config must be provided."); + } + + auto* factory = ep_device->ep_factory; + if (factory == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "ep_device does not have an associated factory."); + } + + if (factory->InitGraphicsInterop == nullptr) { + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "The execution provider does not support graphics interop."); + } + + return factory->InitGraphicsInterop(ep_device->GetMutableFactory(), ep_device, config); + + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::DeinitGraphicsInteropForEpDevice, _In_ const OrtEpDevice* ep_device) { + API_IMPL_BEGIN + if (ep_device == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "ep_device must be provided."); + } + + auto* factory = ep_device->ep_factory; + if (factory == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "ep_device does not have an associated factory."); + } + + if (factory->DeinitGraphicsInterop == nullptr) { + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "The execution provider does not support graphics interop."); + } + + return factory->DeinitGraphicsInterop(ep_device->GetMutableFactory(), ep_device); + + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtApis::CopyTensors, _In_ const OrtEnv* env, _In_reads_(num_tensors) const OrtValue* const* src_tensors, _In_reads_(num_tensors) OrtValue* const* dst_tensors, @@ -4137,6 +4178,19 @@ ORT_API(void, OrtApis::ReleaseSyncStream, _Frees_ptr_opt_ OrtSyncStream* /*ort_s fprintf(stderr, "OrtSyncStream is not supported in a minimal build.\n"); } +ORT_API_STATUS_IMPL(OrtApis::InitGraphicsInteropForEpDevice, _In_ const OrtEpDevice* /*ep_device*/, + _In_ const OrtGraphicsInteropConfig* /*config*/) { + API_IMPL_BEGIN + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "InitGraphicsInteropForEpDevice is not supported in a minimal build."); + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::DeinitGraphicsInteropForEpDevice, _In_ const OrtEpDevice* /*ep_device*/) { + API_IMPL_BEGIN + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "DeinitGraphicsInteropForEpDevice is not supported in a minimal build."); + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtApis::CopyTensors, _In_ const OrtEnv* /*env*/, _In_reads_(num_tensors) const OrtValue* const* /*src_tensors*/, _In_reads_(num_tensors) OrtValue* const* /*dst_tensors*/, @@ -4807,6 +4861,8 @@ static constexpr OrtApi ort_api_1_to_25 = { &OrtApis::RunOptionsEnableProfiling, &OrtApis::RunOptionsDisableProfiling, + &OrtApis::InitGraphicsInteropForEpDevice, + &OrtApis::DeinitGraphicsInteropForEpDevice, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 3d990909cfb41..891e14e0a0223 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -815,4 +815,9 @@ ORT_API_STATUS_IMPL(GetTensorElementTypeAndShapeDataReference, _In_ const OrtVal _Out_ ONNXTensorElementDataType* elem_type, _Outptr_result_maybenull_ const int64_t** shape_data, _Out_ size_t* shape_data_count); + +// Graphics interop +ORT_API_STATUS_IMPL(InitGraphicsInteropForEpDevice, _In_ const OrtEpDevice* ep_device, + _In_ const OrtGraphicsInteropConfig* config); +ORT_API_STATUS_IMPL(DeinitGraphicsInteropForEpDevice, _In_ const OrtEpDevice* ep_device); } // namespace OrtApis diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc b/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc index fe36c9ea0cdd1..ca39d6e750088 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc @@ -34,6 +34,8 @@ EpFactoryInternal::EpFactoryInternal(std::unique_ptr impl OrtEpFactory::CreateSyncStreamForDevice = Forward::CreateSyncStreamForDevice; OrtEpFactory::CreateExternalResourceImporterForDevice = Forward::CreateExternalResourceImporterForDevice; OrtEpFactory::GetHardwareDeviceIncompatibilityDetails = Forward::GetHardwareDeviceIncompatibilityDetails; + OrtEpFactory::InitGraphicsInterop = Forward::InitGraphicsInterop; + OrtEpFactory::DeinitGraphicsInterop = Forward::DeinitGraphicsInterop; } InternalExecutionProviderFactory::InternalExecutionProviderFactory(EpFactoryInternal& ep_factory, diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal.h b/onnxruntime/core/session/plugin_ep/ep_factory_internal.h index ae09c763bbcbf..9ac883da06465 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_internal.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal.h @@ -97,6 +97,15 @@ class EpFactoryInternal : public OrtEpFactory { return impl_->GetHardwareDeviceIncompatibilityDetails(hw, details); } + OrtStatus* InitGraphicsInterop(_In_ const OrtEpDevice* ep_device, + _In_ const OrtGraphicsInteropConfig* config) noexcept { + return impl_->InitGraphicsInterop(ep_device, config); + } + + OrtStatus* DeinitGraphicsInterop(_In_ const OrtEpDevice* ep_device) noexcept { + return impl_->DeinitGraphicsInterop(ep_device); + } + // Function ORT calls to release an EP instance. void ReleaseEp(OrtEp* /*ep*/) noexcept { // we never create an OrtEp so we should never be trying to release one diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h index 01f7bc67a522e..a6140b2ac260f 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h @@ -97,6 +97,19 @@ class EpFactoryInternalImpl { return nullptr; } + virtual OrtStatus* InitGraphicsInterop(_In_ const OrtEpDevice* /*ep_device*/, + _In_ const OrtGraphicsInteropConfig* /*config*/) noexcept { + // Default implementation: graphics interop not supported + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, + "InitGraphicsInterop is not implemented for this EP factory."); + } + + virtual OrtStatus* DeinitGraphicsInterop(_In_ const OrtEpDevice* /*ep_device*/) noexcept { + // Default implementation: graphics interop not supported + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, + "DeinitGraphicsInterop is not implemented for this EP factory."); + } + virtual OrtStatus* GetNumCustomOpDomains(_Out_ size_t* num_domains) const noexcept { *num_domains = 0; return nullptr; diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h index eb1427db87463..dd6bd6f2a8bdc 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h @@ -101,6 +101,23 @@ class ProviderBridgeEpFactory : public EpFactoryInternalImpl { return nullptr; } + OrtStatus* InitGraphicsInterop(const OrtEpDevice* ep_device, + const OrtGraphicsInteropConfig* config) noexcept override { + if (ep_factory_.InitGraphicsInterop == nullptr) { + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, + "InitGraphicsInterop is not implemented for this EP factory."); + } + return ep_factory_.InitGraphicsInterop(&ep_factory_, ep_device, config); + } + + OrtStatus* DeinitGraphicsInterop(const OrtEpDevice* ep_device) noexcept override { + if (ep_factory_.DeinitGraphicsInterop == nullptr) { + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, + "DeinitGraphicsInterop is not implemented for this EP factory."); + } + return ep_factory_.DeinitGraphicsInterop(&ep_factory_, ep_device); + } + OrtEpFactory& ep_factory_; ProviderLibrary& provider_library_; std::optional library_path_; diff --git a/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h b/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h index ce9d06da75cb3..d4d93152c4f80 100644 --- a/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h +++ b/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h @@ -95,6 +95,17 @@ struct ForwardToFactoryImpl { return static_cast(this_ptr)->GetHardwareDeviceIncompatibilityDetails(hw, details); } + static OrtStatus* ORT_API_CALL InitGraphicsInterop(_In_ OrtEpFactory* this_ptr, + _In_ const OrtEpDevice* ep_device, + _In_ const OrtGraphicsInteropConfig* config) noexcept { + return static_cast(this_ptr)->InitGraphicsInterop(ep_device, config); + } + + static OrtStatus* ORT_API_CALL DeinitGraphicsInterop(_In_ OrtEpFactory* this_ptr, + _In_ const OrtEpDevice* ep_device) noexcept { + return static_cast(this_ptr)->DeinitGraphicsInterop(ep_device); + } + static void ORT_API_CALL ReleaseEp(OrtEpFactory* this_ptr, OrtEp* ep) noexcept { static_cast(this_ptr)->ReleaseEp(ep); } diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index a0712af35e455..a9f7b5f744d3a 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -1055,6 +1055,9 @@ def generate_build_tree( cmake_args += [f"-Donnxruntime_PREBUILT_PYTORCH_PATH={os.path.dirname(torch.__file__)}"] cmake_args += ["-D_GLIBCXX_USE_CXX11_ABI=" + str(int(torch._C._GLIBCXX_USE_CXX11_ABI))] + if args.use_dx_interop: + cmake_args += ["-Donnxruntime_USE_DX_INTEROP=ON"] + if args.use_azure: add_default_definition(cmake_extra_defines, "onnxruntime_USE_AZURE", "ON") diff --git a/tools/ci_build/build_args.py b/tools/ci_build/build_args.py index f32666f65cc38..a6dd9553caef2 100644 --- a/tools/ci_build/build_args.py +++ b/tools/ci_build/build_args.py @@ -827,6 +827,10 @@ def add_execution_provider_args(parser: argparse.ArgumentParser) -> None: azure_group = parser.add_argument_group("Azure Execution Provider") azure_group.add_argument("--use_azure", action="store_true", help="Enable Azure EP.") + # --- DX Interop Feature --- + dx_interop_group = parser.add_argument_group("DX Interop Feature") + dx_interop_group.add_argument("--use_dx_interop", action="store_true", help="Enable DX Interop feature for graphics API synchronization.") + def add_other_feature_args(parser: argparse.ArgumentParser) -> None: """Adds arguments for other miscellaneous features.""" From 3e21b904808fadc1825c6f808495738e5aecce9b Mon Sep 17 00:00:00 2001 From: praneshgo Date: Fri, 6 Feb 2026 19:54:04 +0000 Subject: [PATCH 2/2] Remove EP specific terms from public headers and fix line endings. --- include/onnxruntime/core/session/onnxruntime_c_api.h | 6 +++--- include/onnxruntime/core/session/onnxruntime_ep_c_api.h | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 83ca1fd8df869..8194ead07b5d0 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -7313,9 +7313,9 @@ struct OrtApi { /** \brief Initialize graphics interop for an execution provider device. * - * This function enables D3D12/Vulkan interoperability by creating a CIG (CUDA Interop Graphics) context + * This function enables D3D12/Vulkan interoperability by creating a graphics interop context * bound to the provided graphics command queue. Once initialized, any OrtSyncStream created for this - * ep_device via CreateSyncStreamForEpDevice will be created on the CIG context, enabling efficient + * ep_device via CreateSyncStreamForEpDevice will be created on the interop context, enabling efficient * GPU-side synchronization between ONNX Runtime inference and graphics workloads. * * This must be called BEFORE CreateSyncStreamForEpDevice for the same ep_device. @@ -7332,7 +7332,7 @@ struct OrtApi { /** \brief Deinitialize graphics interop for an execution provider device. * - * This function cleans up the CIG context that was created by InitGraphicsInteropForEpDevice. + * This function cleans up the graphics interop context that was created by InitGraphicsInteropForEpDevice. * Should be called when graphics interop is no longer needed for the ep_device. * * \param[in] ep_device The OrtEpDevice to deinitialize graphics interop for. diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index 8b4bbe7b6a493..f3bbb8b9977e9 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -2115,8 +2115,8 @@ struct OrtEpFactory { * synchronization streams via CreateSyncStreamForDevice. This approach (suggested by * Scott McKay) is more graceful than passing the command queue directly during stream creation. * - * For CUDA-based EPs (like NvTensorRTRTX), this sets up CUDA Interop Graphics (CIG) context - * using cuCtxCreate_v4 or equivalent APIs. + * The implementation is EP-specific. EPs may create a specialized interop context using + * platform-specific APIs to enable GPU-GPU synchronization. * * Key design points: * - Single init function with all required params (avoids multiple init signatures)