Skip to content

[NV EP] fix EP context options #24545

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 118 additions & 82 deletions onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,17 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector<OrtCustomOpDomain*>&
while (std::getline(extra_plugin_libs, lib, ';')) {
auto status = LoadDynamicLibrary(ToPathString(lib));
if (status == Status::OK()) {
LOGS_DEFAULT(VERBOSE) << "[Nv EP] Successfully load " << lib;
LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Successfully load " << lib;
} else {
LOGS_DEFAULT(WARNING) << "[Nv EP]" << status.ToString();
LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP]" << status.ToString();
}
}
is_loaded = true;
}

try {
// Get all registered TRT plugins from registry
LOGS_DEFAULT(VERBOSE) << "[Nv EP] Getting all registered TRT plugins from TRT plugin registry ...";
LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Getting all registered TRT plugins from TRT plugin registry ...";
TensorrtLogger trt_logger = GetTensorrtLogger(false);
void* library_handle = nullptr;
const auto& env = onnxruntime::GetDefaultEnv();
Expand All @@ -79,14 +79,14 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector<OrtCustomOpDomain*>&
bool (*dyn_initLibNvInferPlugins)(void* logger, char const* libNamespace);
ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(library_handle, "initLibNvInferPlugins", (void**)&dyn_initLibNvInferPlugins));
dyn_initLibNvInferPlugins(&trt_logger, "");
LOGS_DEFAULT(INFO) << "[Nv EP] Default plugins successfully loaded.";
LOGS_DEFAULT(INFO) << "[NvTensorRTRTX EP] Default plugins successfully loaded.";

#if defined(_MSC_VER)
#pragma warning(push)
#pragma warning(disable : 4996) // Ignore warning C4996: 'nvinfer1::*' was declared deprecated
#endif
} catch (const std::exception&) {
LOGS_DEFAULT(INFO) << "[Nv EP] Default plugin library is not on the path and is therefore ignored";
LOGS_DEFAULT(INFO) << "[NvTensorRTRTX EP] Default plugin library is not on the path and is therefore ignored";
}
try {
int num_plugin_creator = 0;
Expand All @@ -96,7 +96,7 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector<OrtCustomOpDomain*>&
for (int i = 0; i < num_plugin_creator; i++) {
auto plugin_creator = plugin_creators[i];
std::string plugin_name(plugin_creator->getPluginName());
LOGS_DEFAULT(VERBOSE) << "[Nv EP] " << plugin_name << ", version : " << plugin_creator->getPluginVersion();
LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] " << plugin_name << ", version : " << plugin_creator->getPluginVersion();

// plugin has different versions and we only register once
if (registered_plugin_names.find(plugin_name) != registered_plugin_names.end()) {
Expand All @@ -116,7 +116,7 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector<OrtCustomOpDomain*>&
custom_op_domain->domain_ = "trt.plugins";
domain_list.push_back(custom_op_domain.get());
} catch (const std::exception&) {
LOGS_DEFAULT(WARNING) << "[Nv EP] Failed to get TRT plugins from TRT plugin registration. Therefore, TRT EP can't create custom ops for TRT plugins";
LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Failed to get TRT plugins from TRT plugin registration. Therefore, TRT EP can't create custom ops for TRT plugins";
}
return Status::OK();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,31 +169,31 @@
}
std::string unique_graph_name = GetUniqueGraphName(*top_level_graph);
if (subgraph_context_map_.find(unique_graph_name) == subgraph_context_map_.end()) {
LOGS_DEFAULT(ERROR) << "[Nv EP] Can't find top-level graph context. \
LOGS_DEFAULT(ERROR) << "[NvTensorRTRTX EP] Can't find top-level graph context. \

Check warning on line 172 in onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_helper.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Multi-line string ("...") found. This lint script doesn't do well with such strings, and may give bogus warnings. Use C++11 raw strings or concatenation instead. [readability/multiline_string] [5] Raw Output: onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_helper.cc:172: Multi-line string ("...") found. This lint script doesn't do well with such strings, and may give bogus warnings. Use C++11 raw strings or concatenation instead. [readability/multiline_string] [5]
Please check BuildSubGraphContext() has built the graph context correctly.";
return;
}

SubGraphContext* context = subgraph_context_map_.at(unique_graph_name).get();

LOGS_DEFAULT(VERBOSE) << "[Nv EP] Subgraph name is " << graph_build.Name();
LOGS_DEFAULT(VERBOSE) << "[Nv EP] Its parent node is " << graph.ParentNode()->Name();
LOGS_DEFAULT(VERBOSE) << "[Nv EP] Its parent node's implicit inputs:";
LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Subgraph name is " << graph_build.Name();
LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Its parent node is " << graph.ParentNode()->Name();
LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Its parent node's implicit inputs:";

// Iterate all the implicit inputs to set outer scope value for the newly built subgraph
for (const auto& input : graph.ParentNode()->ImplicitInputDefs()) {
LOGS_DEFAULT(VERBOSE) << "[Nv EP] \t" << input->Name();
LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] \t" << input->Name();

// The node arg in parent node's implicit inputs could be used for parent node's other subgraph, for example
// "If" op has two subgraphs. So we need to make sure that the node arg is used in current subgraph only.
// (GetNodeArg searches for specific node arg in all node args in the graph)
if (graph_build.GetNodeArg(input->Name())) {
graph_build.AddOuterScopeNodeArg(input->Name());
LOGS_DEFAULT(VERBOSE) << "[Nv EP] \t" << input->Name() << " is used in this subgraph";
LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] \t" << input->Name() << " is used in this subgraph";

if (context &&
(context->manually_added_graph_inputs.find(input->Name()) != context->manually_added_graph_inputs.end())) {
LOGS_DEFAULT(VERBOSE) << "[Nv EP] \t" << input->Name() << " is already been added as an explicit input to graph";
LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] \t" << input->Name() << " is already been added as an explicit input to graph";
continue;
}

Expand All @@ -213,7 +213,7 @@
type_proto->copy_from(input->TypeAsProto());
auto& n_input = top_level_graph->GetOrCreateNodeArg(name, type_proto.get());
context->manually_added_graph_inputs[n_input.Name()] = &n_input;
LOGS_DEFAULT(VERBOSE) << "[Nv EP] \t" << n_input.Name() << " is added as an explicit input into the newly built graph";
LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] \t" << n_input.Name() << " is added as an explicit input into the newly built graph";
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
#include "core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h"
#include "core/providers/nv_tensorrt_rtx/nv_provider_options.h"

#include "core/session/onnxruntime_session_options_config_keys.h"
#include "core/common/make_string.h"
#include "core/common/parse_string.h"
#include "core/framework/provider_options_utils.h"
#include "core/providers/cuda/cuda_common.h"

namespace onnxruntime {
NvExecutionProviderInfo NvExecutionProviderInfo::FromProviderOptions(const ProviderOptions& options) {
NvExecutionProviderInfo NvExecutionProviderInfo::FromProviderOptions(const ProviderOptions& options,
const ConfigOptions& session_options) {
NvExecutionProviderInfo info{};
void* user_compute_stream = nullptr;
void* onnx_bytestream = nullptr;
Expand Down Expand Up @@ -58,6 +60,25 @@ NvExecutionProviderInfo NvExecutionProviderInfo::FromProviderOptions(const Provi
info.user_compute_stream = user_compute_stream;
info.has_user_compute_stream = (user_compute_stream != nullptr);
info.onnx_bytestream = onnx_bytestream;

// EP context settings
const auto embed_enable = session_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0");
if (embed_enable == "0") {
info.dump_ep_context_model = false;
} else if (embed_enable == "1") {
info.dump_ep_context_model = true;
} else {
ORT_THROW("Invalid ", kOrtSessionOptionEpContextEnable, " must 0 or 1");
}
info.ep_context_file_path = session_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, "");

const auto embed_mode = std::stoi(session_options.GetConfigOrDefault(kOrtSessionOptionEpContextEmbedMode, "1"));
if (0 <= embed_mode || embed_mode < 2) {
info.ep_context_embed_mode = embed_mode;
} else {
ORT_THROW("Invalid ", kOrtSessionOptionEpContextEmbedMode, " must 0 or 1");
}

return info;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
#include "core/framework/ortdevice.h"
#include "core/framework/provider_options.h"
#include "core/framework/framework_provider_common.h"
#include "core/session/onnxruntime_c_api.h"
#include "core/framework/library_handles.h"
#include "core/session/onnxruntime_c_api.h"
#include "core/providers/shared_library/provider_api.h"

#define TRT_DEFAULT_OPTIMIZER_LEVEL 3

Expand All @@ -19,18 +20,10 @@ struct NvExecutionProviderInfo {
int device_id{0};
bool has_user_compute_stream{false};
void* user_compute_stream{nullptr};
bool has_trt_options{false};
int max_partition_iterations{1000};
int min_subgraph_size{1};
size_t max_workspace_size{0};
bool fp16_enable{false};
bool int8_enable{false};
std::string int8_calibration_table_name{""};
bool int8_use_native_calibration_table{false};
bool dla_enable{false};
int dla_core{0};
bool dump_subgraphs{false};
bool engine_cache_enable{false};
std::string engine_cache_path{""};
bool weight_stripped_engine_enable{false};
std::string onnx_model_folder_path{""};
Expand All @@ -40,16 +33,10 @@ struct NvExecutionProviderInfo {
std::string engine_decryption_lib_path{""};
bool force_sequential_engine_build{false};
bool context_memory_sharing_enable{false};
bool layer_norm_fp32_fallback{false};
bool timing_cache_enable{false};
std::string timing_cache_path{""};
bool force_timing_cache{false};
bool detailed_build_log{false};
bool build_heuristics_enable{false};
bool sparsity_enable{false};
int builder_optimization_level{3};
int auxiliary_streams{-1};
std::string tactic_sources{""};
std::string extra_plugin_lib_paths{""};
std::string profile_min_shapes{""};
std::string profile_max_shapes{""};
Expand All @@ -59,10 +46,10 @@ struct NvExecutionProviderInfo {
std::string ep_context_file_path{""};
int ep_context_embed_mode{0};
std::string engine_cache_prefix{""};
bool engine_hw_compatible{false};
std::string op_types_to_exclude{""};

static NvExecutionProviderInfo FromProviderOptions(const ProviderOptions& options);
static NvExecutionProviderInfo FromProviderOptions(const ProviderOptions& options,
const ConfigOptions& session_options);
static ProviderOptions ToProviderOptions(const NvExecutionProviderInfo& info);
std::vector<OrtCustomOpDomain*> custom_op_domain_list;
};
Expand Down
Loading
Loading