Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 127 additions & 29 deletions onnxruntime/test/autoep/library/example_plugin_ep/ep.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,9 @@ ExampleEp::ExampleEp(ExampleEpFactory& factory, const std::string& name, const C
GetCapability = GetCapabilityImpl;
Compile = CompileImpl;
ReleaseNodeComputeInfos = ReleaseNodeComputeInfosImpl;
CreateAllocator = CreateAllocatorImpl; // optional. can be nullptr
CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; // optional. can be nullptr
CreateAllocator = CreateAllocatorImpl; // optional. can be nullptr
CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; // optional. can be nullptr
GetCompiledModelCompatibilityInfo = GetCompiledModelCompatibilityInfoImpl; // compatibility info for compiled models

IGNORE_ORTSTATUS(ort_api.Logger_LogMessage(&logger_,
OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO,
Expand Down Expand Up @@ -207,11 +208,29 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG
}

std::vector<Ort::ConstNode> supported_nodes;
std::vector<Ort::ConstNode> ep_context_nodes;

for (const auto& node : nodes) {
auto op_type = node.GetOperatorType();
auto domain = node.GetDomain();

// Check for EPContext nodes that belong to this EP (from compiled models).
// This is needed to handle loading pre-compiled models with EPContext nodes.
if (op_type == "EPContext" && domain == "com.microsoft") {
// Check if this EPContext node belongs to this EP via the "source" attribute
Ort::ConstOpAttr source_attr;
Ort::Status status = node.GetAttributeByName("source", source_attr);
if (status.IsOK()) {
std::string source_value;
status = source_attr.GetValue(source_value);
if (status.IsOK() && source_value == ep->name_) {
// This EPContext node was created by this EP - collect it for fusion
ep_context_nodes.push_back(node);
}
}
continue; // Don't process further, EPContext is a special case
}

if (op_type == "Mul") {
// Check that Mul has inputs/output of type float
std::vector<Ort::ConstValueInfo> inputs = node.GetInputs();
Expand Down Expand Up @@ -248,28 +267,45 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG
}
}

if (supported_nodes.empty()) {
return nullptr;
}

if (supported_nodes[0].GetOperatorType() == "Mul") {
// Create (optional) fusion options for the supported nodes to fuse.
// Handle EPContext nodes first - these are from loading compiled models
// Each EPContext node is fused individually so it gets its own compiled node
for (const auto& ep_ctx_node : ep_context_nodes) {
std::vector<Ort::ConstNode> single_node = {ep_ctx_node};
OrtNodeFusionOptions node_fusion_options = {};
node_fusion_options.ort_version_supported = ORT_API_VERSION;

// Set "drop constant initializers" to true if the compiling EP doesn't need ORT to provide constant initializers
// as inputs to the fused/compiled node at inference time. This allows ORT to release unused initializers.
// This example EP sets this to true and saves initializers during the call to OrtEp::Compile for use
// during inference.
node_fusion_options.drop_constant_initializers = true;
RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddNodesToFuse(graph_support_info,
reinterpret_cast<const OrtNode* const*>(supported_nodes.data()),
supported_nodes.size(),
reinterpret_cast<const OrtNode* const*>(single_node.data()),
single_node.size(),
&node_fusion_options));
} else if (supported_nodes[0].GetOperatorType() == "Custom_Mul") {
// Calls EpGraphSupportInfo_AddSingleNode() to inform ORT that the custom node should NOT be fused or compiled,
// as CustomMul has the concrete kernel implementation.
RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddSingleNode(graph_support_info, supported_nodes[0]));
}

// Return early if no supported nodes (but not if we have EPContext nodes)
if (supported_nodes.empty() && ep_context_nodes.empty()) {
return nullptr;
}

// Handle regular nodes
if (!supported_nodes.empty()) {
if (supported_nodes[0].GetOperatorType() == "Mul") {
// Create (optional) fusion options for the supported nodes to fuse.
OrtNodeFusionOptions node_fusion_options = {};
node_fusion_options.ort_version_supported = ORT_API_VERSION;

// Set "drop constant initializers" to true if the compiling EP doesn't need ORT to provide constant initializers
// as inputs to the fused/compiled node at inference time. This allows ORT to release unused initializers.
// This example EP sets this to true and saves initializers during the call to OrtEp::Compile for use
// during inference.
node_fusion_options.drop_constant_initializers = true;
RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddNodesToFuse(graph_support_info,
reinterpret_cast<const OrtNode* const*>(supported_nodes.data()),
supported_nodes.size(),
&node_fusion_options));
} else if (supported_nodes[0].GetOperatorType() == "Custom_Mul") {
// Calls EpGraphSupportInfo_AddSingleNode() to inform ORT that the custom node should NOT be fused or compiled,
// as CustomMul has the concrete kernel implementation.
RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddSingleNode(graph_support_info, supported_nodes[0]));
}
}

} catch (const Ort::Exception& ex) {
Expand Down Expand Up @@ -305,29 +341,60 @@ OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const

std::vector<Ort::ConstNode> nodes = graph.GetNodes();
if (nodes.size() != 1) {
Ort::Status status("Expected to compile a single Mul node", ORT_EP_FAIL);
Ort::Status status("Expected to compile a single node", ORT_EP_FAIL);
return status.release();
}

auto node_op_type = nodes[0].GetOperatorType();
if (node_op_type != "Mul") {
Ort::Status status("Expected to compile a single Mul node", ORT_EP_FAIL);
auto node_domain = nodes[0].GetDomain();

// Check if this is an EPContext node (from loading a pre-compiled model)
bool is_ep_context_node = (node_op_type == "EPContext" && node_domain == "com.microsoft");

if (node_op_type != "Mul" && !is_ep_context_node) {
Ort::Status status("Expected to compile a Mul node or EPContext node", ORT_EP_FAIL);
return status.release();
}

// Now we know we're compiling a single Mul node. Create a computation kernel.
std::vector<Ort::ConstValueInfo> node_inputs = nodes[0].GetInputs();
std::array<std::string, 2> node_input_names;
node_input_names[0] = node_inputs[0].GetName();
node_input_names[1] = node_inputs[1].GetName();

Ort::ConstNode fused_node{fused_nodes[0]};
auto ep_name = fused_node.GetEpName();
if (ep_name != ep->name_) {
Ort::Status status("The fused node is expected to assigned to this EP to run on", ORT_EP_FAIL);
return status.release();
}

// Get input names for the kernel
// For both EPContext and Mul nodes, we use the inner node's inputs from the graph
// Note: EPContext nodes from compiled models may have fewer inputs if constant initializers were dropped
std::array<std::string, 2> node_input_names = {"", ""};
std::vector<Ort::ConstValueInfo> node_inputs = nodes[0].GetInputs();

if (is_ep_context_node) {
// This example EP does *not* fully support executing EPContext nodes.
//
// When a model is compiled with this EP, constant initializers may be dropped from the EPContext
// node's inputs. A production EP would serialize initializer data and compiled state into the
// `ep_cache_context` attribute and deserialize it here. This example EP does not do that.
//
// As a result:
// - Session creation with a compiled model will succeed (for metadata access, compatibility testing)
// - Inference may fail at runtime if MulKernel::Compute cannot find expected inputs/initializers
//
// To fully support EPContext execution, deserialize `ep_cache_context` and restore initializers.
for (size_t i = 0; i < node_inputs.size() && i < 2; ++i) {
node_input_names[i] = node_inputs[i].GetName();
}
} else {
// For Mul nodes during initial compilation, we need exactly 2 inputs
if (node_inputs.size() != 2) {
std::string err_msg = "Mul node should have 2 inputs, got " + std::to_string(node_inputs.size());
Ort::Status status(err_msg.c_str(), ORT_EP_FAIL);
return status.release();
}
node_input_names[0] = node_inputs[0].GetName();
node_input_names[1] = node_inputs[1].GetName();
}

// Associate the name of the fused node with our MulKernel.
auto fused_node_name = fused_node.GetName();
ep->kernels_.emplace(std::move(fused_node_name), std::make_unique<MulKernel>(ep->ort_api, ep->logger_,
Expand All @@ -340,7 +407,8 @@ OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const
node_compute_infos[0] = node_compute_info.release();

// Create EpContext nodes for the fused nodes we compiled.
if (ep->config_.enable_ep_context) {
// Don't create new EPContext nodes if we're already processing an EPContext node!
if (ep->config_.enable_ep_context && !is_ep_context_node) {
assert(ep_context_nodes != nullptr);
RETURN_IF_ERROR(ep->CreateEpContextNodes(gsl::span<const OrtNode*>(fused_nodes, count),
gsl::span<OrtNode*>(ep_context_nodes, count)));
Expand Down Expand Up @@ -521,3 +589,33 @@ void ExampleNodeComputeInfo::ReleaseStateImpl(OrtNodeComputeInfo* this_ptr, void
(void)kernel;
// Do nothing for this example.
}

//
// Implementation of GetCompiledModelCompatibilityInfo
//
/*static*/
const char* ORT_API_CALL ExampleEp::GetCompiledModelCompatibilityInfoImpl(OrtEp* this_ptr,
const OrtGraph* graph) noexcept {
// Suppress unused parameter warning. The ORT_UNUSED_PARAMETER macro is in internal headers
// (core/common/common.h) which are not available to plugin EPs using only public APIs.
// A real EP would inspect the graph for model-specific compatibility info.
(void)graph;
auto* ep = static_cast<ExampleEp*>(this_ptr);

// Generate a compatibility string that includes:
// - EP name
// - EP version (from factory)
// - ORT API version
//
// In a real EP, this might include driver versions, hardware IDs, etc.
// The string format is EP-defined and should be parseable by ValidateCompiledModelCompatibilityInfo.
ep->compatibility_info_ = ep->name_ + ";version=" + ep->factory_.GetEpVersionString() + ";ort_api_version=" +
std::to_string(ORT_API_VERSION);

IGNORE_ORTSTATUS(ep->ort_api.Logger_LogMessage(&ep->logger_,
OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO,
("GetCompiledModelCompatibilityInfo returning: " + ep->compatibility_info_).c_str(),
ORT_FILE, __LINE__, __FUNCTION__));

return ep->compatibility_info_.c_str();
}
4 changes: 4 additions & 0 deletions onnxruntime/test/autoep/library/example_plugin_ep/ep.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ class ExampleEp : public OrtEp, public ApiPtrs {
OrtNodeComputeInfo** node_compute_infos,
size_t num_node_compute_infos) noexcept;

static const char* ORT_API_CALL GetCompiledModelCompatibilityInfoImpl(OrtEp* this_ptr,
const OrtGraph* graph) noexcept;

OrtStatus* CreateEpContextNodes(gsl::span<const OrtNode*> fused_nodes,
/*out*/ gsl::span<OrtNode*> ep_context_nodes);

Expand All @@ -89,4 +92,5 @@ class ExampleEp : public OrtEp, public ApiPtrs {
const OrtLogger& logger_;
std::unordered_map<std::string, std::unique_ptr<MulKernel>> kernels_;
std::unordered_map<std::string, FloatInitializer> float_initializers_;
std::string compatibility_info_; // Cached compatibility string returned by GetCompiledModelCompatibilityInfo
};
73 changes: 73 additions & 0 deletions onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ ExampleEpFactory::ExampleEpFactory(const char* ep_name, ApiPtrs apis, const OrtL

GetNumCustomOpDomains = GetNumCustomOpDomainsImpl;
GetCustomOpDomains = GetCustomOpDomainsImpl;
ValidateCompiledModelCompatibilityInfo = ValidateCompiledModelCompatibilityInfoImpl;

// setup the OrtMemoryInfo instances required by the EP.
// We pretend the device the EP is running on is GPU.
Expand Down Expand Up @@ -417,3 +418,75 @@ OrtStatus* ORT_API_CALL ExampleEpFactory::CreateExternalResourceImporterForDevic

return nullptr;
}

OrtStatus* ORT_API_CALL ExampleEpFactory::ValidateCompiledModelCompatibilityInfoImpl(
OrtEpFactory* this_ptr,
const OrtHardwareDevice* const* /*devices*/,
size_t /*num_devices*/,
const char* compatibility_info,
OrtCompiledModelCompatibility* model_compatibility) noexcept {
auto& factory = *static_cast<ExampleEpFactory*>(this_ptr);

if (model_compatibility == nullptr) {
return factory.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, "model_compatibility cannot be nullptr");
}

// Parse the compatibility info to check if it matches our current configuration.
// The expected format is "ExampleEP;version=0.1.0;ort_api_version=24".
// For this example implementation, we simply check if the string starts with our EP name.

if (compatibility_info == nullptr || compatibility_info[0] == '\0') {
*model_compatibility = OrtCompiledModelCompatibility_EP_UNSUPPORTED;
return nullptr;
}

std::string info(compatibility_info);
std::string expected_prefix = factory.ep_name_ + ";";

if (info.find(expected_prefix) != 0) {
// The compatibility info doesn't match our EP
*model_compatibility = OrtCompiledModelCompatibility_EP_UNSUPPORTED;
return nullptr;
}

// Parse version parts: "ExampleEP;version=X;ort_api_version=Y"
// Look for "version=" and extract the value
size_t version_pos = info.find("version=");
size_t ort_version_pos = info.find("ort_api_version=");

if (version_pos == std::string::npos) {
// Invalid format
*model_compatibility = OrtCompiledModelCompatibility_EP_UNSUPPORTED;
return nullptr;
}

// Extract EP version (between "version=" and the next ";")
size_t version_start = version_pos + 8; // length of "version="
size_t version_end = info.find(';', version_start);
std::string ep_version = (version_end != std::string::npos)
? info.substr(version_start, version_end - version_start)
: info.substr(version_start);

// Check if the EP version matches our version
if (ep_version != factory.ep_version_) {
// Different EP version - might work but prefer recompilation
*model_compatibility = OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION;
return nullptr;
}

// Check ORT API version if present
if (ort_version_pos != std::string::npos) {
size_t ort_version_start = ort_version_pos + 16; // length of "ort_api_version="
std::string ort_version = info.substr(ort_version_start);
std::string current_ort_version = std::to_string(ORT_API_VERSION);
if (ort_version != current_ort_version) {
// Different ORT version - might still work but prefer recompilation
*model_compatibility = OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION;
return nullptr;
}
}

// Everything matches - the compiled model is fully compatible
*model_compatibility = OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL;
return nullptr;
}
17 changes: 17 additions & 0 deletions onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,16 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs {
return arena_allocator_.get();
}

// Get the EP version string.
const std::string& GetEpVersionString() const {
return ep_version_;
}

// Get the vendor ID.
uint32_t GetVendorIdValue() const {
return vendor_id_;
}

const OrtLogger& default_logger_; // default logger for the EP factory

private:
Expand Down Expand Up @@ -89,6 +99,13 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs {
_Outptr_result_maybenull_ OrtCustomOpDomain** domains,
_Out_ size_t num_domains) noexcept;

static OrtStatus* ORT_API_CALL ValidateCompiledModelCompatibilityInfoImpl(
OrtEpFactory* this_ptr,
const OrtHardwareDevice* const* devices,
size_t num_devices,
const char* compatibility_info,
OrtCompiledModelCompatibility* model_compatibility) noexcept;

const std::string ep_name_; // EP name
const std::string vendor_{"Contoso"}; // EP vendor name
const uint32_t vendor_id_{0xB357}; // EP vendor ID
Expand Down
Loading
Loading