1- /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
1+ /* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
22
33Licensed under the Apache License, Version 2.0 (the "License");
44you may not use this file except in compliance with the License.
@@ -33,6 +33,13 @@ constexpr int kOutputTensor = 0;
3333
3434struct OpData {
3535 ConcatenationParams params;
36+
37+ #ifdef USE_TFLM_COMPRESSION
38+
39+ // scratch buffers for compressed tensors
40+ int scratch_indices[kMaxInputNum ];
41+
42+ #endif // USE_TFLM_COMPRESSION
3643};
3744
3845// Handles negative axis index, coerces to positive index value.
@@ -52,8 +59,6 @@ inline int CalculatePositiveAxis(int axis, const TfLiteTensor* output_tensor) {
5259inline void GetAllInputTensorShapes (const TfLiteContext* context,
5360 const TfLiteNode* node,
5461 RuntimeShape all_shapes[kMaxInputNum ]) {
55- TFLITE_DCHECK (context != nullptr );
56- TFLITE_DCHECK (node != nullptr );
5762 for (int i = 0 ; i < node->inputs ->size ; ++i) {
5863 const TfLiteEvalTensor* t = tflite::micro::GetEvalInput (context, node, i);
5964 RuntimeShape shape = tflite::micro::GetTensorShape (t);
@@ -73,12 +78,22 @@ inline void GetShapesPointers(const RuntimeShape* shapes, size_t num,
7378template <typename T>
7479inline void GetAllInputTensorData (const TfLiteContext* context,
7580 const TfLiteNode* node,
76- T* all_data[kMaxInputNum ]) {
77- TFLITE_DCHECK (context != nullptr );
78- TFLITE_DCHECK (node != nullptr );
81+ const T* all_data[kMaxInputNum ]) {
82+ #ifdef USE_TFLM_COMPRESSION
83+ const OpData* data = static_cast <const OpData*>(node->user_data );
84+ MicroContext* micro_context = GetMicroContext (context);
85+ #endif // USE_TFLM_COMPRESSION
86+
7987 for (int i = 0 ; i < node->inputs ->size ; ++i) {
8088 const TfLiteEvalTensor* t = tflite::micro::GetEvalInput (context, node, i);
89+ #ifdef USE_TFLM_COMPRESSION
90+ const CompressionTensorData* comp_td =
91+ micro_context->GetTensorCompressionData (node, i);
92+ all_data[i] = tflite::micro::GetTensorData<T>(micro_context, t, comp_td,
93+ data->scratch_indices [i]);
94+ #else // USE_TFLM_COMPRESSION
8195 all_data[i] = tflite::micro::GetTensorData<T>(t);
96+ #endif // USE_TFLM_COMPRESSION
8297 }
8398}
8499
@@ -88,16 +103,17 @@ void EvalUnquantized(TfLiteContext* context, TfLiteNode* node) {
88103 RuntimeShape inputs_shape[kMaxInputNum ];
89104 const RuntimeShape* inputs_shape_ptr[kMaxInputNum ];
90105 const data_type* inputs_data[kMaxInputNum ];
106+ TFLITE_DCHECK (context != nullptr );
107+ TFLITE_DCHECK (node != nullptr );
108+ TFLITE_DCHECK (node->user_data != nullptr );
109+ const OpData* data = static_cast <const OpData*>(node->user_data );
91110 GetAllInputTensorShapes (context, node, inputs_shape);
92111 GetShapesPointers (inputs_shape, node->inputs ->size , inputs_shape_ptr);
93112 GetAllInputTensorData (context, node, inputs_data);
94113
95114 TfLiteEvalTensor* output =
96115 tflite::micro::GetEvalOutput (context, node, kOutputTensor );
97116
98- TFLITE_DCHECK (node->user_data != nullptr );
99- const OpData* data = static_cast <const OpData*>(node->user_data );
100-
101117 reference_ops::Concatenation (data->params , inputs_shape_ptr, inputs_data,
102118 tflite::micro::GetTensorShape (output),
103119 tflite::micro::GetTensorData<data_type>(output));
@@ -126,7 +142,6 @@ TfLiteStatus ConcatenationPrepare(TfLiteContext* context, TfLiteNode* node) {
126142 TfLiteType output_type = output_tensor->type ;
127143
128144 micro_context->DeallocateTempTfLiteTensor (input_tensor);
129- micro_context->DeallocateTempTfLiteTensor (output_tensor);
130145
131146 // Check activation and input type
132147 TF_LITE_ENSURE_EQ (context, params->activation , kTfLiteActNone );
@@ -136,16 +151,22 @@ TfLiteStatus ConcatenationPrepare(TfLiteContext* context, TfLiteNode* node) {
136151 input_type == kTfLiteInt64 || input_type == kTfLiteBool );
137152
138153 // Output type must match input type
139- TF_LITE_ENSURE_EQ (context, output_type, input_type);
154+ TF_LITE_ENSURE_TYPES_EQ (context, output_type, input_type);
140155
141156 // This implementation does not support large number of input tensors
142157 const int num_inputs = NumInputs (node);
143158 TF_LITE_ENSURE (context, num_inputs <= kMaxInputNum );
144159
145- // Shapes with dimensions >4 are not yet supported with static allocation.
160+ // Calculate OpData.
161+ TFLITE_DCHECK (node->user_data != nullptr );
162+ OpData* data = static_cast <OpData*>(node->user_data );
163+
164+ // Shapes with dimensions > kMaxSmallSize are not yet supported with static
165+ // allocation.
146166 for (int i = 0 ; i < num_inputs; ++i) {
147167 TfLiteTensor* input = micro_context->AllocateTempInputTensor (node, i);
148168 TF_LITE_ENSURE (context, input != nullptr );
169+ TF_LITE_ENSURE_TYPES_EQ (context, input->type , input_type);
149170 int num_dimensions = NumDimensions (input);
150171
151172 if (num_dimensions > RuntimeShape::kMaxSmallSize ) {
@@ -155,62 +176,53 @@ TfLiteStatus ConcatenationPrepare(TfLiteContext* context, TfLiteNode* node) {
155176 RuntimeShape::kMaxSmallSize , num_dimensions);
156177 return kTfLiteError ;
157178 }
179+
180+ if (input_type == kTfLiteInt8 ) {
181+ // Make sure there is no re-scaling needed for Int8 quantized kernel. This
182+ // is a restriction we introduced to Int8 kernels.
183+ TF_LITE_ENSURE_EQ (context, static_cast <double >(input->params .scale ),
184+ static_cast <double >(output_tensor->params .scale ));
185+ TF_LITE_ENSURE_EQ (context, input->params .zero_point ,
186+ output_tensor->params .zero_point );
187+ } else if (input_type == kTfLiteInt16 ) {
188+ // Make sure that all Int16 inputs have a null zero-point.
189+ TF_LITE_ENSURE_EQ (context, input->params .zero_point , 0 );
190+ }
191+
192+ #ifdef USE_TFLM_COMPRESSION
193+
194+ // Compression scratch buffers.
195+ // These will only be allocated if the tensor is compressed.
196+ data->scratch_indices [i] =
197+ micro_context->AllocateDecompressionScratchBuffer (node, i);
198+
199+ #endif // USE_TFLM_COMPRESSION
200+
158201 micro_context->DeallocateTempTfLiteTensor (input);
159202 }
160203
161- // Calculate OpData.
162- TFLITE_DCHECK (node->user_data != nullptr );
163- OpData* data = static_cast <OpData*>(node->user_data );
164-
165- TfLiteTensor* output =
166- micro_context->AllocateTempOutputTensor (node, kOutputTensor );
167- TF_LITE_ENSURE (context, output != nullptr );
204+ if (input_type == kTfLiteInt16 ) {
205+ TF_LITE_ENSURE_EQ (context, output_tensor->params .zero_point , 0 );
206+ }
168207
169208 switch (output_type) { // Already know in/outtypes are same.
170209 case kTfLiteBool :
171210 case kTfLiteFloat32 :
211+ case kTfLiteInt8 :
172212 case kTfLiteInt16 :
173213 case kTfLiteInt32 :
174214 case kTfLiteInt64 : {
175- data->params .axis = CalculatePositiveAxis (params->axis , output);
176- data->params .inputs_count = node->inputs ->size ;
177- break ;
178- }
179- case kTfLiteInt8 : {
180- data->params .axis = CalculatePositiveAxis (params->axis , output);
215+ data->params .axis = CalculatePositiveAxis (params->axis , output_tensor);
181216 data->params .inputs_count = node->inputs ->size ;
182-
183- float * input_scales =
184- reinterpret_cast <float *>(context->AllocatePersistentBuffer (
185- context, node->inputs ->size * sizeof (float )));
186-
187- int32_t * input_zero_points =
188- reinterpret_cast <int32_t *>(context->AllocatePersistentBuffer (
189- context, node->inputs ->size * sizeof (int32_t )));
190-
191- // Allocate persistent scale and zeropoint buffers.
192- // Store input scale and zero point values in OpParams:
193- for (int i = 0 ; i < node->inputs ->size ; ++i) {
194- TfLiteTensor* t = micro_context->AllocateTempInputTensor (node, i);
195- TF_LITE_ENSURE (context, t != nullptr );
196- input_scales[i] = t->params .scale ;
197- input_zero_points[i] = t->params .zero_point ;
198- micro_context->DeallocateTempTfLiteTensor (t);
199- }
200-
201- data->params .input_scale = input_scales;
202- data->params .input_zeropoint = input_zero_points;
203- data->params .output_zeropoint = output->params .zero_point ;
204- data->params .output_scale = output->params .scale ;
205217 break ;
206218 }
207219 default :
208- MicroPrintf (" Op Concatenation does not currently support Type '%s'." ,
220+ MicroPrintf (" Op Concatenation does not currently support type '%s'." ,
209221 TfLiteTypeGetName (output_type));
210222 return kTfLiteError ;
211223 }
212224
213- micro_context->DeallocateTempTfLiteTensor (output );
225+ micro_context->DeallocateTempTfLiteTensor (output_tensor );
214226
215227 return kTfLiteOk ;
216228}
0 commit comments