Skip to content

Enable HiFi SIMD for CONV operator (#3081) #3082

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions tensorflow/lite/micro/kernels/conv.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,15 @@ inline TFLMRegistration Register_CONV_2D_INT8() { return Register_CONV_2D(); }
inline TFLMRegistration Register_CONV_2D_INT16() { return Register_CONV_2D(); }
#endif // defined(CMSIS_NN) || defined(XTENSA)

#if defined(XTENSA)
// Returns a TFLMRegistration struct for kernel variant that only supports
// float32 activations and float32 weights and uses the latency optimized
// implementations.
TFLMRegistration Register_CONV_2D_FLOAT32();
#else
inline TFLMRegistration Register_CONV_2D_FLOAT32() { return Register_CONV_2D(); }
#endif

} // namespace tflite

#endif // TENSORFLOW_LITE_MICRO_KERNELS_CONV_H_
37 changes: 6 additions & 31 deletions tensorflow/lite/micro/kernels/xtensa/conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,37 +52,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {

switch (input->type) {
case kTfLiteFloat32: {
#ifdef USE_TFLM_COMPRESSION

MicroContext* micro_context = GetMicroContext(context);

const CompressionTensorData* weights_comp_td =
micro_context->GetTensorCompressionData(node, kConvWeightsTensor);
const CompressionTensorData* bias_comp_td =
micro_context->GetTensorCompressionData(node, kConvBiasTensor);

#endif // USE_TFLM_COMPRESSION
tflite::reference_ops::Conv(
ConvParamsFloat(params, op_data.reference_op_data),
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<float>(input),
tflite::micro::GetTensorShape(filter),
#ifdef USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<float>(
micro_context, filter, weights_comp_td,
op_data.reference_op_data.weights_scratch_index),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<float>(
micro_context, bias, bias_comp_td,
op_data.reference_op_data.bias_scratch_index),
#else // USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<float>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<float>(bias),
#endif // USE_TFLM_COMPRESSION
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output),
tflite::micro::GetTensorShape(nullptr), nullptr);
#if HIFI_VFPU && (defined(HIFI3) || defined(HIFI4) || defined(HIFI5))
ConvEvalHifiFloat32(context, node, params, op_data, input, filter,
bias, output);
#else
return ConvReferenceEvalFloat32(context, node);
#endif
break;
}
case kTfLiteInt8: {
Expand Down
84 changes: 84 additions & 0 deletions tensorflow/lite/micro/kernels/xtensa/conv_float32_reference.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/reference/integer_ops/conv.h"
#include "tensorflow/lite/kernels/internal/reference/conv.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/padding.h"
#include "tensorflow/lite/micro/kernels/conv.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/micro_log.h"
#include "tensorflow/lite/micro/kernels/xtensa/xtensa_conv.h"

namespace tflite {

TfLiteStatus ConvReferenceEvalFloat32(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->user_data != nullptr);
TFLITE_DCHECK(node->builtin_data != nullptr);

const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, kConvInputTensor);

const auto& params =
*(reinterpret_cast<TfLiteConvParams*>(node->builtin_data));
const auto& op_data = *(reinterpret_cast<XtensaConvOpData*>(node->user_data));

TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kConvOutputTensor);
const TfLiteEvalTensor* filter =
tflite::micro::GetEvalInput(context, node, kConvWeightsTensor);
const TfLiteEvalTensor* bias =
tflite::micro::GetEvalInput(context, node, kConvBiasTensor);

#ifdef USE_TFLM_COMPRESSION

MicroContext* micro_context = GetMicroContext(context);

const CompressionTensorData* weights_comp_td =
micro_context->GetTensorCompressionData(node, kConvWeightsTensor);
const CompressionTensorData* bias_comp_td =
micro_context->GetTensorCompressionData(node, kConvBiasTensor);

#endif // USE_TFLM_COMPRESSION
tflite::reference_ops::Conv(
ConvParamsFloat(params, op_data.reference_op_data),
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<float>(input),
tflite::micro::GetTensorShape(filter),
#ifdef USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<float>(
micro_context, filter, weights_comp_td,
op_data.reference_op_data.weights_scratch_index),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<float>(
micro_context, bias, bias_comp_td,
op_data.reference_op_data.bias_scratch_index),
#else // USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<float>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<float>(bias),
#endif // USE_TFLM_COMPRESSION
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output),
tflite::micro::GetTensorShape(nullptr), nullptr);
return kTfLiteOk;
}

} // namespace tflite
115 changes: 114 additions & 1 deletion tensorflow/lite/micro/kernels/xtensa/conv_hifi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/portable_tensor_utils.h"
#include "tensorflow/lite/kernels/internal/reference/conv.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/conv.h"
Expand Down Expand Up @@ -59,7 +60,8 @@ TfLiteStatus ConvPrepareHifi(TfLiteContext* context, TfLiteNode* node) {
inputs_and_bias_ok =
inputs_and_bias_ok &&
(input->type == kTfLiteInt8 ||
(input->type == kTfLiteInt16 && bias->type == kTfLiteInt64));
(input->type == kTfLiteInt16 && bias->type == kTfLiteInt64) ||
input->type == kTfLiteFloat32);
#else
inputs_and_bias_ok = inputs_and_bias_ok && (input->type == kTfLiteInt8);
#endif // defined(HIFI3) || defined(HIFI4) || defined(HIFI5)
Expand All @@ -81,6 +83,7 @@ TfLiteStatus ConvPrepareHifi(TfLiteContext* context, TfLiteNode* node) {
const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
const int filter_height = filter_shape.Dims(1);
const int filter_width = filter_shape.Dims(2);
const int filter_depth = filter_shape.Dims(3);
const int output_height = output_shape.Dims(1);
const int output_width = output_shape.Dims(2);
const int output_channels = output_shape.Dims(3);
Expand Down Expand Up @@ -133,6 +136,13 @@ TfLiteStatus ConvPrepareHifi(TfLiteContext* context, TfLiteNode* node) {
}
TF_LITE_ENSURE(context, required_scratch > 0);
}
#if HIFI_VFPU && (defined(HIFI3) || defined(HIFI4) || defined(HIFI5))
if ((input->type == kTfLiteFloat32) && (input_depth == filter_depth)) {
required_scratch = xa_nn_conv2d_std_getsize(
input_height, input_depth, filter_height, filter_width, stride_height,
pad_height, output_height, output_channels, PREC_F32);
}
#endif
}
TF_LITE_ENSURE_OK(
context, context->RequestScratchBufferInArena(
Expand Down Expand Up @@ -400,5 +410,108 @@ TfLiteStatus ConvEvalHifiInt8(TfLiteContext* context, TfLiteNode* node,
return kTfLiteOk;
}

#if HIFI_VFPU
TfLiteStatus ConvEvalHifiFloat32(TfLiteContext *context, TfLiteNode *node,
const TfLiteConvParams &params,
const XtensaConvOpData &data,
const TfLiteEvalTensor *input,
const TfLiteEvalTensor *filter,
const TfLiteEvalTensor *bias,
TfLiteEvalTensor *output) {
const RuntimeShape &input_shape = tflite::micro::GetTensorShape(input);
const RuntimeShape &filter_shape = tflite::micro::GetTensorShape(filter);
const int stride_width = params.stride_width;
const int stride_height = params.stride_height;
const int pad_width = data.reference_op_data.padding.width;
const int pad_height = data.reference_op_data.padding.height;

const RuntimeShape &output_shape = tflite::micro::GetTensorShape(output);
const int batches = MatchingDim(input_shape, 0, output_shape, 0);
const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
const int input_height = input_shape.Dims(1);
const int input_width = input_shape.Dims(2);
const int input_depth = input_shape.Dims(3);
const int filter_height = filter_shape.Dims(1);
const int filter_width = filter_shape.Dims(2);
const int filter_depth = filter_shape.Dims(3);
const int output_height = output_shape.Dims(1);
const int output_width = output_shape.Dims(2);

const float32_t *input_data = tflite::micro::GetTensorData<float32_t>(input);
const float32_t *filter_data =
tflite::micro::GetTensorData<float32_t>(filter);
const float32_t *bias_data = tflite::micro::GetTensorData<float32_t>(bias);
float32_t *output_data = tflite::micro::GetTensorData<float32_t>(output);
ConvParams op_params;
CalculateActivationRange(params.activation, &op_params.float_activation_min,
&op_params.float_activation_max);

const int output_data_format = 0;
const int out_length = output_height * output_width * output_depth;
if (filter_height == 1 && filter_width == 1) {
for (int batch = 0; batch < batches; ++batch) {
float32_t *p_out_temp;
p_out_temp = &output_data[batch * out_length];

TF_LITE_ENSURE_EQ(
context,
xa_nn_conv2d_pointwise_f32(
p_out_temp, const_cast<float32_t *>(filter_data),
const_cast<float32_t *>(&input_data[batch * input_height *
input_width * input_depth]),
const_cast<float32_t *>(bias_data), input_height, input_width,
input_depth, output_depth, output_data_format),
0);
}
xa_nn_vec_activation_min_max_f32_f32(
output_data, output_data, op_params.float_activation_min,
op_params.float_activation_max,
(batches * output_height * output_width * output_depth));
} else if ((filter_depth == input_depth) &&
((params.dilation_width_factor == 1) &&
(params.dilation_height_factor == 1))) {
void *p_scratch = static_cast<void *>(
context->GetScratchBuffer(context, data.scratch_tensor_index));

for (int batch = 0; batch < batches; ++batch) {
float32_t *p_out_temp;
p_out_temp = &output_data[batch * out_length];
TF_LITE_ENSURE_EQ(
context,
xa_nn_conv2d_std_f32(
p_out_temp,
&input_data[batch * input_height * input_width * input_depth],
const_cast<float32_t *>(filter_data), // filter_data,
bias_data, input_height, input_width, input_depth, filter_height,
filter_width, output_depth, stride_width, stride_height,
pad_width, pad_height, output_height, output_width,
output_data_format, static_cast<void *>(p_scratch)),
0);
}
xa_nn_vec_activation_min_max_f32_f32(
output_data, output_data, op_params.float_activation_min,
op_params.float_activation_max,
(batches * output_height * output_width * output_depth));
} else {
TFLITE_DCHECK(node->user_data != nullptr);
const auto &op_data =
*(reinterpret_cast<XtensaConvOpData *>(node->user_data));
tflite::reference_ops::Conv(
ConvParamsFloat(params, op_data.reference_op_data),
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<float>(input),
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<float>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<float>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output),
tflite::micro::GetTensorShape(nullptr), nullptr);
}

return kTfLiteOk;
}
#endif // HIFI_VFPU

} // namespace tflite
#endif // defined(HIFI3) || defined(HIFI4) || defined(HIFI5)
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,28 @@ TfLiteStatus EvalInt16(TfLiteContext* context, TfLiteNode* node) {
#endif
}

TfLiteStatus EvalFloat32(TfLiteContext* context, TfLiteNode* node) {
#if HIFI_VFPU && (defined(HIFI3) || defined(HIFI4) || defined(HIFI5))
const auto& op_data = *(reinterpret_cast<XtensaConvOpData*>(node->user_data));
const auto& params =
*(reinterpret_cast<TfLiteConvParams*>(node->builtin_data));

const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, kConvInputTensor);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kConvOutputTensor);
const TfLiteEvalTensor* filter =
tflite::micro::GetEvalInput(context, node, kConvWeightsTensor);
const TfLiteEvalTensor* bias =
tflite::micro::GetEvalInput(context, node, kConvBiasTensor);

return ConvEvalHifiFloat32(context, node, params, op_data, input, filter, bias,
output);
#else
return ConvReferenceEvalFloat32(context, node);
#endif
}

} // namespace

TFLMRegistration Register_CONV_2D_INT8() {
Expand All @@ -86,4 +108,9 @@ TFLMRegistration Register_CONV_2D_INT16() {
EvalInt16);
}

TFLMRegistration Register_CONV_2D_FLOAT32() {
return tflite::micro::RegisterOp(ConvInitXtensa, ConvPrepareXtensa,
EvalFloat32);
}

} // namespace tflite
12 changes: 12 additions & 0 deletions tensorflow/lite/micro/kernels/xtensa/xtensa_conv.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,16 @@ TfLiteStatus ConvEvalHifiInt16(TfLiteContext* context, TfLiteNode* node,
const TfLiteEvalTensor* bias,
TfLiteEvalTensor* output);

#if HIFI_VFPU
TfLiteStatus ConvEvalHifiFloat32(TfLiteContext* context, TfLiteNode* node,
const TfLiteConvParams& params,
const XtensaConvOpData& data,
const TfLiteEvalTensor* input,
const TfLiteEvalTensor* filter,
const TfLiteEvalTensor* bias,
TfLiteEvalTensor* output);
#endif

#endif // defined(HIFI3) || defined(HIFI4) || defined(HIFI5)

#if defined(VISION_P6)
Expand All @@ -79,6 +89,8 @@ TfLiteStatus ConvReferenceEvalInt8(TfLiteContext* context, TfLiteNode* node);

TfLiteStatus ConvReferenceEvalInt16(TfLiteContext* context, TfLiteNode* node);

TfLiteStatus ConvReferenceEvalFloat32(TfLiteContext* context, TfLiteNode* node);

void* ConvInitXtensa(TfLiteContext* context, const char* buffer, size_t length);
TfLiteStatus ConvPrepareXtensa(TfLiteContext* context, TfLiteNode* node);

Expand Down
3 changes: 2 additions & 1 deletion tensorflow/lite/micro/tools/make/ext_libs/xtensa.inc
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
MICROLITE_CC_KERNEL_SRCS += \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/add_vision.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/conv_common_xtensa.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/conv_float32_reference.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/conv_hifi.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/conv_int16_reference.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/conv_int8_int16.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/conv_int8_int16_float32.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/conv_int8_reference.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/conv_vision.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/depthwise_conv_hifi.cc \
Expand Down