Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions tensorflow/lite/micro/kernels/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,12 @@ limitations under the License.

namespace tflite {

extern const int kMaxNumberOfAxis;
extern const int kMaxNumberOfReducedAxis;

struct OpDataReduce {
int32_t multiplier;
int shift;
int temp_buffer_idx;
int resolved_axis_idx;
int scratch_accumulator_idx;
int scratch_resolved_axis_idx;
int scratch_input_iter_idx;
int input_zp;
float input_scale;
int output_zp;
Expand Down
80 changes: 42 additions & 38 deletions tensorflow/lite/micro/kernels/reduce_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,6 @@ limitations under the License.

namespace tflite {

const int kMaxNumberOfAxis = 5;
const int kMaxNumberOfReducedAxis = 2;

namespace {

TfLiteStatus PrepareSimple(TfLiteContext* context, TfLiteNode* node,
Expand Down Expand Up @@ -80,7 +77,7 @@ void ResolveAxis(const int* axis_data, int axis_count,

template <typename T>
TfLiteStatus QuantizedMeanOrSum(TfLiteContext* context, TfLiteNode* node,
int* temp_index, int* resolved_axis,
int* input_iter, int* resolved_axis,
int32_t* temp_sum, OpDataReduce* op_data,
bool compute_sum) {
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
Expand All @@ -96,7 +93,7 @@ TfLiteStatus QuantizedMeanOrSum(TfLiteContext* context, TfLiteNode* node,
op_data->multiplier, op_data->shift, op_data->output_zp,
&output->dims->data[0], output->dims->size,
tflite::micro::GetTensorData<int>(axis), op_data->num_axis,
params->keep_dims, temp_index, resolved_axis, temp_sum, compute_sum);
params->keep_dims, input_iter, resolved_axis, temp_sum, compute_sum);
TF_LITE_ENSURE(context, result);

return kTfLiteOk;
Expand All @@ -105,11 +102,11 @@ TfLiteStatus QuantizedMeanOrSum(TfLiteContext* context, TfLiteNode* node,
template <typename integer_type>
TfLiteStatus EvalIntegerMean(TfLiteContext* context, TfLiteNode* node,
int num_axis, OpDataReduce* op_data,
int* temp_index, int* resolved_axis) {
int* input_iter, int* resolved_axis) {
int32_t* temp_sum = static_cast<int32_t*>(
context->GetScratchBuffer(context, op_data->temp_buffer_idx));
context->GetScratchBuffer(context, op_data->scratch_accumulator_idx));

QuantizedMeanOrSum<integer_type>(context, node, temp_index, resolved_axis,
QuantizedMeanOrSum<integer_type>(context, node, input_iter, resolved_axis,
temp_sum, op_data, /*compute_sum=*/false);

return kTfLiteOk;
Expand Down Expand Up @@ -155,10 +152,10 @@ TfLiteStatus EvalMinMaxHelper(TfLiteContext* context, TfLiteNode* node,

// Interpret an axis tensor with null dimensions as a scalar
int num_axis = static_cast<int>(ElementCount(*axis->dims));
int* temp_buffer = static_cast<int*>(
context->GetScratchBuffer(context, op_data->temp_buffer_idx));
int* input_iter = static_cast<int*>(
context->GetScratchBuffer(context, op_data->scratch_input_iter_idx));
int* resolved_axis = static_cast<int*>(
context->GetScratchBuffer(context, op_data->resolved_axis_idx));
context->GetScratchBuffer(context, op_data->scratch_resolved_axis_idx));
switch (input->type) {
case kTfLiteFloat32: {
MinMaxReducerCompare<float> reducer(evalType);
Expand All @@ -169,7 +166,7 @@ TfLiteStatus EvalMinMaxHelper(TfLiteContext* context, TfLiteNode* node,
input->dims->size, tflite::micro::GetTensorData<float>(output),
output->dims->data, output->dims->size,
tflite::micro::GetTensorData<int>(axis), num_axis,
params->keep_dims, temp_buffer, resolved_axis,
params->keep_dims, input_iter, resolved_axis,
reducer.initialValue(), reducer.compare()));
} break;
case kTfLiteInt8: {
Expand All @@ -184,7 +181,7 @@ TfLiteStatus EvalMinMaxHelper(TfLiteContext* context, TfLiteNode* node,
input->dims->size, tflite::micro::GetTensorData<int8_t>(output),
output->dims->data, output->dims->size,
tflite::micro::GetTensorData<int>(axis), num_axis,
params->keep_dims, temp_buffer, resolved_axis,
params->keep_dims, input_iter, resolved_axis,
reducer.initialValue(), reducer.compare()));
} break;
default:
Expand All @@ -211,12 +208,11 @@ TfLiteStatus PrepareMinMaxHelper(TfLiteContext* context, TfLiteNode* node,
op_data->output_zp = output->params.zero_point;
op_data->output_scale = output->params.scale;
op_data->num_output_elements = NumElements(output);

context->RequestScratchBufferInArena(context, sizeof(int) * input->dims->size,
&op_data->temp_buffer_idx);
&op_data->scratch_input_iter_idx);
context->RequestScratchBufferInArena(
context, sizeof(int) * static_cast<int>(ElementCount(*axis->dims)),
&op_data->resolved_axis_idx);
&op_data->scratch_resolved_axis_idx);

micro_context->DeallocateTempTfLiteTensor(input);
micro_context->DeallocateTempTfLiteTensor(output);
Expand All @@ -236,17 +232,22 @@ TfLiteStatus PrepareMeanOrSumHelper(TfLiteContext* context, TfLiteNode* node,
QuantizeMultiplier(real_multiplier, &op_data->multiplier, &op_data->shift);
}

int output_size = NumElements(output);
op_data->num_axis = NumElements(axis);
op_data->num_output_elements = NumElements(output);

if (input->type == kTfLiteInt8 || input->type == kTfLiteInt16) {
context->RequestScratchBufferInArena(context, output_size * sizeof(int32_t),
&op_data->temp_buffer_idx);
context->RequestScratchBufferInArena(
context, sizeof(int32_t) * op_data->num_output_elements,
&op_data->scratch_accumulator_idx);
op_data->input_zp = input->params.zero_point;
op_data->input_scale = input->params.scale;
op_data->output_zp = output->params.zero_point;
op_data->output_scale = output->params.scale;
}
context->RequestScratchBufferInArena(context, sizeof(int) * input->dims->size,
&op_data->scratch_input_iter_idx);
context->RequestScratchBufferInArena(context, sizeof(int) * op_data->num_axis,
&op_data->scratch_resolved_axis_idx);

TF_LITE_ENSURE_OK(
context,
Expand Down Expand Up @@ -274,12 +275,11 @@ TfLiteStatus PrepareAllHelper(TfLiteContext* context, TfLiteNode* node,
op_data->output_zp = output->params.zero_point;
op_data->output_scale = output->params.scale;
op_data->num_output_elements = NumElements(output);

context->RequestScratchBufferInArena(context, sizeof(int) * input->dims->size,
&op_data->temp_buffer_idx);
&op_data->scratch_input_iter_idx);
context->RequestScratchBufferInArena(
context, sizeof(int) * static_cast<int>(ElementCount(*axis->dims)),
&op_data->resolved_axis_idx);
&op_data->scratch_resolved_axis_idx);

micro_context->DeallocateTempTfLiteTensor(input);
micro_context->DeallocateTempTfLiteTensor(output);
Expand All @@ -296,8 +296,10 @@ TfLiteStatus EvalMeanHelper(TfLiteContext* context, TfLiteNode* node,
reinterpret_cast<TfLiteReducerParams*>(node->builtin_data);

int num_axis = static_cast<int>(ElementCount(*axis->dims));
int temp_index[kMaxNumberOfAxis];
int resolved_axis[kMaxNumberOfReducedAxis];
int* input_iter = static_cast<int*>(
context->GetScratchBuffer(context, op_data->scratch_input_iter_idx));
int* resolved_axis = static_cast<int*>(
context->GetScratchBuffer(context, op_data->scratch_resolved_axis_idx));

switch (input->type) {
case kTfLiteFloat32: {
Expand Down Expand Up @@ -326,19 +328,19 @@ TfLiteStatus EvalMeanHelper(TfLiteContext* context, TfLiteNode* node,
input->dims->size, tflite::micro::GetTensorData<float>(output),
output->dims->data, output->dims->size,
tflite::micro::GetTensorData<int>(axis), num_axis,
params->keep_dims, temp_index, resolved_axis,
params->keep_dims, input_iter, resolved_axis,
tflite::micro::GetTensorData<float>(output)));
}
} break;
case kTfLiteInt8: {
TF_LITE_ENSURE_OK(
context, EvalIntegerMean<int8_t>(context, node, num_axis, op_data,
temp_index, resolved_axis));
input_iter, resolved_axis));
} break;
case kTfLiteInt16: {
TF_LITE_ENSURE_OK(
context, EvalIntegerMean<int16_t>(context, node, num_axis, op_data,
temp_index, resolved_axis));
input_iter, resolved_axis));
} break;
default:
TF_LITE_ENSURE_MSG(context, false,
Expand Down Expand Up @@ -369,8 +371,10 @@ TfLiteStatus EvalSumHelper(TfLiteContext* context, TfLiteNode* node,

// Interpret an axis tensor with null dimensions as a scalar.
int num_axis = static_cast<int>(ElementCount(*axis->dims));
int temp_index[kMaxNumberOfAxis];
int resolved_axis[kMaxNumberOfReducedAxis];
int* input_iter = static_cast<int*>(
context->GetScratchBuffer(context, op_data->scratch_input_iter_idx));
int* resolved_axis = static_cast<int*>(
context->GetScratchBuffer(context, op_data->scratch_resolved_axis_idx));

switch (input->type) {
case kTfLiteFloat32: {
Expand All @@ -381,21 +385,21 @@ TfLiteStatus EvalSumHelper(TfLiteContext* context, TfLiteNode* node,
input->dims->size, tflite::micro::GetTensorData<float>(output),
output->dims->data, output->dims->size,
tflite::micro::GetTensorData<int>(axis), num_axis,
params->keep_dims, temp_index, resolved_axis, /*init_value=*/0.f,
params->keep_dims, input_iter, resolved_axis, /*init_value=*/0.f,
[](const float current, const float in) -> float {
return in + current;
}));
} break;
case kTfLiteInt8: {
int32_t* temp_sum = static_cast<int32_t*>(
context->GetScratchBuffer(context, op_data->temp_buffer_idx));
QuantizedMeanOrSum<int8_t>(context, node, temp_index, resolved_axis,
context->GetScratchBuffer(context, op_data->scratch_accumulator_idx));
QuantizedMeanOrSum<int8_t>(context, node, input_iter, resolved_axis,
temp_sum, op_data, /*compute_sum=*/true);
} break;
case kTfLiteInt16: {
int32_t* temp_sum = static_cast<int32_t*>(
context->GetScratchBuffer(context, op_data->temp_buffer_idx));
QuantizedMeanOrSum<int16_t>(context, node, temp_index, resolved_axis,
context->GetScratchBuffer(context, op_data->scratch_accumulator_idx));
QuantizedMeanOrSum<int16_t>(context, node, input_iter, resolved_axis,
temp_sum, op_data, /*compute_sum=*/true);
} break;
default:
Expand All @@ -416,10 +420,10 @@ TfLiteStatus EvalAllHelper(TfLiteContext* context, TfLiteNode* node,

// Interpret an axis tensor with null dimensions as a scalar
int num_axis = static_cast<int>(ElementCount(*axis->dims));
int* temp_buffer = static_cast<int*>(
context->GetScratchBuffer(context, op_data->temp_buffer_idx));
int* input_iter = static_cast<int*>(
context->GetScratchBuffer(context, op_data->scratch_input_iter_idx));
int* resolved_axis = static_cast<int*>(
context->GetScratchBuffer(context, op_data->resolved_axis_idx));
context->GetScratchBuffer(context, op_data->scratch_resolved_axis_idx));
switch (input->type) {
case kTfLiteBool:
TF_LITE_ENSURE(
Expand All @@ -429,7 +433,7 @@ TfLiteStatus EvalAllHelper(TfLiteContext* context, TfLiteNode* node,
input->dims->size, tflite::micro::GetTensorData<bool>(output),
output->dims->data, output->dims->size,
tflite::micro::GetTensorData<int>(axis), num_axis,
params->keep_dims, temp_buffer, resolved_axis, true,
params->keep_dims, input_iter, resolved_axis, true,
[](const bool current, const bool in) -> bool {
return in && current;
}));
Expand Down