-
Notifications
You must be signed in to change notification settings - Fork 4.4k
/
Copy pathONNXRuntime.cc
191 lines (164 loc) · 7.02 KB
/
ONNXRuntime.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
/*
* ONNXRuntime.cc
*
* Created on: Jun 28, 2019
* Author: hqu
*/
#include "PhysicsTools/ONNXRuntime/interface/ONNXRuntime.h"
#include "FWCore/Utilities/interface/Exception.h"
#include "FWCore/Utilities/interface/thread_safety_macros.h"
#include <algorithm>
#include <cassert>
#include <functional>
#include <iostream>
#include <memory>
#include <numeric>
namespace cms::Ort {
using namespace ::Ort;
#ifdef ONNXDebug
const Env ONNXRuntime::env_(ORT_LOGGING_LEVEL_INFO, "");
#else
const Env ONNXRuntime::env_(ORT_LOGGING_LEVEL_ERROR, "");
#endif
ONNXRuntime::ONNXRuntime(const std::string& model_path, const SessionOptions* session_options) {
// create session
if (session_options) {
session_ = std::make_unique<Session>(env_, model_path.c_str(), *session_options);
} else {
session_ = std::make_unique<Session>(env_, model_path.c_str(), defaultSessionOptions());
}
AllocatorWithDefaultOptions allocator;
// get input names and shapes
size_t num_input_nodes = session_->GetInputCount();
input_node_strings_.resize(num_input_nodes);
input_node_names_.resize(num_input_nodes);
input_node_dims_.clear();
for (size_t i = 0; i < num_input_nodes; i++) {
// get input node names
std::string input_name(session_->GetInputName(i, allocator));
input_node_strings_[i] = input_name;
input_node_names_[i] = input_node_strings_[i].c_str();
// get input shapes
auto type_info = session_->GetInputTypeInfo(i);
auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
size_t num_dims = tensor_info.GetDimensionsCount();
input_node_dims_[input_name].resize(num_dims);
tensor_info.GetDimensions(input_node_dims_[input_name].data(), num_dims);
}
size_t num_output_nodes = session_->GetOutputCount();
output_node_strings_.resize(num_output_nodes);
output_node_names_.resize(num_output_nodes);
output_node_dims_.clear();
for (size_t i = 0; i < num_output_nodes; i++) {
// get output node names
std::string output_name(session_->GetOutputName(i, allocator));
output_node_strings_[i] = output_name;
output_node_names_[i] = output_node_strings_[i].c_str();
// get output node types
auto type_info = session_->GetOutputTypeInfo(i);
auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
size_t num_dims = tensor_info.GetDimensionsCount();
output_node_dims_[output_name].resize(num_dims);
tensor_info.GetDimensions(output_node_dims_[output_name].data(), num_dims);
// the 0th dim depends on the batch size
output_node_dims_[output_name].at(0) = -1;
}
}
ONNXRuntime::~ONNXRuntime() {}
SessionOptions ONNXRuntime::defaultSessionOptions(Backend backend) {
SessionOptions sess_opts;
sess_opts.SetIntraOpNumThreads(1);
if (backend == Backend::cuda) {
OrtCUDAProviderOptions options;
sess_opts.AppendExecutionProvider_CUDA(options);
}
#ifdef ONNX_PROFILE
sess_opts.EnableProfiling("ONNXProf");
#endif
return sess_opts;
}
FloatArrays ONNXRuntime::run(const std::vector<std::string>& input_names,
FloatArrays& input_values,
const std::vector<std::vector<int64_t>>& input_shapes,
const std::vector<std::string>& output_names,
int64_t batch_size) const {
assert(input_names.size() == input_values.size());
assert(input_shapes.empty() || input_names.size() == input_shapes.size());
assert(batch_size > 0);
// create input tensor objects from data values
std::vector<Value> input_tensors;
auto memory_info = MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
for (const auto& name : input_node_strings_) {
auto iter = std::find(input_names.begin(), input_names.end(), name);
if (iter == input_names.end()) {
throw cms::Exception("RuntimeError") << "Input " << name << " is not provided!";
}
auto input_pos = iter - input_names.begin();
auto value = input_values.begin() + input_pos;
std::vector<int64_t> input_dims;
if (input_shapes.empty()) {
input_dims = input_node_dims_.at(name);
input_dims[0] = batch_size;
} else {
input_dims = input_shapes[input_pos];
// rely on the given input_shapes to set the batch size
if (input_dims[0] != batch_size) {
throw cms::Exception("RuntimeError") << "The first element of `input_shapes` (" << input_dims[0]
<< ") does not match the given `batch_size` (" << batch_size << ")";
}
}
auto expected_len = std::accumulate(input_dims.begin(), input_dims.end(), 1, std::multiplies<int64_t>());
if (expected_len != (int64_t)value->size()) {
throw cms::Exception("RuntimeError")
<< "Input array " << name << " has a wrong size of " << value->size() << ", expected " << expected_len;
}
auto input_tensor =
Value::CreateTensor<float>(memory_info, value->data(), value->size(), input_dims.data(), input_dims.size());
assert(input_tensor.IsTensor());
input_tensors.emplace_back(std::move(input_tensor));
}
// set output node names; will get all outputs if `output_names` is not provided
std::vector<const char*> run_output_node_names;
if (output_names.empty()) {
run_output_node_names = output_node_names_;
} else {
for (const auto& name : output_names) {
run_output_node_names.push_back(name.c_str());
}
}
// run
auto output_tensors = session_->Run(RunOptions{nullptr},
input_node_names_.data(),
input_tensors.data(),
input_tensors.size(),
run_output_node_names.data(),
run_output_node_names.size());
// convert output to floats
FloatArrays outputs;
for (auto& output_tensor : output_tensors) {
assert(output_tensor.IsTensor());
// get output shape
auto tensor_info = output_tensor.GetTensorTypeAndShapeInfo();
auto length = tensor_info.GetElementCount();
auto floatarr = output_tensor.GetTensorMutableData<float>();
outputs.emplace_back(floatarr, floatarr + length);
}
assert(outputs.size() == run_output_node_names.size());
return outputs;
}
const std::vector<std::string>& ONNXRuntime::getOutputNames() const {
if (session_) {
return output_node_strings_;
} else {
throw cms::Exception("RuntimeError") << "Needs to call createSession() first before getting the output names!";
}
}
const std::vector<int64_t>& ONNXRuntime::getOutputShape(const std::string& output_name) const {
auto iter = output_node_dims_.find(output_name);
if (iter == output_node_dims_.end()) {
throw cms::Exception("RuntimeError") << "Output name " << output_name << " is invalid!";
} else {
return iter->second;
}
}
} /* namespace cms::Ort */