Skip to content

Commit 0432e71

Browse files
authored
perftest: support plugin eps for compile_ep_context (#27121)
* Extend compile_ep_context to also support plugin eps * Adds compile_only option to skip execution, can be used when compiling for virtual devices compile_ep_context (physical device) <img width="1259" height="510" alt="image" src="https://github.com/user-attachments/assets/14650c17-0c8a-4002-a7ce-e8e4c815a516" /> compile_ep_context + compile_only (virtual device) <img width="1262" height="173" alt="image" src="https://github.com/user-attachments/assets/2f0844cc-5e83-4b2d-bf0a-0d815d9bad29" />
1 parent 40a4898 commit 0432e71

File tree

6 files changed

+137
-105
lines changed

6 files changed

+137
-105
lines changed

onnxruntime/test/perftest/command_args_parser.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ ABSL_FLAG(std::string, filter_ep_devices, "",
203203
ABSL_FLAG(bool, compile_ep_context, DefaultPerformanceTestConfig().run_config.compile_ep_context, "Generate an EP context model");
204204
ABSL_FLAG(std::string, compile_model_path, "model_ctx.onnx", "The compiled model path for saving EP context model. Overwrites if already exists");
205205
ABSL_FLAG(bool, compile_binary_embed, DefaultPerformanceTestConfig().run_config.compile_binary_embed, "Embed binary blob within EP context node");
206+
ABSL_FLAG(bool, compile_only, DefaultPerformanceTestConfig().run_config.compile_only, "Only compile EP context model without running it");
206207
ABSL_FLAG(bool, h, false, "Print program usage.");
207208

208209
namespace onnxruntime {
@@ -583,6 +584,9 @@ bool CommandLineParser::ParseArguments(PerformanceTestConfig& test_config, int a
583584
// --compile_binary_embed
584585
test_config.run_config.compile_binary_embed = absl::GetFlag(FLAGS_compile_binary_embed);
585586

587+
// --compile_only
588+
test_config.run_config.compile_only = absl::GetFlag(FLAGS_compile_only);
589+
586590
if (positional.size() == 2) {
587591
test_config.model_info.model_file_path = ToPathString(positional[1]);
588592
test_config.run_config.f_dump_statistics = true;

onnxruntime/test/perftest/common_utils.cc

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,116 @@ std::vector<char*> CStringsFromStrings(std::vector<std::string>& utf8_args) {
9090
return utf8_argv;
9191
}
9292

93+
void AppendPluginExecutionProviders(Ort::Env& env,
94+
Ort::SessionOptions& session_options,
95+
const PerformanceTestConfig& test_config) {
96+
if (test_config.registered_plugin_eps.empty()) {
97+
return;
98+
}
99+
100+
std::vector<Ort::ConstEpDevice> ep_devices = env.GetEpDevices();
101+
// EP -> associated EP devices (All OrtEpDevice instances must be from the same execution provider)
102+
std::unordered_map<std::string, std::vector<Ort::ConstEpDevice>> added_ep_devices;
103+
std::unordered_set<int> added_ep_device_index_set;
104+
105+
auto& ep_list = test_config.machine_config.plugin_provider_type_list;
106+
std::unordered_set<std::string> ep_set(ep_list.begin(), ep_list.end());
107+
108+
// Select EP devices by provided device index
109+
if (!test_config.selected_ep_device_indices.empty()) {
110+
std::vector<int> device_list;
111+
device_list.reserve(test_config.selected_ep_device_indices.size());
112+
ParseEpDeviceIndexList(test_config.selected_ep_device_indices, device_list);
113+
for (auto index : device_list) {
114+
if (static_cast<size_t>(index) > (ep_devices.size() - 1)) {
115+
fprintf(stderr, "%s", "The device index provided is not correct. Will skip this device id.");
116+
continue;
117+
}
118+
119+
Ort::ConstEpDevice& device = ep_devices[index];
120+
if (ep_set.find(std::string(device.EpName())) != ep_set.end()) {
121+
if (added_ep_device_index_set.find(index) == added_ep_device_index_set.end()) {
122+
added_ep_devices[device.EpName()].push_back(device);
123+
added_ep_device_index_set.insert(index);
124+
fprintf(stdout, "[Plugin EP] EP Device [Index: %d, Name: %s, Type: %d] has been added to session.\n", static_cast<int>(index), device.EpName(), device.Device().Type());
125+
}
126+
} else {
127+
std::string err_msg = "[Plugin EP] [WARNING] : The EP device index and its corresponding OrtEpDevice is not created from " +
128+
test_config.machine_config.provider_type_name + ". Will skip adding this device.\n";
129+
fprintf(stderr, "%s", err_msg.c_str());
130+
}
131+
}
132+
} else if (!test_config.filter_ep_device_kv_pairs.empty()) {
133+
// Find and select the OrtEpDevice associated with the EP in "--filter_ep_devices".
134+
for (const auto& kv : test_config.filter_ep_device_kv_pairs) {
135+
for (size_t index = 0; index < ep_devices.size(); ++index) {
136+
auto device = ep_devices[index];
137+
if (ep_set.find(std::string(device.EpName())) == ep_set.end())
138+
continue;
139+
140+
// Skip if deviceid was already added
141+
if (added_ep_devices.find(device.EpName()) != added_ep_devices.end() &&
142+
std::find(added_ep_devices[device.EpName()].begin(), added_ep_devices[device.EpName()].end(), device) != added_ep_devices[device.EpName()].end())
143+
continue;
144+
145+
// Check both EP metadata and device metadata for a match
146+
auto ep_metadata_kv_pairs = device.EpMetadata().GetKeyValuePairs();
147+
auto device_metadata_kv_pairs = device.Device().Metadata().GetKeyValuePairs();
148+
auto ep_metadata_itr = ep_metadata_kv_pairs.find(kv.first);
149+
auto device_metadata_itr = device_metadata_kv_pairs.find(kv.first);
150+
151+
if ((ep_metadata_itr != ep_metadata_kv_pairs.end() && kv.second == ep_metadata_itr->second) ||
152+
(device_metadata_itr != device_metadata_kv_pairs.end() && kv.second == device_metadata_itr->second)) {
153+
added_ep_devices[device.EpName()].push_back(device);
154+
fprintf(stdout, "[Plugin EP] EP Device [Index: %d, Name: %s, Type: %d] has been added to session.\n", static_cast<int>(index), device.EpName(), device.Device().Type());
155+
break;
156+
}
157+
}
158+
}
159+
} else {
160+
// Find and select the OrtEpDevice associated with the EP in "--plugin_eps".
161+
for (size_t index = 0; index < ep_devices.size(); ++index) {
162+
Ort::ConstEpDevice& device = ep_devices[index];
163+
if (ep_set.find(std::string(device.EpName())) != ep_set.end()) {
164+
added_ep_devices[device.EpName()].push_back(device);
165+
fprintf(stdout, "[Plugin EP] EP Device [Index: %d, Name: %s] has been added to session.\n", static_cast<int>(index), device.EpName());
166+
}
167+
}
168+
}
169+
170+
if (added_ep_devices.empty()) {
171+
ORT_THROW("[ERROR] [Plugin EP]: No matching EP devices found.");
172+
}
173+
174+
std::string ep_option_string = ToUTF8String(test_config.run_config.ep_runtime_config_string);
175+
176+
// EP's associated provider option lists
177+
std::vector<std::unordered_map<std::string, std::string>> ep_options_list;
178+
ParseEpOptions(ep_option_string, ep_options_list);
179+
180+
// If user only provide the EPs' provider option lists for the first several EPs,
181+
// add empty provider option lists for the rest EPs.
182+
if (ep_options_list.size() < ep_list.size()) {
183+
for (size_t i = ep_options_list.size(); i < ep_list.size(); ++i) {
184+
ep_options_list.emplace_back(); // Adds a new empty map
185+
}
186+
} else if (ep_options_list.size() > ep_list.size()) {
187+
ORT_THROW("[ERROR] [Plugin EP]: Too many EP provider option lists provided.");
188+
}
189+
190+
// EP -> associated provider options
191+
std::unordered_map<std::string, std::unordered_map<std::string, std::string>> ep_options_map;
192+
for (size_t i = 0; i < ep_list.size(); ++i) {
193+
ep_options_map.emplace(ep_list[i], ep_options_list[i]);
194+
}
195+
196+
for (auto& ep_and_devices : added_ep_devices) {
197+
auto& ep = ep_and_devices.first;
198+
auto& devices = ep_and_devices.second;
199+
session_options.AppendExecutionProvider_V2(env, devices, ep_options_map[ep]);
200+
}
201+
}
202+
93203
} // namespace utils
94204
} // namespace perftest
95205
} // namespace onnxruntime

onnxruntime/test/perftest/main.cc

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ using namespace onnxruntime;
1414
const OrtApi* g_ort = NULL;
1515

1616
int RunPerfTest(Ort::Env& env, const perftest::PerformanceTestConfig& test_config);
17-
Ort::Status CompileEpContextModel(const Ort::Env& env, const perftest::PerformanceTestConfig& test_config);
17+
Ort::Status CompileEpContextModel(Ort::Env& env, const perftest::PerformanceTestConfig& test_config);
1818

1919
#ifdef _WIN32
2020
int real_main(int argc, wchar_t* argv[]) {
@@ -82,6 +82,12 @@ int real_main(int argc, char* argv[]) {
8282
return -1;
8383
}
8484

85+
std::cout << "Model compiled successfully to " << ToUTF8String(test_config.run_config.compile_model_path) << "\n";
86+
if (test_config.run_config.compile_only) {
87+
return 0;
88+
}
89+
90+
std::cout << "\n> Running EP context model...\n";
8591
{
8692
test_config.model_info.model_file_path = test_config.run_config.compile_model_path;
8793
status = RunPerfTest(env, test_config);
@@ -134,14 +140,20 @@ int RunPerfTest(Ort::Env& env, const perftest::PerformanceTestConfig& test_confi
134140
return 0;
135141
}
136142

137-
Ort::Status CompileEpContextModel(const Ort::Env& env, const perftest::PerformanceTestConfig& test_config) {
143+
Ort::Status CompileEpContextModel(Ort::Env& env, const perftest::PerformanceTestConfig& test_config) {
138144
auto output_ctx_model_path = test_config.run_config.compile_model_path;
139145
const auto provider_name = test_config.machine_config.provider_type_name;
140146

141147
Ort::SessionOptions session_options;
142148

143-
std::unordered_map<std::string, std::string> provider_options;
144-
session_options.AppendExecutionProvider(provider_name, provider_options);
149+
// Add EP devices if any (created by plugin EP)
150+
if (!test_config.registered_plugin_eps.empty()) {
151+
perftest::utils::AppendPluginExecutionProviders(env, session_options, test_config);
152+
} else {
153+
// Regular non-plugin EP
154+
std::unordered_map<std::string, std::string> provider_options;
155+
session_options.AppendExecutionProvider(provider_name, provider_options);
156+
}
145157

146158
// free dim override
147159
if (!test_config.run_config.free_dim_name_overrides.empty()) {

onnxruntime/test/perftest/ort_test_session.cc

Lines changed: 2 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "providers.h"
2222
#include "TestCase.h"
2323
#include "strings_helper.h"
24+
#include "utils.h"
2425

2526
#ifdef USE_OPENVINO
2627
#include "nlohmann/json.hpp"
@@ -90,107 +91,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
9091

9192
// Add EP devices if any (created by plugin EP)
9293
if (!performance_test_config.registered_plugin_eps.empty()) {
93-
std::vector<Ort::ConstEpDevice> ep_devices = env.GetEpDevices();
94-
// EP -> associated EP devices (All OrtEpDevice instances must be from the same execution provider)
95-
std::unordered_map<std::string, std::vector<Ort::ConstEpDevice>> added_ep_devices;
96-
std::unordered_set<int> added_ep_device_index_set;
97-
98-
auto& ep_list = performance_test_config.machine_config.plugin_provider_type_list;
99-
std::unordered_set<std::string> ep_set(ep_list.begin(), ep_list.end());
100-
101-
// Select EP devices by provided device index
102-
if (!performance_test_config.selected_ep_device_indices.empty()) {
103-
std::vector<int> device_list;
104-
device_list.reserve(performance_test_config.selected_ep_device_indices.size());
105-
ParseEpDeviceIndexList(performance_test_config.selected_ep_device_indices, device_list);
106-
for (auto index : device_list) {
107-
if (static_cast<size_t>(index) > (ep_devices.size() - 1)) {
108-
fprintf(stderr, "%s", "The device index provided is not correct. Will skip this device id.");
109-
continue;
110-
}
111-
112-
Ort::ConstEpDevice& device = ep_devices[index];
113-
if (ep_set.find(std::string(device.EpName())) != ep_set.end()) {
114-
if (added_ep_device_index_set.find(index) == added_ep_device_index_set.end()) {
115-
added_ep_devices[device.EpName()].push_back(device);
116-
added_ep_device_index_set.insert(index);
117-
fprintf(stdout, "[Plugin EP] EP Device [Index: %d, Name: %s, Type: %d] has been added to session.\n", static_cast<int>(index), device.EpName(), device.Device().Type());
118-
}
119-
} else {
120-
std::string err_msg = "[Plugin EP] [WARNING] : The EP device index and its corresponding OrtEpDevice is not created from " +
121-
performance_test_config.machine_config.provider_type_name + ". Will skip adding this device.\n";
122-
fprintf(stderr, "%s", err_msg.c_str());
123-
}
124-
}
125-
} else if (!performance_test_config.filter_ep_device_kv_pairs.empty()) {
126-
// Find and select the OrtEpDevice associated with the EP in "--filter_ep_devices".
127-
for (const auto& kv : performance_test_config.filter_ep_device_kv_pairs) {
128-
for (size_t index = 0; index < ep_devices.size(); ++index) {
129-
auto device = ep_devices[index];
130-
if (ep_set.find(std::string(device.EpName())) == ep_set.end())
131-
continue;
132-
133-
// Skip if deviced was already added
134-
if (added_ep_devices.find(device.EpName()) != added_ep_devices.end() &&
135-
std::find(added_ep_devices[device.EpName()].begin(), added_ep_devices[device.EpName()].end(), device) != added_ep_devices[device.EpName()].end())
136-
continue;
137-
138-
// Check both EP metadata and device metadata for a match
139-
auto ep_metadata_kv_pairs = device.EpMetadata().GetKeyValuePairs();
140-
auto device_metadata_kv_pairs = device.Device().Metadata().GetKeyValuePairs();
141-
auto ep_metadata_itr = ep_metadata_kv_pairs.find(kv.first);
142-
auto device_metadata_itr = device_metadata_kv_pairs.find(kv.first);
143-
144-
if ((ep_metadata_itr != ep_metadata_kv_pairs.end() && kv.second == ep_metadata_itr->second) ||
145-
(device_metadata_itr != device_metadata_kv_pairs.end() && kv.second == device_metadata_itr->second)) {
146-
added_ep_devices[device.EpName()].push_back(device);
147-
fprintf(stdout, "[Plugin EP] EP Device [Index: %d, Name: %s, Type: %d] has been added to session.\n", static_cast<int>(index), device.EpName(), device.Device().Type());
148-
break;
149-
}
150-
}
151-
}
152-
} else {
153-
// Find and select the OrtEpDevice associated with the EP in "--plugin_eps".
154-
for (size_t index = 0; index < ep_devices.size(); ++index) {
155-
Ort::ConstEpDevice& device = ep_devices[index];
156-
if (ep_set.find(std::string(device.EpName())) != ep_set.end()) {
157-
added_ep_devices[device.EpName()].push_back(device);
158-
fprintf(stdout, "EP Device [Index: %d, Name: %s] has been added to session.\n", static_cast<int>(index), device.EpName());
159-
}
160-
}
161-
}
162-
163-
if (added_ep_devices.empty()) {
164-
ORT_THROW("[ERROR] [Plugin EP]: No matching EP devices found.");
165-
}
166-
167-
std::string ep_option_string = ToUTF8String(performance_test_config.run_config.ep_runtime_config_string);
168-
169-
// EP's associated provider option lists
170-
std::vector<std::unordered_map<std::string, std::string>> ep_options_list;
171-
ParseEpOptions(ep_option_string, ep_options_list);
172-
173-
// If user only provide the EPs' provider option lists for the first several EPs,
174-
// add empty provider option lists for the rest EPs.
175-
if (ep_options_list.size() < ep_list.size()) {
176-
for (size_t i = ep_options_list.size(); i < ep_list.size(); ++i) {
177-
ep_options_list.emplace_back(); // Adds a new empty map
178-
}
179-
} else if (ep_options_list.size() > ep_list.size()) {
180-
ORT_THROW("[ERROR] [Plugin EP]: Too many EP provider option lists provided.");
181-
}
182-
183-
// EP -> associated provider options
184-
std::unordered_map<std::string, std::unordered_map<std::string, std::string>> ep_options_map;
185-
for (size_t i = 0; i < ep_list.size(); ++i) {
186-
ep_options_map.emplace(ep_list[i], ep_options_list[i]);
187-
}
188-
189-
for (auto& ep_and_devices : added_ep_devices) {
190-
auto& ep = ep_and_devices.first;
191-
auto& devices = ep_and_devices.second;
192-
session_options.AppendExecutionProvider_V2(env, devices, ep_options_map[ep]);
193-
}
94+
perftest::utils::AppendPluginExecutionProviders(env, session_options, performance_test_config);
19495
}
19596

19697
provider_name_ = performance_test_config.machine_config.provider_type_name;

onnxruntime/test/perftest/test_configuration.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ struct RunConfig {
7777
bool compile_ep_context{false};
7878
std::basic_string<ORTCHAR_T> compile_model_path;
7979
bool compile_binary_embed{false};
80+
bool compile_only{false};
8081
struct CudaMempoolArenaConfig {
8182
std::string release_threshold;
8283
std::string bytes_to_keep;

onnxruntime/test/perftest/utils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ void UnregisterExecutionProviderLibrary(Ort::Env& env, PerformanceTestConfig& te
3333

3434
void ListEpDevices(const Ort::Env& env);
3535

36+
void AppendPluginExecutionProviders(Ort::Env& env,
37+
Ort::SessionOptions& session_options,
38+
const PerformanceTestConfig& test_config);
39+
3640
} // namespace utils
3741
} // namespace perftest
3842
} // namespace onnxruntime

0 commit comments

Comments
 (0)