Skip to content

Commit 06fe9a4

Browse files
adrastogiAditya Rastogigithub-actions[bot]
authored
Implement compiled model compatibility APIs in example plugin EP and add tests (#27088)
### Description <!-- Describe your changes. --> Add tests for the compiled model compatibility APIs (`GetCompiledModelCompatibilityInfo` and `ValidateCompiledModelCompatibilityInfo`) using the example plugin EP. Changes: - Implement the compatibility methods in example_plugin_ep to enable testing - Update `GetCapabilityImpl` and `CompileImpl` to handle EPContext nodes (required for loading compiled models in tests) - Add `PluginEp_CompatibilityInfo_WrittenToMetadata` test - verifies compatibility info is written to compiled model metadata - Add `PluginEp_CompatibilityInfo_ValidatedOnLoad` test - verifies the round-trip: compile → load → validate compatibility ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> The compiled model compatibility APIs need test coverage to ensure they work correctly end-to-end. This exercises the full flow of writing compatibility metadata during compilation and validating it when loading a pre-compiled model. Additionally, by adding this support into the plugin EP, it enables us to add scenario tests that would exercise this in the context of Windows ML. --------- Co-authored-by: Aditya Rastogi <adityar@ntdev.microsoft.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 040ae1f commit 06fe9a4

File tree

5 files changed

+560
-45
lines changed

5 files changed

+560
-45
lines changed

onnxruntime/test/autoep/library/example_plugin_ep/ep.cc

Lines changed: 215 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,35 @@ OrtStatus* MulKernel::Compute(OrtKernelContext* kernel_ctx) {
107107
return nullptr;
108108
}
109109

110+
OrtStatus* EpContextKernel::Compute(OrtKernelContext* /*kernel_ctx*/) {
111+
// This example EP does not fully support EPContext inference.
112+
// A production EP would:
113+
// 1. Deserialize state from ep_cache_context attribute during Compile
114+
// 2. Use that state here to perform actual computation
115+
//
116+
// Session creation succeeds for metadata access and compatibility testing,
117+
// but inference requires deserializing ep_cache_context (not implemented).
118+
return ort_api.CreateStatus(
119+
ORT_NOT_IMPLEMENTED,
120+
"EPContext inference is not fully implemented in this example EP. "
121+
"Session creation succeeds for metadata access and compatibility testing, "
122+
"but inference requires deserializing ep_cache_context (not implemented). "
123+
"A production EP would restore compiled state from the EPContext node's attributes.");
124+
}
125+
126+
/// <summary>
127+
/// Intermediate base class with virtual destructor for proper polymorphic deletion.
128+
/// This allows ReleaseNodeComputeInfosImpl to delete any derived type correctly
129+
/// without manual type dispatch.
130+
/// </summary>
131+
struct NodeComputeInfoBase : OrtNodeComputeInfo {
132+
virtual ~NodeComputeInfoBase() = default;
133+
};
134+
110135
/// <summary>
111136
/// Example OrtNodeComputeInfo that represents the computation function for a compiled OrtGraph.
112137
/// </summary>
113-
struct ExampleNodeComputeInfo : OrtNodeComputeInfo {
138+
struct ExampleNodeComputeInfo : NodeComputeInfoBase {
114139
explicit ExampleNodeComputeInfo(ExampleEp& ep);
115140

116141
static OrtStatus* ORT_API_CALL CreateStateImpl(OrtNodeComputeInfo* this_ptr,
@@ -123,6 +148,22 @@ struct ExampleNodeComputeInfo : OrtNodeComputeInfo {
123148
ExampleEp& ep;
124149
};
125150

151+
/// <summary>
152+
/// OrtNodeComputeInfo for EPContext nodes - delegates to EpContextKernel.
153+
/// </summary>
154+
struct EpContextNodeComputeInfo : NodeComputeInfoBase {
155+
explicit EpContextNodeComputeInfo(ExampleEp& ep);
156+
157+
static OrtStatus* ORT_API_CALL CreateStateImpl(OrtNodeComputeInfo* this_ptr,
158+
OrtNodeComputeContext* compute_context,
159+
void** compute_state);
160+
static OrtStatus* ORT_API_CALL ComputeImpl(OrtNodeComputeInfo* this_ptr, void* compute_state,
161+
OrtKernelContext* kernel_context);
162+
static void ORT_API_CALL ReleaseStateImpl(OrtNodeComputeInfo* this_ptr, void* compute_state);
163+
164+
ExampleEp& ep;
165+
};
166+
126167
ExampleEp::ExampleEp(ExampleEpFactory& factory, const std::string& name, const Config& config, const OrtLogger& logger)
127168
: OrtEp{}, // explicitly call the struct ctor to ensure all optional values are default initialized
128169
ApiPtrs{static_cast<const ApiPtrs&>(factory)},
@@ -137,8 +178,9 @@ ExampleEp::ExampleEp(ExampleEpFactory& factory, const std::string& name, const C
137178
GetCapability = GetCapabilityImpl;
138179
Compile = CompileImpl;
139180
ReleaseNodeComputeInfos = ReleaseNodeComputeInfosImpl;
140-
CreateAllocator = CreateAllocatorImpl; // optional. can be nullptr
141-
CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; // optional. can be nullptr
181+
CreateAllocator = CreateAllocatorImpl; // optional. can be nullptr
182+
CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; // optional. can be nullptr
183+
GetCompiledModelCompatibilityInfo = GetCompiledModelCompatibilityInfoImpl; // compatibility info for compiled models
142184

143185
IGNORE_ORTSTATUS(ort_api.Logger_LogMessage(&logger_,
144186
OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO,
@@ -206,12 +248,32 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG
206248
return nullptr; // No nodes to process
207249
}
208250

251+
// Single array for all supported node types.
252+
// This EP only supports compiling one node at a time (a documented limitation).
209253
std::vector<Ort::ConstNode> supported_nodes;
210254

211255
for (const auto& node : nodes) {
212256
auto op_type = node.GetOperatorType();
213257
auto domain = node.GetDomain();
214258

259+
// Check for EPContext nodes that belong to this EP (from compiled models).
260+
// This is needed to handle loading pre-compiled models with EPContext nodes.
261+
if (op_type == "EPContext" && domain == "com.microsoft") {
262+
// Check if this EPContext node belongs to this EP via the "source" attribute
263+
Ort::ConstOpAttr source_attr;
264+
Ort::Status status = node.GetAttributeByName("source", source_attr);
265+
if (status.IsOK()) {
266+
std::string source_value;
267+
status = source_attr.GetValue(source_value);
268+
if (status.IsOK() && source_value == ep->name_) {
269+
// This EPContext node was created by this EP - add to supported nodes
270+
supported_nodes.push_back(node);
271+
break; // Only support one node at a time
272+
}
273+
}
274+
continue; // Don't process further, EPContext is a special case
275+
}
276+
215277
if (op_type == "Mul") {
216278
// Check that Mul has inputs/output of type float
217279
std::vector<Ort::ConstValueInfo> inputs = node.GetInputs();
@@ -241,19 +303,29 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG
241303
}
242304
}
243305

244-
supported_nodes.push_back(node); // Only support a single Mul for now.
245-
break;
306+
supported_nodes.push_back(node);
307+
break; // Only support a single Mul for now.
246308
} else if (op_type == "Custom_Mul" && domain == "test") {
247309
supported_nodes.push_back(node);
310+
break; // Only support one node at a time (consistent with Mul/EPContext handling).
248311
}
249312
}
250313

314+
// Return early if no supported nodes
251315
if (supported_nodes.empty()) {
252316
return nullptr;
253317
}
254318

255-
if (supported_nodes[0].GetOperatorType() == "Mul") {
256-
// Create (optional) fusion options for the supported nodes to fuse.
319+
// Unified dispatch based on node type
320+
const auto& node = supported_nodes[0];
321+
auto op_type = node.GetOperatorType();
322+
323+
if (op_type == "Custom_Mul") {
324+
// Custom_Mul has concrete kernel implementation - no fusion needed.
325+
// Calls EpGraphSupportInfo_AddSingleNode() to inform ORT that the custom node should NOT be fused or compiled.
326+
RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddSingleNode(graph_support_info, node));
327+
} else {
328+
// Both EPContext and Mul use AddNodesToFuse
257329
OrtNodeFusionOptions node_fusion_options = {};
258330
node_fusion_options.ort_version_supported = ORT_API_VERSION;
259331

@@ -262,14 +334,11 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG
262334
// This example EP sets this to true and saves initializers during the call to OrtEp::Compile for use
263335
// during inference.
264336
node_fusion_options.drop_constant_initializers = true;
265-
RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddNodesToFuse(graph_support_info,
266-
reinterpret_cast<const OrtNode* const*>(supported_nodes.data()),
267-
supported_nodes.size(),
268-
&node_fusion_options));
269-
} else if (supported_nodes[0].GetOperatorType() == "Custom_Mul") {
270-
// Calls EpGraphSupportInfo_AddSingleNode() to inform ORT that the custom node should NOT be fused or compiled,
271-
// as CustomMul has the concrete kernel implementation.
272-
RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddSingleNode(graph_support_info, supported_nodes[0]));
337+
RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddNodesToFuse(
338+
graph_support_info,
339+
reinterpret_cast<const OrtNode* const*>(supported_nodes.data()),
340+
supported_nodes.size(),
341+
&node_fusion_options));
273342
}
274343

275344
} catch (const Ort::Exception& ex) {
@@ -305,21 +374,32 @@ OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const
305374

306375
std::vector<Ort::ConstNode> nodes = graph.GetNodes();
307376
if (nodes.size() != 1) {
308-
Ort::Status status("Expected to compile a single Mul node", ORT_EP_FAIL);
377+
Ort::Status status("Expected to compile a single node", ORT_EP_FAIL);
309378
return status.release();
310379
}
311380

312381
auto node_op_type = nodes[0].GetOperatorType();
313-
if (node_op_type != "Mul") {
314-
Ort::Status status("Expected to compile a single Mul node", ORT_EP_FAIL);
382+
auto node_domain = nodes[0].GetDomain();
383+
384+
// Check if this is an EPContext node (from loading a pre-compiled model)
385+
bool is_ep_context_node = (node_op_type == "EPContext" && node_domain == "com.microsoft");
386+
387+
// Validate configuration: cannot enable EPContext generation when loading a compiled model.
388+
// This is a configuration error - you cannot re-compile an already compiled model.
389+
if (ep->config_.enable_ep_context && is_ep_context_node) {
390+
Ort::Status status(
391+
"Invalid configuration: 'enable_ep_context' is true but model already contains "
392+
"EPContext nodes. Cannot re-compile an already compiled model. Either:\n"
393+
" 1. Use the original (uncompiled) model as input, or\n"
394+
" 2. Disable ep_context generation when loading a compiled model.",
395+
ORT_INVALID_ARGUMENT);
315396
return status.release();
316397
}
317398

318-
// Now we know we're compiling a single Mul node. Create a computation kernel.
319-
std::vector<Ort::ConstValueInfo> node_inputs = nodes[0].GetInputs();
320-
std::array<std::string, 2> node_input_names;
321-
node_input_names[0] = node_inputs[0].GetName();
322-
node_input_names[1] = node_inputs[1].GetName();
399+
if (node_op_type != "Mul" && !is_ep_context_node) {
400+
Ort::Status status("Expected to compile a Mul node or EPContext node", ORT_EP_FAIL);
401+
return status.release();
402+
}
323403

324404
Ort::ConstNode fused_node{fused_nodes[0]};
325405
auto ep_name = fused_node.GetEpName();
@@ -328,22 +408,42 @@ OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const
328408
return status.release();
329409
}
330410

331-
// Associate the name of the fused node with our MulKernel.
332411
auto fused_node_name = fused_node.GetName();
333-
ep->kernels_.emplace(std::move(fused_node_name), std::make_unique<MulKernel>(ep->ort_api, ep->logger_,
334-
ep->float_initializers_,
335-
node_input_names[0],
336-
node_input_names[1]));
337-
338-
// Update the OrtNodeComputeInfo associated with the graph.
339-
auto node_compute_info = std::make_unique<ExampleNodeComputeInfo>(*ep);
340-
node_compute_infos[0] = node_compute_info.release();
341-
342-
// Create EpContext nodes for the fused nodes we compiled.
343-
if (ep->config_.enable_ep_context) {
344-
assert(ep_context_nodes != nullptr);
345-
RETURN_IF_ERROR(ep->CreateEpContextNodes(gsl::span<const OrtNode*>(fused_nodes, count),
346-
gsl::span<OrtNode*>(ep_context_nodes, count)));
412+
413+
if (is_ep_context_node) {
414+
// Create EpContextKernel for EPContext nodes - clearly separates from MulKernel
415+
ep->ep_context_kernels_.emplace(fused_node_name,
416+
std::make_unique<EpContextKernel>(ep->ort_api, ep->logger_));
417+
418+
// Use EpContextNodeComputeInfo for EPContext nodes
419+
auto node_compute_info = std::make_unique<EpContextNodeComputeInfo>(*ep);
420+
node_compute_infos[0] = node_compute_info.release();
421+
} else {
422+
// For Mul nodes during initial compilation, we need exactly 2 inputs
423+
std::vector<Ort::ConstValueInfo> node_inputs = nodes[0].GetInputs();
424+
if (node_inputs.size() != 2) {
425+
std::string err_msg = "Mul node should have 2 inputs, got " + std::to_string(node_inputs.size());
426+
Ort::Status status(err_msg.c_str(), ORT_EP_FAIL);
427+
return status.release();
428+
}
429+
430+
// Create MulKernel for Mul nodes
431+
ep->mul_kernels_.emplace(fused_node_name,
432+
std::make_unique<MulKernel>(ep->ort_api, ep->logger_,
433+
ep->float_initializers_,
434+
node_inputs[0].GetName(),
435+
node_inputs[1].GetName()));
436+
437+
// Use ExampleNodeComputeInfo for Mul nodes
438+
auto node_compute_info = std::make_unique<ExampleNodeComputeInfo>(*ep);
439+
node_compute_infos[0] = node_compute_info.release();
440+
441+
// Create EpContext nodes for the fused nodes we compiled (only for Mul, not EPContext).
442+
if (ep->config_.enable_ep_context) {
443+
assert(ep_context_nodes != nullptr);
444+
RETURN_IF_ERROR(ep->CreateEpContextNodes(gsl::span<const OrtNode*>(fused_nodes, count),
445+
gsl::span<OrtNode*>(ep_context_nodes, count)));
446+
}
347447
}
348448
} catch (const Ort::Exception& ex) {
349449
Ort::Status status(ex);
@@ -362,7 +462,9 @@ void ORT_API_CALL ExampleEp::ReleaseNodeComputeInfosImpl(OrtEp* this_ptr,
362462
size_t num_node_compute_infos) noexcept {
363463
(void)this_ptr;
364464
for (size_t i = 0; i < num_node_compute_infos; i++) {
365-
delete static_cast<ExampleNodeComputeInfo*>(node_compute_infos[i]);
465+
// All node compute info types derive from NodeComputeInfoBase which has a virtual destructor.
466+
// This ensures correct polymorphic deletion without manual type dispatch.
467+
delete static_cast<NodeComputeInfoBase*>(node_compute_infos[i]);
366468
}
367469
}
368470

@@ -497,9 +599,9 @@ OrtStatus* ExampleNodeComputeInfo::CreateStateImpl(OrtNodeComputeInfo* this_ptr,
497599
ExampleEp& ep = node_compute_info->ep;
498600

499601
std::string fused_node_name = ep.ep_api.NodeComputeContext_NodeName(compute_context);
500-
auto kernel_it = ep.Kernels().find(fused_node_name);
501-
if (kernel_it == ep.Kernels().end()) {
502-
std::string message = "Unable to get kernel for fused node with name " + fused_node_name;
602+
auto kernel_it = ep.MulKernels().find(fused_node_name);
603+
if (kernel_it == ep.MulKernels().end()) {
604+
std::string message = "Unable to get MulKernel for fused node with name " + fused_node_name;
503605
return ep.ort_api.CreateStatus(ORT_EP_FAIL, message.c_str());
504606
}
505607

@@ -521,3 +623,74 @@ void ExampleNodeComputeInfo::ReleaseStateImpl(OrtNodeComputeInfo* this_ptr, void
521623
(void)kernel;
522624
// Do nothing for this example.
523625
}
626+
627+
//
628+
// Implementation of EpContextNodeComputeInfo
629+
//
630+
EpContextNodeComputeInfo::EpContextNodeComputeInfo(ExampleEp& ep) : ep(ep) {
631+
ort_version_supported = ORT_API_VERSION;
632+
CreateState = CreateStateImpl;
633+
Compute = ComputeImpl;
634+
ReleaseState = ReleaseStateImpl;
635+
}
636+
637+
OrtStatus* EpContextNodeComputeInfo::CreateStateImpl(OrtNodeComputeInfo* this_ptr,
638+
OrtNodeComputeContext* compute_context,
639+
void** compute_state) {
640+
auto* node_compute_info = static_cast<EpContextNodeComputeInfo*>(this_ptr);
641+
ExampleEp& ep = node_compute_info->ep;
642+
643+
std::string fused_node_name = ep.ep_api.NodeComputeContext_NodeName(compute_context);
644+
auto kernel_it = ep.EpContextKernels().find(fused_node_name);
645+
if (kernel_it == ep.EpContextKernels().end()) {
646+
std::string message = "Unable to get EpContextKernel for fused node with name " + fused_node_name;
647+
return ep.ort_api.CreateStatus(ORT_EP_FAIL, message.c_str());
648+
}
649+
650+
EpContextKernel& kernel = *kernel_it->second;
651+
*compute_state = &kernel;
652+
return nullptr;
653+
}
654+
655+
OrtStatus* EpContextNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* compute_state,
656+
OrtKernelContext* kernel_context) {
657+
(void)this_ptr;
658+
EpContextKernel& kernel = *reinterpret_cast<EpContextKernel*>(compute_state);
659+
return kernel.Compute(kernel_context);
660+
}
661+
662+
void EpContextNodeComputeInfo::ReleaseStateImpl(OrtNodeComputeInfo* this_ptr, void* compute_state) {
663+
(void)this_ptr;
664+
(void)compute_state;
665+
// Do nothing for this example.
666+
}
667+
668+
//
669+
// Implementation of GetCompiledModelCompatibilityInfo
670+
//
671+
/*static*/
672+
const char* ORT_API_CALL ExampleEp::GetCompiledModelCompatibilityInfoImpl(OrtEp* this_ptr,
673+
const OrtGraph* graph) noexcept {
674+
// Suppress unused parameter warning. The ORT_UNUSED_PARAMETER macro is in internal headers
675+
// (core/common/common.h) which are not available to plugin EPs using only public APIs.
676+
// A real EP would inspect the graph for model-specific compatibility info.
677+
(void)graph;
678+
auto* ep = static_cast<ExampleEp*>(this_ptr);
679+
680+
// Generate a compatibility string that includes:
681+
// - EP name
682+
// - EP version (from factory)
683+
// - ORT API version
684+
//
685+
// In a real EP, this might include driver versions, hardware IDs, etc.
686+
// The string format is EP-defined and should be parseable by ValidateCompiledModelCompatibilityInfo.
687+
ep->compatibility_info_ = ep->name_ + ";version=" + ep->factory_.GetEpVersionString() + ";ort_api_version=" +
688+
std::to_string(ORT_API_VERSION);
689+
690+
IGNORE_ORTSTATUS(ep->ort_api.Logger_LogMessage(&ep->logger_,
691+
OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO,
692+
("GetCompiledModelCompatibilityInfo returning: " + ep->compatibility_info_).c_str(),
693+
ORT_FILE, __LINE__, __FUNCTION__));
694+
695+
return ep->compatibility_info_.c_str();
696+
}

0 commit comments

Comments
 (0)