Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,5 @@ struct OrtTensorRTProviderOptionsV2 {
int trt_engine_hw_compatible{0}; // Enable hardware compatibility. Default 0 = false, nonzero = true
const char* trt_op_types_to_exclude{}; // Exclude specific ops from running on TRT.
int trt_load_user_initializer{0}; // Save initializers locally instead of to disk. Default 0 = false, nonzero = true
bool trt_use_cpu_ep_memcpy_kernels{false}; // Use MemcpyToHost and MemcpyFromHost kernel implementations from CPU EP. It's mainly used for test purpose.
};
52 changes: 49 additions & 3 deletions onnxruntime/core/framework/allocation_planner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -912,18 +912,38 @@ class PlannerImpl {
OrtValueIndex index = Index(node_output->Name());
ProcessDef(index, node_output);
OrtDevice output_device = exec_provider->GetOrtDeviceByMemType(p_kernel_def->OutputMemoryType(i));

#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
// Two cases where we want to override the output location suggested by the kernel def:
// 1.
// Downstream nodes of certain providers may require a CPU accessible location override
// to make sure the EP does not incur an unnecessary copy.
// We only do it for CPU based EPs. We are not likely to encounter
// non CPU devices here since they are already taken care of by using MemCpy nodes earlier.
// However, we still ignore them.
//
// 2.
// MemcpyFromHost node provided by the CPU EP requires special handling.
// As per MemcpyFromHost kernel registration uses default memory type for output which means
// it uses CPU memory for output as it's run on CPU, but it actually may produce output on
// the device specific to its consumer node's EP.
// So we need to check the consumer node's EP and set the output device accordingly.

if (output_device.Type() == OrtDevice::CPU) {
const auto& output_name = node_output->Name();
const auto consumers = graph_viewer_.GetConsumerNodes(output_name);
for (const auto* consumer : consumers) {
if (consumer != nullptr) {
const auto& ep_type = consumer->GetExecutionProviderType();

if ((pnode->OpType() == "MemcpyFromHost") &&
pnode->GetExecutionProviderType() == kCpuExecutionProvider) {
// Check the consumer node's EP and set the output device accordingly
output_device = execution_providers_.Get(ep_type)
->GetOrtDeviceByMemType(p_kernel_def->OutputMemoryType(i));
break;
}

auto suggested_device = execution_providers_.Get(ep_type)
->GetOrtDeviceByMemType(OrtMemType::OrtMemTypeCPUInput);
if (suggested_device.Type() == OrtDevice::CPU) {
Expand Down Expand Up @@ -2097,13 +2117,39 @@ class PlannerImpl {
OrtValueIndex output_arg_idx;
ORT_THROW_IF_ERROR(ort_value_name_idx_map_.GetIdx(output->Name(), output_arg_idx));
// there are two cases we need notification:
// 1. the consumer is not in the same stream
// 1. the consumer is not in the same stream.
// There are typically two types of wait functions defined in the Notification
// class for plugin EPs or other provider-bridge EPs (e.g., CUDA EP and TRT EP):
// (1) WaitNotificationOnDevice
// (2) WaitNotificationOnHost
//
// Note: MemcpyToHost nodes provided by the CPU EP require special handling.
// If a MemcpyToHost node (running on the host) consumes a tensor produced by
// a device node, MemcpyToHost must use WaitNotificationOnHost, because the
// CPU device does not have a stream, which is required by WaitNotificationOnDevice.
//
// 2. the consumer is in the same stream(non-cpu device), but it consumes a CPU tensor from an non-shape op.
// for example, a resize cuda kernel consumes a tensor from MemCpyToHost cuda kernel on the same stream.
// in this case, the FIFO can't guarantee the cpu tensor is ready when resize kernel is launching
const auto& output_arg_device = AllocPlan(output_arg_idx).location;
WaitNotificationFn wait_handle = stream_handle_registry.GetWaitHandle(stream_device,
output_arg_device);

auto get_wait_handle = [&](const Node& node,
const OrtDevice& src_device,
const OrtDevice& dst_device) -> WaitNotificationFn {
if (node.OpType() == "MemcpyToHost" &&
node.GetExecutionProviderType() == kCpuExecutionProvider) {
// The returned wait handle must be host-based rather than device-based,
// because WaitNotificationOnDevice requires a stream. However,
// the MemcpyToHost node provided by the CPU EP performs a blocking
// data transfer and does not use a stream.
return stream_handle_registry.GetWaitHandle(src_device, OrtDevice());
}

return stream_handle_registry.GetWaitHandle(src_device, dst_device);
};

WaitNotificationFn wait_handle = get_wait_handle(*it, stream_device, output_arg_device);

if ((plan_.node_stream_map_[it->Index()] != i || output_arg_device.UsesCpuMemory()) &&
wait_handle != nullptr) {
if (node_to_notification.find(node_index) == node_to_notification.end()) {
Expand Down
15 changes: 13 additions & 2 deletions onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ struct Tensorrt_Provider : Provider {
info.device_id = device_id;
info.has_trt_options = false;

InitializeRegistry();

return std::make_shared<TensorrtProviderFactory>(info);
}

Expand Down Expand Up @@ -125,6 +127,11 @@ struct Tensorrt_Provider : Provider {
info.preview_features = options.trt_preview_features == nullptr ? "" : options.trt_preview_features;
info.load_user_initializer = options.trt_load_user_initializer != 0;

use_cpu_ep_memcpy_kernels_ = options.trt_use_cpu_ep_memcpy_kernels;
if (!use_cpu_ep_memcpy_kernels_) {
InitializeRegistry();
}

return std::make_shared<TensorrtProviderFactory>(info);
}

Expand All @@ -138,13 +145,17 @@ struct Tensorrt_Provider : Provider {
}

void Initialize() override {
InitializeRegistry();
}

void Shutdown() override {
DeleteRegistry();
if (!use_cpu_ep_memcpy_kernels_) {
DeleteRegistry();
}
}

private:
bool use_cpu_ep_memcpy_kernels_{false};

} g_provider;

} // namespace onnxruntime
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@ struct TensorrtProviderFactoryCreator {
static std::shared_ptr<IExecutionProviderFactory> Create(int device_id);
static std::shared_ptr<IExecutionProviderFactory> Create(const OrtTensorRTProviderOptions* provider_options);
static std::shared_ptr<IExecutionProviderFactory> Create(const OrtTensorRTProviderOptionsV2* provider_options);
static void UnloadLibrary();
};
} // namespace onnxruntime
4 changes: 4 additions & 0 deletions onnxruntime/core/session/provider_bridge_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2214,6 +2214,10 @@ std::shared_ptr<IExecutionProviderFactory> TensorrtProviderFactoryCreator::Creat
return nullptr;
}

void TensorrtProviderFactoryCreator::UnloadLibrary() {
s_library_tensorrt.Unload();
}

std::shared_ptr<IExecutionProviderFactory> NvProviderFactoryCreator::Create(int device_id) try {
return s_library_nv.Get().CreateExecutionProviderFactory(device_id);
} catch (const std::exception& exception) {
Expand Down
84 changes: 84 additions & 0 deletions onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1407,5 +1407,89 @@ TEST(TensorrtExecutionProviderTest, TestSessionOutputs) {
ASSERT_TRUE(output_count == 1);
}
}

// #ifdef USE_CPU_MEMCPY_KERNELS_FOR_TENSORRT
TEST(TensorrtExecutionProviderTest, PartiallySupportedModel_MemcpyOpsOnCPU_Inference) {
// The model has Add -> Mul -> Add.
// TensorRT EP intentionally excludes the support for Mul so that the Mul node will be executed on CPU EP.
// Because that trt_use_cpu_ep_memcpy_kernels option is set, MemcpyToHost/MemcpyFromHost CPU implementations
// will be automaically inserted by ORT and assigned to CPU EP.

// Use InferenceSession directly instead of Ort::Session to access the graph
SessionOptions so;
so.session_logid = "TensorrtExecutionProviderTest.PartiallySupportedModel_MemcpyOpsOnCPU_Inference";
RunOptions run_options;
run_options.run_tag = so.session_logid;
InferenceSession session_object{so, GetEnvironment()};

OrtTensorRTProviderOptionsV2 params;
params.trt_use_cpu_ep_memcpy_kernels = true;
params.trt_op_types_to_exclude = "Mul";

// Previous unit tests alread loaded TRT EP library that registered the Memcpy kernels to TRT EP, we need
// to reload the TRT EP library after setting the trt_use_cpu_ep_memcpy_kernels option to make sure the
// Memcpy kernels are registered to CPU EP instead of TRT EP.
std::unique_ptr<IExecutionProvider> execution_provider = TensorrtExecutionProviderWithOptions(&params, /*force_reload_library*/ true);
EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK());

auto status = session_object.Load(ORT_TSTR("testdata/add_mul_add.onnx"));
ASSERT_TRUE(status.IsOK());
status = session_object.Initialize();
ASSERT_TRUE(status.IsOK());

// Verify that MemcpyFromHost and MemcpyToHost nodes exist and are on CPU EP
const auto& graph_after_init = session_object.GetModel().MainGraph();
bool found_memcpy_from_host = false;
bool found_memcpy_to_host = false;

for (const auto& node : graph_after_init.Nodes()) {
if (node.OpType() == "MemcpyFromHost") {
found_memcpy_from_host = true;
ASSERT_EQ(node.GetExecutionProviderType(), kCpuExecutionProvider)
<< "MemcpyFromHost should be assigned to CPU EP";
}
if (node.OpType() == "MemcpyToHost") {
found_memcpy_to_host = true;
ASSERT_EQ(node.GetExecutionProviderType(), kCpuExecutionProvider)
<< "MemcpyToHost should be assigned to CPU EP";
}
}

ASSERT_TRUE(found_memcpy_from_host) << "MemcpyFromHost node should be inserted";
ASSERT_TRUE(found_memcpy_to_host) << "MemcpyToHost node should be inserted";

// Create inputs
auto cuda_provider = DefaultCudaExecutionProvider();
auto cpu_allocator = cuda_provider->CreatePreferredAllocators()[1];
std::vector<int64_t> shape = {3, 2};

std::vector<float> a_data{1, 2, 3, 4, 5, 6};
std::vector<float> b_data{2, 3, 4, 5, 6, 7};

OrtValue ml_value_a;
CreateMLValue<float>(cpu_allocator, shape, a_data, &ml_value_a);
OrtValue ml_value_b;
CreateMLValue<float>(cpu_allocator, shape, b_data, &ml_value_b);

NameMLValMap feeds;
feeds.insert(std::make_pair("A", ml_value_a));
feeds.insert(std::make_pair("B", ml_value_b));

// prepare outputs
std::vector<std::string> output_names;
output_names.push_back("C");
std::vector<OrtValue> fetches;

// Run session and verify outputs
status = session_object.Run(run_options, feeds, output_names, &fetches);
ASSERT_TRUE(status.IsOK());

// Check expected output values
std::vector<int64_t> expected_dims = {3, 2};
std::vector<float> expected_values = {7, 17, 31, 49, 71, 97};
VerifyOutputs(fetches, expected_dims, expected_values);
}
// #endif

} // namespace test
} // namespace onnxruntime
7 changes: 6 additions & 1 deletion onnxruntime/test/util/default_providers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,17 @@ std::unique_ptr<IExecutionProvider> TensorrtExecutionProviderWithOptions(const O
return nullptr;
}

std::unique_ptr<IExecutionProvider> TensorrtExecutionProviderWithOptions(const OrtTensorRTProviderOptionsV2* params) {
std::unique_ptr<IExecutionProvider> TensorrtExecutionProviderWithOptions(const OrtTensorRTProviderOptionsV2* params,
bool force_reload_library) {
#ifdef USE_TENSORRT
if (force_reload_library) {
TensorrtProviderFactoryCreator::UnloadLibrary();
}
if (auto factory = TensorrtProviderFactoryCreator::Create(params))
return factory->CreateProvider();
#else
ORT_UNUSED_PARAMETER(params);
ORT_UNUSED_PARAMETER(force_reload_library);
#endif
return nullptr;
}
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/test/util/include/default_providers.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ std::unique_ptr<IExecutionProvider> DnnlExecutionProviderWithOptions(const OrtDn
std::unique_ptr<IExecutionProvider> DefaultTensorrtExecutionProvider();
std::unique_ptr<IExecutionProvider> DefaultNvTensorRTRTXExecutionProvider();
std::unique_ptr<IExecutionProvider> TensorrtExecutionProviderWithOptions(const OrtTensorRTProviderOptions* params);
std::unique_ptr<IExecutionProvider> TensorrtExecutionProviderWithOptions(const OrtTensorRTProviderOptionsV2* params);
std::unique_ptr<IExecutionProvider> TensorrtExecutionProviderWithOptions(const OrtTensorRTProviderOptionsV2* params, bool force_reload_library = false);
std::unique_ptr<IExecutionProvider> DefaultMIGraphXExecutionProvider();
std::unique_ptr<IExecutionProvider> MIGraphXExecutionProviderWithOptions(const OrtMIGraphXProviderOptions* params);
std::unique_ptr<IExecutionProvider> OpenVINOExecutionProviderWithOptions(const ProviderOptions* params, const SessionOptions* session_options = nullptr);
Expand Down
Loading