diff --git a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h index 0aad80c4ddab9..62ea895798204 100644 --- a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h +++ b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h @@ -97,4 +97,5 @@ struct OrtTensorRTProviderOptionsV2 { int trt_engine_hw_compatible{0}; // Enable hardware compatibility. Default 0 = false, nonzero = true const char* trt_op_types_to_exclude{}; // Exclude specific ops from running on TRT. int trt_load_user_initializer{0}; // Save initializers locally instead of to disk. Default 0 = false, nonzero = true + bool trt_use_cpu_ep_memcpy_kernels{false}; // Use MemcpyToHost and MemcpyFromHost kernel implementations from CPU EP. It's mainly used for test purpose. }; diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index 1c80d83f99feb..d6fcd6f613f93 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -912,18 +912,38 @@ class PlannerImpl { OrtValueIndex index = Index(node_output->Name()); ProcessDef(index, node_output); OrtDevice output_device = exec_provider->GetOrtDeviceByMemType(p_kernel_def->OutputMemoryType(i)); + #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + // Two cases where we want to override the output location suggested by the kernel def: + // 1. // Downstream nodes of certain providers may require a CPU accessible location override // to make sure the EP does not incur an unnecessary copy. // We only do it for CPU based EPs. We are not likely to encounter // non CPU devices here since they are already taken care of by using MemCpy nodes earlier. // However, we still ignore them. + // + // 2. + // MemcpyFromHost node provided by the CPU EP requires special handling. + // As per MemcpyFromHost kernel registration uses default memory type for output which means + // it uses CPU memory for output as it's run on CPU, but it actually may produce output on + // the device specific to its consumer node's EP. + // So we need to check the consumer node's EP and set the output device accordingly. + if (output_device.Type() == OrtDevice::CPU) { const auto& output_name = node_output->Name(); const auto consumers = graph_viewer_.GetConsumerNodes(output_name); for (const auto* consumer : consumers) { if (consumer != nullptr) { const auto& ep_type = consumer->GetExecutionProviderType(); + + if ((pnode->OpType() == "MemcpyFromHost") && + pnode->GetExecutionProviderType() == kCpuExecutionProvider) { + // Check the consumer node's EP and set the output device accordingly + output_device = execution_providers_.Get(ep_type) + ->GetOrtDeviceByMemType(p_kernel_def->OutputMemoryType(i)); + break; + } + auto suggested_device = execution_providers_.Get(ep_type) ->GetOrtDeviceByMemType(OrtMemType::OrtMemTypeCPUInput); if (suggested_device.Type() == OrtDevice::CPU) { @@ -2097,13 +2117,39 @@ class PlannerImpl { OrtValueIndex output_arg_idx; ORT_THROW_IF_ERROR(ort_value_name_idx_map_.GetIdx(output->Name(), output_arg_idx)); // there are two cases we need notification: - // 1. the consumer is not in the same stream + // 1. the consumer is not in the same stream. + // There are typically two types of wait functions defined in the Notification + // class for plugin EPs or other provider-bridge EPs (e.g., CUDA EP and TRT EP): + // (1) WaitNotificationOnDevice + // (2) WaitNotificationOnHost + // + // Note: MemcpyToHost nodes provided by the CPU EP require special handling. + // If a MemcpyToHost node (running on the host) consumes a tensor produced by + // a device node, MemcpyToHost must use WaitNotificationOnHost, because the + // CPU device does not have a stream, which is required by WaitNotificationOnDevice. + // // 2. the consumer is in the same stream(non-cpu device), but it consumes a CPU tensor from an non-shape op. // for example, a resize cuda kernel consumes a tensor from MemCpyToHost cuda kernel on the same stream. // in this case, the FIFO can't guarantee the cpu tensor is ready when resize kernel is launching const auto& output_arg_device = AllocPlan(output_arg_idx).location; - WaitNotificationFn wait_handle = stream_handle_registry.GetWaitHandle(stream_device, - output_arg_device); + + auto get_wait_handle = [&](const Node& node, + const OrtDevice& src_device, + const OrtDevice& dst_device) -> WaitNotificationFn { + if (node.OpType() == "MemcpyToHost" && + node.GetExecutionProviderType() == kCpuExecutionProvider) { + // The returned wait handle must be host-based rather than device-based, + // because WaitNotificationOnDevice requires a stream. However, + // the MemcpyToHost node provided by the CPU EP performs a blocking + // data transfer and does not use a stream. + return stream_handle_registry.GetWaitHandle(src_device, OrtDevice()); + } + + return stream_handle_registry.GetWaitHandle(src_device, dst_device); + }; + + WaitNotificationFn wait_handle = get_wait_handle(*it, stream_device, output_arg_device); + if ((plan_.node_stream_map_[it->Index()] != i || output_arg_device.UsesCpuMemory()) && wait_handle != nullptr) { if (node_to_notification.find(node_index) == node_to_notification.end()) { diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc index 71ea66b0be89f..1f3fccbddf679 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc @@ -68,6 +68,8 @@ struct Tensorrt_Provider : Provider { info.device_id = device_id; info.has_trt_options = false; + InitializeRegistry(); + return std::make_shared(info); } @@ -125,6 +127,11 @@ struct Tensorrt_Provider : Provider { info.preview_features = options.trt_preview_features == nullptr ? "" : options.trt_preview_features; info.load_user_initializer = options.trt_load_user_initializer != 0; + use_cpu_ep_memcpy_kernels_ = options.trt_use_cpu_ep_memcpy_kernels; + if (!use_cpu_ep_memcpy_kernels_) { + InitializeRegistry(); + } + return std::make_shared(info); } @@ -138,13 +145,17 @@ struct Tensorrt_Provider : Provider { } void Initialize() override { - InitializeRegistry(); } void Shutdown() override { - DeleteRegistry(); + if (!use_cpu_ep_memcpy_kernels_) { + DeleteRegistry(); + } } + private: + bool use_cpu_ep_memcpy_kernels_{false}; + } g_provider; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory_creator.h b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory_creator.h index d905003fb7cc1..fd2dcf4071f3e 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory_creator.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory_creator.h @@ -16,5 +16,6 @@ struct TensorrtProviderFactoryCreator { static std::shared_ptr Create(int device_id); static std::shared_ptr Create(const OrtTensorRTProviderOptions* provider_options); static std::shared_ptr Create(const OrtTensorRTProviderOptionsV2* provider_options); + static void UnloadLibrary(); }; } // namespace onnxruntime diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 6949ed0059add..87bc3da4a4fdb 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -2214,6 +2214,10 @@ std::shared_ptr TensorrtProviderFactoryCreator::Creat return nullptr; } +void TensorrtProviderFactoryCreator::UnloadLibrary() { + s_library_tensorrt.Unload(); +} + std::shared_ptr NvProviderFactoryCreator::Create(int device_id) try { return s_library_nv.Get().CreateExecutionProviderFactory(device_id); } catch (const std::exception& exception) { diff --git a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc index dce0d570ec238..8ca7fe30daf1d 100644 --- a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc +++ b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc @@ -1407,5 +1407,89 @@ TEST(TensorrtExecutionProviderTest, TestSessionOutputs) { ASSERT_TRUE(output_count == 1); } } + +// #ifdef USE_CPU_MEMCPY_KERNELS_FOR_TENSORRT +TEST(TensorrtExecutionProviderTest, PartiallySupportedModel_MemcpyOpsOnCPU_Inference) { + // The model has Add -> Mul -> Add. + // TensorRT EP intentionally excludes the support for Mul so that the Mul node will be executed on CPU EP. + // Because that trt_use_cpu_ep_memcpy_kernels option is set, MemcpyToHost/MemcpyFromHost CPU implementations + // will be automaically inserted by ORT and assigned to CPU EP. + + // Use InferenceSession directly instead of Ort::Session to access the graph + SessionOptions so; + so.session_logid = "TensorrtExecutionProviderTest.PartiallySupportedModel_MemcpyOpsOnCPU_Inference"; + RunOptions run_options; + run_options.run_tag = so.session_logid; + InferenceSession session_object{so, GetEnvironment()}; + + OrtTensorRTProviderOptionsV2 params; + params.trt_use_cpu_ep_memcpy_kernels = true; + params.trt_op_types_to_exclude = "Mul"; + + // Previous unit tests alread loaded TRT EP library that registered the Memcpy kernels to TRT EP, we need + // to reload the TRT EP library after setting the trt_use_cpu_ep_memcpy_kernels option to make sure the + // Memcpy kernels are registered to CPU EP instead of TRT EP. + std::unique_ptr execution_provider = TensorrtExecutionProviderWithOptions(¶ms, /*force_reload_library*/ true); + EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); + + auto status = session_object.Load(ORT_TSTR("testdata/add_mul_add.onnx")); + ASSERT_TRUE(status.IsOK()); + status = session_object.Initialize(); + ASSERT_TRUE(status.IsOK()); + + // Verify that MemcpyFromHost and MemcpyToHost nodes exist and are on CPU EP + const auto& graph_after_init = session_object.GetModel().MainGraph(); + bool found_memcpy_from_host = false; + bool found_memcpy_to_host = false; + + for (const auto& node : graph_after_init.Nodes()) { + if (node.OpType() == "MemcpyFromHost") { + found_memcpy_from_host = true; + ASSERT_EQ(node.GetExecutionProviderType(), kCpuExecutionProvider) + << "MemcpyFromHost should be assigned to CPU EP"; + } + if (node.OpType() == "MemcpyToHost") { + found_memcpy_to_host = true; + ASSERT_EQ(node.GetExecutionProviderType(), kCpuExecutionProvider) + << "MemcpyToHost should be assigned to CPU EP"; + } + } + + ASSERT_TRUE(found_memcpy_from_host) << "MemcpyFromHost node should be inserted"; + ASSERT_TRUE(found_memcpy_to_host) << "MemcpyToHost node should be inserted"; + + // Create inputs + auto cuda_provider = DefaultCudaExecutionProvider(); + auto cpu_allocator = cuda_provider->CreatePreferredAllocators()[1]; + std::vector shape = {3, 2}; + + std::vector a_data{1, 2, 3, 4, 5, 6}; + std::vector b_data{2, 3, 4, 5, 6, 7}; + + OrtValue ml_value_a; + CreateMLValue(cpu_allocator, shape, a_data, &ml_value_a); + OrtValue ml_value_b; + CreateMLValue(cpu_allocator, shape, b_data, &ml_value_b); + + NameMLValMap feeds; + feeds.insert(std::make_pair("A", ml_value_a)); + feeds.insert(std::make_pair("B", ml_value_b)); + + // prepare outputs + std::vector output_names; + output_names.push_back("C"); + std::vector fetches; + + // Run session and verify outputs + status = session_object.Run(run_options, feeds, output_names, &fetches); + ASSERT_TRUE(status.IsOK()); + + // Check expected output values + std::vector expected_dims = {3, 2}; + std::vector expected_values = {7, 17, 31, 49, 71, 97}; + VerifyOutputs(fetches, expected_dims, expected_values); +} +// #endif + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index 4bc300fc7263a..41d4d14986b6e 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -70,12 +70,17 @@ std::unique_ptr TensorrtExecutionProviderWithOptions(const O return nullptr; } -std::unique_ptr TensorrtExecutionProviderWithOptions(const OrtTensorRTProviderOptionsV2* params) { +std::unique_ptr TensorrtExecutionProviderWithOptions(const OrtTensorRTProviderOptionsV2* params, + bool force_reload_library) { #ifdef USE_TENSORRT + if (force_reload_library) { + TensorrtProviderFactoryCreator::UnloadLibrary(); + } if (auto factory = TensorrtProviderFactoryCreator::Create(params)) return factory->CreateProvider(); #else ORT_UNUSED_PARAMETER(params); + ORT_UNUSED_PARAMETER(force_reload_library); #endif return nullptr; } diff --git a/onnxruntime/test/util/include/default_providers.h b/onnxruntime/test/util/include/default_providers.h index fb7b168f5e158..4dfd836254102 100644 --- a/onnxruntime/test/util/include/default_providers.h +++ b/onnxruntime/test/util/include/default_providers.h @@ -46,7 +46,7 @@ std::unique_ptr DnnlExecutionProviderWithOptions(const OrtDn std::unique_ptr DefaultTensorrtExecutionProvider(); std::unique_ptr DefaultNvTensorRTRTXExecutionProvider(); std::unique_ptr TensorrtExecutionProviderWithOptions(const OrtTensorRTProviderOptions* params); -std::unique_ptr TensorrtExecutionProviderWithOptions(const OrtTensorRTProviderOptionsV2* params); +std::unique_ptr TensorrtExecutionProviderWithOptions(const OrtTensorRTProviderOptionsV2* params, bool force_reload_library = false); std::unique_ptr DefaultMIGraphXExecutionProvider(); std::unique_ptr MIGraphXExecutionProviderWithOptions(const OrtMIGraphXProviderOptions* params); std::unique_ptr OpenVINOExecutionProviderWithOptions(const ProviderOptions* params, const SessionOptions* session_options = nullptr);