Skip to content

Commit 52453f2

Browse files
Update example kernel EP to get the OrtEp from OrtKernelInfo and add a unit test
1 parent 12b8394 commit 52453f2

File tree

6 files changed

+59
-37
lines changed

6 files changed

+59
-37
lines changed

onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@
1414
#include "ep_factory.h"
1515
#include "../plugin_ep_utils.h"
1616

17-
ExampleKernelEp::ExampleKernelEp(ExampleKernelEpFactory& factory, const OrtLogger& logger)
17+
ExampleKernelEp::ExampleKernelEp(ExampleKernelEpFactory& factory, const Config& config, const OrtLogger& logger)
1818
: OrtEp{}, // explicitly call the struct ctor to ensure all optional values are default initialized
1919
factory_{factory},
2020
ort_api_{factory.GetOrtApi()},
2121
ep_api_{factory.GetEpApi()},
2222
name_{factory.GetEpName()},
23+
config_{config},
2324
logger_{logger} {
2425
ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with.
2526

onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,16 @@ class ExampleKernelEpFactory;
1414
/// </summary>
1515
class ExampleKernelEp : public OrtEp {
1616
public:
17-
ExampleKernelEp(ExampleKernelEpFactory& factory, const OrtLogger& logger);
17+
struct Config {
18+
bool enable_prepack_weight_sharing = false;
19+
};
20+
21+
ExampleKernelEp(ExampleKernelEpFactory& factory, const Config& config, const OrtLogger& logger);
1822
~ExampleKernelEp();
1923

2024
const OrtApi& GetOrtApi() const { return ort_api_; }
2125
const OrtEpApi& GetEpApi() const { return ep_api_; }
26+
const Config& GetConfig() const { return config_; }
2227

2328
private:
2429
static const char* ORT_API_CALL GetNameImpl(const OrtEp* this_ptr) noexcept;
@@ -34,5 +39,6 @@ class ExampleKernelEp : public OrtEp {
3439
const OrtApi& ort_api_;
3540
const OrtEpApi& ep_api_;
3641
std::string name_;
42+
Config config_;
3743
const OrtLogger& logger_;
3844
};

onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_factory.cc

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ OrtStatus* ORT_API_CALL ExampleKernelEpFactory::CreateEpImpl(OrtEpFactory* this_
176176
const OrtHardwareDevice* const* /*devices*/,
177177
const OrtKeyValuePairs* const* /*ep_metadata*/,
178178
size_t num_devices,
179-
const OrtSessionOptions* /*session_options*/,
179+
const OrtSessionOptions* session_options,
180180
const OrtLogger* logger,
181181
OrtEp** ep) noexcept {
182182
auto* factory = static_cast<ExampleKernelEpFactory*>(this_ptr);
@@ -187,7 +187,14 @@ OrtStatus* ORT_API_CALL ExampleKernelEpFactory::CreateEpImpl(OrtEpFactory* this_
187187
"ExampleKernelEpFactory only supports selection for one device.");
188188
}
189189

190-
auto actual_ep = std::make_unique<ExampleKernelEp>(*factory, *logger);
190+
std::string enable_prepack_weight_sharing;
191+
RETURN_IF_ERROR(GetSessionConfigEntryOrDefault(*session_options, "ep.examplekernelep.enable_prepack_weight_sharing",
192+
"0", enable_prepack_weight_sharing));
193+
194+
ExampleKernelEp::Config config = {};
195+
config.enable_prepack_weight_sharing = enable_prepack_weight_sharing == "1";
196+
197+
auto actual_ep = std::make_unique<ExampleKernelEp>(*factory, config, *logger);
191198
*ep = actual_ep.release();
192199

193200
return nullptr;

onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/mul.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <sstream>
66
#include "mul.h"
77
#include "utils.h"
8+
#include "../ep.h"
89

910
// Defines a kernel creation function for version 14 of Mul.
1011
ONNX_OPERATOR_KERNEL_EX(
@@ -34,7 +35,7 @@ OrtStatus* Mul::Create(const OrtKernelInfo* info, void* state,
3435
/*out*/ std::unique_ptr<Mul>& result) noexcept {
3536
EXCEPTION_TO_RETURNED_STATUS_BEGIN
3637
// Note: can do basic validation or preprocessing via the OrtKernelInfo APIs.
37-
result = std::make_unique<Mul>(info, state, PrivateTag{});
38+
result = std::make_unique<Mul>(Ort::ConstKernelInfo(info), state, PrivateTag{});
3839
return nullptr;
3940
EXCEPTION_TO_RETURNED_STATUS_END
4041
}
@@ -48,7 +49,6 @@ void ORT_API_CALL Mul::ReleaseImpl(OrtKernelImpl* this_ptr) noexcept {
4849
OrtStatus* ORT_API_CALL Mul::ComputeImpl(OrtKernelImpl* this_ptr, OrtKernelContext* kernel_ctx) noexcept {
4950
EXCEPTION_TO_RETURNED_STATUS_BEGIN
5051
Mul* mul_kernel = static_cast<Mul*>(this_ptr);
51-
static_cast<void>(mul_kernel->info_); // NOTE: Unused in this example.
5252

5353
Ort::KernelContext kernel_context(kernel_ctx);
5454

@@ -128,9 +128,11 @@ OrtStatus* ORT_API_CALL Mul::PrePackWeightImpl(OrtKernelImpl* this_ptr, const Or
128128

129129
RETURN_IF_ERROR(CopyTensor(*mul_kernel->data_transfer_impl_, original_weight, packed_weight.GetUnowned()));
130130

131-
const bool sharing_allowed = prepacked_weight_cache != nullptr;
131+
const ExampleKernelEp* ep = static_cast<const ExampleKernelEp*>(mul_kernel->info_.GetEp());
132+
const bool ep_sharing_enabled = ep->GetConfig().enable_prepack_weight_sharing;
133+
const bool ort_sharing_allowed = prepacked_weight_cache != nullptr;
132134

133-
if (sharing_allowed) {
135+
if (ort_sharing_allowed && ep_sharing_enabled) {
134136
std::array<void*, 1> buffer_data_ptrs = {weight_info.owned_data.get()};
135137
std::array<size_t, 1> buffer_data_sizes = {weight_info.num_bytes};
136138

onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/mul.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class Mul : public OrtKernelImpl {
4141
size_t num_buffers, int input_index) noexcept;
4242

4343
private:
44-
const OrtKernelInfo* info_;
44+
Ort::ConstKernelInfo info_;
4545
OrtDataTransferImpl* data_transfer_impl_; // Custom state passed from OrtEp
4646
std::optional<PackedWeightInfo> packed_weight_1_info_ = std::nullopt;
4747
};

onnxruntime/test/autoep/test_execution.cc

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -245,42 +245,48 @@ TEST(OrtEpLibrary, KernelPluginEp_Inference) {
245245
example_kernel_ep));
246246
Ort::ConstEpDevice plugin_ep_device(example_kernel_ep.get());
247247

248-
// Create session with example kernel-based plugin EP
249-
Ort::SessionOptions session_options;
250-
session_options.AddConfigEntry(kOrtSessionOptionsDisableCPUEPFallback, "1"); // Fail if any node assigned to CPU EP.
248+
auto run_model_with_ep_options = [&](const std::unordered_map<std::string, std::string>& ep_options) {
249+
// Create session with example kernel-based plugin EP
250+
Ort::SessionOptions session_options;
251+
session_options.AddConfigEntry(kOrtSessionOptionsDisableCPUEPFallback, "1"); // Fail if any node assigned to CPU EP.
252+
session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options);
251253

252-
std::unordered_map<std::string, std::string> ep_options;
253-
session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options);
254+
// This model has Squeeze, Mul, and Relu nodes. The example plugin EP supports all nodes using registered kernels.
255+
Ort::Session session(*ort_env, ORT_TSTR("testdata/squeeze_mul_relu.onnx"), session_options);
254256

255-
// This model has Squeeze, Mul, and Relu nodes. The example plugin EP supports all nodes using registered kernels.
256-
Ort::Session session(*ort_env, ORT_TSTR("testdata/squeeze_mul_relu.onnx"), session_options);
257+
// Create inputs
258+
Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
259+
std::array<int64_t, 3> a_shape = {3, 1, 2};
260+
std::array<int64_t, 2> b_shape = {3, 2};
257261

258-
// Create inputs
259-
Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
260-
std::array<int64_t, 3> a_shape = {3, 1, 2};
261-
std::array<int64_t, 2> b_shape = {3, 2};
262+
std::array<float, 6> a_data = {1.f, -2.f, 3.f, 4.f, -5.f, 6.f};
263+
std::array<float, 6> b_data = {2.f, 3.f, 4.f, -5.f, 6.f, 7.f};
262264

263-
std::array<float, 6> a_data = {1.f, -2.f, 3.f, 4.f, -5.f, 6.f};
264-
std::array<float, 6> b_data = {2.f, 3.f, 4.f, -5.f, 6.f, 7.f};
265+
std::vector<Ort::Value> ort_inputs{};
266+
ort_inputs.emplace_back(
267+
Ort::Value::CreateTensor<float>(memory_info, a_data.data(), a_data.size(), a_shape.data(), a_shape.size()));
268+
ort_inputs.emplace_back(
269+
Ort::Value::CreateTensor<float>(memory_info, b_data.data(), b_data.size(), b_shape.data(), b_shape.size()));
265270

266-
std::vector<Ort::Value> ort_inputs{};
267-
ort_inputs.emplace_back(
268-
Ort::Value::CreateTensor<float>(memory_info, a_data.data(), a_data.size(), a_shape.data(), a_shape.size()));
269-
ort_inputs.emplace_back(
270-
Ort::Value::CreateTensor<float>(memory_info, b_data.data(), b_data.size(), b_shape.data(), b_shape.size()));
271+
std::array ort_input_names{"A", "B"};
271272

272-
std::array ort_input_names{"A", "B"};
273+
// Run session and get outputs
274+
std::array output_names{"C"};
275+
std::vector<Ort::Value> ort_outputs = session.Run(Ort::RunOptions{nullptr}, ort_input_names.data(), ort_inputs.data(),
276+
ort_inputs.size(), output_names.data(), output_names.size());
273277

274-
// Run session and get outputs
275-
std::array output_names{"C"};
276-
std::vector<Ort::Value> ort_outputs = session.Run(Ort::RunOptions{nullptr}, ort_input_names.data(), ort_inputs.data(),
277-
ort_inputs.size(), output_names.data(), output_names.size());
278+
// Check expected output values
279+
Ort::Value& ort_output = ort_outputs[0];
280+
const float* output_data = ort_output.GetTensorData<float>();
281+
gsl::span<const float> output_span(output_data, 6);
282+
EXPECT_THAT(output_span, ::testing::ElementsAre(4, 0, 24, 0, 0, 84));
283+
};
278284

279-
// Check expected output values
280-
Ort::Value& ort_output = ort_outputs[0];
281-
const float* output_data = ort_output.GetTensorData<float>();
282-
gsl::span<const float> output_span(output_data, 6);
283-
EXPECT_THAT(output_span, ::testing::ElementsAre(4, 0, 24, 0, 0, 84));
285+
run_model_with_ep_options({});
286+
287+
// Enable sharing of pre-packed weights.
288+
// This also tests the ability for the kernel implementation to retrieve the OrtEp and get its configuration.
289+
run_model_with_ep_options({{"enable_prepack_weight_sharing", "1"}});
284290
}
285291
} // namespace test
286292
} // namespace onnxruntime

0 commit comments

Comments
 (0)