@@ -153,50 +153,99 @@ TfLiteStatus FullyConnectedEval(TfLiteContext* context, TfLiteNode* node) {
153153 break ;
154154 }
155155 case kTfLiteInt8 : {
156+ if (data.is_per_channel ) {
156157#if ESP_NN
157- const RuntimeShape& filter_shape = tflite::micro::GetTensorShape (filter);
158- const RuntimeShape& output_shape = tflite::micro::GetTensorShape (output);
159-
160- TFLITE_DCHECK_GE (filter_shape.DimensionsCount (), 2 );
161- TFLITE_DCHECK_GE (output_shape.DimensionsCount (), 1 );
162- const int filter_dim_count = filter_shape.DimensionsCount ();
163- const int output_dim_count = output_shape.DimensionsCount ();
164- const int batches = FlatSizeSkipDim (output_shape, output_dim_count - 1 );
165- const int output_depth = output_shape.Dims (output_dim_count - 1 );
166- TFLITE_DCHECK_LE (output_depth, filter_shape.Dims (filter_dim_count - 2 ));
167- const int accum_depth = filter_shape.Dims (filter_dim_count - 1 );
168-
169- const int32_t * bias_data =
170- tflite::micro::GetOptionalTensorData<int32_t >(bias);
171-
172- const int8_t *input_data = tflite::micro::GetTensorData<int8_t >(input);
173- int8_t *output_data = tflite::micro::GetTensorData<int8_t >(output);
174- const int8_t *filter_data = tflite::micro::GetTensorData<int8_t >(filter);
175-
176- for (int b = 0 ; b < batches; ++b) {
177- esp_nn_fully_connected_s8 (input_data, -data.input_zero_point ,
178- accum_depth,
179- filter_data, -data.filter_zero_point ,
180- bias_data, output_data, output_depth,
181- data.output_zero_point ,
182- data.output_shift , data.output_multiplier ,
183- data.output_activation_min ,
184- data.output_activation_max );
185- input_data += accum_depth;
186- output_data += output_depth;
187- }
158+ const RuntimeShape& filter_shape = tflite::micro::GetTensorShape (filter);
159+ const RuntimeShape& output_shape = tflite::micro::GetTensorShape (output);
160+
161+ TFLITE_DCHECK_GE (filter_shape.DimensionsCount (), 2 );
162+ TFLITE_DCHECK_GE (output_shape.DimensionsCount (), 1 );
163+ const int filter_dim_count = filter_shape.DimensionsCount ();
164+ const int output_dim_count = output_shape.DimensionsCount ();
165+ const int batches = FlatSizeSkipDim (output_shape, output_dim_count - 1 );
166+ const int output_depth = output_shape.Dims (output_dim_count - 1 );
167+ TFLITE_DCHECK_LE (output_depth, filter_shape.Dims (filter_dim_count - 2 ));
168+ const int accum_depth = filter_shape.Dims (filter_dim_count - 1 );
169+
170+ const int32_t * bias_data =
171+ tflite::micro::GetOptionalTensorData<int32_t >(bias);
172+
173+ const int8_t *input_data = tflite::micro::GetTensorData<int8_t >(input);
174+ int8_t *output_data = tflite::micro::GetTensorData<int8_t >(output);
175+ const int8_t *filter_data = tflite::micro::GetTensorData<int8_t >(filter);
176+
177+ for (int b = 0 ; b < batches; ++b) {
178+ esp_nn_fully_connected_per_ch_s8 (input_data, -data.input_zero_point ,
179+ accum_depth,
180+ filter_data, -data.filter_zero_point ,
181+ bias_data, output_data, output_depth,
182+ data.output_zero_point ,
183+ data.per_channel_output_shift , data.per_channel_output_multiplier ,
184+ data.output_activation_min ,
185+ data.output_activation_max );
186+ input_data += accum_depth;
187+ output_data += output_depth;
188+ }
188189#else
189- tflite::reference_integer_ops::FullyConnected (
190- FullyConnectedParamsQuantized (data),
191- tflite::micro::GetTensorShape (input),
192- tflite::micro::GetTensorData<int8_t >(input),
193- tflite::micro::GetTensorShape (filter),
194- tflite::micro::GetTensorData<int8_t >(filter),
195- tflite::micro::GetTensorShape (bias),
196- tflite::micro::GetOptionalTensorData<int32_t >(bias),
197- tflite::micro::GetTensorShape (output),
198- tflite::micro::GetTensorData<int8_t >(output));
190+ tflite::reference_integer_ops::FullyConnectedPerChannel (
191+ FullyConnectedParamsQuantized (data),
192+ data.per_channel_output_multiplier ,
193+ reinterpret_cast <const int *>(data.per_channel_output_shift ),
194+ tflite::micro::GetTensorShape (input),
195+ tflite::micro::GetTensorData<int8_t >(input),
196+ tflite::micro::GetTensorShape (filter),
197+ tflite::micro::GetTensorData<int8_t >(filter),
198+ tflite::micro::GetTensorShape (bias),
199+ tflite::micro::GetOptionalTensorData<int32_t >(bias),
200+ tflite::micro::GetTensorShape (output),
201+ tflite::micro::GetTensorData<int8_t >(output));
199202#endif
203+ } else {
204+ #if ESP_NN
205+ const RuntimeShape& filter_shape = tflite::micro::GetTensorShape (filter);
206+ const RuntimeShape& output_shape = tflite::micro::GetTensorShape (output);
207+
208+ TFLITE_DCHECK_GE (filter_shape.DimensionsCount (), 2 );
209+ TFLITE_DCHECK_GE (output_shape.DimensionsCount (), 1 );
210+ const int filter_dim_count = filter_shape.DimensionsCount ();
211+ const int output_dim_count = output_shape.DimensionsCount ();
212+ const int batches = FlatSizeSkipDim (output_shape, output_dim_count - 1 );
213+ const int output_depth = output_shape.Dims (output_dim_count - 1 );
214+ TFLITE_DCHECK_LE (output_depth, filter_shape.Dims (filter_dim_count - 2 ));
215+ const int accum_depth = filter_shape.Dims (filter_dim_count - 1 );
216+
217+ const int32_t * bias_data =
218+ tflite::micro::GetOptionalTensorData<int32_t >(bias);
219+
220+ const int8_t *input_data = tflite::micro::GetTensorData<int8_t >(input);
221+ int8_t *output_data = tflite::micro::GetTensorData<int8_t >(output);
222+ const int8_t *filter_data = tflite::micro::GetTensorData<int8_t >(filter);
223+
224+ for (int b = 0 ; b < batches; ++b) {
225+ esp_nn_fully_connected_s8 (input_data, -data.input_zero_point ,
226+ accum_depth,
227+ filter_data, -data.filter_zero_point ,
228+ bias_data, output_data, output_depth,
229+ data.output_zero_point ,
230+ data.output_shift , data.output_multiplier ,
231+ data.output_activation_min ,
232+ data.output_activation_max );
233+ input_data += accum_depth;
234+ output_data += output_depth;
235+ }
236+ #else
237+ tflite::reference_integer_ops::FullyConnected (
238+ FullyConnectedParamsQuantized (data),
239+ tflite::micro::GetTensorShape (input),
240+ tflite::micro::GetTensorData<int8_t >(input),
241+ tflite::micro::GetTensorShape (filter),
242+ tflite::micro::GetTensorData<int8_t >(filter),
243+ tflite::micro::GetTensorShape (bias),
244+ tflite::micro::GetOptionalTensorData<int32_t >(bias),
245+ tflite::micro::GetTensorShape (output),
246+ tflite::micro::GetTensorData<int8_t >(output));
247+ #endif
248+ }
200249 break ;
201250 }
202251 default : {
0 commit comments