1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/internal/reference/pad.h"
16
17 #include <string.h>
18
19 #include "tensorflow/lite/c/builtin_op_data.h"
20 #include "tensorflow/lite/c/common.h"
21 #include "tensorflow/lite/kernels/internal/portable_tensor.h"
22 #include "tensorflow/lite/kernels/internal/types.h"
23 #include "tensorflow/lite/kernels/kernel_util.h"
24 #include "tensorflow/lite/kernels/op_macros.h"
25 #include "tensorflow/lite/micro/kernels/kernel_util.h"
26
27 namespace tflite {
28 namespace ops {
29 namespace micro {
30 namespace pad {
31 namespace {
32
33 struct OpData {
34 PadParams params;
35 int32_t output_zero_point;
36 };
37
38 } // namespace
39
Init(TfLiteContext * context,const char * buffer,size_t length)40 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
41 TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
42 return context->AllocatePersistentBuffer(context, sizeof(OpData));
43 }
44
Prepare(TfLiteContext * context,TfLiteNode * node)45 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
46 TFLITE_DCHECK(node->user_data != nullptr);
47 OpData* data = static_cast<OpData*>(node->user_data);
48
49 TF_LITE_ENSURE(context, NumInputs(node) == 2 || NumInputs(node) == 3);
50 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
51
52 const TfLiteTensor* input = GetInput(context, node, /*index=*/0);
53 TF_LITE_ENSURE(context, input != nullptr);
54 const TfLiteTensor* paddings = GetInput(context, node, /*index=*/1);
55 TF_LITE_ENSURE(context, paddings != nullptr);
56 const TfLiteTensor* constant_values =
57 NumInputs(node) == 3 ? GetInput(context, node, /*index=*/2) : nullptr;
58 TfLiteTensor* output = GetOutput(context, node, /*index=*/0);
59 TF_LITE_ENSURE(context, output != nullptr);
60
61 TF_LITE_ENSURE_EQ(context, input->type, output->type);
62
63 // Current implementations rely on the inputs being <= 4D.
64 TF_LITE_ENSURE(context, NumDimensions(input) <=
65 reference_ops::PadKernelMaxDimensionCount());
66
67 if (constant_values != nullptr) {
68 TF_LITE_ENSURE_EQ(context, input->type, constant_values->type);
69 // Ensure that constant_values is a scalar.
70 TF_LITE_ENSURE_EQ(context, NumElements(constant_values), 1);
71 }
72
73 // There must be a pair of paddings for each output dimension.
74 TF_LITE_ENSURE_EQ(context, GetTensorShape(paddings).FlatSize(),
75 output->dims->size * 2);
76
77 // On Micro, outputs must be properly sized by the converter.
78 // NOTE: This data is only available because the paddings buffer is stored in
79 // the flatbuffer:
80 TF_LITE_ENSURE(context, IsConstantTensor(paddings));
81 const int32_t* paddings_data = GetTensorData<int32_t>(paddings);
82 for (int i = 0; i < output->dims->size; i++) {
83 int output_dim = output->dims->data[i];
84 int expected_dim =
85 input->dims->data[i] + paddings_data[i * 2] + paddings_data[i * 2 + 1];
86 TF_LITE_ENSURE_EQ(context, output_dim, expected_dim);
87 }
88
89 // Calculate OpData:
90 data->params.resizing_category = ResizingCategory::kGenericResize;
91 const int paddings_total = GetTensorShape(paddings).FlatSize();
92 if (paddings_total == 8 && (paddings_data[0] == 0 && paddings_data[1] == 0) &&
93 (paddings_data[6] == 0 && paddings_data[7] == 0)) {
94 data->params.resizing_category = ResizingCategory::kImageStyle;
95 }
96
97 const int num_input_dimensions = NumDimensions(input);
98 data->params.left_padding_count = num_input_dimensions;
99 data->params.right_padding_count = num_input_dimensions;
100
101 for (int idx = num_input_dimensions - 1; idx >= 0; --idx) {
102 data->params.left_padding[idx] = paddings_data[idx * 2];
103 data->params.right_padding[idx] = paddings_data[idx * 2 + 1];
104 }
105
106 if (input->type == kTfLiteInt8 || input->type == kTfLiteUInt8) {
107 if (constant_values == nullptr) {
108 // Quantized Pad requires that 0 is represented in the quantized
109 // range.
110 if (input->type == kTfLiteUInt8) {
111 TF_LITE_ENSURE(context, output->params.zero_point >=
112 std::numeric_limits<uint8_t>::min());
113 TF_LITE_ENSURE(context, output->params.zero_point <=
114 std::numeric_limits<uint8_t>::max());
115 } else {
116 TF_LITE_ENSURE(context, output->params.zero_point >=
117 std::numeric_limits<int8_t>::min());
118 TF_LITE_ENSURE(context, output->params.zero_point <=
119 std::numeric_limits<int8_t>::max());
120 }
121 } else {
122 // Quantized Pad requires that 'constant_values' is represented in the
123 // same quantized range as the input and output tensors.
124 TF_LITE_ENSURE_EQ(context, output->params.zero_point,
125 constant_values->params.zero_point);
126 TF_LITE_ENSURE_EQ(context, static_cast<double>(output->params.scale),
127 static_cast<double>(constant_values->params.scale));
128 }
129 data->output_zero_point = output->params.zero_point;
130 }
131
132 return kTfLiteOk;
133 }
134
Eval(TfLiteContext * context,TfLiteNode * node)135 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
136 TFLITE_DCHECK(node->user_data != nullptr);
137 const OpData* data = static_cast<const OpData*>(node->user_data);
138
139 const TfLiteEvalTensor* input =
140 tflite::micro::GetEvalInput(context, node, /*index=*/0);
141 const TfLiteEvalTensor* constant_values =
142 NumInputs(node) == 3
143 ? tflite::micro::GetEvalInput(context, node, /*index=*/2)
144 : nullptr;
145 TfLiteEvalTensor* output =
146 tflite::micro::GetEvalOutput(context, node, /*index=*/0);
147
148 switch (input->type) {
149 case kTfLiteFloat32: {
150 float pad_value =
151 constant_values == nullptr
152 ? 0.f
153 : *tflite::micro::GetTensorData<float>(constant_values);
154 if (data->params.resizing_category == ResizingCategory::kImageStyle) {
155 reference_ops::PadImageStyle(
156 data->params, tflite::micro::GetTensorShape(input),
157 tflite::micro::GetTensorData<float>(input), &pad_value,
158 tflite::micro::GetTensorShape(output),
159 tflite::micro::GetTensorData<float>(output));
160 } else {
161 reference_ops::Pad(data->params, tflite::micro::GetTensorShape(input),
162 tflite::micro::GetTensorData<float>(input),
163 &pad_value, tflite::micro::GetTensorShape(output),
164 tflite::micro::GetTensorData<float>(output));
165 }
166 } break;
167 case kTfLiteUInt8: {
168 uint8_t pad_value;
169 if (constant_values == nullptr) {
170 pad_value = static_cast<uint8_t>(data->output_zero_point);
171 } else {
172 pad_value = *tflite::micro::GetTensorData<uint8_t>(constant_values);
173 }
174 if (data->params.resizing_category == ResizingCategory::kImageStyle) {
175 reference_ops::PadImageStyle(
176 data->params, tflite::micro::GetTensorShape(input),
177 tflite::micro::GetTensorData<uint8_t>(input), &pad_value,
178 tflite::micro::GetTensorShape(output),
179 tflite::micro::GetTensorData<uint8_t>(output));
180 } else {
181 reference_ops::Pad(data->params, tflite::micro::GetTensorShape(input),
182 tflite::micro::GetTensorData<uint8_t>(input),
183 &pad_value, tflite::micro::GetTensorShape(output),
184 tflite::micro::GetTensorData<uint8_t>(output));
185 }
186 } break;
187 case kTfLiteInt8: {
188 int8_t pad_value;
189 if (constant_values == nullptr) {
190 pad_value = static_cast<uint8_t>(data->output_zero_point);
191 } else {
192 pad_value = *tflite::micro::GetTensorData<int8_t>(constant_values);
193 }
194 if (data->params.resizing_category == ResizingCategory::kImageStyle) {
195 reference_ops::PadImageStyle(
196 data->params, tflite::micro::GetTensorShape(input),
197 tflite::micro::GetTensorData<int8_t>(input), &pad_value,
198 tflite::micro::GetTensorShape(output),
199 tflite::micro::GetTensorData<int8_t>(output));
200 } else {
201 reference_ops::Pad(data->params, tflite::micro::GetTensorShape(input),
202 tflite::micro::GetTensorData<int8_t>(input),
203 &pad_value, tflite::micro::GetTensorShape(output),
204 tflite::micro::GetTensorData<int8_t>(output));
205 }
206 } break;
207 case kTfLiteInt32: {
208 int32_t pad_value =
209 constant_values == nullptr
210 ? 0
211 : *tflite::micro::GetTensorData<int32_t>(constant_values);
212 reference_ops::Pad(data->params, tflite::micro::GetTensorShape(input),
213 tflite::micro::GetTensorData<int32_t>(input),
214 &pad_value, tflite::micro::GetTensorShape(output),
215 tflite::micro::GetTensorData<int32_t>(output));
216 } break;
217 default:
218
219 TF_LITE_KERNEL_LOG(context, "Type %s not currently supported by Pad.",
220 TfLiteTypeGetName(input->type));
221 return kTfLiteError;
222 }
223 #undef TF_LITE_PAD
224 return kTfLiteOk;
225 }
226
227 } // namespace pad
228
Register_PAD()229 TfLiteRegistration Register_PAD() {
230 return {/*init=*/pad::Init,
231 /*free=*/nullptr,
232 /*prepare=*/pad::Prepare,
233 /*invoke=*/pad::Eval,
234 /*profiling_string=*/nullptr,
235 /*builtin_code=*/0,
236 /*custom_name=*/nullptr,
237 /*version=*/0};
238 }
239
240 // Also register Pad as PadV2.
Register_PADV2()241 TfLiteRegistration Register_PADV2() {
242 return {/*init=*/pad::Init,
243 /*free=*/nullptr,
244 /*prepare=*/pad::Prepare,
245 /*invoke=*/pad::Eval,
246 /*profiling_string=*/nullptr,
247 /*builtin_code=*/0,
248 /*custom_name=*/nullptr,
249 /*version=*/0};
250 }
251
252 } // namespace micro
253 } // namespace ops
254 } // namespace tflite
255