Skip to content

Commit 9884cf1

Browse files
[EP ABI] GetCpuPreferredNodes for kernel-based plugin EPs
1 parent 96fe212 commit 9884cf1

File tree

10 files changed

+418
-21
lines changed

10 files changed

+418
-21
lines changed

cmake/onnxruntime_unittests.cmake

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2100,6 +2100,8 @@ if (onnxruntime_BUILD_SHARED_LIB AND
21002100
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_lib_entry.cc"
21012101
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_factory.h"
21022102
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_factory.cc"
2103+
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/get_capability_utils.h"
2104+
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/get_capability_utils.cc"
21032105
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep.h"
21042106
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep.cc"
21052107
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_allocator.h"

onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep.cc

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@
99
#include <memory>
1010
#include <optional>
1111
#include <string>
12+
#include <unordered_set>
1213
#include <vector>
1314

1415
#include "ep_factory.h"
16+
#include "get_capability_utils.h"
1517
#include "../plugin_ep_utils.h"
1618

1719
ExampleKernelEp::ExampleKernelEp(ExampleKernelEpFactory& factory, const OrtLogger& logger)
@@ -60,12 +62,19 @@ OrtStatus* ORT_API_CALL ExampleKernelEp::GetCapabilityImpl(OrtEp* this_ptr, cons
6062
}
6163

6264
// Collect candidate nodes that this EP may support.
63-
std::vector<Ort::ConstNode> candidate_nodes;
65+
std::vector<const OrtNode*> candidate_nodes;
6466

6567
for (const auto& node : all_nodes) {
6668
std::string op_type = node.GetOperatorType();
6769

68-
if (op_type == "Relu" || op_type == "Squeeze") {
70+
const OrtKernelDef* kernel_def = nullptr;
71+
RETURN_IF_ERROR(ep->ep_api_.EpGraphSupportInfo_LookUpKernel(graph_support_info, node, &kernel_def));
72+
73+
if (kernel_def == nullptr) {
74+
continue; // Does not have a registered kernel for this node.
75+
}
76+
77+
if (op_type == "Relu" || op_type == "Squeeze" || op_type == "Shape" || op_type == "Reshape") {
6978
candidate_nodes.push_back(node);
7079
} else if (op_type == "Mul") {
7180
std::vector<Ort::ConstValueInfo> inputs = node.GetInputs();
@@ -86,12 +95,13 @@ OrtStatus* ORT_API_CALL ExampleKernelEp::GetCapabilityImpl(OrtEp* this_ptr, cons
8695
}
8796
}
8897

89-
// Mark candidate nodes as supported if we have a registered kernel.
90-
for (const auto& node : candidate_nodes) {
91-
const OrtKernelDef* kernel_def = nullptr;
92-
RETURN_IF_ERROR(ep->ep_api_.EpGraphSupportInfo_LookUpKernel(graph_support_info, node, &kernel_def));
98+
// Get subset of candidate nodes that would be better to offload to CPU.
99+
std::unordered_set<const OrtNode*> cpu_nodes;
100+
RETURN_IF_ERROR(GetCpuPreferredNodes(*ort_graph, *graph_support_info, ep->logger_, candidate_nodes, cpu_nodes));
93101

94-
if (kernel_def != nullptr) {
102+
// Mark candidate nodes as supported.
103+
for (const auto& node : candidate_nodes) {
104+
if (cpu_nodes.count(node) == 0) {
95105
RETURN_IF_ERROR(ep->ep_api_.EpGraphSupportInfo_AddSingleNode(graph_support_info, node));
96106
}
97107
}
Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
#include "get_capability_utils.h"
4+
5+
#include <queue>
6+
#include <unordered_map>
7+
#include <unordered_set>
8+
#include <vector>
9+
10+
using NodeId = size_t;
11+
constexpr int64_t kSmallInitializerThreshold = 100;
12+
13+
constexpr static inline bool MemTypeOnCpuExplicitly(OrtMemType mem_type) {
14+
return mem_type == OrtMemTypeCPUInput || mem_type == OrtMemTypeCPUOutput;
15+
}
16+
17+
// Get all output nodes that consume an output from the given node.
18+
static OrtStatus* GetOutputNodes(gsl::span<Ort::ConstValueInfo const> node_outputs, std::vector<Ort::ConstNode>& result) {
19+
EXCEPTION_TO_RETURNED_STATUS_BEGIN
20+
std::vector<Ort::ConstNode> output_nodes;
21+
output_nodes.reserve(node_outputs.size()); // May have more
22+
23+
// Gather the OrtNode consumers of every output.
24+
for (Ort::ConstValueInfo output : node_outputs) {
25+
if (output == nullptr) continue; // Skip missing optional output
26+
27+
auto consumers_info = output.GetConsumers();
28+
for (const auto& consumer : consumers_info) {
29+
output_nodes.push_back(consumer.node);
30+
}
31+
}
32+
33+
result = std::move(output_nodes);
34+
return nullptr;
35+
EXCEPTION_TO_RETURNED_STATUS_END
36+
}
37+
38+
// Returns nodes that should be assigned to CPU EP instead of this example EP to avoid costly I/O copies.
39+
// Based on GetCpuPreferredNodes from onnxruntime/core/framework/fallback_cpu_capability.cc
40+
OrtStatus* GetCpuPreferredNodes(const OrtGraph& ort_graph, OrtEpGraphSupportInfo& graph_support_info,
41+
const OrtLogger& logger, gsl::span<const OrtNode* const> tentative_nodes,
42+
/*out*/ std::unordered_set<const OrtNode*>& cpu_preferred_nodes) {
43+
EXCEPTION_TO_RETURNED_STATUS_BEGIN
44+
const OrtApi& ort_api = Ort::GetApi();
45+
const OrtEpApi& ep_api = Ort::GetEpApi();
46+
Ort::ConstGraph graph{&ort_graph};
47+
std::vector<Ort::ConstNode> ordered_nodes = graph.GetNodes();
48+
49+
if (ordered_nodes.empty()) {
50+
return nullptr;
51+
}
52+
53+
std::unordered_map<NodeId, Ort::ConstNode> node_id_to_node;
54+
std::unordered_map<NodeId, size_t> node_id_to_order_map;
55+
for (size_t i = 0, num_nodes = ordered_nodes.size(); i < num_nodes; i++) {
56+
NodeId node_id = ordered_nodes[i].GetId();
57+
node_id_to_node[node_id] = ordered_nodes[i];
58+
node_id_to_order_map[node_id] = i;
59+
}
60+
61+
// If return false, n1 will be output first; If return true, n2 will be output first
62+
auto greater_order_comp = [&](const NodeId node_id1, const NodeId node_id2) {
63+
return node_id_to_order_map[node_id1] > node_id_to_order_map[node_id2];
64+
};
65+
std::priority_queue<NodeId, std::vector<NodeId>, decltype(greater_order_comp)> candidates(greater_order_comp);
66+
std::unordered_set<const OrtValueInfo*> cpu_output_args;
67+
68+
std::unordered_set<NodeId> provider_nodes;
69+
provider_nodes.reserve(tentative_nodes.size());
70+
71+
std::unordered_map<NodeId, Ort::ConstKernelDef> node_to_kernel;
72+
node_to_kernel.reserve(tentative_nodes.size());
73+
74+
for (const OrtNode* ort_node : tentative_nodes) {
75+
Ort::ConstNode node(ort_node);
76+
NodeId node_id = node.GetId();
77+
78+
provider_nodes.insert(node_id);
79+
80+
// Expect at least one registry has a target provider's kernel for this node.
81+
const OrtKernelDef* ort_kernel_def = nullptr;
82+
RETURN_IF_ERROR(ep_api.EpGraphSupportInfo_LookUpKernel(&graph_support_info, node, &ort_kernel_def));
83+
RETURN_IF(ort_kernel_def == nullptr, ort_api, "Must have a registered kernel definition on the target EP");
84+
85+
Ort::ConstKernelDef kernel_def(ort_kernel_def);
86+
node_to_kernel.insert({node_id, kernel_def});
87+
88+
// Find all the direct consumers of CPU tensors.
89+
std::vector<Ort::ConstValueInfo> outputs = node.GetOutputs();
90+
for (size_t out_index = 0; out_index < outputs.size(); out_index++) {
91+
Ort::ConstValueInfo output = outputs[out_index];
92+
if (output == nullptr) continue; // Skip missing optional output
93+
94+
bool is_output_on_cpu = MemTypeOnCpuExplicitly(kernel_def.GetOutputMemType(out_index));
95+
if (is_output_on_cpu) {
96+
cpu_output_args.insert(output);
97+
98+
auto consumer_infos = output.GetConsumers();
99+
for (const auto& consumer_info : consumer_infos) {
100+
candidates.push(consumer_info.node.GetId());
101+
ORT_CXX_LOGF(Ort::Logger(&logger), ORT_LOGGING_LEVEL_INFO, "Candidate for fallback CPU execution: %s\n",
102+
consumer_info.node.GetName().c_str());
103+
}
104+
}
105+
}
106+
}
107+
108+
std::unordered_set<NodeId> visited;
109+
visited.reserve(candidates.size());
110+
111+
std::unordered_set<const OrtNode*> cpu_nodes;
112+
cpu_nodes.reserve(candidates.size());
113+
114+
// The algo below is trying to identity a subgraph that only depends on cpu tensors.
115+
// Usually it is a subgraph that doing shape calculation based on a GPU tensor, then reshape it back.
116+
// The detail:
117+
// for each candidate, if one of its input is a cpu tensor and the Non-CPU kernel doesn't mark it as cpu input,
118+
// force the node to CPU to avoid memory cpu and add its output to the small cpu tensors.
119+
while (!candidates.empty()) {
120+
NodeId cur = candidates.top();
121+
candidates.pop();
122+
123+
auto p = visited.insert(cur);
124+
if (!p.second) {
125+
continue;
126+
}
127+
128+
auto node_iter = node_id_to_node.find(cur);
129+
RETURN_IF(node_iter == node_id_to_node.end(), ort_api, "Unable to get OrtNode for a given node ID");
130+
Ort::ConstNode node = node_iter->second;
131+
132+
if (provider_nodes.find(cur) == provider_nodes.end()) {
133+
// Nodes not in provider_nodes are either have EP assigned or no kernel found on target EP.
134+
// we assume these nodes will fallback to CPU, so add all direct consumers of all outputs to candidates.
135+
std::string ep_name = node.GetEpName();
136+
if (ep_name.empty() || ep_name == "CPUExecutionProvider") {
137+
std::vector<Ort::ConstValueInfo> outputs = node.GetOutputs();
138+
139+
for (Ort::ConstValueInfo output : outputs) {
140+
if (output == nullptr) continue; // Skip missing optional output
141+
cpu_output_args.insert(output);
142+
}
143+
144+
std::vector<Ort::ConstNode> output_nodes;
145+
RETURN_IF_ERROR(GetOutputNodes(outputs, output_nodes));
146+
147+
for (Ort::ConstNode downstream_node : output_nodes) {
148+
candidates.push(downstream_node.GetId());
149+
}
150+
}
151+
continue;
152+
}
153+
154+
std::vector<Ort::ConstValueInfo> inputs = node.GetInputs();
155+
bool place_in_cpu = true;
156+
157+
for (size_t i = 0; i < inputs.size(); i++) {
158+
Ort::ConstValueInfo input = inputs[i];
159+
if (input == nullptr) continue; // Skip missing optional input
160+
161+
// skip placing on CPU if the data typs is float16 or bfloat16 or
162+
// float8e4m3fn, float8e4m3fnuz, floate5m2, floate5m2fnuz or float4e2m1
163+
Ort::ConstTypeInfo type_info = input.TypeInfo();
164+
auto type_shape_info = type_info.GetTensorTypeAndShapeInfo();
165+
auto elem_type = type_shape_info.GetElementType();
166+
if (elem_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 ||
167+
elem_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 ||
168+
elem_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN ||
169+
elem_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ ||
170+
elem_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2 ||
171+
elem_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ ||
172+
elem_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT4E2M1) {
173+
place_in_cpu = false;
174+
break;
175+
}
176+
177+
bool is_small_initializer = input.IsConstantInitializer() &&
178+
type_shape_info.GetElementCount() <= kSmallInitializerThreshold;
179+
180+
// Allow placing on CPU if it's a small initializer or graph input
181+
if (is_small_initializer || input.IsRequiredGraphInput() || input.IsOptionalGraphInput()) {
182+
continue;
183+
}
184+
185+
// the input is not a CPU tensor
186+
if (cpu_output_args.find(input) == cpu_output_args.end()) {
187+
place_in_cpu = false;
188+
break;
189+
}
190+
191+
// input is a CPU tensor, but it's intended to be consumed as CPU input by the target EP
192+
bool is_input_on_cpu = MemTypeOnCpuExplicitly(node_to_kernel[cur].GetOutputMemType(i));
193+
if (is_input_on_cpu) {
194+
place_in_cpu = false;
195+
break;
196+
}
197+
}
198+
199+
if (place_in_cpu) {
200+
cpu_nodes.insert(node);
201+
ORT_CXX_LOGF(Ort::Logger(&logger), ORT_LOGGING_LEVEL_WARNING,
202+
"EP optimization: Force fallback to CPU execution for node %s because the CPU execution path "
203+
"is deemed faster than overhead involved with execution on other EPs capable of executing "
204+
"this node.\n",
205+
node.GetName().c_str());
206+
207+
std::vector<Ort::ConstValueInfo> outputs = node.GetOutputs();
208+
for (Ort::ConstValueInfo output : outputs) {
209+
if (output == nullptr) continue; // Skip missing optional output
210+
cpu_output_args.insert(output);
211+
}
212+
213+
std::vector<Ort::ConstNode> output_nodes;
214+
RETURN_IF_ERROR(GetOutputNodes(outputs, output_nodes));
215+
216+
for (Ort::ConstNode downstream_node : output_nodes) {
217+
candidates.push(downstream_node.GetId());
218+
}
219+
}
220+
}
221+
222+
cpu_preferred_nodes = std::move(cpu_nodes);
223+
224+
return nullptr;
225+
EXCEPTION_TO_RETURNED_STATUS_END
226+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include <gsl/span>
7+
#include <unordered_set>
8+
#include "../plugin_ep_utils.h"
9+
10+
// Returns nodes that should be assigned to CPU EP instead of this example EP to avoid costly I/O copies.
11+
// Based on GetCpuPreferredNodes from onnxruntime/core/framework/fallback_cpu_capability.cc
12+
OrtStatus* GetCpuPreferredNodes(const OrtGraph& ort_graph, OrtEpGraphSupportInfo& graph_support_info,
13+
const OrtLogger& logger, gsl::span<const OrtNode* const> tentative_nodes,
14+
/*out*/ std::unordered_set<const OrtNode*>& cpu_preferred_nodes);

onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/relu.cc

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ ONNX_OPERATOR_KERNEL_EX(
1515
kOnnxDomain,
1616
/*version*/ 14, // Equivalent to start_version: 14, end_version: 14 (inclusive)
1717
(Ort::KernelDefBuilder()
18-
.AddTypeConstraint("T", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT))
18+
.AddTypeConstraint("T", GetTensorTypes({ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,
19+
ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64}))
1920
.AddInputOutputMutableAlias(0, 0)),
2021
Relu)
2122

@@ -36,22 +37,39 @@ OrtStatus* Relu::Create(const OrtKernelInfo* info, void* state, /*out*/ std::uni
3637
EXCEPTION_TO_RETURNED_STATUS_END
3738
}
3839

40+
template <typename T>
41+
static OrtStatus* ApplyRelu(Ort::KernelContext kernel_context) noexcept {
42+
EXCEPTION_TO_RETURNED_STATUS_BEGIN
43+
gsl::span<const T> input0;
44+
std::vector<int64_t> shape0;
45+
RETURN_IF_ERROR(GetKernelInputDataAndShape<T>(kernel_context, 0, input0, shape0));
46+
47+
Ort::UnownedValue output = kernel_context.GetOutput(0, shape0);
48+
T* output_data = output.GetTensorMutableData<T>();
49+
50+
for (size_t i = 0; i < input0.size(); ++i) {
51+
output_data[i] = std::max(static_cast<T>(0), input0[i]);
52+
}
53+
return nullptr;
54+
EXCEPTION_TO_RETURNED_STATUS_END
55+
}
56+
3957
/*static*/
4058
OrtStatus* ORT_API_CALL Relu::ComputeImpl(OrtKernelImpl* this_ptr, OrtKernelContext* kernel_ctx) noexcept {
4159
EXCEPTION_TO_RETURNED_STATUS_BEGIN
4260
Relu* relu_kernel = static_cast<Relu*>(this_ptr);
4361
Ort::KernelContext kernel_context(kernel_ctx);
4462
static_cast<void>(relu_kernel->info_); // NOTE: Unused in this example.
4563

46-
gsl::span<const float> input0;
47-
std::vector<int64_t> shape0;
48-
RETURN_IF_ERROR(GetKernelInputDataAndShape<float>(kernel_context, 0, input0, shape0));
64+
Ort::ConstValue input = kernel_context.GetInput(0);
65+
auto type_shape = input.GetTensorTypeAndShapeInfo();
66+
auto elem_type = type_shape.GetElementType();
4967

50-
Ort::UnownedValue output = kernel_context.GetOutput(0, shape0);
51-
float* output_data = output.GetTensorMutableData<float>();
52-
53-
for (size_t i = 0; i < input0.size(); ++i) {
54-
output_data[i] = std::max(0.0f, input0[i]);
68+
if (elem_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
69+
ApplyRelu<float>(kernel_context);
70+
} else {
71+
assert(elem_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64);
72+
ApplyRelu<int64_t>(kernel_context);
5573
}
5674

5775
return nullptr;

onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/reshape.cc

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -128,13 +128,15 @@ OrtStatus* ORT_API_CALL Reshape::ComputeImpl(OrtKernelImpl* this_ptr, OrtKernelC
128128
// Input[1] has the requested shape for the reshape operation.
129129
Ort::ConstValue shape_input = kernel_context.GetInput(1);
130130
gsl::span<const int64_t> shape_input_data;
131-
std::vector<int64_t> final_shape;
131+
std::vector<int64_t> shape_input_shape;
132132

133-
RETURN_IF_ERROR(GetValueDataAndShape(shape_input, shape_input_data, final_shape));
134-
RETURN_IF(final_shape.size() != 1, Ort::GetApi(), "A shape tensor must have one dimension");
135-
RETURN_IF_ERROR(GetRequestedShape(input_shape, reshape_kernel->allow_zero_, final_shape));
133+
RETURN_IF_ERROR(GetValueDataAndShape(shape_input, shape_input_data, shape_input_shape));
134+
RETURN_IF(shape_input_shape.size() != 1, Ort::GetApi(), "A shape tensor must have one dimension");
136135

137-
Ort::UnownedValue output = kernel_context.GetOutput(0, final_shape);
136+
std::vector<int64_t> output_shape(shape_input_data.begin(), shape_input_data.end());
137+
RETURN_IF_ERROR(GetRequestedShape(input_shape, reshape_kernel->allow_zero_, output_shape));
138+
139+
Ort::UnownedValue output = kernel_context.GetOutput(0, output_shape);
138140

139141
// This kernel aliases the input and output, so a copy is not really necessary.
140142
// CopyTensor() will not do a copy if the source and destination buffers are the same.

0 commit comments

Comments
 (0)