@@ -107,10 +107,35 @@ OrtStatus* MulKernel::Compute(OrtKernelContext* kernel_ctx) {
107107 return nullptr ;
108108}
109109
110+ OrtStatus* EpContextKernel::Compute (OrtKernelContext* /* kernel_ctx*/ ) {
111+ // This example EP does not fully support EPContext inference.
112+ // A production EP would:
113+ // 1. Deserialize state from ep_cache_context attribute during Compile
114+ // 2. Use that state here to perform actual computation
115+ //
116+ // Session creation succeeds for metadata access and compatibility testing,
117+ // but inference requires deserializing ep_cache_context (not implemented).
118+ return ort_api.CreateStatus (
119+ ORT_NOT_IMPLEMENTED,
120+ " EPContext inference is not fully implemented in this example EP. "
121+ " Session creation succeeds for metadata access and compatibility testing, "
122+ " but inference requires deserializing ep_cache_context (not implemented). "
123+ " A production EP would restore compiled state from the EPContext node's attributes." );
124+ }
125+
126+ // / <summary>
127+ // / Intermediate base class with virtual destructor for proper polymorphic deletion.
128+ // / This allows ReleaseNodeComputeInfosImpl to delete any derived type correctly
129+ // / without manual type dispatch.
130+ // / </summary>
131+ struct NodeComputeInfoBase : OrtNodeComputeInfo {
132+ virtual ~NodeComputeInfoBase () = default ;
133+ };
134+
110135// / <summary>
111136// / Example OrtNodeComputeInfo that represents the computation function for a compiled OrtGraph.
112137// / </summary>
113- struct ExampleNodeComputeInfo : OrtNodeComputeInfo {
138+ struct ExampleNodeComputeInfo : NodeComputeInfoBase {
114139 explicit ExampleNodeComputeInfo (ExampleEp& ep);
115140
116141 static OrtStatus* ORT_API_CALL CreateStateImpl (OrtNodeComputeInfo* this_ptr,
@@ -123,6 +148,22 @@ struct ExampleNodeComputeInfo : OrtNodeComputeInfo {
123148 ExampleEp& ep;
124149};
125150
151+ // / <summary>
152+ // / OrtNodeComputeInfo for EPContext nodes - delegates to EpContextKernel.
153+ // / </summary>
154+ struct EpContextNodeComputeInfo : NodeComputeInfoBase {
155+ explicit EpContextNodeComputeInfo (ExampleEp& ep);
156+
157+ static OrtStatus* ORT_API_CALL CreateStateImpl (OrtNodeComputeInfo* this_ptr,
158+ OrtNodeComputeContext* compute_context,
159+ void ** compute_state);
160+ static OrtStatus* ORT_API_CALL ComputeImpl (OrtNodeComputeInfo* this_ptr, void * compute_state,
161+ OrtKernelContext* kernel_context);
162+ static void ORT_API_CALL ReleaseStateImpl (OrtNodeComputeInfo* this_ptr, void * compute_state);
163+
164+ ExampleEp& ep;
165+ };
166+
126167ExampleEp::ExampleEp (ExampleEpFactory& factory, const std::string& name, const Config& config, const OrtLogger& logger)
127168 : OrtEp{}, // explicitly call the struct ctor to ensure all optional values are default initialized
128169 ApiPtrs{static_cast <const ApiPtrs&>(factory)},
@@ -137,8 +178,9 @@ ExampleEp::ExampleEp(ExampleEpFactory& factory, const std::string& name, const C
137178 GetCapability = GetCapabilityImpl;
138179 Compile = CompileImpl;
139180 ReleaseNodeComputeInfos = ReleaseNodeComputeInfosImpl;
140- CreateAllocator = CreateAllocatorImpl; // optional. can be nullptr
141- CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; // optional. can be nullptr
181+ CreateAllocator = CreateAllocatorImpl; // optional. can be nullptr
182+ CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; // optional. can be nullptr
183+ GetCompiledModelCompatibilityInfo = GetCompiledModelCompatibilityInfoImpl; // compatibility info for compiled models
142184
143185 IGNORE_ORTSTATUS (ort_api.Logger_LogMessage (&logger_,
144186 OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO,
@@ -206,12 +248,32 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG
206248 return nullptr ; // No nodes to process
207249 }
208250
251+ // Single array for all supported node types.
252+ // This EP only supports compiling one node at a time (a documented limitation).
209253 std::vector<Ort::ConstNode> supported_nodes;
210254
211255 for (const auto & node : nodes) {
212256 auto op_type = node.GetOperatorType ();
213257 auto domain = node.GetDomain ();
214258
259+ // Check for EPContext nodes that belong to this EP (from compiled models).
260+ // This is needed to handle loading pre-compiled models with EPContext nodes.
261+ if (op_type == " EPContext" && domain == " com.microsoft" ) {
262+ // Check if this EPContext node belongs to this EP via the "source" attribute
263+ Ort::ConstOpAttr source_attr;
264+ Ort::Status status = node.GetAttributeByName (" source" , source_attr);
265+ if (status.IsOK ()) {
266+ std::string source_value;
267+ status = source_attr.GetValue (source_value);
268+ if (status.IsOK () && source_value == ep->name_ ) {
269+ // This EPContext node was created by this EP - add to supported nodes
270+ supported_nodes.push_back (node);
271+ break ; // Only support one node at a time
272+ }
273+ }
274+ continue ; // Don't process further, EPContext is a special case
275+ }
276+
215277 if (op_type == " Mul" ) {
216278 // Check that Mul has inputs/output of type float
217279 std::vector<Ort::ConstValueInfo> inputs = node.GetInputs ();
@@ -241,19 +303,29 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG
241303 }
242304 }
243305
244- supported_nodes.push_back (node); // Only support a single Mul for now.
245- break ;
306+ supported_nodes.push_back (node);
307+ break ; // Only support a single Mul for now.
246308 } else if (op_type == " Custom_Mul" && domain == " test" ) {
247309 supported_nodes.push_back (node);
310+ break ; // Only support one node at a time (consistent with Mul/EPContext handling).
248311 }
249312 }
250313
314+ // Return early if no supported nodes
251315 if (supported_nodes.empty ()) {
252316 return nullptr ;
253317 }
254318
255- if (supported_nodes[0 ].GetOperatorType () == " Mul" ) {
256- // Create (optional) fusion options for the supported nodes to fuse.
319+ // Unified dispatch based on node type
320+ const auto & node = supported_nodes[0 ];
321+ auto op_type = node.GetOperatorType ();
322+
323+ if (op_type == " Custom_Mul" ) {
324+ // Custom_Mul has concrete kernel implementation - no fusion needed.
325+ // Calls EpGraphSupportInfo_AddSingleNode() to inform ORT that the custom node should NOT be fused or compiled.
326+ RETURN_IF_ERROR (ep->ep_api .EpGraphSupportInfo_AddSingleNode (graph_support_info, node));
327+ } else {
328+ // Both EPContext and Mul use AddNodesToFuse
257329 OrtNodeFusionOptions node_fusion_options = {};
258330 node_fusion_options.ort_version_supported = ORT_API_VERSION;
259331
@@ -262,14 +334,11 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG
262334 // This example EP sets this to true and saves initializers during the call to OrtEp::Compile for use
263335 // during inference.
264336 node_fusion_options.drop_constant_initializers = true ;
265- RETURN_IF_ERROR (ep->ep_api .EpGraphSupportInfo_AddNodesToFuse (graph_support_info,
266- reinterpret_cast <const OrtNode* const *>(supported_nodes.data ()),
267- supported_nodes.size (),
268- &node_fusion_options));
269- } else if (supported_nodes[0 ].GetOperatorType () == " Custom_Mul" ) {
270- // Calls EpGraphSupportInfo_AddSingleNode() to inform ORT that the custom node should NOT be fused or compiled,
271- // as CustomMul has the concrete kernel implementation.
272- RETURN_IF_ERROR (ep->ep_api .EpGraphSupportInfo_AddSingleNode (graph_support_info, supported_nodes[0 ]));
337+ RETURN_IF_ERROR (ep->ep_api .EpGraphSupportInfo_AddNodesToFuse (
338+ graph_support_info,
339+ reinterpret_cast <const OrtNode* const *>(supported_nodes.data ()),
340+ supported_nodes.size (),
341+ &node_fusion_options));
273342 }
274343
275344 } catch (const Ort::Exception& ex) {
@@ -305,21 +374,32 @@ OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const
305374
306375 std::vector<Ort::ConstNode> nodes = graph.GetNodes ();
307376 if (nodes.size () != 1 ) {
308- Ort::Status status (" Expected to compile a single Mul node" , ORT_EP_FAIL);
377+ Ort::Status status (" Expected to compile a single node" , ORT_EP_FAIL);
309378 return status.release ();
310379 }
311380
312381 auto node_op_type = nodes[0 ].GetOperatorType ();
313- if (node_op_type != " Mul" ) {
314- Ort::Status status (" Expected to compile a single Mul node" , ORT_EP_FAIL);
382+ auto node_domain = nodes[0 ].GetDomain ();
383+
384+ // Check if this is an EPContext node (from loading a pre-compiled model)
385+ bool is_ep_context_node = (node_op_type == " EPContext" && node_domain == " com.microsoft" );
386+
387+ // Validate configuration: cannot enable EPContext generation when loading a compiled model.
388+ // This is a configuration error - you cannot re-compile an already compiled model.
389+ if (ep->config_ .enable_ep_context && is_ep_context_node) {
390+ Ort::Status status (
391+ " Invalid configuration: 'enable_ep_context' is true but model already contains "
392+ " EPContext nodes. Cannot re-compile an already compiled model. Either:\n "
393+ " 1. Use the original (uncompiled) model as input, or\n "
394+ " 2. Disable ep_context generation when loading a compiled model." ,
395+ ORT_INVALID_ARGUMENT);
315396 return status.release ();
316397 }
317398
318- // Now we know we're compiling a single Mul node. Create a computation kernel.
319- std::vector<Ort::ConstValueInfo> node_inputs = nodes[0 ].GetInputs ();
320- std::array<std::string, 2 > node_input_names;
321- node_input_names[0 ] = node_inputs[0 ].GetName ();
322- node_input_names[1 ] = node_inputs[1 ].GetName ();
399+ if (node_op_type != " Mul" && !is_ep_context_node) {
400+ Ort::Status status (" Expected to compile a Mul node or EPContext node" , ORT_EP_FAIL);
401+ return status.release ();
402+ }
323403
324404 Ort::ConstNode fused_node{fused_nodes[0 ]};
325405 auto ep_name = fused_node.GetEpName ();
@@ -328,22 +408,42 @@ OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const
328408 return status.release ();
329409 }
330410
331- // Associate the name of the fused node with our MulKernel.
332411 auto fused_node_name = fused_node.GetName ();
333- ep->kernels_ .emplace (std::move (fused_node_name), std::make_unique<MulKernel>(ep->ort_api , ep->logger_ ,
334- ep->float_initializers_ ,
335- node_input_names[0 ],
336- node_input_names[1 ]));
337-
338- // Update the OrtNodeComputeInfo associated with the graph.
339- auto node_compute_info = std::make_unique<ExampleNodeComputeInfo>(*ep);
340- node_compute_infos[0 ] = node_compute_info.release ();
341-
342- // Create EpContext nodes for the fused nodes we compiled.
343- if (ep->config_ .enable_ep_context ) {
344- assert (ep_context_nodes != nullptr );
345- RETURN_IF_ERROR (ep->CreateEpContextNodes (gsl::span<const OrtNode*>(fused_nodes, count),
346- gsl::span<OrtNode*>(ep_context_nodes, count)));
412+
413+ if (is_ep_context_node) {
414+ // Create EpContextKernel for EPContext nodes - clearly separates from MulKernel
415+ ep->ep_context_kernels_ .emplace (fused_node_name,
416+ std::make_unique<EpContextKernel>(ep->ort_api , ep->logger_ ));
417+
418+ // Use EpContextNodeComputeInfo for EPContext nodes
419+ auto node_compute_info = std::make_unique<EpContextNodeComputeInfo>(*ep);
420+ node_compute_infos[0 ] = node_compute_info.release ();
421+ } else {
422+ // For Mul nodes during initial compilation, we need exactly 2 inputs
423+ std::vector<Ort::ConstValueInfo> node_inputs = nodes[0 ].GetInputs ();
424+ if (node_inputs.size () != 2 ) {
425+ std::string err_msg = " Mul node should have 2 inputs, got " + std::to_string (node_inputs.size ());
426+ Ort::Status status (err_msg.c_str (), ORT_EP_FAIL);
427+ return status.release ();
428+ }
429+
430+ // Create MulKernel for Mul nodes
431+ ep->mul_kernels_ .emplace (fused_node_name,
432+ std::make_unique<MulKernel>(ep->ort_api , ep->logger_ ,
433+ ep->float_initializers_ ,
434+ node_inputs[0 ].GetName (),
435+ node_inputs[1 ].GetName ()));
436+
437+ // Use ExampleNodeComputeInfo for Mul nodes
438+ auto node_compute_info = std::make_unique<ExampleNodeComputeInfo>(*ep);
439+ node_compute_infos[0 ] = node_compute_info.release ();
440+
441+ // Create EpContext nodes for the fused nodes we compiled (only for Mul, not EPContext).
442+ if (ep->config_ .enable_ep_context ) {
443+ assert (ep_context_nodes != nullptr );
444+ RETURN_IF_ERROR (ep->CreateEpContextNodes (gsl::span<const OrtNode*>(fused_nodes, count),
445+ gsl::span<OrtNode*>(ep_context_nodes, count)));
446+ }
347447 }
348448 } catch (const Ort::Exception& ex) {
349449 Ort::Status status (ex);
@@ -362,7 +462,9 @@ void ORT_API_CALL ExampleEp::ReleaseNodeComputeInfosImpl(OrtEp* this_ptr,
362462 size_t num_node_compute_infos) noexcept {
363463 (void )this_ptr;
364464 for (size_t i = 0 ; i < num_node_compute_infos; i++) {
365- delete static_cast <ExampleNodeComputeInfo*>(node_compute_infos[i]);
465+ // All node compute info types derive from NodeComputeInfoBase which has a virtual destructor.
466+ // This ensures correct polymorphic deletion without manual type dispatch.
467+ delete static_cast <NodeComputeInfoBase*>(node_compute_infos[i]);
366468 }
367469}
368470
@@ -497,9 +599,9 @@ OrtStatus* ExampleNodeComputeInfo::CreateStateImpl(OrtNodeComputeInfo* this_ptr,
497599 ExampleEp& ep = node_compute_info->ep ;
498600
499601 std::string fused_node_name = ep.ep_api .NodeComputeContext_NodeName (compute_context);
500- auto kernel_it = ep.Kernels ().find (fused_node_name);
501- if (kernel_it == ep.Kernels ().end ()) {
502- std::string message = " Unable to get kernel for fused node with name " + fused_node_name;
602+ auto kernel_it = ep.MulKernels ().find (fused_node_name);
603+ if (kernel_it == ep.MulKernels ().end ()) {
604+ std::string message = " Unable to get MulKernel for fused node with name " + fused_node_name;
503605 return ep.ort_api .CreateStatus (ORT_EP_FAIL, message.c_str ());
504606 }
505607
@@ -521,3 +623,74 @@ void ExampleNodeComputeInfo::ReleaseStateImpl(OrtNodeComputeInfo* this_ptr, void
521623 (void )kernel;
522624 // Do nothing for this example.
523625}
626+
627+ //
628+ // Implementation of EpContextNodeComputeInfo
629+ //
630+ EpContextNodeComputeInfo::EpContextNodeComputeInfo (ExampleEp& ep) : ep(ep) {
631+ ort_version_supported = ORT_API_VERSION;
632+ CreateState = CreateStateImpl;
633+ Compute = ComputeImpl;
634+ ReleaseState = ReleaseStateImpl;
635+ }
636+
637+ OrtStatus* EpContextNodeComputeInfo::CreateStateImpl (OrtNodeComputeInfo* this_ptr,
638+ OrtNodeComputeContext* compute_context,
639+ void ** compute_state) {
640+ auto * node_compute_info = static_cast <EpContextNodeComputeInfo*>(this_ptr);
641+ ExampleEp& ep = node_compute_info->ep ;
642+
643+ std::string fused_node_name = ep.ep_api .NodeComputeContext_NodeName (compute_context);
644+ auto kernel_it = ep.EpContextKernels ().find (fused_node_name);
645+ if (kernel_it == ep.EpContextKernels ().end ()) {
646+ std::string message = " Unable to get EpContextKernel for fused node with name " + fused_node_name;
647+ return ep.ort_api .CreateStatus (ORT_EP_FAIL, message.c_str ());
648+ }
649+
650+ EpContextKernel& kernel = *kernel_it->second ;
651+ *compute_state = &kernel;
652+ return nullptr ;
653+ }
654+
655+ OrtStatus* EpContextNodeComputeInfo::ComputeImpl (OrtNodeComputeInfo* this_ptr, void * compute_state,
656+ OrtKernelContext* kernel_context) {
657+ (void )this_ptr;
658+ EpContextKernel& kernel = *reinterpret_cast <EpContextKernel*>(compute_state);
659+ return kernel.Compute (kernel_context);
660+ }
661+
662+ void EpContextNodeComputeInfo::ReleaseStateImpl (OrtNodeComputeInfo* this_ptr, void * compute_state) {
663+ (void )this_ptr;
664+ (void )compute_state;
665+ // Do nothing for this example.
666+ }
667+
668+ //
669+ // Implementation of GetCompiledModelCompatibilityInfo
670+ //
671+ /* static*/
672+ const char * ORT_API_CALL ExampleEp::GetCompiledModelCompatibilityInfoImpl (OrtEp* this_ptr,
673+ const OrtGraph* graph) noexcept {
674+ // Suppress unused parameter warning. The ORT_UNUSED_PARAMETER macro is in internal headers
675+ // (core/common/common.h) which are not available to plugin EPs using only public APIs.
676+ // A real EP would inspect the graph for model-specific compatibility info.
677+ (void )graph;
678+ auto * ep = static_cast <ExampleEp*>(this_ptr);
679+
680+ // Generate a compatibility string that includes:
681+ // - EP name
682+ // - EP version (from factory)
683+ // - ORT API version
684+ //
685+ // In a real EP, this might include driver versions, hardware IDs, etc.
686+ // The string format is EP-defined and should be parseable by ValidateCompiledModelCompatibilityInfo.
687+ ep->compatibility_info_ = ep->name_ + " ;version=" + ep->factory_ .GetEpVersionString () + " ;ort_api_version=" +
688+ std::to_string (ORT_API_VERSION);
689+
690+ IGNORE_ORTSTATUS (ep->ort_api .Logger_LogMessage (&ep->logger_ ,
691+ OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO,
692+ (" GetCompiledModelCompatibilityInfo returning: " + ep->compatibility_info_ ).c_str (),
693+ ORT_FILE, __LINE__, __FUNCTION__));
694+
695+ return ep->compatibility_info_ .c_str ();
696+ }
0 commit comments