Skip to content

Commit 8e951ef

Browse files
Update weight sharing tool to support plugin EPs (#26614)
### Description - Updates the `ep_weight_sharing_ctx_gen` tool to support specifying a plugin EP configuration (via JSON). - Mark the `ep_weight_sharing_ctx_gen` tool as deprecated and add notification to README that recommends the use the public Python ORT APIs instead. - Note we no longer publish a binary for this tool [as of ORT 1.22.2](#24895). - Added an example Python script in the README. - Added a Python unit test that tests compiling models with weight sharing using an example plugin EP. #### Tool usage Create a JSON file that contains information about the plugin EP to load/use (e.g., `example_plugin_ep_config.json`): ```json { "ep_library_registration_name": "example_plugin_ep", "ep_library_path": "example_plugin_ep.dll", "selected_ep_name": "example_plugin_ep", "default_ep_options": { "option_key": "option_value" } } ``` Call the `ep_weight_sharing_ctx_gen` tool with the `-p` command-line option to specify the location of the above configuration file: ```console $ ep_weight_sharing_ctx_gen.exe -p example_plugin_ep_config.json model_1.onnx,model_2.onnx ``` ### Motivation and Context Close the functionality gap between traditional provider-bridge EPs and plugin EPs. This PR allows using plugin EPs with the tool that compiles models with weight sharing.
1 parent e8bcd0d commit 8e951ef

File tree

7 files changed

+319
-3
lines changed

7 files changed

+319
-3
lines changed

onnxruntime/test/ep_weight_sharing_ctx_gen/README.md

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# ONNXRuntime EP Context Model Generation with Weight Sharing
22

3+
> [!NOTE]
4+
> This tool is deprecated. Please use the public ONNX Runtime Python APIs to compile models with resource sharing. Refer to the example Python script at the end of this document.
5+
36
[EP context with weight sharing design doc](https://onnxruntime.ai/docs/execution-providers/EP-Context-Design.html#epcontext-with-weight-sharing)
47

58
OnnxRuntime provides the ep_weight_sharing_ctx_gen tool to automate the weight-sharing workflow. This tool handles the entire process. This tool is specifically designed for weight sharing scenarios, streamlining the EPContext model generation process.
@@ -13,6 +16,23 @@ Example: ./ep_weight_sharing_ctx_gen -e qnn -i "soc_model|60 htp_graph_finalizat
1316
1417
Options:
1518
-e [qnn|tensorrt|openvino|vitisai]: Specifies the compile based provider 'qnn', 'tensorrt', 'openvino', 'vitisai'. Default: 'qnn'.
19+
-p [plugin_ep_config_json_file]: Specify JSON configuration file for a plugin EP. Takes precedence over the '-e' and '-i' options.
20+
21+
Example JSON configuration that selects plugin EP devices via name:
22+
{
23+
"ep_library_registration_name": "example_plugin_ep",
24+
"ep_library_path": "example_plugin_ep.dll",
25+
"selected_ep_name": "example_plugin_ep",
26+
"default_ep_options": { "key": "value" }
27+
}
28+
29+
Example JSON configuration that selects plugin EP devices via index:
30+
{
31+
"ep_library_registration_name": "example_plugin_ep",
32+
"ep_library_path": "example_plugin_ep.dll",
33+
"selected_ep_device_indices": [ 0 ],
34+
"default_ep_options": { "key": "value" }
35+
}
1636
-v: Show verbose information.
1737
-C: Specify session configuration entries as key-value pairs: -C "<key1>|<value1> <key2>|<value2>"
1838
Refer to onnxruntime_session_options_config_keys.h for valid keys and values.
@@ -36,3 +56,49 @@ Options:
3656
3757
-h: help
3858
```
59+
60+
# Example: Use Python APIs to compile models with resource sharing
61+
Use of the public ORT Python APIs is now recommended for compiling models with resource (e.g., "weight") sharing.
62+
The following snippet shows an example that compiles two models using an example plugin EP.
63+
64+
```Python
65+
import onnxruntime
66+
import os
67+
68+
def main():
69+
ep_name = "example_ep"
70+
ep_lib_path = "example_plugin_ep.dll"
71+
72+
onnxruntime.register_execution_provider_library(ep_name, os.path.realpath(ep_lib_path))
73+
74+
# Find one or more EP devices that correspond to the EP of interest.
75+
# In this example, we pick the first one.
76+
ep_device = next((d for d in onnxruntime.get_ep_devices() if d.ep_name == ep_name), None)
77+
78+
# These are the names/paths to the input and output models.
79+
input_models = ["model_0.onnx", "model_1.onnx"]
80+
output_models = ["model_0_ctx.onnx", "model_1_ctx.onnx"]
81+
82+
num_models = len(input_models)
83+
session_options = onnxruntime.SessionOptions()
84+
provider_options = {} # Empty for this example
85+
86+
# Set option that tells EP to share resources (e.g., weights) across sessions.
87+
session_options.add_session_config_entry("ep.share_ep_contexts", "1")
88+
session_options.add_provider_for_devices([ep_device], provider_options)
89+
90+
# Compile individual models
91+
for i in range(len(input_models)):
92+
if i == num_models - 1:
93+
# Tell EP that this is the last compiling session that will be sharing resources.
94+
session_options.add_session_config_entry("ep.stop_share_ep_contexts", "1")
95+
96+
model_compiler = onnxruntime.ModelCompiler(
97+
session_options,
98+
input_models[i],
99+
embed_compiled_data_into_model=False,
100+
)
101+
model_compiler.compile_to_file(output_models[i])
102+
103+
onnxruntime.unregister_execution_provider_library(ep_name)
104+
```

onnxruntime/test/ep_weight_sharing_ctx_gen/command_args_parser.cc

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "command_args_parser.h"
55

66
#include <string.h>
7+
#include <fstream>
78
#include <iostream>
89
#include <sstream>
910
#include <string_view>
@@ -21,6 +22,7 @@
2122
#include <core/platform/path_lib.h>
2223
#include <core/optimizer/graph_transformer_level.h>
2324

25+
#include "nlohmann/json.hpp"
2426
#include "test_configuration.h"
2527

2628
namespace onnxruntime {
@@ -35,6 +37,23 @@ namespace qnnctxgen {
3537
"\n"
3638
"Options:\n"
3739
"\t-e [qnn|tensorrt|openvino|vitisai]: Specifies the compile based provider 'qnn', 'tensorrt', 'openvino', 'vitisai'. Default: 'qnn'.\n"
40+
"\t-p [plugin_ep_config_json_file]: Specify JSON configuration file for a plugin EP. Takes precedence over the '-e' and '-i' options.\n"
41+
"\n"
42+
"\t Example JSON configuration that selects plugin EP devices via EP name:\n"
43+
"\t {\n"
44+
"\t \"ep_library_registration_name\": \"example_plugin_ep\",\n"
45+
"\t \"ep_library_path\": \"example_plugin_ep.dll\",\n"
46+
"\t \"selected_ep_name\": \"example_plugin_ep\",\n"
47+
"\t \"default_ep_options\": { \"key\": \"value\" }\n"
48+
"\t }\n"
49+
"\n"
50+
"\t Example JSON configuration that selects plugin EP devices via index:\n"
51+
"\t {\n"
52+
"\t \"ep_library_registration_name\": \"example_plugin_ep\",\n"
53+
"\t \"ep_library_path\": \"example_plugin_ep.dll\",\n"
54+
"\t \"selected_ep_device_indices\": [ 0 ],\n"
55+
"\t \"default_ep_options\": { \"key\": \"value\" }\n"
56+
"\t }\n"
3857
"\t-v: Show verbose information.\n"
3958
"\t-C: Specify session configuration entries as key-value pairs: -C \"<key1>|<value1> <key2>|<value2>\" \n"
4059
"\t Refer to onnxruntime_session_options_config_keys.h for valid keys and values. \n"
@@ -58,6 +77,7 @@ namespace qnnctxgen {
5877
"\n"
5978
"\t-h: help\n");
6079
}
80+
6181
#ifdef _WIN32
6282
static const ORTCHAR_T* delimiter = L",";
6383
#else
@@ -110,9 +130,63 @@ static bool ParseSessionConfigs(const std::string& configs_string,
110130
return true;
111131
}
112132

133+
static bool ParsePluginEpConfig(const std::string& json_file_path, PluginEpConfig& config_out) {
134+
using json = nlohmann::json;
135+
bool success = true;
136+
137+
ORT_TRY {
138+
std::ifstream ifs{json_file_path};
139+
if (!ifs) {
140+
std::cerr << "ERROR: Failed to open plugin EP configuration file at path: "
141+
<< json_file_path.c_str() << std::endl;
142+
return false;
143+
}
144+
145+
std::string content(std::istreambuf_iterator<char>{ifs},
146+
std::istreambuf_iterator<char>{});
147+
PluginEpConfig config{};
148+
const auto parsed_json = json::parse(content);
149+
150+
// required keys
151+
parsed_json.at("ep_library_registration_name").get_to(config.ep_library_registration_name);
152+
parsed_json.at("ep_library_path").get_to(config.ep_library_path);
153+
154+
// optional keys
155+
config.default_ep_options = parsed_json.value<decltype(config.default_ep_options)>("default_ep_options", {});
156+
config.selected_ep_name = parsed_json.value<decltype(config.selected_ep_name)>("selected_ep_name", {});
157+
config.selected_ep_device_indices =
158+
parsed_json.value<decltype(config.selected_ep_device_indices)>("selected_ep_device_indices", {});
159+
160+
if (config.selected_ep_name.empty() == config.selected_ep_device_indices.empty()) {
161+
std::cerr << "ERROR: Plugin EP configuration must specify exactly one of 'selected_ep_name' "
162+
<< "or 'selected_ep_device_indices'" << std::endl;
163+
return false;
164+
}
165+
166+
config_out = std::move(config);
167+
return success;
168+
}
169+
ORT_CATCH(const json::exception& e) {
170+
ORT_HANDLE_EXCEPTION([&]() {
171+
std::string kExampleValidJsonStr =
172+
"{\n"
173+
" \"ep_library_registration_name\": \"example_plugin_ep\",\n"
174+
" \"ep_library_path\": \"/path/to/example_plugin_ep.dll\",\n"
175+
" \"selected_ep_name\": \"example_plugin_ep\"\n"
176+
"}";
177+
178+
success = false;
179+
std::cerr << "ERROR: JSON parse error: " << e.what() << std::endl;
180+
std::cerr << "This is an example valid JSON configuration:\n"
181+
<< kExampleValidJsonStr.c_str() << std::endl;
182+
});
183+
}
184+
return success;
185+
}
186+
113187
/*static*/ bool CommandLineParser::ParseArguments(TestConfig& test_config, int argc, ORTCHAR_T* argv[]) {
114188
int ch;
115-
while ((ch = getopt(argc, argv, ORT_TSTR("e:o:u:i:C:vh"))) != -1) {
189+
while ((ch = getopt(argc, argv, ORT_TSTR("e:p:o:u:i:C:vh"))) != -1) {
116190
switch (ch) {
117191
case 'e':
118192
if (!CompareCString(optarg, ORT_TSTR("qnn"))) {
@@ -128,6 +202,20 @@ static bool ParseSessionConfigs(const std::string& configs_string,
128202
return false;
129203
}
130204
break;
205+
case 'p': {
206+
#ifdef _MSC_VER
207+
std::string plugin_ep_config_file_path = ToUTF8String(optarg);
208+
#else
209+
std::string plugin_ep_config_file_path = optarg;
210+
#endif
211+
PluginEpConfig plugin_ep_config{};
212+
if (!ParsePluginEpConfig(plugin_ep_config_file_path, plugin_ep_config)) {
213+
return false;
214+
}
215+
216+
test_config.machine_config.plugin_ep_config = std::move(plugin_ep_config);
217+
break;
218+
}
131219
case 'v':
132220
test_config.run_config.f_verbose = true;
133221
break;
@@ -202,6 +290,11 @@ static bool ParseSessionConfigs(const std::string& configs_string,
202290
argc -= optind;
203291
argv += optind;
204292

293+
if (argc == 0) {
294+
std::cerr << "ERROR: Did not specify model paths" << std::endl;
295+
return false;
296+
}
297+
205298
ParsePaths(argv[0], test_config.model_file_paths);
206299

207300
return true;
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"ep_library_registration_name": "example_plugin_ep",
3+
"ep_library_path": "example_plugin_ep.dll",
4+
"selected_ep_name": "example_plugin_ep",
5+
"default_ep_options": { "option_key": "option_value" }
6+
}

onnxruntime/test/ep_weight_sharing_ctx_gen/main.cc

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
// onnx dependencies
1212
#include "onnx/onnx_pb.h"
13+
#include <algorithm>
1314
#include <fstream>
1415

1516
using namespace onnxruntime;
@@ -81,6 +82,72 @@ static void UpdateEpContextModel(const std::vector<std::basic_string<ORTCHAR_T>>
8182
}
8283
}
8384

85+
using PluginEpLibraryRegistrationHandle = std::unique_ptr<void, std::function<void(void*)>>;
86+
87+
static PluginEpLibraryRegistrationHandle RegisterPluginEpLibrary(Ort::Env& env,
88+
const std::string& ep_library_registration_name,
89+
const std::basic_string<ORTCHAR_T>& ep_library_path) {
90+
env.RegisterExecutionProviderLibrary(ep_library_registration_name.c_str(), ep_library_path);
91+
92+
auto unregister_ep_library = [&env, registration_name = ep_library_registration_name](void* p) {
93+
if (p == nullptr) {
94+
return;
95+
}
96+
97+
ORT_TRY {
98+
env.UnregisterExecutionProviderLibrary(registration_name.c_str());
99+
}
100+
ORT_CATCH(const Ort::Exception& e) {
101+
ORT_HANDLE_EXCEPTION([&]() {
102+
std::cerr << "Failed to unregister EP library with name '" << registration_name << "': "
103+
<< e.what() << std::endl;
104+
});
105+
}
106+
};
107+
108+
// Set `handle_value` to something not equal to nullptr. The particular value doesn't really matter.
109+
// We are just using the unique_ptr deleter to unregister the EP library.
110+
void* const handle_value = reinterpret_cast<void*>(0x1);
111+
return PluginEpLibraryRegistrationHandle{handle_value, unregister_ep_library};
112+
}
113+
114+
static bool SetPluginEpSessionOptions(Ort::Env& env, Ort::SessionOptions& session_options,
115+
const qnnctxgen::PluginEpConfig& config,
116+
PluginEpLibraryRegistrationHandle& plugin_ep_library_registration_handle) {
117+
auto lib_registration_handle = RegisterPluginEpLibrary(env, config.ep_library_registration_name,
118+
ToPathString(config.ep_library_path));
119+
120+
std::vector<Ort::ConstEpDevice> ep_devices = env.GetEpDevices();
121+
std::vector<Ort::ConstEpDevice> selected_ep_devices{};
122+
123+
if (!config.selected_ep_device_indices.empty()) {
124+
for (const auto idx : config.selected_ep_device_indices) {
125+
if (idx >= ep_devices.size()) {
126+
std::cerr << "ERROR: Selected EP device index is out of range (max is " << ep_devices.size() - 1 << "): "
127+
<< idx << std::endl;
128+
return false;
129+
}
130+
131+
selected_ep_devices.push_back(ep_devices[idx]);
132+
}
133+
} else {
134+
std::copy_if(ep_devices.begin(), ep_devices.end(), std::back_inserter(selected_ep_devices),
135+
[&selected_ep_name = std::as_const(config.selected_ep_name)](Ort::ConstEpDevice ep_device) {
136+
return ep_device.EpName() == selected_ep_name;
137+
});
138+
}
139+
140+
if (selected_ep_devices.empty()) {
141+
std::cerr << "ERROR: No EP devices were selected" << std::endl;
142+
return false;
143+
}
144+
145+
session_options.AppendExecutionProvider_V2(env, selected_ep_devices, config.default_ep_options);
146+
plugin_ep_library_registration_handle = std::move(lib_registration_handle);
147+
148+
return true;
149+
}
150+
84151
#ifdef _WIN32
85152
int real_main(int argc, wchar_t* argv[]) {
86153
#else
@@ -98,6 +165,7 @@ int real_main(int argc, char* argv[]) {
98165
Ort::Env env(logging_level, "ep_weight_sharing");
99166

100167
ORT_TRY {
168+
PluginEpLibraryRegistrationHandle plugin_ep_library_registration_handle{};
101169
Ort::SessionOptions so;
102170
so.SetLogId("ep_weight_sharing_ctx_gen_session_logger");
103171
// Set default session option to dump EPContext model with non-embed mode
@@ -136,7 +204,14 @@ int real_main(int argc, char* argv[]) {
136204
// The context binary file generated later includes all graphs from previous models
137205
{
138206
std::string provider_name_ = test_config.machine_config.provider_type_name;
139-
if (provider_name_ == onnxruntime::kQnnExecutionProvider) {
207+
208+
if (const auto& plugin_ep_config = test_config.machine_config.plugin_ep_config; plugin_ep_config.has_value()) {
209+
if (!SetPluginEpSessionOptions(env, so, *plugin_ep_config, plugin_ep_library_registration_handle)) {
210+
std::cerr << "ERROR: Failed to initialize session for plugin EP "
211+
<< test_config.machine_config.plugin_ep_config->ep_library_path << std::endl;
212+
return 1;
213+
}
214+
} else if (provider_name_ == onnxruntime::kQnnExecutionProvider) {
140215
#ifdef USE_QNN
141216
so.AppendExecutionProvider("QNN", provider_options);
142217
#else

onnxruntime/test/ep_weight_sharing_ctx_gen/test_configuration.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
#include <map>
77
#include <cstdint>
8+
#include <optional>
89
#include <string>
910
#include <unordered_map>
1011

@@ -14,8 +15,25 @@
1415
namespace onnxruntime {
1516
namespace qnnctxgen {
1617

18+
// Configuration for initializing the dynamic plugin EP infrastructure.
19+
struct PluginEpConfig {
20+
std::string ep_library_registration_name{};
21+
std::string ep_library_path{};
22+
23+
// Note: Exactly one of `selected_ep_name` or `selected_ep_device_indices` should be set.
24+
// An empty value for either means it is unset.
25+
26+
// Specifies the EP devices matching this EP name as the selected EP devices.
27+
std::string selected_ep_name{};
28+
// Specifies the selected EP devices by their indices.
29+
std::vector<size_t> selected_ep_device_indices{};
30+
31+
std::unordered_map<std::string, std::string> default_ep_options{};
32+
};
33+
1734
struct MachineConfig {
1835
std::string provider_type_name{onnxruntime::kQnnExecutionProvider};
36+
std::optional<PluginEpConfig> plugin_ep_config = std::nullopt;
1937
};
2038

2139
struct RunConfig {

onnxruntime/test/python/helper.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import sys
23

34

45
def get_name(name):
@@ -13,3 +14,14 @@ def get_name(name):
1314
if os.path.exists(res):
1415
return res
1516
raise FileNotFoundError(f"Unable to find '{name}' or '{rel}' or '{res}'")
17+
18+
19+
def get_shared_library_filename_for_platform(base_name):
20+
if sys.platform.startswith("win"):
21+
return base_name + ".dll"
22+
23+
if sys.platform.startswith("darwin"):
24+
return "lib" + base_name + ".dylib"
25+
26+
# Else, assume linux
27+
return "lib" + base_name + ".so"

0 commit comments

Comments
 (0)