Skip to content

Commit 382dd17

Browse files
committed
Support for resetting Ort::Session and Ort::SessionOption
1 parent 580a62c commit 382dd17

File tree

2 files changed

+22
-27
lines changed

2 files changed

+22
-27
lines changed

src/algorithms/machinelearning/onnxpredict.cpp

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -58,22 +58,14 @@ void OnnxPredict::configure() {
5858
try{
5959
// Define environment
6060
_env = Ort::Env(ORT_LOGGING_LEVEL_WARNING, "multi_io_inference"); // {"default", "test", "multi_io_inference"}
61-
62-
/* Auto-detect EPs
63-
auto providers = Ort::GetAvailableProviders();
64-
if (std::find(providers.begin(), providers.end(), "CUDAExecutionProvider") != providers.end()) {
65-
OrtSessionOptionsAppendExecutionProvider_CUDA(_sessionOptions, 0); // device_id = 0
66-
E_INFO("✅ Using CUDA Execution Provider");
67-
} else if (std::find(providers.begin(), providers.end(), "MetalExecutionProvider") != providers.end()) {
68-
OrtSessionOptionsAppendExecutionProvider_Metal(_sessionOptions, 0); // device_id = 0
69-
E_INFO("✅ Using Metal Execution Provider");
70-
} else if (std::find(providers.begin(), providers.end(), "CoreMLExecutionProvider") != providers.end()) {
71-
OrtSessionOptionsAppendExecutionProvider_CoreML(_sessionOptions, 0); // device_id = 0
72-
E_INFO("✅ Using Core ML Execution Provider");
73-
}else {
74-
// Default = CPU - CPU is always available, no need to append explicitly
75-
}*/
7661

62+
// Reset session
63+
_session.reset();
64+
65+
// Reset SessionOptions by constructing a fresh object
66+
_sessionOptions = Ort::SessionOptions{};
67+
68+
// Auto-detect EPs
7769
#ifdef USE_CUDA
7870
if (std::find(providers.begin(), providers.end(), "CUDAExecutionProvider") != providers.end()) {
7971
OrtSessionOptionsAppendExecutionProvider_CUDA(_sessionOptions, 0);
@@ -100,16 +92,17 @@ void OnnxPredict::configure() {
10092
_sessionOptions.SetIntraOpNumThreads(0);
10193

10294
// Initialize session
103-
_session = Ort::Session(_env, _graphFilename.c_str(), _sessionOptions);
95+
_session = std::make_unique<Ort::Session>(_env, _graphFilename.c_str(), _sessionOptions);
96+
10497
}
10598
catch (Ort::Exception oe) {
10699
throw EssentiaException(string("OnnxPredict:") + oe.what(), oe.GetOrtErrorCode());
107100
}
108101
E_INFO("OnnxPredict: Successfully loaded graph file: `" << _graphFilename << "`");
109102

110103
// get input and output info (names, type and shapes)
111-
all_input_infos = setTensorInfos(_session, _allocator, "inputs");
112-
all_output_infos = setTensorInfos(_session, _allocator, "outputs");
104+
all_input_infos = setTensorInfos(*_session, _allocator, "inputs");
105+
all_output_infos = setTensorInfos(*_session, _allocator, "outputs");
113106

114107
// read inputs and outputs as input parameter
115108
_inputs = parameter("inputs").toVectorString();
@@ -312,13 +305,13 @@ void OnnxPredict::compute() {
312305
}
313306

314307
// Run the Onnxruntime session.
315-
auto output_tensors = _session.Run(_runOptions, // Run options.
316-
input_names.data(), // Input node names.
317-
input_tensors.data(), // Input tensor values.
318-
_nInputs, // Number of inputs.
319-
output_names.data(), // Output node names.
320-
_nOutputs // Number of outputs.
321-
);
308+
auto output_tensors = _session->Run(_runOptions, // Run options.
309+
input_names.data(), // Input node names.
310+
input_tensors.data(), // Input tensor values.
311+
_nInputs, // Number of inputs.
312+
output_names.data(), // Output node names.
313+
_nOutputs // Number of outputs.
314+
);
322315

323316
// Map output tensors to pool
324317
for (size_t i = 0; i < output_tensors.size(); ++i) {

src/algorithms/machinelearning/onnxpredict.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,9 @@ class OnnxPredict : public Algorithm {
6363

6464
Ort::Env _env{nullptr};
6565
Ort::SessionOptions _sessionOptions{nullptr};
66-
Ort::Session _session{nullptr};
66+
//Ort::Session _session{nullptr};
67+
std::unique_ptr<Ort::Session> _session;
68+
6769

6870
Ort::RunOptions _runOptions;
6971
Ort::AllocatorWithDefaultOptions _allocator;
@@ -100,7 +102,7 @@ class OnnxPredict : public Algorithm {
100102
public:
101103

102104
OnnxPredict() : _env(Ort::Env(ORT_LOGGING_LEVEL_WARNING, "test")),
103-
_sessionOptions(Ort::SessionOptions()), _session(Ort::Session(nullptr)), _runOptions(NULL), _isConfigured(false){
105+
_sessionOptions(Ort::SessionOptions()), _session(nullptr), _runOptions(NULL), _isConfigured(false){
104106
declareInput(_poolIn, "poolIn", "the pool where to get the feature tensors");
105107
declareOutput(_poolOut, "poolOut", "the pool where to store the output tensors");
106108
}

0 commit comments

Comments
 (0)