1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PAD_H_
17 #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PAD_H_
18 
19 #include <vector>
20 
21 #include "tensorflow/lite/kernels/internal/types.h"
22 
23 namespace tflite {
24 
25 namespace reference_ops {
26 
27 // TFLite Pad supports activation tensors with up to 5 dimensions.
PadKernelMaxDimensionCount()28 constexpr int PadKernelMaxDimensionCount() { return 5; }
29 
30 // There are two versions of pad: Pad and PadV2.  In PadV2 there is a second
31 // scalar input that provides the padding value.  Therefore pad_value_ptr can be
32 // equivalent to a simple input1_data.  For Pad, it should point to a zero
33 // value.
34 //
35 // Note that two typenames are required, so that T=P=int32_t is considered a
36 // specialization distinct from P=int32_t.
37 template <typename T, typename P>
PadImpl(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const T * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,T * output_data)38 inline void PadImpl(const tflite::PadParams& op_params,
39                     const RuntimeShape& input_shape, const T* input_data,
40                     const P* pad_value_ptr, const RuntimeShape& output_shape,
41                     T* output_data) {
42   const RuntimeShape ext_input_shape =
43       RuntimeShape::ExtendedShape(PadKernelMaxDimensionCount(), input_shape);
44   const RuntimeShape ext_output_shape =
45       RuntimeShape::ExtendedShape(PadKernelMaxDimensionCount(), output_shape);
46   TFLITE_DCHECK_LE(op_params.left_padding_count, PadKernelMaxDimensionCount());
47   TFLITE_DCHECK_LE(op_params.right_padding_count, PadKernelMaxDimensionCount());
48 
49   // Runtime calls are currently fixed at 5 dimensions. Copy inputs so we can
50   // pad them to 5 dims (yes, we are "padding the padding").
51   int left_padding_copy[PadKernelMaxDimensionCount()];
52   for (int i = 0; i < PadKernelMaxDimensionCount(); i++) {
53     left_padding_copy[i] = 0;
54   }
55   for (int i = 0; i < op_params.left_padding_count; ++i) {
56     left_padding_copy[i + PadKernelMaxDimensionCount() -
57                       op_params.left_padding_count] = op_params.left_padding[i];
58   }
59   int right_padding_copy[PadKernelMaxDimensionCount()];
60   for (int i = 0; i < PadKernelMaxDimensionCount(); i++) {
61     right_padding_copy[i] = 0;
62   }
63   for (int i = 0; i < op_params.right_padding_count; ++i) {
64     right_padding_copy[i + PadKernelMaxDimensionCount() -
65                        op_params.right_padding_count] =
66         op_params.right_padding[i];
67   }
68 
69   const int output_batch = ext_output_shape.Dims(0);
70   const int output_plane = ext_output_shape.Dims(1);
71   const int output_height = ext_output_shape.Dims(2);
72   const int output_width = ext_output_shape.Dims(3);
73   const int output_depth = ext_output_shape.Dims(4);
74 
75   const int left_b_padding = left_padding_copy[0];
76   const int left_p_padding = left_padding_copy[1];
77   const int left_h_padding = left_padding_copy[2];
78   const int left_w_padding = left_padding_copy[3];
79   const int left_d_padding = left_padding_copy[4];
80 
81   const int right_b_padding = right_padding_copy[0];
82   const int right_p_padding = right_padding_copy[1];
83   const int right_h_padding = right_padding_copy[2];
84   const int right_w_padding = right_padding_copy[3];
85   const int right_d_padding = right_padding_copy[4];
86 
87   const T pad_value = *pad_value_ptr;
88 
89   const T* in_ptr = input_data;
90   T* out_ptr = output_data;
91   for (int out_b = 0; out_b < output_batch; ++out_b) {
92     for (int out_p = 0; out_p < output_plane; ++out_p) {
93       for (int out_h = 0; out_h < output_height; ++out_h) {
94         for (int out_w = 0; out_w < output_width; ++out_w) {
95           for (int out_d = 0; out_d < output_depth; ++out_d) {
96             if (out_b < left_b_padding ||
97                 out_b >= output_batch - right_b_padding ||
98                 out_p < left_p_padding ||
99                 out_p >= output_plane - right_p_padding ||
100                 out_h < left_h_padding ||
101                 out_h >= output_height - right_h_padding ||
102                 out_w < left_w_padding ||
103                 out_w >= output_width - right_w_padding ||
104                 out_d < left_d_padding ||
105                 out_d >= output_depth - right_d_padding) {
106               *out_ptr++ = pad_value;
107             } else {
108               *out_ptr++ = *in_ptr++;
109             }
110           }
111         }
112       }
113     }
114   }
115 }
116 
117 template <typename T, typename P>
Pad(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const T * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,T * output_data)118 inline void Pad(const tflite::PadParams& op_params,
119                 const RuntimeShape& input_shape, const T* input_data,
120                 const P* pad_value_ptr, const RuntimeShape& output_shape,
121                 T* output_data) {
122   PadImpl(op_params, input_shape, input_data, pad_value_ptr, output_shape,
123           output_data);
124 }
125 
126 // The second (pad-value) input can be int32_t when, say, the first is uint8_t.
127 template <typename T>
Pad(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const T * input_data,const int32_t * pad_value_ptr,const RuntimeShape & output_shape,T * output_data)128 inline void Pad(const tflite::PadParams& op_params,
129                 const RuntimeShape& input_shape, const T* input_data,
130                 const int32_t* pad_value_ptr, const RuntimeShape& output_shape,
131                 T* output_data) {
132   const T converted_pad_value = static_cast<T>(*pad_value_ptr);
133   PadImpl(op_params, input_shape, input_data, &converted_pad_value,
134           output_shape, output_data);
135 }
136 
137 // This version avoids conflicting template matching.
138 template <>
Pad(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const int32_t * input_data,const int32_t * pad_value_ptr,const RuntimeShape & output_shape,int32_t * output_data)139 inline void Pad(const tflite::PadParams& op_params,
140                 const RuntimeShape& input_shape, const int32_t* input_data,
141                 const int32_t* pad_value_ptr, const RuntimeShape& output_shape,
142                 int32_t* output_data) {
143   PadImpl(op_params, input_shape, input_data, pad_value_ptr, output_shape,
144           output_data);
145 }
146 
147 template <typename T, typename P>
PadImageStyle(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const T * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,T * output_data)148 inline void PadImageStyle(const tflite::PadParams& op_params,
149                           const RuntimeShape& input_shape, const T* input_data,
150                           const P* pad_value_ptr,
151                           const RuntimeShape& output_shape, T* output_data) {
152   Pad(op_params, input_shape, input_data, pad_value_ptr, output_shape,
153       output_data);
154 }
155 
156 template <typename P>
PadImageStyle(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const float * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,float * output_data)157 inline void PadImageStyle(const tflite::PadParams& op_params,
158                           const RuntimeShape& input_shape,
159                           const float* input_data, const P* pad_value_ptr,
160                           const RuntimeShape& output_shape,
161                           float* output_data) {
162   Pad(op_params, input_shape, input_data, pad_value_ptr, output_shape,
163       output_data);
164 }
165 
166 }  // namespace reference_ops
167 }  // namespace tflite
168 
169 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PAD_H_
170