|
| 1 | +// Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | +// Licensed under the MIT License. |
| 3 | +#include "get_capability_utils.h" |
| 4 | + |
| 5 | +#include <queue> |
| 6 | +#include <unordered_map> |
| 7 | +#include <unordered_set> |
| 8 | +#include <vector> |
| 9 | + |
| 10 | +using NodeId = size_t; |
| 11 | +constexpr int64_t kSmallInitializerThreshold = 100; |
| 12 | + |
| 13 | +constexpr static inline bool MemTypeOnCpuExplicitly(OrtMemType mem_type) { |
| 14 | + return mem_type == OrtMemTypeCPUInput || mem_type == OrtMemTypeCPUOutput; |
| 15 | +} |
| 16 | + |
| 17 | +// Get all output nodes that consume an output from the given node. |
| 18 | +static OrtStatus* GetOutputNodes(gsl::span<Ort::ConstValueInfo const> node_outputs, std::vector<Ort::ConstNode>& result) { |
| 19 | + EXCEPTION_TO_RETURNED_STATUS_BEGIN |
| 20 | + std::vector<Ort::ConstNode> output_nodes; |
| 21 | + output_nodes.reserve(node_outputs.size()); // May have more |
| 22 | + |
| 23 | + // Gather the OrtNode consumers of every output. |
| 24 | + for (Ort::ConstValueInfo output : node_outputs) { |
| 25 | + if (output == nullptr) continue; // Skip missing optional output |
| 26 | + |
| 27 | + auto consumers_info = output.GetConsumers(); |
| 28 | + for (const auto& consumer : consumers_info) { |
| 29 | + output_nodes.push_back(consumer.node); |
| 30 | + } |
| 31 | + } |
| 32 | + |
| 33 | + result = std::move(output_nodes); |
| 34 | + return nullptr; |
| 35 | + EXCEPTION_TO_RETURNED_STATUS_END |
| 36 | +} |
| 37 | + |
| 38 | +// Returns nodes that should be assigned to CPU EP instead of this example EP to avoid costly I/O copies. |
| 39 | +// Based on GetCpuPreferredNodes from onnxruntime/core/framework/fallback_cpu_capability.cc |
| 40 | +OrtStatus* GetCpuPreferredNodes(const OrtGraph& ort_graph, OrtEpGraphSupportInfo& graph_support_info, |
| 41 | + const OrtLogger& logger, gsl::span<const OrtNode* const> tentative_nodes, |
| 42 | + /*out*/ std::unordered_set<const OrtNode*>& cpu_preferred_nodes) { |
| 43 | + EXCEPTION_TO_RETURNED_STATUS_BEGIN |
| 44 | + const OrtApi& ort_api = Ort::GetApi(); |
| 45 | + const OrtEpApi& ep_api = Ort::GetEpApi(); |
| 46 | + Ort::ConstGraph graph{&ort_graph}; |
| 47 | + std::vector<Ort::ConstNode> ordered_nodes = graph.GetNodes(); |
| 48 | + |
| 49 | + if (ordered_nodes.empty()) { |
| 50 | + return nullptr; |
| 51 | + } |
| 52 | + |
| 53 | + std::unordered_map<NodeId, Ort::ConstNode> node_id_to_node; |
| 54 | + std::unordered_map<NodeId, size_t> node_id_to_order_map; |
| 55 | + for (size_t i = 0, num_nodes = ordered_nodes.size(); i < num_nodes; i++) { |
| 56 | + NodeId node_id = ordered_nodes[i].GetId(); |
| 57 | + node_id_to_node[node_id] = ordered_nodes[i]; |
| 58 | + node_id_to_order_map[node_id] = i; |
| 59 | + } |
| 60 | + |
| 61 | + // If return false, n1 will be output first; If return true, n2 will be output first |
| 62 | + auto greater_order_comp = [&](const NodeId node_id1, const NodeId node_id2) { |
| 63 | + return node_id_to_order_map[node_id1] > node_id_to_order_map[node_id2]; |
| 64 | + }; |
| 65 | + std::priority_queue<NodeId, std::vector<NodeId>, decltype(greater_order_comp)> candidates(greater_order_comp); |
| 66 | + std::unordered_set<const OrtValueInfo*> cpu_output_args; |
| 67 | + |
| 68 | + std::unordered_set<NodeId> provider_nodes; |
| 69 | + provider_nodes.reserve(tentative_nodes.size()); |
| 70 | + |
| 71 | + std::unordered_map<NodeId, Ort::ConstKernelDef> node_to_kernel; |
| 72 | + node_to_kernel.reserve(tentative_nodes.size()); |
| 73 | + |
| 74 | + for (const OrtNode* ort_node : tentative_nodes) { |
| 75 | + Ort::ConstNode node(ort_node); |
| 76 | + NodeId node_id = node.GetId(); |
| 77 | + |
| 78 | + provider_nodes.insert(node_id); |
| 79 | + |
| 80 | + // Expect at least one registry has a target provider's kernel for this node. |
| 81 | + const OrtKernelDef* ort_kernel_def = nullptr; |
| 82 | + RETURN_IF_ERROR(ep_api.EpGraphSupportInfo_LookUpKernel(&graph_support_info, node, &ort_kernel_def)); |
| 83 | + RETURN_IF(ort_kernel_def == nullptr, ort_api, "Must have a registered kernel definition on the target EP"); |
| 84 | + |
| 85 | + Ort::ConstKernelDef kernel_def(ort_kernel_def); |
| 86 | + node_to_kernel.insert({node_id, kernel_def}); |
| 87 | + |
| 88 | + // Find all the direct consumers of CPU tensors. |
| 89 | + std::vector<Ort::ConstValueInfo> outputs = node.GetOutputs(); |
| 90 | + for (size_t out_index = 0; out_index < outputs.size(); out_index++) { |
| 91 | + Ort::ConstValueInfo output = outputs[out_index]; |
| 92 | + if (output == nullptr) continue; // Skip missing optional output |
| 93 | + |
| 94 | + bool is_output_on_cpu = MemTypeOnCpuExplicitly(kernel_def.GetOutputMemType(out_index)); |
| 95 | + if (is_output_on_cpu) { |
| 96 | + cpu_output_args.insert(output); |
| 97 | + |
| 98 | + auto consumer_infos = output.GetConsumers(); |
| 99 | + for (const auto& consumer_info : consumer_infos) { |
| 100 | + candidates.push(consumer_info.node.GetId()); |
| 101 | + ORT_CXX_LOGF(Ort::Logger(&logger), ORT_LOGGING_LEVEL_INFO, "Candidate for fallback CPU execution: %s\n", |
| 102 | + consumer_info.node.GetName().c_str()); |
| 103 | + } |
| 104 | + } |
| 105 | + } |
| 106 | + } |
| 107 | + |
| 108 | + std::unordered_set<NodeId> visited; |
| 109 | + visited.reserve(candidates.size()); |
| 110 | + |
| 111 | + std::unordered_set<const OrtNode*> cpu_nodes; |
| 112 | + cpu_nodes.reserve(candidates.size()); |
| 113 | + |
| 114 | + // The algo below is trying to identity a subgraph that only depends on cpu tensors. |
| 115 | + // Usually it is a subgraph that doing shape calculation based on a GPU tensor, then reshape it back. |
| 116 | + // The detail: |
| 117 | + // for each candidate, if one of its input is a cpu tensor and the Non-CPU kernel doesn't mark it as cpu input, |
| 118 | + // force the node to CPU to avoid memory cpu and add its output to the small cpu tensors. |
| 119 | + while (!candidates.empty()) { |
| 120 | + NodeId cur = candidates.top(); |
| 121 | + candidates.pop(); |
| 122 | + |
| 123 | + auto p = visited.insert(cur); |
| 124 | + if (!p.second) { |
| 125 | + continue; |
| 126 | + } |
| 127 | + |
| 128 | + auto node_iter = node_id_to_node.find(cur); |
| 129 | + RETURN_IF(node_iter == node_id_to_node.end(), ort_api, "Unable to get OrtNode for a given node ID"); |
| 130 | + Ort::ConstNode node = node_iter->second; |
| 131 | + |
| 132 | + if (provider_nodes.find(cur) == provider_nodes.end()) { |
| 133 | + // Nodes not in provider_nodes are either have EP assigned or no kernel found on target EP. |
| 134 | + // we assume these nodes will fallback to CPU, so add all direct consumers of all outputs to candidates. |
| 135 | + std::string ep_name = node.GetEpName(); |
| 136 | + if (ep_name.empty() || ep_name == "CPUExecutionProvider") { |
| 137 | + std::vector<Ort::ConstValueInfo> outputs = node.GetOutputs(); |
| 138 | + |
| 139 | + for (Ort::ConstValueInfo output : outputs) { |
| 140 | + if (output == nullptr) continue; // Skip missing optional output |
| 141 | + cpu_output_args.insert(output); |
| 142 | + } |
| 143 | + |
| 144 | + std::vector<Ort::ConstNode> output_nodes; |
| 145 | + RETURN_IF_ERROR(GetOutputNodes(outputs, output_nodes)); |
| 146 | + |
| 147 | + for (Ort::ConstNode downstream_node : output_nodes) { |
| 148 | + candidates.push(downstream_node.GetId()); |
| 149 | + } |
| 150 | + } |
| 151 | + continue; |
| 152 | + } |
| 153 | + |
| 154 | + std::vector<Ort::ConstValueInfo> inputs = node.GetInputs(); |
| 155 | + bool place_in_cpu = true; |
| 156 | + |
| 157 | + for (size_t i = 0; i < inputs.size(); i++) { |
| 158 | + Ort::ConstValueInfo input = inputs[i]; |
| 159 | + if (input == nullptr) continue; // Skip missing optional input |
| 160 | + |
| 161 | + // skip placing on CPU if the data typs is float16 or bfloat16 or |
| 162 | + // float8e4m3fn, float8e4m3fnuz, floate5m2, floate5m2fnuz or float4e2m1 |
| 163 | + Ort::ConstTypeInfo type_info = input.TypeInfo(); |
| 164 | + auto type_shape_info = type_info.GetTensorTypeAndShapeInfo(); |
| 165 | + auto elem_type = type_shape_info.GetElementType(); |
| 166 | + if (elem_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 || |
| 167 | + elem_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 || |
| 168 | + elem_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN || |
| 169 | + elem_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ || |
| 170 | + elem_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2 || |
| 171 | + elem_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ || |
| 172 | + elem_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT4E2M1) { |
| 173 | + place_in_cpu = false; |
| 174 | + break; |
| 175 | + } |
| 176 | + |
| 177 | + bool is_small_initializer = input.IsConstantInitializer() && |
| 178 | + type_shape_info.GetElementCount() <= kSmallInitializerThreshold; |
| 179 | + |
| 180 | + // Allow placing on CPU if it's a small initializer or graph input |
| 181 | + if (is_small_initializer || input.IsRequiredGraphInput() || input.IsOptionalGraphInput()) { |
| 182 | + continue; |
| 183 | + } |
| 184 | + |
| 185 | + // the input is not a CPU tensor |
| 186 | + if (cpu_output_args.find(input) == cpu_output_args.end()) { |
| 187 | + place_in_cpu = false; |
| 188 | + break; |
| 189 | + } |
| 190 | + |
| 191 | + // input is a CPU tensor, but it's intended to be consumed as CPU input by the target EP |
| 192 | + bool is_input_on_cpu = MemTypeOnCpuExplicitly(node_to_kernel[cur].GetOutputMemType(i)); |
| 193 | + if (is_input_on_cpu) { |
| 194 | + place_in_cpu = false; |
| 195 | + break; |
| 196 | + } |
| 197 | + } |
| 198 | + |
| 199 | + if (place_in_cpu) { |
| 200 | + cpu_nodes.insert(node); |
| 201 | + ORT_CXX_LOGF(Ort::Logger(&logger), ORT_LOGGING_LEVEL_WARNING, |
| 202 | + "EP optimization: Force fallback to CPU execution for node %s because the CPU execution path " |
| 203 | + "is deemed faster than overhead involved with execution on other EPs capable of executing " |
| 204 | + "this node.\n", |
| 205 | + node.GetName().c_str()); |
| 206 | + |
| 207 | + std::vector<Ort::ConstValueInfo> outputs = node.GetOutputs(); |
| 208 | + for (Ort::ConstValueInfo output : outputs) { |
| 209 | + if (output == nullptr) continue; // Skip missing optional output |
| 210 | + cpu_output_args.insert(output); |
| 211 | + } |
| 212 | + |
| 213 | + std::vector<Ort::ConstNode> output_nodes; |
| 214 | + RETURN_IF_ERROR(GetOutputNodes(outputs, output_nodes)); |
| 215 | + |
| 216 | + for (Ort::ConstNode downstream_node : output_nodes) { |
| 217 | + candidates.push(downstream_node.GetId()); |
| 218 | + } |
| 219 | + } |
| 220 | + } |
| 221 | + |
| 222 | + cpu_preferred_nodes = std::move(cpu_nodes); |
| 223 | + |
| 224 | + return nullptr; |
| 225 | + EXCEPTION_TO_RETURNED_STATUS_END |
| 226 | +} |
0 commit comments