diff --git a/tensorflow/lite/micro/kernels/xtensa/sub.cc b/tensorflow/lite/micro/kernels/xtensa/sub.cc index b8308c93eaa..4400adb20c2 100644 --- a/tensorflow/lite/micro/kernels/xtensa/sub.cc +++ b/tensorflow/lite/micro/kernels/xtensa/sub.cc @@ -36,14 +36,55 @@ void* SubInit(TfLiteContext* context, const char* buffer, size_t length) { return context->AllocatePersistentBuffer(context, sizeof(OpDataSub)); } -void EvalSub(TfLiteContext* context, TfLiteNode* node, TfLiteSubParams* params, - const OpDataSub* data, const TfLiteEvalTensor* input1, - const TfLiteEvalTensor* input2, TfLiteEvalTensor* output) { +TfLiteStatus EvalSub(TfLiteContext* context, TfLiteNode* node, + TfLiteSubParams* params, const OpDataSub* data, + const TfLiteEvalTensor* input1, + const TfLiteEvalTensor* input2, TfLiteEvalTensor* output) { float output_activation_min, output_activation_max; CalculateActivationRange(params->activation, &output_activation_min, &output_activation_max); tflite::ArithmeticParams op_params; SetActivationParams(output_activation_min, output_activation_max, &op_params); + +#if HIFI_VFPU && (defined(HIFI3) || defined(HIFI4) || defined(HIFI5)) + const RuntimeShape extended_input1_shape = + RuntimeShape::ExtendedShape(5, tflite::micro::GetTensorShape(input1)); + const RuntimeShape extended_input2_shape = + RuntimeShape::ExtendedShape(5, tflite::micro::GetTensorShape(input2)); + const RuntimeShape extended_output_shape = + RuntimeShape::ExtendedShape(5, tflite::micro::GetTensorShape(output)); + const int* input1_dims = extended_input1_shape.DimsData(); + const int* input2_dims = extended_input2_shape.DimsData(); + const int* output_dims = extended_output_shape.DimsData(); + + int inp1_off = 0; + int inp2_off = 0; + int out_off = output_dims[1] * output_dims[2] * output_dims[3] * output_dims[4]; + if (input1_dims[0] > 1) { + inp1_off = + input1_dims[1] * input1_dims[2] * input1_dims[3] * input1_dims[4]; + } + if (input2_dims[0] > 1) { + inp2_off = + input2_dims[1] * input2_dims[2] * input2_dims[3] * input2_dims[4]; + } + + for (int b = 0; b < output_dims[0]; b++) { + int err = xa_nn_elm_sub_broadcast_4D_f32xf32_f32( + tflite::micro::GetTensorData(output) + b * out_off, + output_dims + 1, + tflite::micro::GetTensorData(input1) + b * inp1_off, + input1_dims + 1, + tflite::micro::GetTensorData(input2) + b * inp2_off, + input2_dims + 1); + TF_LITE_ENSURE(context, err == 0); + } + + float* output_data = tflite::micro::GetTensorData(output); + xa_nn_vec_activation_min_max_f32_f32( + output_data, output_data, op_params.float_activation_min, + op_params.float_activation_max, (output_dims[0] * out_off)); +#else // HIFI_VFPU && (defined(HIFI3) || defined(HIFI4) || defined(HIFI5)) if (data->requires_broadcast) { tflite::reference_ops::BroadcastSubSlow( op_params, tflite::micro::GetTensorShape(input1), @@ -61,6 +102,9 @@ void EvalSub(TfLiteContext* context, TfLiteNode* node, TfLiteSubParams* params, tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); } +#endif + + return kTfLiteOk; } TfLiteStatus EvalSubQuantized(TfLiteContext* context, TfLiteNode* node, @@ -229,7 +273,8 @@ TfLiteStatus SubEval(TfLiteContext* context, TfLiteNode* node) { const OpDataSub& data = *(static_cast(node->user_data)); if (output->type == kTfLiteFloat32) { - EvalSub(context, node, params, &data, input1, input2, output); + TF_LITE_ENSURE_OK( + context, EvalSub(context, node, params, &data, input1, input2, output)); } else if (output->type == kTfLiteInt8 || output->type == kTfLiteInt16) { TF_LITE_ENSURE_OK(context, EvalSubQuantized(context, node, params, &data, input1, input2, output));