Skip to content

Commit 395e20d

Browse files
committed
add option to force reload TRT EP library
1 parent b085a7d commit 395e20d

File tree

5 files changed

+16
-3
lines changed

5 files changed

+16
-3
lines changed

onnxruntime/core/providers/tensorrt/tensorrt_provider_factory_creator.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,6 @@ struct TensorrtProviderFactoryCreator {
1616
static std::shared_ptr<IExecutionProviderFactory> Create(int device_id);
1717
static std::shared_ptr<IExecutionProviderFactory> Create(const OrtTensorRTProviderOptions* provider_options);
1818
static std::shared_ptr<IExecutionProviderFactory> Create(const OrtTensorRTProviderOptionsV2* provider_options);
19+
static void UnloadLibrary();
1920
};
2021
} // namespace onnxruntime

onnxruntime/core/session/provider_bridge_ort.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2214,6 +2214,10 @@ std::shared_ptr<IExecutionProviderFactory> TensorrtProviderFactoryCreator::Creat
22142214
return nullptr;
22152215
}
22162216

2217+
void TensorrtProviderFactoryCreator::UnloadLibrary() {
2218+
s_library_tensorrt.Unload();
2219+
}
2220+
22172221
std::shared_ptr<IExecutionProviderFactory> NvProviderFactoryCreator::Create(int device_id) try {
22182222
return s_library_nv.Get().CreateExecutionProviderFactory(device_id);
22192223
} catch (const std::exception& exception) {

onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1425,7 +1425,11 @@ TEST(TensorrtExecutionProviderTest, PartiallySupportedModel_MemcpyOpsOnCPU_Infer
14251425
OrtTensorRTProviderOptionsV2 params;
14261426
params.trt_use_cpu_ep_memcpy_kernels = true;
14271427
params.trt_op_types_to_exclude = "Mul";
1428-
std::unique_ptr<IExecutionProvider> execution_provider = TensorrtExecutionProviderWithOptions(&params);
1428+
1429+
// Previous unit tests alread loaded TRT EP library that registered the Memcpy kernels to TRT EP, we need
1430+
// to reload the TRT EP library after setting the trt_use_cpu_ep_memcpy_kernels option to make sure the
1431+
// Memcpy kernels are registered to CPU EP instead of TRT EP.
1432+
std::unique_ptr<IExecutionProvider> execution_provider = TensorrtExecutionProviderWithOptions(&params, /*force_reload_library*/ true);
14291433
EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK());
14301434

14311435
auto status = session_object.Load(ORT_TSTR("testdata/add_mul_add.onnx"));

onnxruntime/test/util/default_providers.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,12 @@ std::unique_ptr<IExecutionProvider> TensorrtExecutionProviderWithOptions(const O
7070
return nullptr;
7171
}
7272

73-
std::unique_ptr<IExecutionProvider> TensorrtExecutionProviderWithOptions(const OrtTensorRTProviderOptionsV2* params) {
73+
std::unique_ptr<IExecutionProvider> TensorrtExecutionProviderWithOptions(const OrtTensorRTProviderOptionsV2* params,
74+
bool force_reload_library) {
7475
#ifdef USE_TENSORRT
76+
if (force_reload_library) {
77+
TensorrtProviderFactoryCreator::UnloadLibrary();
78+
}
7579
if (auto factory = TensorrtProviderFactoryCreator::Create(params))
7680
return factory->CreateProvider();
7781
#else

onnxruntime/test/util/include/default_providers.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ std::unique_ptr<IExecutionProvider> DnnlExecutionProviderWithOptions(const OrtDn
4646
std::unique_ptr<IExecutionProvider> DefaultTensorrtExecutionProvider();
4747
std::unique_ptr<IExecutionProvider> DefaultNvTensorRTRTXExecutionProvider();
4848
std::unique_ptr<IExecutionProvider> TensorrtExecutionProviderWithOptions(const OrtTensorRTProviderOptions* params);
49-
std::unique_ptr<IExecutionProvider> TensorrtExecutionProviderWithOptions(const OrtTensorRTProviderOptionsV2* params);
49+
std::unique_ptr<IExecutionProvider> TensorrtExecutionProviderWithOptions(const OrtTensorRTProviderOptionsV2* params, bool force_reload_library = false);
5050
std::unique_ptr<IExecutionProvider> DefaultMIGraphXExecutionProvider();
5151
std::unique_ptr<IExecutionProvider> MIGraphXExecutionProviderWithOptions(const OrtMIGraphXProviderOptions* params);
5252
std::unique_ptr<IExecutionProvider> OpenVINOExecutionProviderWithOptions(const ProviderOptions* params, const SessionOptions* session_options = nullptr);

0 commit comments

Comments
 (0)