1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_ADD_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_ADD_H_
17
18 #include <algorithm>
19
20 #include "fixedpoint/fixedpoint.h"
21 #include "ruy/profiler/instrumentation.h" // from @ruy
22 #include "tensorflow/lite/kernels/internal/common.h"
23 #include "tensorflow/lite/kernels/internal/compatibility.h"
24 #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h"
25 #include "tensorflow/lite/kernels/internal/optimized/neon_check.h"
26 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
27 #include "tensorflow/lite/kernels/internal/reference/integer_ops/add.h"
28 #include "tensorflow/lite/kernels/internal/types.h"
29
30 namespace tflite {
31 namespace optimized_integer_ops {
32
33 // Element-wise add that can often be used for inner loop of broadcast add as
34 // well as the non-broadcast add.
AddElementwise(int size,const ArithmeticParams & params,const int8 * input1_data,const int8 * input2_data,int8 * output_data)35 inline void AddElementwise(int size, const ArithmeticParams& params,
36 const int8* input1_data, const int8* input2_data,
37 int8* output_data) {
38 ruy::profiler::ScopeLabel label("AddElementwiseInt8/8bit");
39 int i = 0;
40 TFLITE_DCHECK_GT(params.input1_offset, -256);
41 TFLITE_DCHECK_GT(params.input2_offset, -256);
42 TFLITE_DCHECK_LT(params.input1_offset, 256);
43 TFLITE_DCHECK_LT(params.input2_offset, 256);
44
45 #ifdef USE_NEON
46 const int8x16_t output_activation_min_vector =
47 vdupq_n_s8(params.quantized_activation_min);
48 const int8x16_t output_activation_max_vector =
49 vdupq_n_s8(params.quantized_activation_max);
50
51 const int input1_left_shift = params.left_shift + params.input1_shift;
52 const int input2_left_shift = params.left_shift + params.input2_shift;
53 const int32x4_t input1_left_dup = vdupq_n_s32(input1_left_shift);
54 const int32x4_t input2_left_dup = vdupq_n_s32(input2_left_shift);
55
56 const int16x8_t input1_offset_dup = vdupq_n_s16(params.input1_offset);
57 const int16x8_t input2_offset_dup = vdupq_n_s16(params.input2_offset);
58
59 for (; i <= size - 16; i += 16) {
60 const int8x16_t input1_val_original = vld1q_s8(input1_data + i);
61 const int8x16_t input2_val_original = vld1q_s8(input2_data + i);
62
63 const int16x8_t input1_val_s16_high =
64 vmovl_s8(vget_high_s8(input1_val_original));
65 const int16x8_t input1_val_s16_low =
66 vmovl_s8(vget_low_s8(input1_val_original));
67
68 const int16x8_t input2_val_s16_high =
69 vmovl_s8(vget_high_s8(input2_val_original));
70 const int16x8_t input2_val_s16_low =
71 vmovl_s8(vget_low_s8(input2_val_original));
72 const int16x8_t input1_val_high =
73 vaddq_s16(input1_val_s16_high, input1_offset_dup);
74 const int16x8_t input2_val_high =
75 vaddq_s16(input2_val_s16_high, input2_offset_dup);
76 const int16x8_t input1_val_low =
77 vaddq_s16(input1_val_s16_low, input1_offset_dup);
78 const int16x8_t input2_val_low =
79 vaddq_s16(input2_val_s16_low, input2_offset_dup);
80 const int16x4_t input1_val_high_high = vget_high_s16(input1_val_high);
81 const int16x4_t input1_val_high_low = vget_low_s16(input1_val_high);
82 const int16x4_t input1_val_low_high = vget_high_s16(input1_val_low);
83 const int16x4_t input1_val_low_low = vget_low_s16(input1_val_low);
84 const int16x4_t input2_val_high_high = vget_high_s16(input2_val_high);
85 const int16x4_t input2_val_high_low = vget_low_s16(input2_val_high);
86 const int16x4_t input2_val_low_high = vget_high_s16(input2_val_low);
87 const int16x4_t input2_val_low_low = vget_low_s16(input2_val_low);
88 int32x4_t x111 = vmovl_s16(input1_val_low_low);
89 int32x4_t x112 = vmovl_s16(input1_val_low_high);
90 int32x4_t x121 = vmovl_s16(input1_val_high_low);
91 int32x4_t x122 = vmovl_s16(input1_val_high_high);
92 int32x4_t x211 = vmovl_s16(input2_val_low_low);
93 int32x4_t x212 = vmovl_s16(input2_val_low_high);
94 int32x4_t x221 = vmovl_s16(input2_val_high_low);
95 int32x4_t x222 = vmovl_s16(input2_val_high_high);
96
97 x111 = vshlq_s32(x111, input1_left_dup);
98 x112 = vshlq_s32(x112, input1_left_dup);
99 x121 = vshlq_s32(x121, input1_left_dup);
100 x122 = vshlq_s32(x122, input1_left_dup);
101 x211 = vshlq_s32(x211, input2_left_dup);
102 x212 = vshlq_s32(x212, input2_left_dup);
103 x221 = vshlq_s32(x221, input2_left_dup);
104 x222 = vshlq_s32(x222, input2_left_dup);
105 x111 = vqrdmulhq_n_s32(x111, params.input1_multiplier);
106 x112 = vqrdmulhq_n_s32(x112, params.input1_multiplier);
107 x121 = vqrdmulhq_n_s32(x121, params.input1_multiplier);
108 x122 = vqrdmulhq_n_s32(x122, params.input1_multiplier);
109 x211 = vqrdmulhq_n_s32(x211, params.input2_multiplier);
110 x212 = vqrdmulhq_n_s32(x212, params.input2_multiplier);
111 x221 = vqrdmulhq_n_s32(x221, params.input2_multiplier);
112 x222 = vqrdmulhq_n_s32(x222, params.input2_multiplier);
113 int32x4_t s11 = vaddq_s32(x111, x211);
114 int32x4_t s12 = vaddq_s32(x112, x212);
115 int32x4_t s21 = vaddq_s32(x121, x221);
116 int32x4_t s22 = vaddq_s32(x122, x222);
117 s11 = vqrdmulhq_n_s32(s11, params.output_multiplier);
118 s12 = vqrdmulhq_n_s32(s12, params.output_multiplier);
119 s21 = vqrdmulhq_n_s32(s21, params.output_multiplier);
120 s22 = vqrdmulhq_n_s32(s22, params.output_multiplier);
121 using gemmlowp::RoundingDivideByPOT;
122 s11 = RoundingDivideByPOT(s11, -params.output_shift);
123 s12 = RoundingDivideByPOT(s12, -params.output_shift);
124 s21 = RoundingDivideByPOT(s21, -params.output_shift);
125 s22 = RoundingDivideByPOT(s22, -params.output_shift);
126 const int16x4_t s11_narrowed = vmovn_s32(s11);
127 const int16x4_t s12_narrowed = vmovn_s32(s12);
128 const int16x4_t s21_narrowed = vmovn_s32(s21);
129 const int16x4_t s22_narrowed = vmovn_s32(s22);
130 const int16x8_t s1 = vaddq_s16(vcombine_s16(s11_narrowed, s12_narrowed),
131 vdupq_n_s16(params.output_offset));
132 const int16x8_t s2 = vaddq_s16(vcombine_s16(s21_narrowed, s22_narrowed),
133 vdupq_n_s16(params.output_offset));
134 const int8x16_t s = vcombine_s8(vqmovn_s16(s1), vqmovn_s16(s2));
135
136 const int8x16_t clamped =
137 vmaxq_s8(output_activation_min_vector,
138 vminq_s8(output_activation_max_vector, s));
139 vst1q_s8(output_data + i, clamped);
140 }
141 #endif // NEON
142
143 for (; i < size; ++i) {
144 const int32 input1_val = params.input1_offset + input1_data[i];
145 const int32 input2_val = params.input2_offset + input2_data[i];
146 const int32 shifted_input1_val = input1_val * (1 << params.left_shift);
147 const int32 shifted_input2_val = input2_val * (1 << params.left_shift);
148 const int32 scaled_input1_val =
149 MultiplyByQuantizedMultiplierSmallerThanOneExp(
150 shifted_input1_val, params.input1_multiplier, params.input1_shift);
151 const int32 scaled_input2_val =
152 MultiplyByQuantizedMultiplierSmallerThanOneExp(
153 shifted_input2_val, params.input2_multiplier, params.input2_shift);
154 const int32 raw_sum = scaled_input1_val + scaled_input2_val;
155 const int32 raw_output =
156 MultiplyByQuantizedMultiplierSmallerThanOneExp(
157 raw_sum, params.output_multiplier, params.output_shift) +
158 params.output_offset;
159 const int32 clamped_output =
160 std::min(params.quantized_activation_max,
161 std::max(params.quantized_activation_min, raw_output));
162 output_data[i] = static_cast<int8>(clamped_output);
163 }
164 }
165
166 // Scalar-broadcast add that can be used for inner loop of more general
167 // broadcast add, so that, for example, scalar-broadcast with batch will still
168 // be fast.
AddScalarBroadcast(int size,const ArithmeticParams & params,int8 input1_data,const int8 * input2_data,int8 * output_data)169 inline void AddScalarBroadcast(int size, const ArithmeticParams& params,
170 int8 input1_data, const int8* input2_data,
171 int8* output_data) {
172 using gemmlowp::RoundingDivideByPOT;
173
174 ruy::profiler::ScopeLabel label("AddScalarBroadcastInt8/8bit");
175 TFLITE_DCHECK_GT(params.input1_offset, -256);
176 TFLITE_DCHECK_GT(params.input2_offset, -256);
177 TFLITE_DCHECK_LT(params.input1_offset, 256);
178 TFLITE_DCHECK_LT(params.input2_offset, 256);
179
180 int i = 0;
181
182 #ifdef USE_NEON
183 const int32x4_t left_shift_dup = vdupq_n_s32(params.left_shift);
184 const int8x8_t output_activation_min_vector =
185 vdup_n_s8(params.quantized_activation_min);
186 const int8x8_t output_activation_max_vector =
187 vdup_n_s8(params.quantized_activation_max);
188
189 // Process broadcast scalar.
190 const int8x8_t input1_val_original = vdup_n_s8(input1_data);
191 const int16x8_t input1_val_s16 = vmovl_s8(input1_val_original);
192 const int16x8_t input1_val =
193 vaddq_s16(input1_val_s16, vdupq_n_s16(params.input1_offset));
194 const int16x4_t input1_val_high = vget_high_s16(input1_val);
195 const int16x4_t input1_val_low = vget_low_s16(input1_val);
196 int32x4_t x11 = vmovl_s16(input1_val_low);
197 int32x4_t x12 = vmovl_s16(input1_val_high);
198 x11 = vshlq_s32(x11, left_shift_dup);
199 x12 = vshlq_s32(x12, left_shift_dup);
200 x11 = vqrdmulhq_n_s32(x11, params.input1_multiplier);
201 x12 = vqrdmulhq_n_s32(x12, params.input1_multiplier);
202 const int32x4_t input1_shift_dup = vdupq_n_s32(params.input1_shift);
203 x11 = vshlq_s32(x11, input1_shift_dup);
204 x12 = vshlq_s32(x12, input1_shift_dup);
205
206 for (; i <= size - 8; i += 8) {
207 const int8x8_t input2_val_original = vld1_s8(input2_data + i);
208 const int16x8_t input2_val_s16 = vmovl_s8(input2_val_original);
209 const int16x8_t input2_val =
210 vaddq_s16(input2_val_s16, vdupq_n_s16(params.input2_offset));
211 const int16x4_t input2_val_high = vget_high_s16(input2_val);
212 const int16x4_t input2_val_low = vget_low_s16(input2_val);
213 int32x4_t x21 = vmovl_s16(input2_val_low);
214 int32x4_t x22 = vmovl_s16(input2_val_high);
215 x21 = vshlq_s32(x21, left_shift_dup);
216 x22 = vshlq_s32(x22, left_shift_dup);
217 x21 = vqrdmulhq_n_s32(x21, params.input2_multiplier);
218 x22 = vqrdmulhq_n_s32(x22, params.input2_multiplier);
219 const int32x4_t input2_shift_dup = vdupq_n_s32(params.input2_shift);
220 x21 = vshlq_s32(x21, input2_shift_dup);
221 x22 = vshlq_s32(x22, input2_shift_dup);
222 int32x4_t s1 = vaddq_s32(x11, x21);
223 int32x4_t s2 = vaddq_s32(x12, x22);
224 s1 = vqrdmulhq_n_s32(s1, params.output_multiplier);
225 s2 = vqrdmulhq_n_s32(s2, params.output_multiplier);
226 s1 = RoundingDivideByPOT(s1, -params.output_shift);
227 s2 = RoundingDivideByPOT(s2, -params.output_shift);
228 const int16x4_t s1_narrowed = vmovn_s32(s1);
229 const int16x4_t s2_narrowed = vmovn_s32(s2);
230 const int16x8_t s = vaddq_s16(vcombine_s16(s1_narrowed, s2_narrowed),
231 vdupq_n_s16(params.output_offset));
232 const int8x8_t clamped =
233 vmax_s8(output_activation_min_vector,
234 vmin_s8(output_activation_max_vector, vqmovn_s16(s)));
235 vst1_s8(output_data + i, clamped);
236 }
237 #endif // NEON
238
239 if (i < size) {
240 // Process broadcast scalar.
241 const int32 input1_val = params.input1_offset + input1_data;
242 const int32 shifted_input1_val = input1_val * (1 << params.left_shift);
243 const int32 scaled_input1_val =
244 MultiplyByQuantizedMultiplierSmallerThanOneExp(
245 shifted_input1_val, params.input1_multiplier, params.input1_shift);
246
247 for (; i < size; ++i) {
248 const int32 input2_val = params.input2_offset + input2_data[i];
249 const int32 shifted_input2_val = input2_val * (1 << params.left_shift);
250 const int32 scaled_input2_val =
251 MultiplyByQuantizedMultiplierSmallerThanOneExp(
252 shifted_input2_val, params.input2_multiplier,
253 params.input2_shift);
254 const int32 raw_sum = scaled_input1_val + scaled_input2_val;
255 const int32 raw_output =
256 MultiplyByQuantizedMultiplierSmallerThanOneExp(
257 raw_sum, params.output_multiplier, params.output_shift) +
258 params.output_offset;
259 const int32 clamped_output =
260 std::min(params.quantized_activation_max,
261 std::max(params.quantized_activation_min, raw_output));
262 output_data[i] = static_cast<int8>(clamped_output);
263 }
264 }
265 }
266
Add(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int8 * input1_data,const RuntimeShape & input2_shape,const int8 * input2_data,const RuntimeShape & output_shape,int8 * output_data)267 inline void Add(const ArithmeticParams& params,
268 const RuntimeShape& input1_shape, const int8* input1_data,
269 const RuntimeShape& input2_shape, const int8* input2_data,
270 const RuntimeShape& output_shape, int8* output_data) {
271 TFLITE_DCHECK_LE(params.quantized_activation_min,
272 params.quantized_activation_max);
273 ruy::profiler::ScopeLabel label("AddInt8/8bit");
274 const int flat_size =
275 MatchingElementsSize(input1_shape, input2_shape, output_shape);
276
277 TFLITE_DCHECK_GT(params.input1_offset, -256);
278 TFLITE_DCHECK_GT(params.input2_offset, -256);
279 TFLITE_DCHECK_LT(params.input1_offset, 256);
280 TFLITE_DCHECK_LT(params.input2_offset, 256);
281 AddElementwise(flat_size, params, input1_data, input2_data, output_data);
282 }
283
BroadcastAddDispatch(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int8 * input1_data,const RuntimeShape & input2_shape,const int8 * input2_data,const RuntimeShape & output_shape,int8 * output_data)284 inline void BroadcastAddDispatch(const ArithmeticParams& params,
285 const RuntimeShape& input1_shape,
286 const int8* input1_data,
287 const RuntimeShape& input2_shape,
288 const int8* input2_data,
289 const RuntimeShape& output_shape,
290 int8* output_data) {
291 if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast) {
292 return reference_integer_ops::BroadcastAdd4DSlow(
293 params, input1_shape, input1_data, input2_shape, input2_data,
294 output_shape, output_data);
295 }
296
297 optimized_ops::BinaryBroadcastFiveFold(
298 params, input1_shape, input1_data, input2_shape, input2_data,
299 output_shape, output_data, AddElementwise, AddScalarBroadcast);
300 }
301
302 } // namespace optimized_integer_ops
303 } // namespace tflite
304
305 #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_ADD_H_
306