|
| 1 | +#pragma once |
| 2 | + |
| 3 | +#include "dl_module_base.hpp" |
| 4 | +#include <cmath> |
| 5 | + |
| 6 | +namespace dl { |
| 7 | +namespace module { |
| 8 | + |
| 9 | +class ReduceBase : public Module { |
| 10 | +protected: |
| 11 | + int m_keepdims; /*!< Keep the reduced dimension or not. */ |
| 12 | + std::vector<bool> m_axes_reduce_flag; /*!< A bool list with the same dims as input0, indicating whether to perform |
| 13 | + reduction on the axes. */ |
| 14 | + std::string m_op_type; /*!< Reduce operation type. */ |
| 15 | + |
| 16 | +public: |
| 17 | + /** |
| 18 | + * @brief Construct a new ReduceBase object. |
| 19 | + * |
| 20 | + * @param axes a list of integers, along which to reduce. |
| 21 | + * @param name name of module. |
| 22 | + * @param inplace inplace type. |
| 23 | + * @param quant_type quant type. |
| 24 | + */ |
| 25 | + ReduceBase(int keepdims, |
| 26 | + std::vector<bool> axes_reduce_flag, |
| 27 | + std::string op_type, |
| 28 | + const char *name = NULL, |
| 29 | + module_inplace_t inplace = MODULE_NON_INPLACE, |
| 30 | + quant_type_t quant_type = QUANT_TYPE_NONE) : |
| 31 | + Module(name, inplace, quant_type), |
| 32 | + m_keepdims(keepdims), |
| 33 | + m_axes_reduce_flag(axes_reduce_flag), |
| 34 | + m_op_type(op_type) |
| 35 | + { |
| 36 | + } |
| 37 | + |
| 38 | + /** |
| 39 | + * @brief Destroy the ReduceBase object. |
| 40 | + */ |
| 41 | + ~ReduceBase() {} |
| 42 | + |
| 43 | + std::vector<std::vector<int>> get_output_shape(std::vector<std::vector<int>> &input_shapes) |
| 44 | + { |
| 45 | + std::vector<int> input_shape = input_shapes[0]; |
| 46 | + std::vector<int> output_shape; |
| 47 | + if (m_keepdims) { |
| 48 | + for (int i = 0; i < input_shape.size(); i++) { |
| 49 | + if (m_axes_reduce_flag[i]) { |
| 50 | + output_shape.push_back(1); |
| 51 | + } else { |
| 52 | + output_shape.push_back(input_shape[i]); |
| 53 | + } |
| 54 | + } |
| 55 | + } else { |
| 56 | + uint32_t reduce_dims_count = 0; |
| 57 | + for (int i = 0; i < input_shape.size(); i++) { |
| 58 | + if (m_axes_reduce_flag[i]) { |
| 59 | + reduce_dims_count++; |
| 60 | + continue; |
| 61 | + } else { |
| 62 | + output_shape.push_back(input_shape[i]); |
| 63 | + } |
| 64 | + } |
| 65 | + if (reduce_dims_count == input_shape.size()) { |
| 66 | + output_shape = {1}; |
| 67 | + } |
| 68 | + } |
| 69 | + return {output_shape}; |
| 70 | + } |
| 71 | + |
| 72 | + template <typename V_T, typename T> |
| 73 | + struct reduce_op_add { |
| 74 | + V_T operator()(const V_T &x, const T &y, void *arg) const { return x + y; } |
| 75 | + }; |
| 76 | + |
| 77 | + template <typename V_T, typename T> |
| 78 | + struct reduce_op_square_add { |
| 79 | + V_T operator()(const V_T &x, const T &y, void *arg) const { return x + y * y; } |
| 80 | + }; |
| 81 | + |
| 82 | + template <typename Op, typename V_T, typename T> |
| 83 | + static V_T reduce(V_T v0, const T *ptr, int size0, int stride0, int size1, int stride1, void *arg) |
| 84 | + { |
| 85 | + Op op; |
| 86 | + V_T sum = v0; |
| 87 | + |
| 88 | + for (int i = 0; i < size1; i++) { |
| 89 | + const T *ptr0 = ptr; |
| 90 | + for (int j = 0; j < size0; j++) { |
| 91 | + sum = op(sum, *ptr0, arg); |
| 92 | + ptr0 += stride0; |
| 93 | + } |
| 94 | + ptr += stride1; |
| 95 | + } |
| 96 | + |
| 97 | + return sum; |
| 98 | + } |
| 99 | + |
| 100 | + template <typename Op, typename V_T, typename T> |
| 101 | + static V_T reduce( |
| 102 | + int input_exponent, V_T v0, const T *ptr, int size0, int stride0, int size1, int stride1, void *arg) |
| 103 | + { |
| 104 | + Op op; |
| 105 | + V_T sum = v0; |
| 106 | + float input_scale = DL_SCALE(input_exponent); |
| 107 | + for (int i = 0; i < size1; i++) { |
| 108 | + const T *ptr0 = ptr; |
| 109 | + for (int j = 0; j < size0; j++) { |
| 110 | + float tmp = (*ptr0) * input_scale; |
| 111 | + sum = op(sum, tmp, arg); |
| 112 | + ptr0 += stride0; |
| 113 | + } |
| 114 | + ptr += stride1; |
| 115 | + } |
| 116 | + |
| 117 | + return sum; |
| 118 | + } |
| 119 | + |
| 120 | + template <typename V_T, typename T, typename ReduceFn> |
| 121 | + void forward_template(ModelContext *context, runtime_mode_t mode, V_T v0, ReduceFn &&reduce_fn, void *arg) |
| 122 | + { |
| 123 | + TensorBase *input = context->get_tensor(m_inputs_index[0]); |
| 124 | + TensorBase *output = context->get_tensor(m_outputs_index[0]); |
| 125 | + int merged_dims = input->get_shape().size(); |
| 126 | + int i_exp = input->get_exponent(); |
| 127 | + int o_exp = output->get_exponent(); |
| 128 | + std::vector<int> new_input_shape = input->get_shape(); // NCHW |
| 129 | + std::vector<bool> new_reduce_flag = m_axes_reduce_flag; |
| 130 | + T *input_ptr = input->get_element_ptr<T>(); |
| 131 | + T *output_ptr = output->get_element_ptr<T>(); |
| 132 | + int stride0 = 1; |
| 133 | + int size1 = 1; |
| 134 | + int stride1 = 0; |
| 135 | + |
| 136 | + // Merge input shape and reduce flags. |
| 137 | + if (new_reduce_flag.size() > 1) { |
| 138 | + for (int i = 0; i < new_reduce_flag.size() - 1; ++i) { |
| 139 | + if (new_reduce_flag[i] == new_reduce_flag[i + 1]) { |
| 140 | + new_input_shape[i] *= new_input_shape[i + 1]; |
| 141 | + new_input_shape.erase(new_input_shape.begin() + i + 1); |
| 142 | + new_reduce_flag.erase(new_reduce_flag.begin() + i + 1); |
| 143 | + // Since an element was removed, we need to step back one position and continue checking. |
| 144 | + --i; |
| 145 | + } |
| 146 | + } |
| 147 | + merged_dims = new_input_shape.size(); |
| 148 | + } |
| 149 | + assert(new_input_shape.size() == new_reduce_flag.size()); |
| 150 | + |
| 151 | + if (merged_dims == 1) { |
| 152 | + output_ptr[0] = |
| 153 | + reduce_fn(m_op_type, i_exp, o_exp, v0, input_ptr, input->get_size(), stride0, size1, stride1, arg); |
| 154 | + } else if (merged_dims == 2) { |
| 155 | + if (!new_reduce_flag[0] && new_reduce_flag[1]) { |
| 156 | + T *input_ptr_tmp = input_ptr; |
| 157 | + for (int i = 0; i < new_input_shape[0]; i++) { |
| 158 | + output_ptr[i] = reduce_fn( |
| 159 | + m_op_type, i_exp, o_exp, v0, input_ptr_tmp, new_input_shape[1], stride0, size1, stride1, arg); |
| 160 | + input_ptr_tmp += new_input_shape[1]; |
| 161 | + } |
| 162 | + } else if (new_reduce_flag[0] && !new_reduce_flag[1]) { |
| 163 | + for (int i = 0; i < new_input_shape[1]; i++) { |
| 164 | + output_ptr[i] = reduce_fn(m_op_type, |
| 165 | + i_exp, |
| 166 | + o_exp, |
| 167 | + v0, |
| 168 | + input_ptr + i, |
| 169 | + new_input_shape[0], |
| 170 | + new_input_shape[1], |
| 171 | + size1, |
| 172 | + stride1, |
| 173 | + arg); |
| 174 | + } |
| 175 | + } |
| 176 | + } else if (merged_dims == 3) { |
| 177 | + if (new_reduce_flag[0] && !new_reduce_flag[1] && new_reduce_flag[2]) { |
| 178 | + T *input_ptr_tmp = input_ptr; |
| 179 | + int stride = new_input_shape[1] * new_input_shape[2]; |
| 180 | + for (int i = 0; i < new_input_shape[1]; i++) { |
| 181 | + output_ptr[i] = reduce_fn(m_op_type, |
| 182 | + i_exp, |
| 183 | + o_exp, |
| 184 | + v0, |
| 185 | + input_ptr_tmp, |
| 186 | + new_input_shape[2], |
| 187 | + 1, |
| 188 | + new_input_shape[0], |
| 189 | + stride, |
| 190 | + arg); |
| 191 | + input_ptr_tmp += new_input_shape[2]; |
| 192 | + } |
| 193 | + } else if (!new_reduce_flag[0] && new_reduce_flag[1] && !new_reduce_flag[2]) { |
| 194 | + int offset = new_input_shape[1] * new_input_shape[2]; |
| 195 | + T *input_ptr_tmp = input_ptr; |
| 196 | + T *output_ptr_tmp = output_ptr; |
| 197 | + for (int i = 0; i < new_input_shape[0]; i++) { |
| 198 | + for (int j = 0; j < new_input_shape[2]; j++) { |
| 199 | + output_ptr_tmp[j] = reduce_fn(m_op_type, |
| 200 | + i_exp, |
| 201 | + o_exp, |
| 202 | + v0, |
| 203 | + input_ptr_tmp + j, |
| 204 | + new_input_shape[1], |
| 205 | + new_input_shape[2], |
| 206 | + size1, |
| 207 | + stride1, |
| 208 | + arg); |
| 209 | + } |
| 210 | + input_ptr_tmp += offset; |
| 211 | + output_ptr_tmp += new_input_shape[2]; |
| 212 | + } |
| 213 | + } |
| 214 | + } else if (merged_dims == 4) { |
| 215 | + if (!new_reduce_flag[0] && new_reduce_flag[1] && !new_reduce_flag[2] && new_reduce_flag[3]) { |
| 216 | + int offset0 = new_input_shape[1] * new_input_shape[2] * new_input_shape[3]; |
| 217 | + int offset1 = new_input_shape[3]; |
| 218 | + int stride = new_input_shape[2] * new_input_shape[3]; |
| 219 | + T *input_ptr_tmp0 = input_ptr; |
| 220 | + T *output_ptr_tmp = output_ptr; |
| 221 | + for (int i = 0; i < new_input_shape[0]; i++) { |
| 222 | + T *input_ptr_tmp1 = input_ptr_tmp0; |
| 223 | + for (int j = 0; j < new_input_shape[2]; j++) { |
| 224 | + output_ptr_tmp[j] = reduce_fn(m_op_type, |
| 225 | + i_exp, |
| 226 | + o_exp, |
| 227 | + v0, |
| 228 | + input_ptr_tmp1, |
| 229 | + new_input_shape[3], |
| 230 | + 1, |
| 231 | + new_input_shape[1], |
| 232 | + stride, |
| 233 | + arg); |
| 234 | + input_ptr_tmp1 += offset1; |
| 235 | + } |
| 236 | + input_ptr_tmp0 += offset0; |
| 237 | + output_ptr_tmp += new_input_shape[2]; |
| 238 | + } |
| 239 | + } else if (new_reduce_flag[0] && !new_reduce_flag[1] && new_reduce_flag[2] && !new_reduce_flag[3]) { |
| 240 | + int offset = new_input_shape[2] * new_input_shape[3]; |
| 241 | + int stride = new_input_shape[1] * new_input_shape[2] * new_input_shape[3]; |
| 242 | + T *input_ptr_tmp0 = input_ptr; |
| 243 | + T *output_ptr_tmp = output_ptr; |
| 244 | + for (int i = 0; i < new_input_shape[1]; i++) { |
| 245 | + T *input_ptr_tmp1 = input_ptr_tmp0; |
| 246 | + for (int j = 0; j < new_input_shape[3]; j++) { |
| 247 | + output_ptr_tmp[j] = reduce_fn(m_op_type, |
| 248 | + i_exp, |
| 249 | + o_exp, |
| 250 | + v0, |
| 251 | + input_ptr_tmp1 + j, |
| 252 | + new_input_shape[2], |
| 253 | + new_input_shape[3], |
| 254 | + new_input_shape[0], |
| 255 | + stride, |
| 256 | + arg); |
| 257 | + } |
| 258 | + input_ptr_tmp0 += offset; |
| 259 | + output_ptr_tmp += new_input_shape[3]; |
| 260 | + } |
| 261 | + } |
| 262 | + } |
| 263 | + } |
| 264 | + |
| 265 | + static void get_attributes(fbs::FbsModel *fbs_model, |
| 266 | + std::string node_name, |
| 267 | + int &keepdims, |
| 268 | + std::vector<bool> &axes_reduce_flag, |
| 269 | + quant_type_t &quant_type) |
| 270 | + { |
| 271 | + int noop_with_empty_axes = 0; |
| 272 | + std::vector<int> input0_shape; |
| 273 | + |
| 274 | + TensorBase *axes = fbs_model->get_operation_parameter(node_name, 1); |
| 275 | + fbs_model->get_operation_attribute(node_name, "quant_type", quant_type); |
| 276 | + fbs_model->get_operation_attribute(node_name, "keepdims", keepdims); |
| 277 | + fbs_model->get_operation_attribute(node_name, "noop_with_empty_axes", noop_with_empty_axes); |
| 278 | + fbs_model->get_operation_input_shape(node_name, 0, input0_shape); |
| 279 | + |
| 280 | + std::vector<bool> axes_reduce_flag_tmp(input0_shape.size(), false); |
| 281 | + if (axes && axes->get_size() > 0) { |
| 282 | + for (int i = 0; i < axes->get_size(); i++) { |
| 283 | + int axis = static_cast<int>(axes->get_element<int64_t>(i)); |
| 284 | + if (axis < 0) { |
| 285 | + axis += input0_shape.size(); |
| 286 | + } |
| 287 | + axes_reduce_flag_tmp[axis] = true; |
| 288 | + } |
| 289 | + } else { |
| 290 | + if (!noop_with_empty_axes) { |
| 291 | + for (int i = 0; i < axes_reduce_flag_tmp.size(); i++) { |
| 292 | + axes_reduce_flag_tmp[i] = true; |
| 293 | + } |
| 294 | + } |
| 295 | + } |
| 296 | + delete axes; |
| 297 | + axes_reduce_flag = axes_reduce_flag_tmp; |
| 298 | + } |
| 299 | + |
| 300 | + void print(std::string tag) |
| 301 | + { |
| 302 | + ESP_LOGI(tag.c_str(), |
| 303 | + "quant_type: %s, op_type: %s, keepdims: %d, axes_reduce_flag: %s.", |
| 304 | + quant_type_to_string(quant_type), |
| 305 | + m_op_type.c_str(), |
| 306 | + m_keepdims, |
| 307 | + vector_to_string(m_axes_reduce_flag).c_str()); |
| 308 | + } |
| 309 | + |
| 310 | + virtual void print() { print("ReduceBase"); } |
| 311 | +}; |
| 312 | +} // namespace module |
| 313 | +} // namespace dl |
0 commit comments