Skip to content

Enable HiFi SIMD for TRANSPOSE_CONV operator #3084

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 77 additions & 1 deletion tensorflow/lite/micro/kernels/xtensa/transpose_conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,30 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
#endif // #if defined(HIFI3) || defined(HIFI4) || defined(HIFI5)
}

#if HIFI_VFPU && (defined(HIFI3) || defined(HIFI4) || defined(HIFI5))
if (input->type == kTfLiteFloat32) {
TFLITE_DCHECK(context->RequestScratchBufferInArena != nullptr);
const int stride_width = params->stride_width;
const int stride_height = params->stride_height;

const int input_height = SizeOfDimension(input, 1);
const int input_width = SizeOfDimension(input, 2);
const int input_depth = SizeOfDimension(input, 3);
const int output_height = height;
const int output_width = width;
int32_t scratch_buffer_size = 0;
scratch_buffer_size = xa_nn_transpose_conv_getsize(input_height,
input_width, input_depth, filter_height,
filter_width, stride_width, stride_height,
output_height, output_width, num_channels,
PREC_F32, PREC_F32);
TFLITE_DCHECK(context->RequestScratchBufferInArena(
context,
scratch_buffer_size,
&(data->scratch_buffer_index)) == kTfLiteOk);
}
#endif // HIFI_VFPU && (defined(HIFI3) || defined(HIFI4) || defined(HIFI5))

// All per-channel quantized tensors need valid zero point and scale arrays.
if (input->type == kTfLiteInt8 || input->type == kTfLiteInt16) {
TF_LITE_ENSURE_EQ(context, filter->quantization.type,
Expand Down Expand Up @@ -334,7 +358,58 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
CalculateActivationRange(params.activation,
&op_params.float_activation_min,
&op_params.float_activation_max);

#if HIFI_VFPU && (defined(HIFI3) || defined(HIFI4) || defined(HIFI5))
std::float_t *scratch_buffer = static_cast<float_t *>(
context->GetScratchBuffer(context, data.scratch_buffer_index));
const RuntimeShape &input_shape = tflite::micro::GetTensorShape(input);
const RuntimeShape &filter_shape = tflite::micro::GetTensorShape(filter);
const RuntimeShape &output_shape = tflite::micro::GetTensorShape(output);
const int stride_width = data.params.stride_width;
const int stride_height = data.params.stride_height;
const int pad_width = data.params.padding_values.width;
const int pad_height = data.params.padding_values.height;

const int batches = MatchingDim(input_shape, 0, output_shape, 0);
const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);

const int input_height = input_shape.Dims(1);
const int input_width = input_shape.Dims(2);
const int filter_height = filter_shape.Dims(1);
const int filter_width = filter_shape.Dims(2);
const int output_height = output_shape.Dims(1);
const int output_width = output_shape.Dims(2);
const float *input_data = tflite::micro::GetTensorData<float>(input);
#ifdef USE_TFLM_COMPRESSION
const float *filter_data = tflite::micro::GetTensorData<float>(
micro_context, filter, filter_comp_td, data.filter_scratch_index);
const float *bias_data = tflite::micro::GetOptionalTensorData<float>(
micro_context, bias, bias_comp_td, data.bias_scratch_index);
#else
const float *filter_data = tflite::micro::GetTensorData<float>(filter);
const float *bias_data = tflite::micro::GetTensorData<float>(bias);
#endif // USE_TFLM_COMPRESSION

float *output_data = tflite::micro::GetTensorData<float>(output);

const int num_elements = output_shape.FlatSize();
const int output_elements =
batches * output_height * output_width * output_depth;

for (int b = 0; b < batches; b++) {
xa_nn_transpose_conv_f32(
&output_data[b * output_height * output_width * output_depth],
const_cast<FLOAT32 *>(
&input_data[b * input_height * input_width * input_depth]),
const_cast<FLOAT32 *>(filter_data), const_cast<FLOAT32 *>(bias_data),
stride_width, stride_height, pad_width, pad_height, input_depth,
output_depth, input_height, input_width, filter_height, filter_width,
output_height, output_width, num_elements / batches, scratch_buffer);
}
xa_nn_vec_activation_min_max_f32_f32(
output_data, output_data, op_params.float_activation_min,
op_params.float_activation_max, output_elements);
#else
reference_ops::TransposeConv(
op_params, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<float>(input),
Expand All @@ -353,6 +428,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output),
tflite::micro::GetTensorShape(nullptr), nullptr);
#endif // HIFI_VFPU && (defined(HIFI3) || defined(HIFI4) || defined(HIFI5))
break;
}
case kTfLiteInt8: {
Expand Down