From 58c184c40d2f01cee21283297f443152ee1bf730 Mon Sep 17 00:00:00 2001 From: Anupam Pandey Date: Mon, 7 Apr 2025 10:54:18 +0530 Subject: [PATCH] Enable HiFi SIMD for FULLY_CONNECTED operator This change facilitates invoking the HiFi5 SIMD for FULLY_CONNECTED Operator during the inference of FP32xFP32 and INT16XINT8 models on Xtensa. --- .../lite/micro/kernels/fully_connected.h | 12 +- .../micro/kernels/xtensa/fully_connected.cc | 56 +------ .../kernels/xtensa/fully_connected_f32.cc | 142 +++++++++++++++++ .../kernels/xtensa/fully_connected_int16.cc | 146 ++++++++++++++++++ .../kernels/xtensa/xtensa_fully_connected.h | 10 ++ .../lite/micro/tools/make/ext_libs/xtensa.inc | 2 + 6 files changed, 313 insertions(+), 55 deletions(-) create mode 100644 tensorflow/lite/micro/kernels/xtensa/fully_connected_f32.cc create mode 100644 tensorflow/lite/micro/kernels/xtensa/fully_connected_int16.cc diff --git a/tensorflow/lite/micro/kernels/fully_connected.h b/tensorflow/lite/micro/kernels/fully_connected.h index 64213f0fb63..035a9858a80 100644 --- a/tensorflow/lite/micro/kernels/fully_connected.h +++ b/tensorflow/lite/micro/kernels/fully_connected.h @@ -117,14 +117,20 @@ TFLMRegistration Register_FULLY_CONNECTED_INT4(); // define fallback implementation that allow reference kernels to still be used // from applications that call a more specific kernel variant. -inline TFLMRegistration Register_FULLY_CONNECTED_INT16() { +inline TFLMRegistration Register_FULLY_CONNECTED_INT4() { return Register_FULLY_CONNECTED(); } -inline TFLMRegistration Register_FULLY_CONNECTED_INT4() { +#endif + +#if defined(XTENSA) +// Returns a TFLMRegistration struct for kernel variant that only supports +// float32. +TFLMRegistration Register_FULLY_CONNECTED_FLOAT32(); +#else +inline TFLMRegistration Register_FULLY_CONNECTED_FLOAT32() { return Register_FULLY_CONNECTED(); } - #endif } // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa/fully_connected.cc b/tensorflow/lite/micro/kernels/xtensa/fully_connected.cc index 511335a550f..d7ef104077d 100644 --- a/tensorflow/lite/micro/kernels/xtensa/fully_connected.cc +++ b/tensorflow/lite/micro/kernels/xtensa/fully_connected.cc @@ -33,8 +33,6 @@ namespace { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TFLITE_DCHECK(node->builtin_data != nullptr); - const auto* params = - static_cast(node->builtin_data); const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, kFullyConnectedInputTensor); @@ -45,18 +43,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, kFullyConnectedOutputTensor); -#ifdef USE_TFLM_COMPRESSION - - MicroContext* micro_context = GetMicroContext(context); - - const CompressionTensorData* weights_comp_td = - micro_context->GetTensorCompressionData(node, - kFullyConnectedWeightsTensor); - const CompressionTensorData* bias_comp_td = - micro_context->GetTensorCompressionData(node, kFullyConnectedBiasTensor); - -#endif // USE_TFLM_COMPRESSION - TFLITE_DCHECK(node->user_data != nullptr); const auto& data = @@ -65,25 +51,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // Checks in Prepare ensure input, output and filter types are all the same. switch (input->type) { case kTfLiteFloat32: { - tflite::reference_ops::FullyConnected( - FullyConnectedParamsFloat(params->activation), - tflite::micro::GetTensorShape(input), - tflite::micro::GetTensorData(input), - tflite::micro::GetTensorShape(filter), -#ifdef USE_TFLM_COMPRESSION - tflite::micro::GetTensorData(micro_context, filter, - weights_comp_td, - data.weights_scratch_index), - tflite::micro::GetTensorShape(bias), - tflite::micro::GetOptionalTensorData( - micro_context, bias, bias_comp_td, data.bias_scratch_index), -#else // USE_TFLM_COMPRESSION - tflite::micro::GetTensorData(filter), - tflite::micro::GetTensorShape(bias), - tflite::micro::GetOptionalTensorData(bias), -#endif // USE_TFLM_COMPRESSION - tflite::micro::GetTensorShape(output), - tflite::micro::GetTensorData(output)); + return XtensaEvalFullyConnectedF32( + context, node, data, input, filter, bias, output); break; } @@ -109,25 +78,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { case kTfLiteInt16: { switch (filter->type) { case kTfLiteInt8: { - tflite::reference_integer_ops::FullyConnected( - FullyConnectedParamsQuantized(data), - tflite::micro::GetTensorShape(input), - tflite::micro::GetTensorData(input), - tflite::micro::GetTensorShape(filter), -#ifdef USE_TFLM_COMPRESSION - tflite::micro::GetTensorData(micro_context, filter, - weights_comp_td, - data.weights_scratch_index), - tflite::micro::GetTensorShape(bias), - tflite::micro::GetOptionalTensorData( - micro_context, bias, bias_comp_td, data.bias_scratch_index), -#else // USE_TFLM_COMPRESSION - tflite::micro::GetTensorData(filter), - tflite::micro::GetTensorShape(bias), - tflite::micro::GetOptionalTensorData(bias), -#endif // USE_TFLM_COMPRESSION - tflite::micro::GetTensorShape(output), - tflite::micro::GetTensorData(output)); + return XtensaEvalFullyConnectedQuantizedInt16( + context, node, data, input, filter, bias, output); break; } default: { diff --git a/tensorflow/lite/micro/kernels/xtensa/fully_connected_f32.cc b/tensorflow/lite/micro/kernels/xtensa/fully_connected_f32.cc new file mode 100644 index 00000000000..20b7a0c58f3 --- /dev/null +++ b/tensorflow/lite/micro/kernels/xtensa/fully_connected_f32.cc @@ -0,0 +1,142 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/portable_tensor_utils.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/fully_connected.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/micro/kernels/kernel_util.h" +#include "tensorflow/lite/micro/kernels/xtensa/xtensa.h" +#include "tensorflow/lite/micro/kernels/xtensa/xtensa_fully_connected.h" + +namespace tflite { + +TfLiteStatus XtensaEvalFullyConnectedF32( + TfLiteContext* context, TfLiteNode* node, const OpDataFullyConnected& data, + const TfLiteEvalTensor* input, const TfLiteEvalTensor* filter, + const TfLiteEvalTensor* bias, TfLiteEvalTensor* output) { +#if !defined(VISION_P6) + +#ifdef USE_TFLM_COMPRESSION + + MicroContext* micro_context = GetMicroContext(context); + + const CompressionTensorData* weights_comp_td = + micro_context->GetTensorCompressionData(node, + kFullyConnectedWeightsTensor); + const CompressionTensorData* bias_comp_td = + micro_context->GetTensorCompressionData(node, kFullyConnectedBiasTensor); + +#endif // USE_TFLM_COMPRESSION + + const float* bias_data = +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetOptionalTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index); +#else // USE_TFLM_COMPRESSION + tflite::micro::GetOptionalTensorData(bias); +#endif // USE_TFLM_COMPRESSION + + const float* filter_data = +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData( + micro_context, filter, weights_comp_td, data.weights_scratch_index); +#else // USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(filter); +#endif // USE_TFLM_COMPRESSION + +#endif // !defined(VISION_P6) + + const auto* node_data = + static_cast(node->builtin_data); + const FullyConnectedParams params = + FullyConnectedParamsFloat(node_data->activation); + const float* input_data = tflite::micro::GetTensorData(input); + float* output_data = tflite::micro::GetTensorData(output); + +#if HIFI_VFPU && (defined(HIFI3) || defined(HIFI4) || defined(HIFI5)) + const RuntimeShape& output_shape = tflite::micro::GetTensorShape(output); + const int num_batches = + FlatSizeSkipDim(output_shape, output_shape.DimensionsCount() - 1); + const int output_depth = + output_shape.Dims(output_shape.DimensionsCount() - 1); + + const RuntimeShape& filter_shape = tflite::micro::GetTensorShape(filter); + const int filter_dim_count = filter_shape.DimensionsCount(); + const int accum_depth = filter_shape.Dims(filter_dim_count - 1); + if (num_batches == 1) { + TF_LITE_ENSURE_EQ( + context, + xa_nn_fully_connected_f32(output_data, filter_data, input_data, + bias_data, accum_depth, output_depth), + 0); + } else { + TF_LITE_ENSURE_EQ(context, + xa_nn_matmul_f32xf32_f32( + output_data, filter_data, input_data, bias_data, + output_depth, accum_depth, accum_depth, num_batches, + accum_depth, output_depth, 1), + 0); + } + TF_LITE_ENSURE_EQ( + context, + xa_nn_vec_activation_min_max_f32_f32( + output_data, output_data, params.float_activation_min, + params.float_activation_max, num_batches * output_depth), + 0); +#else + tflite::reference_ops::FullyConnected( + params, tflite::micro::GetTensorShape(input), input_data, + tflite::micro::GetTensorShape(filter), filter_data, + tflite::micro::GetTensorShape(bias), bias_data, + tflite::micro::GetTensorShape(output), output_data); +#endif // HIFI_VFPU && (defined(HIFI3) || defined(HIFI4) || defined(HIFI5)) + + return kTfLiteOk; +} + +namespace { + +TfLiteStatus EvalFloat32(TfLiteContext* context, TfLiteNode* node) { + TFLITE_DCHECK(node->user_data != nullptr); + const auto& data = + *(static_cast(node->user_data)); + + const TfLiteEvalTensor* input = + tflite::micro::GetEvalInput(context, node, kFullyConnectedInputTensor); + const TfLiteEvalTensor* filter = + tflite::micro::GetEvalInput(context, node, kFullyConnectedWeightsTensor); + const TfLiteEvalTensor* bias = + tflite::micro::GetEvalInput(context, node, kFullyConnectedBiasTensor); + + TfLiteEvalTensor* output = + tflite::micro::GetEvalOutput(context, node, kFullyConnectedOutputTensor); + + return XtensaEvalFullyConnectedF32(context, node, data, input, filter, bias, + output); +} + +} // namespace + +TFLMRegistration Register_FULLY_CONNECTED_FLOAT32() { + return tflite::micro::RegisterOp(XtensaInitFullyConnected, + XtensaPrepareFullyConnected, EvalFloat32); +} + +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa/fully_connected_int16.cc b/tensorflow/lite/micro/kernels/xtensa/fully_connected_int16.cc new file mode 100644 index 00000000000..2ca5aa11f8f --- /dev/null +++ b/tensorflow/lite/micro/kernels/xtensa/fully_connected_int16.cc @@ -0,0 +1,146 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/portable_tensor_utils.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/micro/kernels/kernel_util.h" +#include "tensorflow/lite/micro/kernels/xtensa/xtensa.h" +#include "tensorflow/lite/micro/kernels/xtensa/xtensa_fully_connected.h" + +namespace tflite { + +TfLiteStatus XtensaEvalFullyConnectedQuantizedInt16( + TfLiteContext* context, TfLiteNode* node, const OpDataFullyConnected& data, + const TfLiteEvalTensor* input, const TfLiteEvalTensor* filter, + const TfLiteEvalTensor* bias, TfLiteEvalTensor* output) { +#if !defined(VISION_P6) + +#ifdef USE_TFLM_COMPRESSION + + MicroContext* micro_context = GetMicroContext(context); + + const CompressionTensorData* weights_comp_td = + micro_context->GetTensorCompressionData(node, + kFullyConnectedWeightsTensor); + const CompressionTensorData* bias_comp_td = + micro_context->GetTensorCompressionData(node, kFullyConnectedBiasTensor); + +#endif // USE_TFLM_COMPRESSION + + const int64_t* bias_data = +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetOptionalTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index); +#else // USE_TFLM_COMPRESSION + tflite::micro::GetOptionalTensorData(bias); +#endif // USE_TFLM_COMPRESSION + + const int8_t* filter_data = +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData( + micro_context, filter, weights_comp_td, data.weights_scratch_index); +#else // USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(filter); +#endif // USE_TFLM_COMPRESSION + +#endif // !defined(VISION_P6) + + int16_t* output_data = tflite::micro::GetTensorData(output); + +#if defined(HIFI3) || defined(HIFI4) || defined(HIFI5) + const RuntimeShape& output_shape = tflite::micro::GetTensorShape(output); + const int num_batches = + FlatSizeSkipDim(output_shape, output_shape.DimensionsCount() - 1); + const int output_depth = + output_shape.Dims(output_shape.DimensionsCount() - 1); + + const RuntimeShape& filter_shape = tflite::micro::GetTensorShape(filter); + const int filter_dim_count = filter_shape.DimensionsCount(); + const int accum_depth = filter_shape.Dims(filter_dim_count - 1); + + FullyConnectedParams op_params = FullyConnectedParamsQuantized(data); + if (num_batches == 1) { + TF_LITE_ENSURE_EQ(context, + xa_nn_fully_connected_sym8sxsym16s_sym16s( + output_data, filter_data, + tflite::micro::GetTensorData(input), + bias_data, accum_depth, output_depth, + op_params.output_multiplier, op_params.output_shift), + 0); + } else { + TF_LITE_ENSURE_EQ(context, + xa_nn_matmul_sym8sxsym16s_sym16s( + tflite::micro::GetTensorData(output), + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorData(input), + bias_data, output_depth, accum_depth, accum_depth, + num_batches, accum_depth, output_depth, 1, + op_params.input_offset, op_params.output_multiplier, + op_params.output_shift, op_params.output_offset), + 0); + } + + TF_LITE_ENSURE_EQ(context, + xa_nn_vec_activation_min_max_16_16( + output_data, output_data, data.output_activation_min, + data.output_activation_max, num_batches * output_depth), + 0); +#else + tflite::reference_integer_ops::FullyConnected( + FullyConnectedParamsQuantized(data), tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), filter_data, + tflite::micro::GetTensorShape(bias), bias_data, + tflite::micro::GetTensorShape(output), output_data); +#endif // defined(HIFI3) || defined(HIFI4) || defined(HIFI5) + + return kTfLiteOk; +} + +namespace { + +TfLiteStatus EvalInt16(TfLiteContext* context, TfLiteNode* node) { + TFLITE_DCHECK(node->user_data != nullptr); + const auto& data = + *(static_cast(node->user_data)); + + const TfLiteEvalTensor* input = + tflite::micro::GetEvalInput(context, node, kFullyConnectedInputTensor); + const TfLiteEvalTensor* filter = + tflite::micro::GetEvalInput(context, node, kFullyConnectedWeightsTensor); + const TfLiteEvalTensor* bias = + tflite::micro::GetEvalInput(context, node, kFullyConnectedBiasTensor); + + TfLiteEvalTensor* output = + tflite::micro::GetEvalOutput(context, node, kFullyConnectedOutputTensor); + + return XtensaEvalFullyConnectedQuantizedInt16(context, node, data, input, + filter, bias, output); +} + +} // namespace + +TFLMRegistration Register_FULLY_CONNECTED_INT16() { + return tflite::micro::RegisterOp(XtensaInitFullyConnected, + XtensaPrepareFullyConnected, EvalInt16); +} + +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa/xtensa_fully_connected.h b/tensorflow/lite/micro/kernels/xtensa/xtensa_fully_connected.h index e030f0b3bd4..e02e5d243d2 100644 --- a/tensorflow/lite/micro/kernels/xtensa/xtensa_fully_connected.h +++ b/tensorflow/lite/micro/kernels/xtensa/xtensa_fully_connected.h @@ -65,6 +65,16 @@ TfLiteStatus XtensaEvalFullyConnectedQuantizedInt8( const TfLiteEvalTensor* input, const TfLiteEvalTensor* filter, const TfLiteEvalTensor* bias, TfLiteEvalTensor* output); +TfLiteStatus XtensaEvalFullyConnectedQuantizedInt16( + TfLiteContext* context, TfLiteNode* node, const OpDataFullyConnected& data, + const TfLiteEvalTensor* input, const TfLiteEvalTensor* filter, + const TfLiteEvalTensor* bias, TfLiteEvalTensor* output); + +TfLiteStatus XtensaEvalFullyConnectedF32( + TfLiteContext* context, TfLiteNode* node, const OpDataFullyConnected& data, + const TfLiteEvalTensor* input, const TfLiteEvalTensor* filter, + const TfLiteEvalTensor* bias, TfLiteEvalTensor* output); + TfLiteStatus XtensaCalculateOpDataFullyConnected( TfLiteContext* context, TfLiteFusedActivation activation, TfLiteType data_type, const TfLiteTensor* input, const TfLiteTensor* filter, diff --git a/tensorflow/lite/micro/tools/make/ext_libs/xtensa.inc b/tensorflow/lite/micro/tools/make/ext_libs/xtensa.inc index 70e1880c800..4680bd05517 100644 --- a/tensorflow/lite/micro/tools/make/ext_libs/xtensa.inc +++ b/tensorflow/lite/micro/tools/make/ext_libs/xtensa.inc @@ -13,6 +13,8 @@ MICROLITE_CC_KERNEL_SRCS += \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/depthwise_conv_vision.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/fully_connected_common_xtensa.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/fully_connected_int8.cc \ + $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/fully_connected_int16.cc \ + $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/fully_connected_f32.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/fully_connected_vision.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/pad_vision.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/pooling_int8.cc \