Skip to content

Commit 727db0d

Browse files
authored
Engine compatibility validity API implementation (#26774)
Added support for engine validation check for EP Context models. ### Motivation and Context We wanted to implement the GetModelCompatibilityForEpDevices() API support and thus have an end user available API for the engine validation check for EP context models. Added this support and the necessary function implementation
1 parent e1236ca commit 727db0d

File tree

6 files changed

+291
-1
lines changed

6 files changed

+291
-1
lines changed

onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2578,6 +2578,84 @@ const InlinedVector<const Node*> NvExecutionProvider::GetEpContextNodes() const
25782578
return ep_context_nodes;
25792579
}
25802580

2581+
std::string NvExecutionProvider::GetCompiledModelCompatibilityInfo(
2582+
const onnxruntime::GraphViewer& graph_viewer) const {
2583+
ORT_UNUSED_PARAMETER(graph_viewer);
2584+
2585+
// Protect read access to engine_headers_ for thread safety
2586+
auto lock = GetApiLock();
2587+
2588+
// Compatibility info is only supported when there is exactly one engine.
2589+
// If multiple EPContext nodes/engines exist, return empty so validation is not applicable.
2590+
if (engine_headers_.size() > 1) {
2591+
return std::string();
2592+
}
2593+
2594+
// If we have stored engine headers, return the first one found
2595+
// (typically there's only one per EP context)
2596+
if (!engine_headers_.empty()) {
2597+
return engine_headers_.begin()->second;
2598+
}
2599+
2600+
// No headers available - validation not supported for this model
2601+
return std::string();
2602+
}
2603+
2604+
common::Status NvExecutionProvider::ValidateCompiledModelCompatibilityInfo(
2605+
const std::string& compatibility_info,
2606+
OrtCompiledModelCompatibility& model_compatibility) const {
2607+
// If no compatibility info provided, validation not applicable
2608+
if (compatibility_info.empty()) {
2609+
model_compatibility = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE;
2610+
return Status::OK();
2611+
}
2612+
2613+
// Decode hex string to binary
2614+
std::vector<uint8_t> engine_header;
2615+
try {
2616+
engine_header = HexStringToBinary(compatibility_info);
2617+
} catch (const std::exception& ex) {
2618+
LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Failed to decode engine header: " << ex.what();
2619+
model_compatibility = OrtCompiledModelCompatibility_EP_UNSUPPORTED;
2620+
return Status::OK();
2621+
}
2622+
2623+
// Use TensorRT RTX's getEngineValidity to check compatibility
2624+
uint64_t diagnostics = 0;
2625+
nvinfer1::EngineValidity validity = runtime_->getEngineValidity(
2626+
engine_header.data(),
2627+
engine_header.size(),
2628+
&diagnostics);
2629+
2630+
// Map TensorRT RTX validity to ORT compatibility status
2631+
switch (validity) {
2632+
case nvinfer1::EngineValidity::kVALID:
2633+
LOGS_DEFAULT(INFO) << "[NvTensorRTRTX EP] Engine is fully compatible with this system";
2634+
model_compatibility = OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL;
2635+
break;
2636+
2637+
case nvinfer1::EngineValidity::kSUBOPTIMAL:
2638+
LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Engine is compatible but recompilation recommended "
2639+
<< "(diagnostics: 0x" << std::hex << diagnostics << std::dec << ")";
2640+
model_compatibility = OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION;
2641+
break;
2642+
2643+
case nvinfer1::EngineValidity::kINVALID:
2644+
LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Engine is incompatible with this system "
2645+
<< "(diagnostics: 0x" << std::hex << diagnostics << std::dec << ")";
2646+
model_compatibility = OrtCompiledModelCompatibility_EP_UNSUPPORTED;
2647+
break;
2648+
2649+
default:
2650+
LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Unknown TensorRT validity status: "
2651+
<< static_cast<int>(validity);
2652+
model_compatibility = OrtCompiledModelCompatibility_EP_UNSUPPORTED;
2653+
break;
2654+
}
2655+
2656+
return Status::OK();
2657+
}
2658+
25812659
Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& graph_body_viewer,
25822660
const Node& fused_node,
25832661
std::unordered_map<std::string, size_t>& input_map,
@@ -2854,6 +2932,18 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr
28542932
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
28552933
"NvTensorRTRTX EP failed to create engine from network for fused node: " + fused_node.Name());
28562934
}
2935+
2936+
// Capture engine header (first 64 bytes) for compatibility validation
2937+
if (serialized_engine->size() >= kTensorRTEngineHeaderSize) {
2938+
std::string engine_header_hex = BinaryToHexString(
2939+
serialized_engine->data(),
2940+
kTensorRTEngineHeaderSize);
2941+
engine_headers_[fused_node.Name()] = engine_header_hex;
2942+
} else {
2943+
LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Engine too small to capture header for validation: "
2944+
<< serialized_engine->size() << " bytes";
2945+
}
2946+
28572947
trt_engine = std::unique_ptr<nvinfer1::ICudaEngine>(runtime_->deserializeCudaEngine(serialized_engine->data(), serialized_engine->size()));
28582948
if (trt_engine == nullptr) {
28592949
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,

onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,13 @@ class NvExecutionProvider : public IExecutionProvider {
345345

346346
const InlinedVector<const Node*> GetEpContextNodes() const override;
347347

348+
// Engine compatibility validation methods
349+
std::string GetCompiledModelCompatibilityInfo(const onnxruntime::GraphViewer& graph_viewer) const override;
350+
351+
common::Status ValidateCompiledModelCompatibilityInfo(
352+
const std::string& compatibility_info,
353+
OrtCompiledModelCompatibility& model_compatibility) const override;
354+
348355
private:
349356
mutable NvExecutionProviderInfo info_;
350357
bool external_stream_ = false;
@@ -424,6 +431,10 @@ class NvExecutionProvider : public IExecutionProvider {
424431
std::unordered_map<std::string, std::vector<nvinfer1::IOptimizationProfile*>> profiles_;
425432
std::unordered_map<std::string, DDSOutputAllocatorMap> dds_output_allocator_maps_;
426433

434+
// Storage for engine headers (64 bytes) for compatibility validation
435+
// Maps fused_node_name -> hex-encoded engine header
436+
mutable std::unordered_map<std::string, std::string> engine_headers_;
437+
427438
// for external stream, we need to create its cudnn/cublass handle before cuda EP enable cuda graph capture
428439
cudnnHandle_t external_cudnn_handle_ = nullptr;
429440
cublasHandle_t external_cublas_handle_ = nullptr;

onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc

Lines changed: 122 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "core/providers/cuda/shared_inc/cuda_call.h"
1414
#include "core/providers/cuda/cuda_stream_handle.h"
1515

16+
#include "onnx_ctx_model_helper.h"
1617
#include "nv_provider_factory.h"
1718
#include "nv_execution_provider.h"
1819
#include "nv_provider_factory_creator.h"
@@ -21,6 +22,11 @@
2122

2223
using namespace onnxruntime;
2324

25+
// External declarations
26+
namespace onnxruntime {
27+
extern TensorrtLogger& GetTensorrtLogger(bool verbose_log);
28+
}
29+
2430
namespace onnxruntime {
2531

2632
void InitializeRegistry();
@@ -541,7 +547,7 @@ struct NvTensorRtRtxEpFactory : OrtEpFactory {
541547

542548
IsStreamAware = IsStreamAwareImpl;
543549
CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl;
544-
550+
ValidateCompiledModelCompatibilityInfo = ValidateCompiledModelCompatibilityInfoImpl;
545551
ort_version_supported = ORT_API_VERSION; // Set to the ORT version we were compiled with.
546552
}
547553

@@ -681,6 +687,7 @@ struct NvTensorRtRtxEpFactory : OrtEpFactory {
681687

682688
RETURN_IF_ERROR(factory->ort_api.GetEpApi()->CreateEpDevice(factory, &device, ep_metadata, ep_options,
683689
&ep_devices[num_ep_devices]));
690+
684691
factory->ort_api.ReleaseKeyValuePairs(ep_options);
685692
factory->ort_api.ReleaseKeyValuePairs(ep_metadata);
686693

@@ -755,6 +762,120 @@ struct NvTensorRtRtxEpFactory : OrtEpFactory {
755762
return nullptr;
756763
}
757764

765+
/**
766+
* This function is called by the public C API GetModelCompatibilityForEpDevices.
767+
* It uses TensorRT RTX runtime directly to call runtime->getEngineValidity() to check the 64-byte engine header.
768+
*
769+
* @param this_ptr Factory instance pointer
770+
* @param devices Hardware devices (not used, validation is done against current system)
771+
* @param num_devices Number of devices
772+
* @param compatibility_info Hex-encoded 64-byte TensorRT RTX engine header (128 hex characters)
773+
* @param model_compatibility Output parameter for compatibility status
774+
* @return OrtStatus* nullptr on success, error status on failure
775+
*/
776+
static OrtStatus* ORT_API_CALL ValidateCompiledModelCompatibilityInfoImpl(
777+
OrtEpFactory* this_ptr,
778+
const OrtHardwareDevice* const* devices,
779+
size_t num_devices,
780+
const char* compatibility_info,
781+
OrtCompiledModelCompatibility* model_compatibility) noexcept {
782+
auto& factory = *static_cast<NvTensorRtRtxEpFactory*>(this_ptr);
783+
784+
// Validate input parameters
785+
if (compatibility_info == nullptr || model_compatibility == nullptr) {
786+
return factory.ort_api.CreateStatus(ORT_INVALID_ARGUMENT,
787+
"[NvTensorRTRTX EP] Invalid arguments: compatibility_info or model_compatibility is null");
788+
}
789+
790+
// Device parameters not used for header validation
791+
ORT_UNUSED_PARAMETER(devices);
792+
ORT_UNUSED_PARAMETER(num_devices);
793+
794+
try {
795+
// If no compatibility info provided, validation not applicable
796+
if (compatibility_info[0] == '\0') {
797+
*model_compatibility = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE;
798+
return nullptr;
799+
}
800+
801+
// Decode hex string to binary
802+
std::vector<uint8_t> engine_header;
803+
try {
804+
engine_header = HexStringToBinary(std::string(compatibility_info));
805+
} catch (const std::exception& ex) {
806+
LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Failed to decode engine header: " << ex.what();
807+
*model_compatibility = OrtCompiledModelCompatibility_EP_UNSUPPORTED;
808+
return nullptr;
809+
}
810+
811+
// Validate header size (keep in sync with TensorRT engine header size)
812+
if (engine_header.size() != kTensorRTEngineHeaderSize) {
813+
LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Invalid header size: " << engine_header.size()
814+
<< " bytes (expected " << kTensorRTEngineHeaderSize << ")";
815+
*model_compatibility = OrtCompiledModelCompatibility_EP_UNSUPPORTED;
816+
return nullptr;
817+
}
818+
819+
// Create TensorRT runtime for validation
820+
static std::mutex runtime_creation_mutex;
821+
std::unique_ptr<nvinfer1::IRuntime> runtime;
822+
{
823+
std::lock_guard<std::mutex> lock(runtime_creation_mutex);
824+
TensorrtLogger& trt_logger = GetTensorrtLogger(false);
825+
runtime.reset(nvinfer1::createInferRuntime(trt_logger));
826+
}
827+
828+
if (!runtime) {
829+
LOGS_DEFAULT(ERROR) << "[NvTensorRTRTX EP] Failed to create TensorRT runtime";
830+
return factory.ort_api.CreateStatus(ORT_FAIL,
831+
"[NvTensorRTRTX EP] Failed to create TensorRT runtime");
832+
}
833+
834+
// Use TensorRT's getEngineValidity to check compatibility
835+
uint64_t diagnostics = 0;
836+
nvinfer1::EngineValidity validity = runtime->getEngineValidity(
837+
engine_header.data(),
838+
engine_header.size(),
839+
&diagnostics);
840+
841+
// Map TensorRT validity to ORT compatibility status
842+
switch (validity) {
843+
case nvinfer1::EngineValidity::kVALID:
844+
*model_compatibility = OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL;
845+
break;
846+
847+
case nvinfer1::EngineValidity::kSUBOPTIMAL:
848+
LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Engine compatible but recompilation recommended "
849+
<< "(diagnostics: 0x" << std::hex << diagnostics << std::dec << ")";
850+
*model_compatibility = OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION;
851+
break;
852+
853+
case nvinfer1::EngineValidity::kINVALID:
854+
LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Engine incompatible with this system "
855+
<< "(diagnostics: 0x" << std::hex << diagnostics << std::dec << ")";
856+
*model_compatibility = OrtCompiledModelCompatibility_EP_UNSUPPORTED;
857+
break;
858+
859+
default:
860+
LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Unknown validity status: "
861+
<< static_cast<int>(validity);
862+
*model_compatibility = OrtCompiledModelCompatibility_EP_UNSUPPORTED;
863+
break;
864+
}
865+
866+
return nullptr;
867+
868+
} catch (const std::exception& ex) {
869+
std::string error_msg = std::string("[NvTensorRTRTX EP] Exception during validation: ") + ex.what();
870+
LOGS_DEFAULT(ERROR) << error_msg;
871+
return factory.ort_api.CreateStatus(ORT_FAIL, error_msg.c_str());
872+
} catch (...) {
873+
LOGS_DEFAULT(ERROR) << "[NvTensorRTRTX EP] Unknown exception during validation";
874+
return factory.ort_api.CreateStatus(ORT_FAIL,
875+
"[NvTensorRTRTX EP] Unknown exception during validation");
876+
}
877+
}
878+
758879
OrtStatus* CreateMemoryInfoForDevices(int num_devices) {
759880
gpu_memory_infos.reserve(num_devices);
760881
host_accessible_memory_infos.reserve(num_devices);

onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,53 @@
1414
namespace onnxruntime {
1515
extern TensorrtLogger& GetTensorrtLogger(bool verbose_log);
1616

17+
/*
18+
* Convert binary data to hex string
19+
*/
20+
std::string BinaryToHexString(const void* data, size_t size) {
21+
static const char hex_chars[] = "0123456789abcdef";
22+
const uint8_t* bytes = static_cast<const uint8_t*>(data);
23+
std::string result;
24+
result.reserve(size * 2);
25+
26+
for (size_t i = 0; i < size; ++i) {
27+
result.push_back(hex_chars[(bytes[i] >> 4) & 0xF]);
28+
result.push_back(hex_chars[bytes[i] & 0xF]);
29+
}
30+
return result;
31+
}
32+
33+
/*
34+
* Convert hex string back to binary
35+
*/
36+
std::vector<uint8_t> HexStringToBinary(const std::string& hex) {
37+
if (hex.size() % 2 != 0) {
38+
ORT_THROW("Hex string must have even length");
39+
}
40+
41+
std::vector<uint8_t> result;
42+
result.reserve(hex.size() / 2);
43+
44+
for (size_t i = 0; i < hex.size(); i += 2) {
45+
uint8_t byte = 0;
46+
47+
// High nibble
48+
char c = hex[i];
49+
byte |= (c >= '0' && c <= '9') ? static_cast<uint8_t>((c - '0') << 4) : (c >= 'a' && c <= 'f') ? static_cast<uint8_t>((c - 'a' + 10) << 4)
50+
: (c >= 'A' && c <= 'F') ? static_cast<uint8_t>((c - 'A' + 10) << 4)
51+
: 0;
52+
53+
// Low nibble
54+
c = hex[i + 1];
55+
byte |= (c >= '0' && c <= '9') ? static_cast<uint8_t>(c - '0') : (c >= 'a' && c <= 'f') ? static_cast<uint8_t>(c - 'a' + 10)
56+
: (c >= 'A' && c <= 'F') ? static_cast<uint8_t>(c - 'A' + 10)
57+
: 0;
58+
59+
result.push_back(byte);
60+
}
61+
return result;
62+
}
63+
1764
/*
1865
* Check whether the graph has the EP context contrib op.
1966
* The op can contain the precompiled engine info for TRT EP to directly load the engine.

onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@ static const std::string PARTITION_NAME = "partition_name";
2424
static const std::string SDK_VERSION = "ep_sdk_version";
2525
static const std::string EPCONTEXT_OP_DOMAIN = "com.microsoft";
2626

27+
// TensorRT does not currently expose a header size define; keep in sync with TRT engine serialization header size.
28+
constexpr size_t kTensorRTEngineHeaderSize = 64;
29+
// Helper functions for engine header validation
30+
std::string BinaryToHexString(const void* data, size_t size);
31+
std::vector<uint8_t> HexStringToBinary(const std::string& hex);
32+
2733
bool GraphHasCtxNode(const GraphViewer& graph_viewer, size_t& node_idx);
2834
const std::filesystem::path& GetModelPath(const GraphViewer& graph_viewer);
2935
std::filesystem::path GetPathOrParentPathOfCtxModel(const std::string& ep_context_file_path);

onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,21 @@ class ProviderBridgeEpFactory : public EpFactoryInternalImpl {
8686
return ep_factory_.GetHardwareDeviceIncompatibilityDetails(&ep_factory_, hw, details);
8787
}
8888

89+
OrtStatus* ValidateCompiledModelCompatibilityInfo(
90+
const OrtHardwareDevice* const* devices,
91+
size_t num_devices,
92+
const char* compatibility_info,
93+
OrtCompiledModelCompatibility* model_compatibility) noexcept override {
94+
// Forward to underlying factory if it supports validation
95+
if (ep_factory_.ValidateCompiledModelCompatibilityInfo) {
96+
return ep_factory_.ValidateCompiledModelCompatibilityInfo(
97+
&ep_factory_, devices, num_devices, compatibility_info, model_compatibility);
98+
}
99+
// If not supported, return NOT_APPLICABLE
100+
*model_compatibility = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE;
101+
return nullptr;
102+
}
103+
89104
OrtEpFactory& ep_factory_;
90105
ProviderLibrary& provider_library_;
91106
std::optional<std::filesystem::path> library_path_;

0 commit comments

Comments
 (0)