11// Copyright (c) Microsoft Corporation. All rights reserved.
22// Licensed under the MIT License.
33
4- #include " transformer_memcpy.h"
4+ #include " core/optimizer/transformer_memcpy.h"
5+
56#include " core/common/logging/logging.h"
67#include " core/framework/kernel_registry_manager.h"
78#include " core/framework/execution_providers.h"
1213using namespace ONNX_NAMESPACE ;
1314namespace onnxruntime {
1415
16+ static ProviderTypeToProviderMap GetProvidersByType (
17+ const InlinedVector<gsl::not_null<const IExecutionProvider*>>& providers) {
18+ ProviderTypeToProviderMap providers_by_type{};
19+ for (const auto provider : providers) {
20+ providers_by_type.emplace (provider->Type (), provider);
21+ }
22+ return providers_by_type;
23+ }
24+
25+ MemcpyTransformer::MemcpyTransformer (InlinedVector<gsl::not_null<const IExecutionProvider*>> providers,
26+ const KernelRegistryManager& registry_manager)
27+ : GraphTransformer(" MemcpyTransformer" ),
28+ providers_ (std::move(providers)),
29+ providers_by_type_(GetProvidersByType(providers_)),
30+ registry_manager_(std::cref(registry_manager)) {
31+ }
32+
1533// implements MemCpy node insertion in graph transform
1634// note that GraphTransformer::Apply() is supposed to be stateless, so this cannot derive from GraphTransformer
1735class TransformerMemcpyImpl {
1836 public:
19- TransformerMemcpyImpl (onnxruntime::Graph& graph, const std::string& provider)
20- : graph_(graph), provider_(provider) {}
37+ TransformerMemcpyImpl (onnxruntime::Graph& graph, const IExecutionProvider& provider,
38+ const ProviderTypeToProviderMap& providers_by_type)
39+ : graph_(graph), provider_(provider), providers_by_type_(providers_by_type) {
40+ }
2141
2242 bool ModifyGraph (const KernelRegistryManager& schema_registries,
2343 const logging::Logger& logger,
2444 int & copy_node_counter);
2545
2646 private:
47+ bool IsNodeCompatibleWithProvider (const onnxruntime::Node& node) const ;
48+
2749 void ProcessDefs (onnxruntime::Node& node,
2850 const KernelRegistryManager& kernel_registries,
2951 InitializedTensorSet& initializers_consumed,
3052 const logging::Logger& logger);
3153 void BuildDefsMapping (const onnxruntime::NodeArg* arg,
3254 const KernelRegistryManager& kernel_registries,
3355 const logging::Logger& logger);
34- void AddCopyNode (onnxruntime::NodeArg* arg, bool is_input, const logging::Logger& logger);
56+ void AddCopyNode (onnxruntime::NodeArg* arg,
57+ bool is_input,
58+ const logging::Logger& logger);
3559 bool ProcessInitializers (const KernelRegistryManager& kernel_registries,
3660 const InitializedTensorSet& initializers_consumed,
3761 const logging::Logger& logger);
@@ -55,7 +79,8 @@ class TransformerMemcpyImpl {
5579 std::map<const onnxruntime::NodeArg*, std::set<onnxruntime::Node*, NodeCompare>> provider_output_nodes_;
5680
5781 onnxruntime::Graph& graph_;
58- std::string provider_;
82+ const IExecutionProvider& provider_;
83+ const ProviderTypeToProviderMap& providers_by_type_;
5984};
6085
6186/* * Helper that returns a pointer to the corresponding TensorProto for a name if it is an initializer.
@@ -73,17 +98,18 @@ static const onnx::TensorProto* GetInitializer(const Graph& graph, const std::st
7398
7499// very simple GraphTransformer that uses TransformerMemcpyImpl for each graph
75100// and mainly provides the subgraph recursion functionality
76- common::Status MemcpyTransformer::ApplyImpl (Graph& graph, bool & modified, int graph_level,
77- const logging::Logger& logger) const {
78- for (auto & provider : provider_types_) {
79- if (!utils::ProviderIsCpuBased (provider)) {
80- TransformerMemcpyImpl copy_impl (graph, provider);
101+ Status MemcpyTransformer::ApplyImpl (Graph& graph, bool & modified, int graph_level,
102+ const logging::Logger& logger) const {
103+ for (const auto provider : providers_) {
104+ const auto & provider_type = provider->Type ();
105+ if (!utils::ProviderIsCpuBased (*provider)) {
106+ TransformerMemcpyImpl copy_impl (graph, *provider, providers_by_type_);
81107
82108 int copy_node_counter = 0 ;
83109 auto current_modified = copy_impl.ModifyGraph (registry_manager_, logger, copy_node_counter);
84- if (copy_node_counter > 0 && provider == kCudaExecutionProvider ) {
110+ if (copy_node_counter > 0 && provider_type == kCudaExecutionProvider ) {
85111 LOGS (logger, WARNING) << copy_node_counter << " Memcpy nodes are added to the graph " << graph.Name ()
86- << " for " << provider
112+ << " for " << provider_type
87113 << " . It might have negative impact on performance (including unable to run CUDA graph). "
88114 << " Set session_options.log_severity_level=1 to see the detail logs before this message." ;
89115 }
@@ -213,15 +239,42 @@ bool TransformerMemcpyImpl::ModifyGraph(const KernelRegistryManager& kernel_regi
213239 return modified;
214240}
215241
242+ static const IExecutionProvider* FindProviderByType (ProviderTypeToProviderMap providers_by_type,
243+ std::string_view provider_type) {
244+ const auto it = providers_by_type.find (provider_type);
245+ if (it != providers_by_type.end ()) {
246+ return &*it->second ;
247+ }
248+ return nullptr ;
249+ }
250+
251+ bool TransformerMemcpyImpl::IsNodeCompatibleWithProvider (const onnxruntime::Node& node) const {
252+ const auto & node_provider_type = node.GetExecutionProviderType ();
253+ const auto * node_provider = FindProviderByType (providers_by_type_, node_provider_type);
254+ ORT_ENFORCE (node_provider != nullptr , " Unable to get provider associated with provider type " , node_provider_type);
255+
256+ // Same provider?
257+ if (node_provider->Type () == provider_.Type ()) {
258+ return true ;
259+ }
260+
261+ const auto & node_provider_device = node_provider->GetDevice ();
262+ const auto & provider_device = provider_.GetDevice ();
263+
264+ // Same provider device type and vendor?
265+ if (node_provider_device.Type () == provider_device.Type () &&
266+ node_provider_device.Vendor () == provider_device.Vendor ()) {
267+ return true ;
268+ }
269+
270+ return false ;
271+ }
272+
216273void TransformerMemcpyImpl::ProcessDefs (onnxruntime::Node& node,
217274 const KernelRegistryManager& kernel_registries,
218275 InitializedTensorSet& initializers_consumed,
219276 const logging::Logger& logger) {
220- auto node_provider_type = node.GetExecutionProviderType ();
221- if ((node_provider_type == provider_) ||
222- (node_provider_type == kCudaExecutionProvider && kTensorrtExecutionProvider == provider_) ||
223- (node_provider_type == kCudaExecutionProvider && kNvTensorRTRTXExecutionProvider == provider_) ||
224- (node_provider_type == kRocmExecutionProvider && kMIGraphXExecutionProvider == provider_)) {
277+ if (IsNodeCompatibleWithProvider (node)) {
225278 provider_nodes_.insert (&node);
226279 // note KernelCreateInfo might be nullptr for custom kernel
227280 const KernelCreateInfo* kci = nullptr ;
@@ -268,9 +321,7 @@ void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node,
268321 else
269322 provider_output_defs_.insert (arg);
270323 }
271- } else if (node_provider_type != kCudaExecutionProvider && node_provider_type != kTensorrtExecutionProvider &&
272- node_provider_type != kCudaExecutionProvider && node_provider_type != kNvTensorRTRTXExecutionProvider &&
273- node_provider_type != kRocmExecutionProvider && node_provider_type != kMIGraphXExecutionProvider ) {
324+ } else {
274325 for (const auto * arg : node.InputDefs ()) {
275326 if (arg->Exists ())
276327 non_provider_input_defs_.insert (arg);
@@ -297,7 +348,7 @@ void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg,
297348 const KernelRegistryManager& kernel_registries,
298349 const logging::Logger& logger) {
299350 for (auto & it : graph_.Nodes ()) {
300- if (it. OpType () == " MemcpyFromHost " || it. OpType () == " MemcpyToHost " ) continue ;
351+ if (utils::IsMemcpyNode (it) ) continue ;
301352 auto input_it =
302353 std::find (it.MutableInputDefs ().begin (), it.MutableInputDefs ().end (), const_cast <onnxruntime::NodeArg*>(arg));
303354 auto output_it =
@@ -309,10 +360,7 @@ void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg,
309360 if (arg_input_index == -1 && arg_output_index == -1 )
310361 continue ;
311362 auto node_provider_type = it.GetExecutionProviderType ();
312- if ((node_provider_type == provider_) ||
313- (node_provider_type == kCudaExecutionProvider && kTensorrtExecutionProvider == provider_) ||
314- (node_provider_type == kCudaExecutionProvider && kNvTensorRTRTXExecutionProvider == provider_) ||
315- (node_provider_type == kRocmExecutionProvider && kMIGraphXExecutionProvider == provider_)) {
363+ if (IsNodeCompatibleWithProvider (it)) {
316364 const KernelCreateInfo* kci = nullptr ;
317365 ORT_IGNORE_RETURN_VALUE (kernel_registries.SearchKernelRegistry (it, logger, &kci));
318366 if (arg_input_index != -1 ) {
@@ -325,9 +373,11 @@ void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg,
325373 }
326374}
327375
328- void TransformerMemcpyImpl::AddCopyNode (onnxruntime::NodeArg* arg, bool is_input, const logging::Logger& logger) {
376+ void TransformerMemcpyImpl::AddCopyNode (onnxruntime::NodeArg* arg,
377+ bool is_input,
378+ const logging::Logger& logger) {
329379 // create unique name for new def
330- std::string new_def_name = graph_.GenerateNodeArgName (arg->Name () + " _" + provider_);
380+ std::string new_def_name = graph_.GenerateNodeArgName (arg->Name () + " _" + provider_. Type () );
331381
332382 auto * new_arg = &graph_.GetOrCreateNodeArg (new_def_name, arg->TypeAsProto ());
333383 auto * src_arg = is_input ? arg : new_arg;
@@ -338,12 +388,14 @@ void TransformerMemcpyImpl::AddCopyNode(onnxruntime::NodeArg* arg, bool is_input
338388
339389 const auto op_name = is_input ? " MemcpyFromHost" : " MemcpyToHost" ;
340390 LOGS (logger, INFO) << " Add " << op_name << (is_input ? " after " : " before " ) << arg->Name ()
341- << " for " << provider_;
391+ << " for " << provider_. Type () ;
342392
343393 auto & new_node = graph_.AddNode (new_node_name, op_name, " Copy from/to host memory" ,
344394 std::vector<onnxruntime::NodeArg*>{src_arg},
345395 std::vector<onnxruntime::NodeArg*>{dst_arg});
346- new_node.SetExecutionProviderType (provider_);
396+
397+ new_node.SetExecutionProviderType (provider_.Type ());
398+
347399 std::map<const onnxruntime::NodeArg*, onnxruntime::NodeArg*> map = {{arg, new_arg}};
348400 auto it = provider_input_nodes_.find (arg);
349401 if (it != provider_input_nodes_.end ()) {
0 commit comments