Skip to content

Commit 580a62c

Browse files
committed
Handle parallel computational providers with macros to avoid errors when cuda, metal or open_ml are not compiled in onnxruntime library
1 parent 442bae5 commit 580a62c

File tree

1 file changed

+26
-5
lines changed

1 file changed

+26
-5
lines changed

src/algorithms/machinelearning/onnxpredict.cpp

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,20 +59,41 @@ void OnnxPredict::configure() {
5959
// Define environment
6060
_env = Ort::Env(ORT_LOGGING_LEVEL_WARNING, "multi_io_inference"); // {"default", "test", "multi_io_inference"}
6161

62-
// Auto-detect EPs
62+
/* Auto-detect EPs
6363
auto providers = Ort::GetAvailableProviders();
6464
if (std::find(providers.begin(), providers.end(), "CUDAExecutionProvider") != providers.end()) {
65-
OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0); // device_id = 0
65+
OrtSessionOptionsAppendExecutionProvider_CUDA(_sessionOptions, 0); // device_id = 0
6666
E_INFO("✅ Using CUDA Execution Provider");
6767
} else if (std::find(providers.begin(), providers.end(), "MetalExecutionProvider") != providers.end()) {
68-
OrtSessionOptionsAppendExecutionProvider_Metal(session_options, 0); // device_id = 0
68+
OrtSessionOptionsAppendExecutionProvider_Metal(_sessionOptions, 0); // device_id = 0
6969
E_INFO("✅ Using Metal Execution Provider");
7070
} else if (std::find(providers.begin(), providers.end(), "CoreMLExecutionProvider") != providers.end()) {
71-
OrtSessionOptionsAppendExecutionProvider_CoreML(session_options, 0); // device_id = 0
71+
OrtSessionOptionsAppendExecutionProvider_CoreML(_sessionOptions, 0); // device_id = 0
7272
E_INFO("✅ Using Core ML Execution Provider");
73-
} else {
73+
}else {
7474
// Default = CPU - CPU is always available, no need to append explicitly
75+
}*/
76+
77+
#ifdef USE_CUDA
78+
if (std::find(providers.begin(), providers.end(), "CUDAExecutionProvider") != providers.end()) {
79+
OrtSessionOptionsAppendExecutionProvider_CUDA(_sessionOptions, 0);
80+
E_INFO("✅ Using CUDA Execution Provider");
81+
}
82+
#endif
83+
84+
#ifdef USE_METAL
85+
if (std::find(providers.begin(), providers.end(), "MetalExecutionProvider") != providers.end()) {
86+
OrtSessionOptionsAppendExecutionProvider_Metal(_sessionOptions, 0);
87+
E_INFO("✅ Using Metal Execution Provider");
88+
}
89+
#endif
90+
91+
#ifdef USE_COREML
92+
if (std::find(providers.begin(), providers.end(), "CoreMLExecutionProvider") != providers.end()) {
93+
OrtSessionOptionsAppendExecutionProvider_CoreML(_sessionOptions, 0);
94+
E_INFO("✅ Using Core ML Execution Provider");
7595
}
96+
#endif
7697

7798
// Set graph optimization level - check https://onnxruntime.ai/docs/performance/model-optimizations/graph-optimizations.html
7899
_sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);

0 commit comments

Comments
 (0)