Skip to content

Commit 431b274

Browse files
committed
Merge branch 'add_reduce_op' into 'master'
Add reduce op. See merge request ai/esp-dl!201
2 parents 48f5306 + 39adbd2 commit 431b274

20 files changed

+2328
-4
lines changed

esp-dl/dl/module/include/dl_module_creator.hpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,16 @@
2727
#include "dl_module_mul.hpp"
2828
#include "dl_module_pad.hpp"
2929
#include "dl_module_prelu.hpp"
30+
#include "dl_module_reduce_l1.hpp"
31+
#include "dl_module_reduce_l2.hpp"
32+
#include "dl_module_reduce_log_sum.hpp"
33+
#include "dl_module_reduce_log_sum_exp.hpp"
34+
#include "dl_module_reduce_max.hpp"
35+
#include "dl_module_reduce_mean.hpp"
36+
#include "dl_module_reduce_min.hpp"
37+
#include "dl_module_reduce_prod.hpp"
38+
#include "dl_module_reduce_sum.hpp"
39+
#include "dl_module_reduce_sum_square.hpp"
3040
#include "dl_module_relu.hpp"
3141
#include "dl_module_requantize_linear.hpp"
3242
#include "dl_module_reshape.hpp"
@@ -154,6 +164,16 @@ class ModuleCreator {
154164
this->register_module("ReverseSequence", ReverseSequence::deserialize);
155165
this->register_module("Identity", Identity::deserialize);
156166
this->register_module("Swish", Swish::deserialize);
167+
this->register_module("ReduceL1", ReduceL1::deserialize);
168+
this->register_module("ReduceL2", ReduceL2::deserialize);
169+
this->register_module("ReduceMin", ReduceMin::deserialize);
170+
this->register_module("ReduceMax", ReduceMax::deserialize);
171+
this->register_module("ReduceSum", ReduceSum::deserialize);
172+
this->register_module("ReduceProd", ReduceProd::deserialize);
173+
this->register_module("ReduceMean", ReduceMean::deserialize);
174+
this->register_module("ReduceSumSquare", ReduceSumSquare::deserialize);
175+
this->register_module("ReduceLogSum", ReduceLogSum::deserialize);
176+
this->register_module("ReduceLogSumExp", ReduceLogSumExp::deserialize);
157177
}
158178
}
159179

Lines changed: 313 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,313 @@
1+
#pragma once
2+
3+
#include "dl_module_base.hpp"
4+
#include <cmath>
5+
6+
namespace dl {
7+
namespace module {
8+
9+
class ReduceBase : public Module {
10+
protected:
11+
int m_keepdims; /*!< Keep the reduced dimension or not. */
12+
std::vector<bool> m_axes_reduce_flag; /*!< A bool list with the same dims as input0, indicating whether to perform
13+
reduction on the axes. */
14+
std::string m_op_type; /*!< Reduce operation type. */
15+
16+
public:
17+
/**
18+
* @brief Construct a new ReduceBase object.
19+
*
20+
* @param axes a list of integers, along which to reduce.
21+
* @param name name of module.
22+
* @param inplace inplace type.
23+
* @param quant_type quant type.
24+
*/
25+
ReduceBase(int keepdims,
26+
std::vector<bool> axes_reduce_flag,
27+
std::string op_type,
28+
const char *name = NULL,
29+
module_inplace_t inplace = MODULE_NON_INPLACE,
30+
quant_type_t quant_type = QUANT_TYPE_NONE) :
31+
Module(name, inplace, quant_type),
32+
m_keepdims(keepdims),
33+
m_axes_reduce_flag(axes_reduce_flag),
34+
m_op_type(op_type)
35+
{
36+
}
37+
38+
/**
39+
* @brief Destroy the ReduceBase object.
40+
*/
41+
~ReduceBase() {}
42+
43+
std::vector<std::vector<int>> get_output_shape(std::vector<std::vector<int>> &input_shapes)
44+
{
45+
std::vector<int> input_shape = input_shapes[0];
46+
std::vector<int> output_shape;
47+
if (m_keepdims) {
48+
for (int i = 0; i < input_shape.size(); i++) {
49+
if (m_axes_reduce_flag[i]) {
50+
output_shape.push_back(1);
51+
} else {
52+
output_shape.push_back(input_shape[i]);
53+
}
54+
}
55+
} else {
56+
uint32_t reduce_dims_count = 0;
57+
for (int i = 0; i < input_shape.size(); i++) {
58+
if (m_axes_reduce_flag[i]) {
59+
reduce_dims_count++;
60+
continue;
61+
} else {
62+
output_shape.push_back(input_shape[i]);
63+
}
64+
}
65+
if (reduce_dims_count == input_shape.size()) {
66+
output_shape = {1};
67+
}
68+
}
69+
return {output_shape};
70+
}
71+
72+
template <typename V_T, typename T>
73+
struct reduce_op_add {
74+
V_T operator()(const V_T &x, const T &y, void *arg) const { return x + y; }
75+
};
76+
77+
template <typename V_T, typename T>
78+
struct reduce_op_square_add {
79+
V_T operator()(const V_T &x, const T &y, void *arg) const { return x + y * y; }
80+
};
81+
82+
template <typename Op, typename V_T, typename T>
83+
static V_T reduce(V_T v0, const T *ptr, int size0, int stride0, int size1, int stride1, void *arg)
84+
{
85+
Op op;
86+
V_T sum = v0;
87+
88+
for (int i = 0; i < size1; i++) {
89+
const T *ptr0 = ptr;
90+
for (int j = 0; j < size0; j++) {
91+
sum = op(sum, *ptr0, arg);
92+
ptr0 += stride0;
93+
}
94+
ptr += stride1;
95+
}
96+
97+
return sum;
98+
}
99+
100+
template <typename Op, typename V_T, typename T>
101+
static V_T reduce(
102+
int input_exponent, V_T v0, const T *ptr, int size0, int stride0, int size1, int stride1, void *arg)
103+
{
104+
Op op;
105+
V_T sum = v0;
106+
float input_scale = DL_SCALE(input_exponent);
107+
for (int i = 0; i < size1; i++) {
108+
const T *ptr0 = ptr;
109+
for (int j = 0; j < size0; j++) {
110+
float tmp = (*ptr0) * input_scale;
111+
sum = op(sum, tmp, arg);
112+
ptr0 += stride0;
113+
}
114+
ptr += stride1;
115+
}
116+
117+
return sum;
118+
}
119+
120+
template <typename V_T, typename T, typename ReduceFn>
121+
void forward_template(ModelContext *context, runtime_mode_t mode, V_T v0, ReduceFn &&reduce_fn, void *arg)
122+
{
123+
TensorBase *input = context->get_tensor(m_inputs_index[0]);
124+
TensorBase *output = context->get_tensor(m_outputs_index[0]);
125+
int merged_dims = input->get_shape().size();
126+
int i_exp = input->get_exponent();
127+
int o_exp = output->get_exponent();
128+
std::vector<int> new_input_shape = input->get_shape(); // NCHW
129+
std::vector<bool> new_reduce_flag = m_axes_reduce_flag;
130+
T *input_ptr = input->get_element_ptr<T>();
131+
T *output_ptr = output->get_element_ptr<T>();
132+
int stride0 = 1;
133+
int size1 = 1;
134+
int stride1 = 0;
135+
136+
// Merge input shape and reduce flags.
137+
if (new_reduce_flag.size() > 1) {
138+
for (int i = 0; i < new_reduce_flag.size() - 1; ++i) {
139+
if (new_reduce_flag[i] == new_reduce_flag[i + 1]) {
140+
new_input_shape[i] *= new_input_shape[i + 1];
141+
new_input_shape.erase(new_input_shape.begin() + i + 1);
142+
new_reduce_flag.erase(new_reduce_flag.begin() + i + 1);
143+
// Since an element was removed, we need to step back one position and continue checking.
144+
--i;
145+
}
146+
}
147+
merged_dims = new_input_shape.size();
148+
}
149+
assert(new_input_shape.size() == new_reduce_flag.size());
150+
151+
if (merged_dims == 1) {
152+
output_ptr[0] =
153+
reduce_fn(m_op_type, i_exp, o_exp, v0, input_ptr, input->get_size(), stride0, size1, stride1, arg);
154+
} else if (merged_dims == 2) {
155+
if (!new_reduce_flag[0] && new_reduce_flag[1]) {
156+
T *input_ptr_tmp = input_ptr;
157+
for (int i = 0; i < new_input_shape[0]; i++) {
158+
output_ptr[i] = reduce_fn(
159+
m_op_type, i_exp, o_exp, v0, input_ptr_tmp, new_input_shape[1], stride0, size1, stride1, arg);
160+
input_ptr_tmp += new_input_shape[1];
161+
}
162+
} else if (new_reduce_flag[0] && !new_reduce_flag[1]) {
163+
for (int i = 0; i < new_input_shape[1]; i++) {
164+
output_ptr[i] = reduce_fn(m_op_type,
165+
i_exp,
166+
o_exp,
167+
v0,
168+
input_ptr + i,
169+
new_input_shape[0],
170+
new_input_shape[1],
171+
size1,
172+
stride1,
173+
arg);
174+
}
175+
}
176+
} else if (merged_dims == 3) {
177+
if (new_reduce_flag[0] && !new_reduce_flag[1] && new_reduce_flag[2]) {
178+
T *input_ptr_tmp = input_ptr;
179+
int stride = new_input_shape[1] * new_input_shape[2];
180+
for (int i = 0; i < new_input_shape[1]; i++) {
181+
output_ptr[i] = reduce_fn(m_op_type,
182+
i_exp,
183+
o_exp,
184+
v0,
185+
input_ptr_tmp,
186+
new_input_shape[2],
187+
1,
188+
new_input_shape[0],
189+
stride,
190+
arg);
191+
input_ptr_tmp += new_input_shape[2];
192+
}
193+
} else if (!new_reduce_flag[0] && new_reduce_flag[1] && !new_reduce_flag[2]) {
194+
int offset = new_input_shape[1] * new_input_shape[2];
195+
T *input_ptr_tmp = input_ptr;
196+
T *output_ptr_tmp = output_ptr;
197+
for (int i = 0; i < new_input_shape[0]; i++) {
198+
for (int j = 0; j < new_input_shape[2]; j++) {
199+
output_ptr_tmp[j] = reduce_fn(m_op_type,
200+
i_exp,
201+
o_exp,
202+
v0,
203+
input_ptr_tmp + j,
204+
new_input_shape[1],
205+
new_input_shape[2],
206+
size1,
207+
stride1,
208+
arg);
209+
}
210+
input_ptr_tmp += offset;
211+
output_ptr_tmp += new_input_shape[2];
212+
}
213+
}
214+
} else if (merged_dims == 4) {
215+
if (!new_reduce_flag[0] && new_reduce_flag[1] && !new_reduce_flag[2] && new_reduce_flag[3]) {
216+
int offset0 = new_input_shape[1] * new_input_shape[2] * new_input_shape[3];
217+
int offset1 = new_input_shape[3];
218+
int stride = new_input_shape[2] * new_input_shape[3];
219+
T *input_ptr_tmp0 = input_ptr;
220+
T *output_ptr_tmp = output_ptr;
221+
for (int i = 0; i < new_input_shape[0]; i++) {
222+
T *input_ptr_tmp1 = input_ptr_tmp0;
223+
for (int j = 0; j < new_input_shape[2]; j++) {
224+
output_ptr_tmp[j] = reduce_fn(m_op_type,
225+
i_exp,
226+
o_exp,
227+
v0,
228+
input_ptr_tmp1,
229+
new_input_shape[3],
230+
1,
231+
new_input_shape[1],
232+
stride,
233+
arg);
234+
input_ptr_tmp1 += offset1;
235+
}
236+
input_ptr_tmp0 += offset0;
237+
output_ptr_tmp += new_input_shape[2];
238+
}
239+
} else if (new_reduce_flag[0] && !new_reduce_flag[1] && new_reduce_flag[2] && !new_reduce_flag[3]) {
240+
int offset = new_input_shape[2] * new_input_shape[3];
241+
int stride = new_input_shape[1] * new_input_shape[2] * new_input_shape[3];
242+
T *input_ptr_tmp0 = input_ptr;
243+
T *output_ptr_tmp = output_ptr;
244+
for (int i = 0; i < new_input_shape[1]; i++) {
245+
T *input_ptr_tmp1 = input_ptr_tmp0;
246+
for (int j = 0; j < new_input_shape[3]; j++) {
247+
output_ptr_tmp[j] = reduce_fn(m_op_type,
248+
i_exp,
249+
o_exp,
250+
v0,
251+
input_ptr_tmp1 + j,
252+
new_input_shape[2],
253+
new_input_shape[3],
254+
new_input_shape[0],
255+
stride,
256+
arg);
257+
}
258+
input_ptr_tmp0 += offset;
259+
output_ptr_tmp += new_input_shape[3];
260+
}
261+
}
262+
}
263+
}
264+
265+
static void get_attributes(fbs::FbsModel *fbs_model,
266+
std::string node_name,
267+
int &keepdims,
268+
std::vector<bool> &axes_reduce_flag,
269+
quant_type_t &quant_type)
270+
{
271+
int noop_with_empty_axes = 0;
272+
std::vector<int> input0_shape;
273+
274+
TensorBase *axes = fbs_model->get_operation_parameter(node_name, 1);
275+
fbs_model->get_operation_attribute(node_name, "quant_type", quant_type);
276+
fbs_model->get_operation_attribute(node_name, "keepdims", keepdims);
277+
fbs_model->get_operation_attribute(node_name, "noop_with_empty_axes", noop_with_empty_axes);
278+
fbs_model->get_operation_input_shape(node_name, 0, input0_shape);
279+
280+
std::vector<bool> axes_reduce_flag_tmp(input0_shape.size(), false);
281+
if (axes && axes->get_size() > 0) {
282+
for (int i = 0; i < axes->get_size(); i++) {
283+
int axis = static_cast<int>(axes->get_element<int64_t>(i));
284+
if (axis < 0) {
285+
axis += input0_shape.size();
286+
}
287+
axes_reduce_flag_tmp[axis] = true;
288+
}
289+
} else {
290+
if (!noop_with_empty_axes) {
291+
for (int i = 0; i < axes_reduce_flag_tmp.size(); i++) {
292+
axes_reduce_flag_tmp[i] = true;
293+
}
294+
}
295+
}
296+
delete axes;
297+
axes_reduce_flag = axes_reduce_flag_tmp;
298+
}
299+
300+
void print(std::string tag)
301+
{
302+
ESP_LOGI(tag.c_str(),
303+
"quant_type: %s, op_type: %s, keepdims: %d, axes_reduce_flag: %s.",
304+
quant_type_to_string(quant_type),
305+
m_op_type.c_str(),
306+
m_keepdims,
307+
vector_to_string(m_axes_reduce_flag).c_str());
308+
}
309+
310+
virtual void print() { print("ReduceBase"); }
311+
};
312+
} // namespace module
313+
} // namespace dl

0 commit comments

Comments
 (0)