Skip to content

Add INT32 support to SUB #3037

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 63 additions & 26 deletions tensorflow/lite/micro/kernels/sub.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,39 +36,76 @@ void* SubInit(TfLiteContext* context, const char* buffer, size_t length) {
return context->AllocatePersistentBuffer(context, sizeof(OpDataSub));
}

void EvalSub(TfLiteContext* context, TfLiteNode* node, TfLiteSubParams* params,
const OpDataSub* data, const TfLiteEvalTensor* input1,
const TfLiteEvalTensor* input2, TfLiteEvalTensor* output) {
float output_activation_min, output_activation_max;
CalculateActivationRange(params->activation, &output_activation_min,
&output_activation_max);
tflite::ArithmeticParams op_params;
SetActivationParams(output_activation_min, output_activation_max, &op_params);
if (data->requires_broadcast) {
tflite::reference_ops::BroadcastSubSlow(
op_params, tflite::micro::GetTensorShape(input1),
tflite::micro::GetTensorData<float>(input1),
tflite::micro::GetTensorShape(input2),
tflite::micro::GetTensorData<float>(input2),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output));
} else {
tflite::reference_ops::SubWithActivation(
op_params, tflite::micro::GetTensorShape(input1),
tflite::micro::GetTensorData<float>(input1),
tflite::micro::GetTensorShape(input2),
tflite::micro::GetTensorData<float>(input2),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output));
TfLiteStatus EvalSub(TfLiteContext* context, TfLiteNode* node,
TfLiteSubParams* params, const OpDataSub* data,
const TfLiteEvalTensor* input1,
const TfLiteEvalTensor* input2, TfLiteEvalTensor* output) {
switch (output->type) {
case kTfLiteFloat32: {
float output_activation_min, output_activation_max;
CalculateActivationRange(params->activation, &output_activation_min,
&output_activation_max);
tflite::ArithmeticParams op_params;
SetActivationParams(output_activation_min, output_activation_max,
&op_params);
if (data->requires_broadcast) {
tflite::reference_ops::BroadcastSubSlow(
op_params, tflite::micro::GetTensorShape(input1),
tflite::micro::GetTensorData<float>(input1),
tflite::micro::GetTensorShape(input2),
tflite::micro::GetTensorData<float>(input2),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output));
} else {
tflite::reference_ops::SubWithActivation(
op_params, tflite::micro::GetTensorShape(input1),
tflite::micro::GetTensorData<float>(input1),
tflite::micro::GetTensorShape(input2),
tflite::micro::GetTensorData<float>(input2),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output));
}
} break;
case kTfLiteInt32: {
int32_t output_activation_min, output_activation_max;
CalculateActivationRange(params->activation, &output_activation_min,
&output_activation_max);
tflite::ArithmeticParams op_params;
SetActivationParams(output_activation_min, output_activation_max,
&op_params);
if (data->requires_broadcast) {
tflite::reference_ops::BroadcastSubSlow(
op_params, tflite::micro::GetTensorShape(input1),
tflite::micro::GetTensorData<int32_t>(input1),
tflite::micro::GetTensorShape(input2),
tflite::micro::GetTensorData<int32_t>(input2),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int32_t>(output));
} else {
tflite::reference_ops::SubWithActivation(
op_params, tflite::micro::GetTensorShape(input1),
tflite::micro::GetTensorData<int32_t>(input1),
tflite::micro::GetTensorShape(input2),
tflite::micro::GetTensorData<int32_t>(input2),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int32_t>(output));
}
} break;
default:
MicroPrintf("Type %s (%d) not supported.",
TfLiteTypeGetName(output->type), output->type);
return kTfLiteError;
}

return kTfLiteOk;
}

TfLiteStatus EvalSubQuantized(TfLiteContext* context, TfLiteNode* node,
TfLiteSubParams* params, const OpDataSub* data,
const TfLiteEvalTensor* input1,
const TfLiteEvalTensor* input2,
TfLiteEvalTensor* output) {
tflite::ArithmeticParams op_params;
tflite::ArithmeticParams op_params = {};
op_params.left_shift = data->left_shift;
op_params.input1_offset = data->input1_offset;
op_params.input1_multiplier = data->input1_multiplier;
Expand Down Expand Up @@ -147,7 +184,7 @@ TfLiteStatus SubEval(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->user_data != nullptr);
const OpDataSub& data = *(static_cast<const OpDataSub*>(node->user_data));

if (output->type == kTfLiteFloat32) {
if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32) {
EvalSub(context, node, params, &data, input1, input2, output);
Copy link
Member

@ddavis-2015 ddavis-2015 Mar 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a TF_LITE_ENSURE_OK check here (line 188). It will make the code more consistent.

} else if (output->type == kTfLiteInt8 || output->type == kTfLiteInt16) {
TF_LITE_ENSURE_OK(context, EvalSubQuantized(context, node, params, &data,
Expand Down
33 changes: 33 additions & 0 deletions tensorflow/lite/micro/kernels/sub_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,27 @@ void TestSubFloat(int* input1_dims_data, const float* input1_data,
ElementCount(*output_dims), activation);
}

void TestSubInt32(int* input1_dims_data, const int32_t* input1_data,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You will need #if !defined(XTENSA) around this method also to prevent the unused function warning (which is promoted to an error).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done @ddavis-2015 .

int* input2_dims_data, const int32_t* input2_data,
int* output_dims_data, const int32_t* expected_output,
TfLiteFusedActivation activation, int32_t* output_data) {
TfLiteIntArray* input1_dims = IntArrayFromInts(input1_dims_data);
TfLiteIntArray* input2_dims = IntArrayFromInts(input2_dims_data);
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);

constexpr int inputs_size = 2;
constexpr int outputs_size = 1;
constexpr int tensors_size = inputs_size + outputs_size;
TfLiteTensor tensors[tensors_size] = {
CreateTensor(input1_data, input1_dims),
CreateTensor(input2_data, input2_dims),
CreateTensor(output_data, output_dims),
};

ValidateSubGoldens(tensors, tensors_size, expected_output, output_data,
ElementCount(*output_dims), activation);
}

template <typename T>
void TestSubQuantized(int* input1_dims_data, const float* input1_data,
T* input1_quantized, float input1_scale,
Expand Down Expand Up @@ -219,6 +240,18 @@ TF_LITE_MICRO_TEST(FloatSubWithScalarBroadcast) {
}
}

TF_LITE_MICRO_TEST(Int32SubNoActivation) {
int inout_shape[] = {4, 1, 2, 2, 1};
const int32_t input1_values[] = {-2, 2147483646, -1, 1146622854};
const int32_t input2_values[] = {3, 1, -2147483647, -726978367};
const int32_t golden_values[] = {-5, 2147483645, 2147483646, 1873601221};
const int kOutputDimsCount = 4;
int32_t output_data[kOutputDimsCount];
tflite::testing::TestSubInt32(inout_shape, input1_values, inout_shape,
input2_values, inout_shape, golden_values,
kTfLiteActNone, output_data);
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code will need #if !defined(XTENSA) around it in order to pass the CI tests.

TF_LITE_MICRO_TEST(QuantizedSubNoActivationInt8) {
const float scales[] = {0.25, 0.5, 1.0};
const int zero_points[] = {-10, 4, 13};
Expand Down
Loading