Skip to content

Commit 1171b51

Browse files
committed
Merge branch 'fully_connected_channelwise_optimisations' into 'master'
Use ESP-NN optimisations for FullyConnected Per-Channel operation See merge request app-frameworks/esp-tflite-micro!177
2 parents f4393a9 + 448528d commit 1171b51

File tree

2 files changed

+91
-42
lines changed

2 files changed

+91
-42
lines changed

idf_component.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ dependencies:
88
idf:
99
version: ">=4.4"
1010
espressif/esp-nn:
11-
version: "^1.1.0"
11+
version: ">=1.1.1"
1212
files:
1313
exclude:
1414
- scripts

tensorflow/lite/micro/kernels/esp_nn/fully_connected.cc

Lines changed: 90 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -153,50 +153,99 @@ TfLiteStatus FullyConnectedEval(TfLiteContext* context, TfLiteNode* node) {
153153
break;
154154
}
155155
case kTfLiteInt8: {
156+
if (data.is_per_channel) {
156157
#if ESP_NN
157-
const RuntimeShape& filter_shape = tflite::micro::GetTensorShape(filter);
158-
const RuntimeShape& output_shape = tflite::micro::GetTensorShape(output);
159-
160-
TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
161-
TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
162-
const int filter_dim_count = filter_shape.DimensionsCount();
163-
const int output_dim_count = output_shape.DimensionsCount();
164-
const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
165-
const int output_depth = output_shape.Dims(output_dim_count - 1);
166-
TFLITE_DCHECK_LE(output_depth, filter_shape.Dims(filter_dim_count - 2));
167-
const int accum_depth = filter_shape.Dims(filter_dim_count - 1);
168-
169-
const int32_t* bias_data =
170-
tflite::micro::GetOptionalTensorData<int32_t>(bias);
171-
172-
const int8_t *input_data = tflite::micro::GetTensorData<int8_t>(input);
173-
int8_t *output_data = tflite::micro::GetTensorData<int8_t>(output);
174-
const int8_t *filter_data = tflite::micro::GetTensorData<int8_t>(filter);
175-
176-
for (int b = 0; b < batches; ++b) {
177-
esp_nn_fully_connected_s8(input_data, -data.input_zero_point,
178-
accum_depth,
179-
filter_data, -data.filter_zero_point,
180-
bias_data, output_data, output_depth,
181-
data.output_zero_point,
182-
data.output_shift, data.output_multiplier,
183-
data.output_activation_min,
184-
data.output_activation_max);
185-
input_data += accum_depth;
186-
output_data += output_depth;
187-
}
158+
const RuntimeShape& filter_shape = tflite::micro::GetTensorShape(filter);
159+
const RuntimeShape& output_shape = tflite::micro::GetTensorShape(output);
160+
161+
TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
162+
TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
163+
const int filter_dim_count = filter_shape.DimensionsCount();
164+
const int output_dim_count = output_shape.DimensionsCount();
165+
const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
166+
const int output_depth = output_shape.Dims(output_dim_count - 1);
167+
TFLITE_DCHECK_LE(output_depth, filter_shape.Dims(filter_dim_count - 2));
168+
const int accum_depth = filter_shape.Dims(filter_dim_count - 1);
169+
170+
const int32_t* bias_data =
171+
tflite::micro::GetOptionalTensorData<int32_t>(bias);
172+
173+
const int8_t *input_data = tflite::micro::GetTensorData<int8_t>(input);
174+
int8_t *output_data = tflite::micro::GetTensorData<int8_t>(output);
175+
const int8_t *filter_data = tflite::micro::GetTensorData<int8_t>(filter);
176+
177+
for (int b = 0; b < batches; ++b) {
178+
esp_nn_fully_connected_per_ch_s8(input_data, -data.input_zero_point,
179+
accum_depth,
180+
filter_data, -data.filter_zero_point,
181+
bias_data, output_data, output_depth,
182+
data.output_zero_point,
183+
data.per_channel_output_shift, data.per_channel_output_multiplier,
184+
data.output_activation_min,
185+
data.output_activation_max);
186+
input_data += accum_depth;
187+
output_data += output_depth;
188+
}
188189
#else
189-
tflite::reference_integer_ops::FullyConnected(
190-
FullyConnectedParamsQuantized(data),
191-
tflite::micro::GetTensorShape(input),
192-
tflite::micro::GetTensorData<int8_t>(input),
193-
tflite::micro::GetTensorShape(filter),
194-
tflite::micro::GetTensorData<int8_t>(filter),
195-
tflite::micro::GetTensorShape(bias),
196-
tflite::micro::GetOptionalTensorData<int32_t>(bias),
197-
tflite::micro::GetTensorShape(output),
198-
tflite::micro::GetTensorData<int8_t>(output));
190+
tflite::reference_integer_ops::FullyConnectedPerChannel(
191+
FullyConnectedParamsQuantized(data),
192+
data.per_channel_output_multiplier,
193+
reinterpret_cast<const int*>(data.per_channel_output_shift),
194+
tflite::micro::GetTensorShape(input),
195+
tflite::micro::GetTensorData<int8_t>(input),
196+
tflite::micro::GetTensorShape(filter),
197+
tflite::micro::GetTensorData<int8_t>(filter),
198+
tflite::micro::GetTensorShape(bias),
199+
tflite::micro::GetOptionalTensorData<int32_t>(bias),
200+
tflite::micro::GetTensorShape(output),
201+
tflite::micro::GetTensorData<int8_t>(output));
199202
#endif
203+
} else {
204+
#if ESP_NN
205+
const RuntimeShape& filter_shape = tflite::micro::GetTensorShape(filter);
206+
const RuntimeShape& output_shape = tflite::micro::GetTensorShape(output);
207+
208+
TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
209+
TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
210+
const int filter_dim_count = filter_shape.DimensionsCount();
211+
const int output_dim_count = output_shape.DimensionsCount();
212+
const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
213+
const int output_depth = output_shape.Dims(output_dim_count - 1);
214+
TFLITE_DCHECK_LE(output_depth, filter_shape.Dims(filter_dim_count - 2));
215+
const int accum_depth = filter_shape.Dims(filter_dim_count - 1);
216+
217+
const int32_t* bias_data =
218+
tflite::micro::GetOptionalTensorData<int32_t>(bias);
219+
220+
const int8_t *input_data = tflite::micro::GetTensorData<int8_t>(input);
221+
int8_t *output_data = tflite::micro::GetTensorData<int8_t>(output);
222+
const int8_t *filter_data = tflite::micro::GetTensorData<int8_t>(filter);
223+
224+
for (int b = 0; b < batches; ++b) {
225+
esp_nn_fully_connected_s8(input_data, -data.input_zero_point,
226+
accum_depth,
227+
filter_data, -data.filter_zero_point,
228+
bias_data, output_data, output_depth,
229+
data.output_zero_point,
230+
data.output_shift, data.output_multiplier,
231+
data.output_activation_min,
232+
data.output_activation_max);
233+
input_data += accum_depth;
234+
output_data += output_depth;
235+
}
236+
#else
237+
tflite::reference_integer_ops::FullyConnected(
238+
FullyConnectedParamsQuantized(data),
239+
tflite::micro::GetTensorShape(input),
240+
tflite::micro::GetTensorData<int8_t>(input),
241+
tflite::micro::GetTensorShape(filter),
242+
tflite::micro::GetTensorData<int8_t>(filter),
243+
tflite::micro::GetTensorShape(bias),
244+
tflite::micro::GetOptionalTensorData<int32_t>(bias),
245+
tflite::micro::GetTensorShape(output),
246+
tflite::micro::GetTensorData<int8_t>(output));
247+
#endif
248+
}
200249
break;
201250
}
202251
default: {

0 commit comments

Comments
 (0)