Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 33 additions & 3 deletions tflite/kernels/batch_matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -345,9 +345,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
rhs_data->type == kTfLiteInt16);
// Either we have a hybrid quantization with a float32 and an int8 input,
// otherwise both inputs should be of the same type.
TF_LITE_ENSURE(context, (lhs_data->type == kTfLiteFloat32 &&
rhs_data->type == kTfLiteInt8) ||
lhs_data->type == rhs_data->type);
TF_LITE_ENSURE(
context,
(lhs_data->type == kTfLiteFloat32 && rhs_data->type == kTfLiteInt8) ||
lhs_data->type == rhs_data->type ||
(lhs_data->type == kTfLiteInt16 && rhs_data->type == kTfLiteInt8));
// Support dimensions between 2 and 5, inclusive.
TF_LITE_ENSURE(context, NumDimensions(lhs_data) >= 2);
TF_LITE_ENSURE(context, NumDimensions(lhs_data) <= 5);
Expand Down Expand Up @@ -592,6 +594,31 @@ TfLiteStatus EvalInt16(TfLiteContext* context, const OpData* data,
return kTfLiteOk;
}

TfLiteStatus EvalInt16Int8(TfLiteContext* context, const OpData* data,
const RuntimeShape& lhs_shape,
const TfLiteTensor* lhs,
const RuntimeShape& rhs_shape,
const TfLiteTensor* rhs,
const RuntimeShape& output_shape,
TfLiteTensor* output) {
// Reuse params struct from FullyConnected Op.
FullyConnectedParams op_params;
op_params.input_offset = -lhs->params.zero_point;
op_params.weights_offset = -rhs->params.zero_point;
op_params.output_offset = output->params.zero_point;
op_params.output_multiplier = data->output_multiplier;
op_params.output_shift = data->output_shift;
op_params.quantized_activation_min = data->output_activation_min;
op_params.quantized_activation_max = data->output_activation_max;

reference_ops::BatchMatMul<int8_t, int64_t, int16_t, int16_t>(
op_params, rhs_shape, GetTensorData<int8_t>(rhs), lhs_shape,
GetTensorData<int16_t>(lhs), GetTensorShape(output),
GetTensorData<int16_t>(output));

return kTfLiteOk;
}

TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
OpData* data, const RuntimeShape& lhs_shape,
const TfLiteTensor* lhs,
Expand Down Expand Up @@ -627,6 +654,9 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
} else if (lhs->type == kTfLiteInt16 && rhs->type == kTfLiteInt16) {
return EvalInt16(context, data, lhs_shape, lhs, rhs_shape, rhs,
GetTensorShape(output), output);
} else if (lhs->type == kTfLiteInt16 && rhs->type == kTfLiteInt8) {
return EvalInt16Int8(context, data, lhs_shape, lhs, rhs_shape, rhs,
GetTensorShape(output), output);
} else {
TF_LITE_KERNEL_LOG(
context,
Expand Down
95 changes: 68 additions & 27 deletions tflite/kernels/batch_matmul_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -973,25 +973,11 @@ TEST(HybridSymmetricBatchMatMulOpTest, QuantizedInt8BroadcastInputs) {

class QuantizedBatchMatMulOpModel : public SingleOpModel {
public:
QuantizedBatchMatMulOpModel(int units, int batches, const TensorData& lhs,
QuantizedBatchMatMulOpModel(const TensorData& lhs, const TensorData& rhs,
const TensorData& output = {TensorType_INT8},
bool adj_x = false, bool adj_y = false)
: units_(units), batches_(batches) {
int total_input_size = 1;
for (size_t i = 0; i < lhs.shape.size(); ++i) {
total_input_size *= lhs.shape[i];
}
input_size_ = total_input_size / batches_;

int rhs_batch_size = adj_y ? units_ : input_size_;
int rhs_channels = adj_y ? input_size_ : units_;
bool adj_x = false, bool adj_y = false) {
lhs_id_ = AddInput(lhs);
rhs_id_ = AddInput({lhs.type,
{rhs_batch_size, rhs_channels},
0,
0,
GetScale(lhs_id_),
GetZeroPoint(lhs_id_)});
rhs_id_ = AddInput(rhs);

output_id_ = AddOutput(output);

Expand Down Expand Up @@ -1026,15 +1012,12 @@ class QuantizedBatchMatMulOpModel : public SingleOpModel {
int lhs_id_;
int rhs_id_;
int output_id_;
int units_;
int batches_;
int input_size_;
};

TEST(QuantizedBatchMatMulOpTest, SimpleTestQuantizedInt8) {
QuantizedBatchMatMulOpModel m(
/*units=*/3, /*batches*/ 2,
/*lhs=*/{TensorType_INT8, {2, 10}, -63.5, 64},
/*rhs=*/{TensorType_INT8, {10, 3}, -63.5, 64},
/*output=*/{TensorType_INT8, {}, -127, 128});

m.SetWeights<int8_t>({
Expand All @@ -1056,8 +1039,8 @@ TEST(QuantizedBatchMatMulOpTest, SimpleTestQuantizedInt8) {

TEST(QuantizedBatchMatMulOpTest, SimpleTestQuantizedInt8AdjRHS) {
QuantizedBatchMatMulOpModel m(
/*units=*/3, /*batches*/ 2,
/*lhs=*/{TensorType_INT8, {2, 10}, -63.5, 64},
/*rhs=*/{TensorType_INT8, {3, 10}, -63.5, 64},
/*output=*/{TensorType_INT8, {}, -127, 128}, false, true);

m.SetWeights<int8_t>({
Expand All @@ -1083,11 +1066,9 @@ TEST(QuantizedBatchMatMulOpTest, SimpleTestQuantizedInt16) {
const int32_t zero_point = 0;

QuantizedBatchMatMulOpModel m(
/*units=*/3, /*batches*/ 2,
/*lhs=*/
{TensorType_INT16, {2, 10}, 0, 0, inputs_scale, zero_point},
/*output=*/
{TensorType_INT16, {}, 0, 0, output_scale, zero_point});
/*lhs=*/{TensorType_INT16, {2, 10}, 0, 0, inputs_scale, zero_point},
/*rhs=*/{TensorType_INT16, {10, 3}, 0, 0, inputs_scale, zero_point},
/*output=*/{TensorType_INT16, {}, 0, 0, output_scale, zero_point});

m.SetWeights<int16_t>({
1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5,
Expand All @@ -1106,5 +1087,65 @@ TEST(QuantizedBatchMatMulOpTest, SimpleTestQuantizedInt16) {
EXPECT_THAT(m.GetOutput<int16_t>(), ElementsAre(23, 23, 23, 57, 57, 57));
}

TEST(QuantizedBatchMatMulOpTest, SimpleTestQuantizedInt16Int8) {
const float inputs_scale = 1.0;
const float weights_scale = 1.0;
const float output_scale = 1.0;
const int32_t zero_point = 0;

QuantizedBatchMatMulOpModel m(
/*lhs=*/{TensorType_INT16, {2, 10}, 0, 0, inputs_scale, zero_point},
/*rhs=*/{TensorType_INT8, {10, 3}, 0, 0, weights_scale, zero_point},
/*output=*/{TensorType_INT16, {}, 0, 0, output_scale, zero_point});

m.SetWeights<int8_t>({
1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5,
6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10, 10,
});

m.SetInput<int16_t>({
1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
});

ASSERT_EQ(m.Invoke(), kTfLiteOk);

EXPECT_THAT(m.GetDequantizedOutput<int16_t>(),
ElementsAreArray(ArrayFloatNear({23, 23, 23, 57, 57, 57})));
EXPECT_THAT(m.GetOutput<int16_t>(), ElementsAre(23, 23, 23, 57, 57, 57));
}

TEST(QuantizedBatchMatMulOpTest, SimpleTestQuantizedInt16Int8WithScales) {
const float inputs_scale = 0.5;
const float weights_scale = 2.0;
const float output_scale = 0.25;
const int32_t zero_point = 0;

QuantizedBatchMatMulOpModel m(
/*lhs=*/{TensorType_INT16, {2, 10}, 0, 0, inputs_scale, zero_point},
/*rhs=*/{TensorType_INT8, {10, 3}, 0, 0, weights_scale, zero_point},
/*output=*/{TensorType_INT16, {}, 0, 0, output_scale, zero_point});

m.SetWeights<int8_t>({
2, 2, 2, 4, 4, 4, 6, 6, 6, 8, 8, 8, 10, 10, 10,
12, 12, 12, 14, 14, 14, 16, 16, 16, 18, 18, 18, 20, 20, 20,
});

m.SetInput<int16_t>({
0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, -4.5, -5.0, // b = 0
0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, -4.0, 4.5, -5.0, // b = 1
});

ASSERT_EQ(m.Invoke(), kTfLiteOk);

// Accumulator for b=0: 23. Combined Scale = (0.5 * 2.0) / 0.25 = 4.0.
// Quantized output: 23 * 4 = 92.
// Accumulator for b=1: 57. Combined Scale = 4.0.
// Quantized output: 57 * 4 = 228.
EXPECT_THAT(m.GetOutput<int16_t>(), ElementsAre(92, 92, 92, 228, 228, 228));
EXPECT_THAT(m.GetDequantizedOutput<int16_t>(),
ElementsAreArray(ArrayFloatNear({23, 23, 23, 57, 57, 57})));
}

} // namespace
} // namespace tflite
30 changes: 16 additions & 14 deletions tflite/kernels/internal/reference/batch_matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,11 +202,13 @@ inline void BatchMatMul(const RuntimeShape& lhs_shape, const int8_t* lhs_data,
}
}

template <typename T, typename AccumT>
template <typename lhsT, typename AccumT, typename rhsT = lhsT,
typename outputT = lhsT>
inline void BatchMatMul(const FullyConnectedParams& params,
const RuntimeShape& lhs_shape, const T* lhs_data,
const RuntimeShape& rhs_shape, const T* rhs_data,
const RuntimeShape& output_shape, T* output_data) {
const RuntimeShape& lhs_shape, const lhsT* lhs_data,
const RuntimeShape& rhs_shape, const rhsT* rhs_data,
const RuntimeShape& output_shape,
outputT* output_data) {
const RuntimeShape extended_lhs_shape =
RuntimeShape::ExtendedShape(5, lhs_shape);
const RuntimeShape extended_rhs_shape =
Expand Down Expand Up @@ -241,17 +243,17 @@ inline void BatchMatMul(const FullyConnectedParams& params,
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);

for (int b0 = 0; b0 < batch_dim0; ++b0) {
const T* lhs_ptr0 = lhs_data + (b0 * lhs_ext0);
const T* rhs_ptr0 = rhs_data + (b0 * rhs_ext0);
const lhsT* lhs_ptr0 = lhs_data + (b0 * lhs_ext0);
const rhsT* rhs_ptr0 = rhs_data + (b0 * rhs_ext0);
for (int b1 = 0; b1 < batch_dim1; ++b1) {
const T* lhs_ptr1 = lhs_ptr0 + b1 * lhs_ext1;
const T* rhs_ptr1 = rhs_ptr0 + b1 * rhs_ext1;
const lhsT* lhs_ptr1 = lhs_ptr0 + b1 * lhs_ext1;
const rhsT* rhs_ptr1 = rhs_ptr0 + b1 * rhs_ext1;
for (int b2 = 0; b2 < batch_dim2; ++b2) {
const T* lhs_ptr2 = lhs_ptr1 + b2 * lhs_ext2;
const T* rhs_ptr2 = rhs_ptr1 + b2 * rhs_ext2;
T* out_ptr = output_data +
((b0 * batch_dim1 * batch_dim2) + b1 * batch_dim2 + b2) *
lhs_rows * rhs_cols;
const lhsT* lhs_ptr2 = lhs_ptr1 + b2 * lhs_ext2;
const rhsT* rhs_ptr2 = rhs_ptr1 + b2 * rhs_ext2;
outputT* out_ptr = output_data + ((b0 * batch_dim1 * batch_dim2) +
b1 * batch_dim2 + b2) *
lhs_rows * rhs_cols;

for (int j = 0; j < rhs_cols; ++j) {
for (int i = 0; i < lhs_rows; ++i) {
Expand All @@ -267,7 +269,7 @@ inline void BatchMatMul(const FullyConnectedParams& params,
total_scaled = std::max(total_scaled, output_activation_min);
total_scaled = std::min(total_scaled, output_activation_max);
const int idx = lhs_rows * j + i;
out_ptr[idx] = static_cast<T>(total_scaled);
out_ptr[idx] = static_cast<outputT>(total_scaled);
}
}
}
Expand Down
Loading