1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/internal/reference/pad.h"
16 
17 #include <string.h>
18 
19 #include "tensorflow/lite/c/builtin_op_data.h"
20 #include "tensorflow/lite/c/common.h"
21 #include "tensorflow/lite/kernels/internal/portable_tensor.h"
22 #include "tensorflow/lite/kernels/internal/types.h"
23 #include "tensorflow/lite/kernels/kernel_util.h"
24 #include "tensorflow/lite/kernels/op_macros.h"
25 #include "tensorflow/lite/micro/kernels/kernel_util.h"
26 
27 namespace tflite {
28 namespace ops {
29 namespace micro {
30 namespace pad {
31 namespace {
32 
33 struct OpData {
34   PadParams params;
35   int32_t output_zero_point;
36 };
37 
38 }  // namespace
39 
Init(TfLiteContext * context,const char * buffer,size_t length)40 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
41   TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
42   return context->AllocatePersistentBuffer(context, sizeof(OpData));
43 }
44 
Prepare(TfLiteContext * context,TfLiteNode * node)45 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
46   TFLITE_DCHECK(node->user_data != nullptr);
47   OpData* data = static_cast<OpData*>(node->user_data);
48 
49   TF_LITE_ENSURE(context, NumInputs(node) == 2 || NumInputs(node) == 3);
50   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
51 
52   const TfLiteTensor* input = GetInput(context, node, /*index=*/0);
53   TF_LITE_ENSURE(context, input != nullptr);
54   const TfLiteTensor* paddings = GetInput(context, node, /*index=*/1);
55   TF_LITE_ENSURE(context, paddings != nullptr);
56   const TfLiteTensor* constant_values =
57       NumInputs(node) == 3 ? GetInput(context, node, /*index=*/2) : nullptr;
58   TfLiteTensor* output = GetOutput(context, node, /*index=*/0);
59   TF_LITE_ENSURE(context, output != nullptr);
60 
61   TF_LITE_ENSURE_EQ(context, input->type, output->type);
62 
63   // Current implementations rely on the inputs being <= 4D.
64   TF_LITE_ENSURE(context, NumDimensions(input) <=
65                               reference_ops::PadKernelMaxDimensionCount());
66 
67   if (constant_values != nullptr) {
68     TF_LITE_ENSURE_EQ(context, input->type, constant_values->type);
69     // Ensure that constant_values is a scalar.
70     TF_LITE_ENSURE_EQ(context, NumElements(constant_values), 1);
71   }
72 
73   // There must be a pair of paddings for each output dimension.
74   TF_LITE_ENSURE_EQ(context, GetTensorShape(paddings).FlatSize(),
75                     output->dims->size * 2);
76 
77   // On Micro, outputs must be properly sized by the converter.
78   // NOTE: This data is only available because the paddings buffer is stored in
79   // the flatbuffer:
80   TF_LITE_ENSURE(context, IsConstantTensor(paddings));
81   const int32_t* paddings_data = GetTensorData<int32_t>(paddings);
82   for (int i = 0; i < output->dims->size; i++) {
83     int output_dim = output->dims->data[i];
84     int expected_dim =
85         input->dims->data[i] + paddings_data[i * 2] + paddings_data[i * 2 + 1];
86     TF_LITE_ENSURE_EQ(context, output_dim, expected_dim);
87   }
88 
89   // Calculate OpData:
90   data->params.resizing_category = ResizingCategory::kGenericResize;
91   const int paddings_total = GetTensorShape(paddings).FlatSize();
92   if (paddings_total == 8 && (paddings_data[0] == 0 && paddings_data[1] == 0) &&
93       (paddings_data[6] == 0 && paddings_data[7] == 0)) {
94     data->params.resizing_category = ResizingCategory::kImageStyle;
95   }
96 
97   const int num_input_dimensions = NumDimensions(input);
98   data->params.left_padding_count = num_input_dimensions;
99   data->params.right_padding_count = num_input_dimensions;
100 
101   for (int idx = num_input_dimensions - 1; idx >= 0; --idx) {
102     data->params.left_padding[idx] = paddings_data[idx * 2];
103     data->params.right_padding[idx] = paddings_data[idx * 2 + 1];
104   }
105 
106   if (input->type == kTfLiteInt8 || input->type == kTfLiteUInt8) {
107     if (constant_values == nullptr) {
108       // Quantized Pad requires that 0 is represented in the quantized
109       // range.
110       if (input->type == kTfLiteUInt8) {
111         TF_LITE_ENSURE(context, output->params.zero_point >=
112                                     std::numeric_limits<uint8_t>::min());
113         TF_LITE_ENSURE(context, output->params.zero_point <=
114                                     std::numeric_limits<uint8_t>::max());
115       } else {
116         TF_LITE_ENSURE(context, output->params.zero_point >=
117                                     std::numeric_limits<int8_t>::min());
118         TF_LITE_ENSURE(context, output->params.zero_point <=
119                                     std::numeric_limits<int8_t>::max());
120       }
121     } else {
122       // Quantized Pad requires that 'constant_values' is represented in the
123       // same quantized range as the input and output tensors.
124       TF_LITE_ENSURE_EQ(context, output->params.zero_point,
125                         constant_values->params.zero_point);
126       TF_LITE_ENSURE_EQ(context, static_cast<double>(output->params.scale),
127                         static_cast<double>(constant_values->params.scale));
128     }
129     data->output_zero_point = output->params.zero_point;
130   }
131 
132   return kTfLiteOk;
133 }
134 
Eval(TfLiteContext * context,TfLiteNode * node)135 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
136   TFLITE_DCHECK(node->user_data != nullptr);
137   const OpData* data = static_cast<const OpData*>(node->user_data);
138 
139   const TfLiteEvalTensor* input =
140       tflite::micro::GetEvalInput(context, node, /*index=*/0);
141   const TfLiteEvalTensor* constant_values =
142       NumInputs(node) == 3
143           ? tflite::micro::GetEvalInput(context, node, /*index=*/2)
144           : nullptr;
145   TfLiteEvalTensor* output =
146       tflite::micro::GetEvalOutput(context, node, /*index=*/0);
147 
148   switch (input->type) {
149     case kTfLiteFloat32: {
150       float pad_value =
151           constant_values == nullptr
152               ? 0.f
153               : *tflite::micro::GetTensorData<float>(constant_values);
154       if (data->params.resizing_category == ResizingCategory::kImageStyle) {
155         reference_ops::PadImageStyle(
156             data->params, tflite::micro::GetTensorShape(input),
157             tflite::micro::GetTensorData<float>(input), &pad_value,
158             tflite::micro::GetTensorShape(output),
159             tflite::micro::GetTensorData<float>(output));
160       } else {
161         reference_ops::Pad(data->params, tflite::micro::GetTensorShape(input),
162                            tflite::micro::GetTensorData<float>(input),
163                            &pad_value, tflite::micro::GetTensorShape(output),
164                            tflite::micro::GetTensorData<float>(output));
165       }
166     } break;
167     case kTfLiteUInt8: {
168       uint8_t pad_value;
169       if (constant_values == nullptr) {
170         pad_value = static_cast<uint8_t>(data->output_zero_point);
171       } else {
172         pad_value = *tflite::micro::GetTensorData<uint8_t>(constant_values);
173       }
174       if (data->params.resizing_category == ResizingCategory::kImageStyle) {
175         reference_ops::PadImageStyle(
176             data->params, tflite::micro::GetTensorShape(input),
177             tflite::micro::GetTensorData<uint8_t>(input), &pad_value,
178             tflite::micro::GetTensorShape(output),
179             tflite::micro::GetTensorData<uint8_t>(output));
180       } else {
181         reference_ops::Pad(data->params, tflite::micro::GetTensorShape(input),
182                            tflite::micro::GetTensorData<uint8_t>(input),
183                            &pad_value, tflite::micro::GetTensorShape(output),
184                            tflite::micro::GetTensorData<uint8_t>(output));
185       }
186     } break;
187     case kTfLiteInt8: {
188       int8_t pad_value;
189       if (constant_values == nullptr) {
190         pad_value = static_cast<uint8_t>(data->output_zero_point);
191       } else {
192         pad_value = *tflite::micro::GetTensorData<int8_t>(constant_values);
193       }
194       if (data->params.resizing_category == ResizingCategory::kImageStyle) {
195         reference_ops::PadImageStyle(
196             data->params, tflite::micro::GetTensorShape(input),
197             tflite::micro::GetTensorData<int8_t>(input), &pad_value,
198             tflite::micro::GetTensorShape(output),
199             tflite::micro::GetTensorData<int8_t>(output));
200       } else {
201         reference_ops::Pad(data->params, tflite::micro::GetTensorShape(input),
202                            tflite::micro::GetTensorData<int8_t>(input),
203                            &pad_value, tflite::micro::GetTensorShape(output),
204                            tflite::micro::GetTensorData<int8_t>(output));
205       }
206     } break;
207     case kTfLiteInt32: {
208       int32_t pad_value =
209           constant_values == nullptr
210               ? 0
211               : *tflite::micro::GetTensorData<int32_t>(constant_values);
212       reference_ops::Pad(data->params, tflite::micro::GetTensorShape(input),
213                          tflite::micro::GetTensorData<int32_t>(input),
214                          &pad_value, tflite::micro::GetTensorShape(output),
215                          tflite::micro::GetTensorData<int32_t>(output));
216     } break;
217     default:
218 
219       TF_LITE_KERNEL_LOG(context, "Type %s not currently supported by Pad.",
220                          TfLiteTypeGetName(input->type));
221       return kTfLiteError;
222   }
223 #undef TF_LITE_PAD
224   return kTfLiteOk;
225 }
226 
227 }  // namespace pad
228 
Register_PAD()229 TfLiteRegistration Register_PAD() {
230   return {/*init=*/pad::Init,
231           /*free=*/nullptr,
232           /*prepare=*/pad::Prepare,
233           /*invoke=*/pad::Eval,
234           /*profiling_string=*/nullptr,
235           /*builtin_code=*/0,
236           /*custom_name=*/nullptr,
237           /*version=*/0};
238 }
239 
240 // Also register Pad as PadV2.
Register_PADV2()241 TfLiteRegistration Register_PADV2() {
242   return {/*init=*/pad::Init,
243           /*free=*/nullptr,
244           /*prepare=*/pad::Prepare,
245           /*invoke=*/pad::Eval,
246           /*profiling_string=*/nullptr,
247           /*builtin_code=*/0,
248           /*custom_name=*/nullptr,
249           /*version=*/0};
250 }
251 
252 }  // namespace micro
253 }  // namespace ops
254 }  // namespace tflite
255