Skip to content

Commit 4545732

Browse files
authored
MemcpyFromHost and MemcpyToHost support for plugin EPs (#26088)
### Description <!-- Describe your changes. --> Add support for `MemcpyFromHost` and `MemcpyToHost` ops with plugin EPs. - Add CPU EP fallback kernels for the memcpy ops. These are generic implementations using a data transfer manager. - Update `SessionState::PopulateKernelCreateInfo()` to fall back to CPU memcpy kernels if a node's assigned provider doesn't have them. - Update `MemcpyTransformer` to determine whether providers are CPU-based or compatible with other providers by looking at the device type instead of matching against a hardcoded list of provider types. This accommodates plugin EPs, where the provider type can't be hardcoded. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Allow plugin EPs to work with models where memcpy ops are required (i.e., models where connected nodes are not fully assigned to the plugin EP).
1 parent ff66c70 commit 4545732

23 files changed

+368
-104
lines changed

cmake/onnxruntime_unittests.cmake

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1973,6 +1973,7 @@ endif()
19731973
if (WIN32 AND onnxruntime_BUILD_SHARED_LIB AND
19741974
NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND
19751975
NOT onnxruntime_MINIMAL_BUILD)
1976+
# example_plugin_ep
19761977
file(GLOB onnxruntime_autoep_test_library_src "${TEST_SRC_DIR}/autoep/library/*.h"
19771978
"${TEST_SRC_DIR}/autoep/library/*.cc")
19781979
onnxruntime_add_shared_library_module(example_plugin_ep ${onnxruntime_autoep_test_library_src})
@@ -1995,6 +1996,9 @@ if (WIN32 AND onnxruntime_BUILD_SHARED_LIB AND
19951996
set_property(TARGET example_plugin_ep APPEND_STRING PROPERTY LINK_FLAGS
19961997
${ONNXRUNTIME_AUTOEP_LIB_LINK_FLAG})
19971998

1999+
set_target_properties(example_plugin_ep PROPERTIES FOLDER "ONNXRuntimeTest")
2000+
source_group(TREE ${TEST_SRC_DIR} FILES ${onnxruntime_autoep_test_library_src})
2001+
19982002
# test library
19992003
file(GLOB onnxruntime_autoep_test_SRC "${ONNXRUNTIME_AUTOEP_TEST_SRC_DIR}/*.h"
20002004
"${ONNXRUNTIME_AUTOEP_TEST_SRC_DIR}/*.cc")

docs/OperatorKernels.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,8 @@ Do not modify directly.*
261261
|||[9, 12]|**T** = tensor(float)|
262262
|||[1, 8]|**T** = tensor(float)|
263263
|MelWeightMatrix|*in* num_mel_bins:**T1**<br> *in* dft_length:**T1**<br> *in* sample_rate:**T1**<br> *in* lower_edge_hertz:**T2**<br> *in* upper_edge_hertz:**T2**<br> *out* output:**T3**|17+|**T1** = tensor(int32), tensor(int64)<br/> **T2** = tensor(float)<br/> **T3** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
264+
|MemcpyFromHost|*in* X:**T**<br> *out* Y:**T**|1+|**T** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(float8e4m3fn)), seq(tensor(float8e4m3fnuz)), seq(tensor(float8e5m2)), seq(tensor(float8e5m2fnuz)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
265+
|MemcpyToHost|*in* X:**T**<br> *out* Y:**T**|1+|**T** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(float8e4m3fn)), seq(tensor(float8e4m3fnuz)), seq(tensor(float8e5m2)), seq(tensor(float8e5m2fnuz)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
264266
|Min|*in* data_0:**T**<br> *out* min:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
265267
|||12|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
266268
|||[8, 11]|**T** = tensor(double), tensor(float)|

onnxruntime/core/framework/session_state.cc

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -226,13 +226,22 @@ Status SessionState::PopulateKernelCreateInfo(const KernelRegistryManager& kerne
226226
for (auto& node : graph_.Nodes()) {
227227
const KernelCreateInfo* kci = nullptr;
228228
auto status = kernel_registry_manager.SearchKernelRegistry(node, logger_, &kci);
229-
if (!status.IsOK() && saving_ort_format) {
230-
// if we didn't find the kernel and are saving to ORT format an EP that compiles nodes is enabled.
231-
// in that case we assigned the node to that EP but do not compile it into a fused node.
232-
// this keeps the original node and prevents level 2 and level 3 optimizers from modifying it.
233-
// we now revert to the CPU EP kernel as a fallback.
234-
// at runtime when the model is loaded in a minimal build, the compiling EP will replace this node if possible.
235-
// if that's not possible for some reason we can fallback to the CPU EP implementation.
229+
230+
// There are two cases where we allow fallback to CPU EP kernels:
231+
//
232+
// 1. if we didn't find the kernel and are saving to ORT format an EP that compiles nodes is enabled.
233+
// in that case we assigned the node to that EP but do not compile it into a fused node.
234+
// this keeps the original node and prevents level 2 and level 3 optimizers from modifying it.
235+
// we now revert to the CPU EP kernel as a fallback.
236+
// at runtime when the model is loaded in a minimal build, the compiling EP will replace this node if possible.
237+
// if that's not possible for some reason we can fallback to the CPU EP implementation.
238+
//
239+
// 2. If the node is a memcpy node.
240+
// EPs may provide their own memcpy kernels. The CPU EP provides a generic version to fall back to if the EP does
241+
// not provide one.
242+
const bool allow_cpu_ep_kernel_fallback = saving_ort_format || utils::IsMemcpyNode(node);
243+
244+
if (!status.IsOK() && allow_cpu_ep_kernel_fallback) {
236245
node.SetExecutionProviderType(kCpuExecutionProvider);
237246
status = kernel_registry_manager.SearchKernelRegistry(node, logger_, &kci);
238247
}

onnxruntime/core/framework/utils.cc

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -46,22 +46,13 @@ void DestroyStrings(void* p_data, int64_t elements) {
4646
ptr[i].~string();
4747
}
4848

49-
bool ProviderIsCpuBased(const std::string& provider_type) {
50-
return provider_type == onnxruntime::kCpuExecutionProvider ||
51-
provider_type == onnxruntime::kDnnlExecutionProvider ||
52-
provider_type == onnxruntime::kVitisAIExecutionProvider ||
53-
provider_type == onnxruntime::kOpenVINOExecutionProvider ||
54-
provider_type == onnxruntime::kNnapiExecutionProvider ||
55-
provider_type == onnxruntime::kVSINPUExecutionProvider ||
56-
provider_type == onnxruntime::kAclExecutionProvider ||
57-
provider_type == onnxruntime::kArmNNExecutionProvider ||
58-
provider_type == onnxruntime::kRknpuExecutionProvider ||
59-
provider_type == onnxruntime::kCoreMLExecutionProvider ||
60-
provider_type == onnxruntime::kSnpeExecutionProvider ||
61-
provider_type == onnxruntime::kQnnExecutionProvider ||
62-
provider_type == onnxruntime::kXnnpackExecutionProvider ||
63-
provider_type == onnxruntime::kAzureExecutionProvider ||
64-
provider_type == onnxruntime::utils::kInternalTestingExecutionProvider;
49+
bool ProviderIsCpuBased(const IExecutionProvider& provider) {
50+
return provider.GetDevice().Type() == OrtDevice::CPU;
51+
}
52+
53+
bool IsMemcpyNode(const Node& node) {
54+
return node.Domain() == kOnnxDomain &&
55+
(node.OpType() == "MemcpyFromHost" || node.OpType() == "MemcpyToHost");
6556
}
6657

6758
static common::Status AllocateHelper(const AllocatorPtr& allocator,
@@ -210,7 +201,7 @@ static Status BatchOrCopyMLValue(const SessionState& session_state,
210201

211202
static bool HaveCpuExecutionProvidersOnly(const ExecutionProviders& execution_providers) {
212203
for (const auto& execution_provider : execution_providers) {
213-
if (!ProviderIsCpuBased(execution_provider->Type())) {
204+
if (!ProviderIsCpuBased(*execution_provider)) {
214205
return false;
215206
}
216207
}

onnxruntime/core/framework/utils.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,10 @@ void DestroyStrings(void* p_data, int64_t elements);
5252

5353
const std::string& GetNodeInputProviderType(const SessionState::NodeInfo& info);
5454

55-
// EP used for internal testing. We define it here as it's used in ProviderIsCpuBased, but we don't want
56-
// it to be in the public header include/onnxruntime/core/graph/constants.h as it's purely internal.
57-
constexpr const char* kInternalTestingExecutionProvider = "InternalTestingExecutionProvider";
58-
5955
// return true if the execution provider is CPU based (meaning no copies to device are required)
60-
bool ProviderIsCpuBased(const std::string& provider_type);
56+
bool ProviderIsCpuBased(const IExecutionProvider& provider);
57+
58+
bool IsMemcpyNode(const Node& node);
6159

6260
common::Status CopyOneInputAcrossDevices(const SessionState& session_state, const std::string& input_name,
6361
const OrtValue& orig_mlvalue, OrtValue& new_mlvalue);

onnxruntime/core/optimizer/transformer_memcpy.cc

Lines changed: 81 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
22
// Licensed under the MIT License.
33

4-
#include "transformer_memcpy.h"
4+
#include "core/optimizer/transformer_memcpy.h"
5+
56
#include "core/common/logging/logging.h"
67
#include "core/framework/kernel_registry_manager.h"
78
#include "core/framework/execution_providers.h"
@@ -12,26 +13,49 @@
1213
using namespace ONNX_NAMESPACE;
1314
namespace onnxruntime {
1415

16+
static ProviderTypeToProviderMap GetProvidersByType(
17+
const InlinedVector<gsl::not_null<const IExecutionProvider*>>& providers) {
18+
ProviderTypeToProviderMap providers_by_type{};
19+
for (const auto provider : providers) {
20+
providers_by_type.emplace(provider->Type(), provider);
21+
}
22+
return providers_by_type;
23+
}
24+
25+
MemcpyTransformer::MemcpyTransformer(InlinedVector<gsl::not_null<const IExecutionProvider*>> providers,
26+
const KernelRegistryManager& registry_manager)
27+
: GraphTransformer("MemcpyTransformer"),
28+
providers_(std::move(providers)),
29+
providers_by_type_(GetProvidersByType(providers_)),
30+
registry_manager_(std::cref(registry_manager)) {
31+
}
32+
1533
// implements MemCpy node insertion in graph transform
1634
// note that GraphTransformer::Apply() is supposed to be stateless, so this cannot derive from GraphTransformer
1735
class TransformerMemcpyImpl {
1836
public:
19-
TransformerMemcpyImpl(onnxruntime::Graph& graph, const std::string& provider)
20-
: graph_(graph), provider_(provider) {}
37+
TransformerMemcpyImpl(onnxruntime::Graph& graph, const IExecutionProvider& provider,
38+
const ProviderTypeToProviderMap& providers_by_type)
39+
: graph_(graph), provider_(provider), providers_by_type_(providers_by_type) {
40+
}
2141

2242
bool ModifyGraph(const KernelRegistryManager& schema_registries,
2343
const logging::Logger& logger,
2444
int& copy_node_counter);
2545

2646
private:
47+
bool IsNodeCompatibleWithProvider(const onnxruntime::Node& node) const;
48+
2749
void ProcessDefs(onnxruntime::Node& node,
2850
const KernelRegistryManager& kernel_registries,
2951
InitializedTensorSet& initializers_consumed,
3052
const logging::Logger& logger);
3153
void BuildDefsMapping(const onnxruntime::NodeArg* arg,
3254
const KernelRegistryManager& kernel_registries,
3355
const logging::Logger& logger);
34-
void AddCopyNode(onnxruntime::NodeArg* arg, bool is_input, const logging::Logger& logger);
56+
void AddCopyNode(onnxruntime::NodeArg* arg,
57+
bool is_input,
58+
const logging::Logger& logger);
3559
bool ProcessInitializers(const KernelRegistryManager& kernel_registries,
3660
const InitializedTensorSet& initializers_consumed,
3761
const logging::Logger& logger);
@@ -55,7 +79,8 @@ class TransformerMemcpyImpl {
5579
std::map<const onnxruntime::NodeArg*, std::set<onnxruntime::Node*, NodeCompare>> provider_output_nodes_;
5680

5781
onnxruntime::Graph& graph_;
58-
std::string provider_;
82+
const IExecutionProvider& provider_;
83+
const ProviderTypeToProviderMap& providers_by_type_;
5984
};
6085

6186
/** Helper that returns a pointer to the corresponding TensorProto for a name if it is an initializer.
@@ -73,17 +98,18 @@ static const onnx::TensorProto* GetInitializer(const Graph& graph, const std::st
7398

7499
// very simple GraphTransformer that uses TransformerMemcpyImpl for each graph
75100
// and mainly provides the subgraph recursion functionality
76-
common::Status MemcpyTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level,
77-
const logging::Logger& logger) const {
78-
for (auto& provider : provider_types_) {
79-
if (!utils::ProviderIsCpuBased(provider)) {
80-
TransformerMemcpyImpl copy_impl(graph, provider);
101+
Status MemcpyTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level,
102+
const logging::Logger& logger) const {
103+
for (const auto provider : providers_) {
104+
const auto& provider_type = provider->Type();
105+
if (!utils::ProviderIsCpuBased(*provider)) {
106+
TransformerMemcpyImpl copy_impl(graph, *provider, providers_by_type_);
81107

82108
int copy_node_counter = 0;
83109
auto current_modified = copy_impl.ModifyGraph(registry_manager_, logger, copy_node_counter);
84-
if (copy_node_counter > 0 && provider == kCudaExecutionProvider) {
110+
if (copy_node_counter > 0 && provider_type == kCudaExecutionProvider) {
85111
LOGS(logger, WARNING) << copy_node_counter << " Memcpy nodes are added to the graph " << graph.Name()
86-
<< " for " << provider
112+
<< " for " << provider_type
87113
<< ". It might have negative impact on performance (including unable to run CUDA graph). "
88114
<< "Set session_options.log_severity_level=1 to see the detail logs before this message.";
89115
}
@@ -213,15 +239,42 @@ bool TransformerMemcpyImpl::ModifyGraph(const KernelRegistryManager& kernel_regi
213239
return modified;
214240
}
215241

242+
static const IExecutionProvider* FindProviderByType(ProviderTypeToProviderMap providers_by_type,
243+
std::string_view provider_type) {
244+
const auto it = providers_by_type.find(provider_type);
245+
if (it != providers_by_type.end()) {
246+
return &*it->second;
247+
}
248+
return nullptr;
249+
}
250+
251+
bool TransformerMemcpyImpl::IsNodeCompatibleWithProvider(const onnxruntime::Node& node) const {
252+
const auto& node_provider_type = node.GetExecutionProviderType();
253+
const auto* node_provider = FindProviderByType(providers_by_type_, node_provider_type);
254+
ORT_ENFORCE(node_provider != nullptr, "Unable to get provider associated with provider type ", node_provider_type);
255+
256+
// Same provider?
257+
if (node_provider->Type() == provider_.Type()) {
258+
return true;
259+
}
260+
261+
const auto& node_provider_device = node_provider->GetDevice();
262+
const auto& provider_device = provider_.GetDevice();
263+
264+
// Same provider device type and vendor?
265+
if (node_provider_device.Type() == provider_device.Type() &&
266+
node_provider_device.Vendor() == provider_device.Vendor()) {
267+
return true;
268+
}
269+
270+
return false;
271+
}
272+
216273
void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node,
217274
const KernelRegistryManager& kernel_registries,
218275
InitializedTensorSet& initializers_consumed,
219276
const logging::Logger& logger) {
220-
auto node_provider_type = node.GetExecutionProviderType();
221-
if ((node_provider_type == provider_) ||
222-
(node_provider_type == kCudaExecutionProvider && kTensorrtExecutionProvider == provider_) ||
223-
(node_provider_type == kCudaExecutionProvider && kNvTensorRTRTXExecutionProvider == provider_) ||
224-
(node_provider_type == kRocmExecutionProvider && kMIGraphXExecutionProvider == provider_)) {
277+
if (IsNodeCompatibleWithProvider(node)) {
225278
provider_nodes_.insert(&node);
226279
// note KernelCreateInfo might be nullptr for custom kernel
227280
const KernelCreateInfo* kci = nullptr;
@@ -268,9 +321,7 @@ void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node,
268321
else
269322
provider_output_defs_.insert(arg);
270323
}
271-
} else if (node_provider_type != kCudaExecutionProvider && node_provider_type != kTensorrtExecutionProvider &&
272-
node_provider_type != kCudaExecutionProvider && node_provider_type != kNvTensorRTRTXExecutionProvider &&
273-
node_provider_type != kRocmExecutionProvider && node_provider_type != kMIGraphXExecutionProvider) {
324+
} else {
274325
for (const auto* arg : node.InputDefs()) {
275326
if (arg->Exists())
276327
non_provider_input_defs_.insert(arg);
@@ -297,7 +348,7 @@ void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg,
297348
const KernelRegistryManager& kernel_registries,
298349
const logging::Logger& logger) {
299350
for (auto& it : graph_.Nodes()) {
300-
if (it.OpType() == "MemcpyFromHost" || it.OpType() == "MemcpyToHost") continue;
351+
if (utils::IsMemcpyNode(it)) continue;
301352
auto input_it =
302353
std::find(it.MutableInputDefs().begin(), it.MutableInputDefs().end(), const_cast<onnxruntime::NodeArg*>(arg));
303354
auto output_it =
@@ -309,10 +360,7 @@ void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg,
309360
if (arg_input_index == -1 && arg_output_index == -1)
310361
continue;
311362
auto node_provider_type = it.GetExecutionProviderType();
312-
if ((node_provider_type == provider_) ||
313-
(node_provider_type == kCudaExecutionProvider && kTensorrtExecutionProvider == provider_) ||
314-
(node_provider_type == kCudaExecutionProvider && kNvTensorRTRTXExecutionProvider == provider_) ||
315-
(node_provider_type == kRocmExecutionProvider && kMIGraphXExecutionProvider == provider_)) {
363+
if (IsNodeCompatibleWithProvider(it)) {
316364
const KernelCreateInfo* kci = nullptr;
317365
ORT_IGNORE_RETURN_VALUE(kernel_registries.SearchKernelRegistry(it, logger, &kci));
318366
if (arg_input_index != -1) {
@@ -325,9 +373,11 @@ void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg,
325373
}
326374
}
327375

328-
void TransformerMemcpyImpl::AddCopyNode(onnxruntime::NodeArg* arg, bool is_input, const logging::Logger& logger) {
376+
void TransformerMemcpyImpl::AddCopyNode(onnxruntime::NodeArg* arg,
377+
bool is_input,
378+
const logging::Logger& logger) {
329379
// create unique name for new def
330-
std::string new_def_name = graph_.GenerateNodeArgName(arg->Name() + "_" + provider_);
380+
std::string new_def_name = graph_.GenerateNodeArgName(arg->Name() + "_" + provider_.Type());
331381

332382
auto* new_arg = &graph_.GetOrCreateNodeArg(new_def_name, arg->TypeAsProto());
333383
auto* src_arg = is_input ? arg : new_arg;
@@ -338,12 +388,14 @@ void TransformerMemcpyImpl::AddCopyNode(onnxruntime::NodeArg* arg, bool is_input
338388

339389
const auto op_name = is_input ? "MemcpyFromHost" : "MemcpyToHost";
340390
LOGS(logger, INFO) << "Add " << op_name << (is_input ? " after " : " before ") << arg->Name()
341-
<< " for " << provider_;
391+
<< " for " << provider_.Type();
342392

343393
auto& new_node = graph_.AddNode(new_node_name, op_name, "Copy from/to host memory",
344394
std::vector<onnxruntime::NodeArg*>{src_arg},
345395
std::vector<onnxruntime::NodeArg*>{dst_arg});
346-
new_node.SetExecutionProviderType(provider_);
396+
397+
new_node.SetExecutionProviderType(provider_.Type());
398+
347399
std::map<const onnxruntime::NodeArg*, onnxruntime::NodeArg*> map = {{arg, new_arg}};
348400
auto it = provider_input_nodes_.find(arg);
349401
if (it != provider_input_nodes_.end()) {

0 commit comments

Comments
 (0)