1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_ADD_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_ADD_H_
17 
18 #include <algorithm>
19 
20 #include "fixedpoint/fixedpoint.h"
21 #include "ruy/profiler/instrumentation.h"  // from @ruy
22 #include "tensorflow/lite/kernels/internal/common.h"
23 #include "tensorflow/lite/kernels/internal/compatibility.h"
24 #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h"
25 #include "tensorflow/lite/kernels/internal/optimized/neon_check.h"
26 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
27 #include "tensorflow/lite/kernels/internal/reference/integer_ops/add.h"
28 #include "tensorflow/lite/kernels/internal/types.h"
29 
30 namespace tflite {
31 namespace optimized_integer_ops {
32 
33 // Element-wise add that can often be used for inner loop of broadcast add as
34 // well as the non-broadcast add.
AddElementwise(int size,const ArithmeticParams & params,const int8 * input1_data,const int8 * input2_data,int8 * output_data)35 inline void AddElementwise(int size, const ArithmeticParams& params,
36                            const int8* input1_data, const int8* input2_data,
37                            int8* output_data) {
38   ruy::profiler::ScopeLabel label("AddElementwiseInt8/8bit");
39   int i = 0;
40   TFLITE_DCHECK_GT(params.input1_offset, -256);
41   TFLITE_DCHECK_GT(params.input2_offset, -256);
42   TFLITE_DCHECK_LT(params.input1_offset, 256);
43   TFLITE_DCHECK_LT(params.input2_offset, 256);
44 
45 #ifdef USE_NEON
46   const int8x16_t output_activation_min_vector =
47       vdupq_n_s8(params.quantized_activation_min);
48   const int8x16_t output_activation_max_vector =
49       vdupq_n_s8(params.quantized_activation_max);
50 
51   const int input1_left_shift = params.left_shift + params.input1_shift;
52   const int input2_left_shift = params.left_shift + params.input2_shift;
53   const int32x4_t input1_left_dup = vdupq_n_s32(input1_left_shift);
54   const int32x4_t input2_left_dup = vdupq_n_s32(input2_left_shift);
55 
56   const int16x8_t input1_offset_dup = vdupq_n_s16(params.input1_offset);
57   const int16x8_t input2_offset_dup = vdupq_n_s16(params.input2_offset);
58 
59   for (; i <= size - 16; i += 16) {
60     const int8x16_t input1_val_original = vld1q_s8(input1_data + i);
61     const int8x16_t input2_val_original = vld1q_s8(input2_data + i);
62 
63     const int16x8_t input1_val_s16_high =
64         vmovl_s8(vget_high_s8(input1_val_original));
65     const int16x8_t input1_val_s16_low =
66         vmovl_s8(vget_low_s8(input1_val_original));
67 
68     const int16x8_t input2_val_s16_high =
69         vmovl_s8(vget_high_s8(input2_val_original));
70     const int16x8_t input2_val_s16_low =
71         vmovl_s8(vget_low_s8(input2_val_original));
72     const int16x8_t input1_val_high =
73         vaddq_s16(input1_val_s16_high, input1_offset_dup);
74     const int16x8_t input2_val_high =
75         vaddq_s16(input2_val_s16_high, input2_offset_dup);
76     const int16x8_t input1_val_low =
77         vaddq_s16(input1_val_s16_low, input1_offset_dup);
78     const int16x8_t input2_val_low =
79         vaddq_s16(input2_val_s16_low, input2_offset_dup);
80     const int16x4_t input1_val_high_high = vget_high_s16(input1_val_high);
81     const int16x4_t input1_val_high_low = vget_low_s16(input1_val_high);
82     const int16x4_t input1_val_low_high = vget_high_s16(input1_val_low);
83     const int16x4_t input1_val_low_low = vget_low_s16(input1_val_low);
84     const int16x4_t input2_val_high_high = vget_high_s16(input2_val_high);
85     const int16x4_t input2_val_high_low = vget_low_s16(input2_val_high);
86     const int16x4_t input2_val_low_high = vget_high_s16(input2_val_low);
87     const int16x4_t input2_val_low_low = vget_low_s16(input2_val_low);
88     int32x4_t x111 = vmovl_s16(input1_val_low_low);
89     int32x4_t x112 = vmovl_s16(input1_val_low_high);
90     int32x4_t x121 = vmovl_s16(input1_val_high_low);
91     int32x4_t x122 = vmovl_s16(input1_val_high_high);
92     int32x4_t x211 = vmovl_s16(input2_val_low_low);
93     int32x4_t x212 = vmovl_s16(input2_val_low_high);
94     int32x4_t x221 = vmovl_s16(input2_val_high_low);
95     int32x4_t x222 = vmovl_s16(input2_val_high_high);
96 
97     x111 = vshlq_s32(x111, input1_left_dup);
98     x112 = vshlq_s32(x112, input1_left_dup);
99     x121 = vshlq_s32(x121, input1_left_dup);
100     x122 = vshlq_s32(x122, input1_left_dup);
101     x211 = vshlq_s32(x211, input2_left_dup);
102     x212 = vshlq_s32(x212, input2_left_dup);
103     x221 = vshlq_s32(x221, input2_left_dup);
104     x222 = vshlq_s32(x222, input2_left_dup);
105     x111 = vqrdmulhq_n_s32(x111, params.input1_multiplier);
106     x112 = vqrdmulhq_n_s32(x112, params.input1_multiplier);
107     x121 = vqrdmulhq_n_s32(x121, params.input1_multiplier);
108     x122 = vqrdmulhq_n_s32(x122, params.input1_multiplier);
109     x211 = vqrdmulhq_n_s32(x211, params.input2_multiplier);
110     x212 = vqrdmulhq_n_s32(x212, params.input2_multiplier);
111     x221 = vqrdmulhq_n_s32(x221, params.input2_multiplier);
112     x222 = vqrdmulhq_n_s32(x222, params.input2_multiplier);
113     int32x4_t s11 = vaddq_s32(x111, x211);
114     int32x4_t s12 = vaddq_s32(x112, x212);
115     int32x4_t s21 = vaddq_s32(x121, x221);
116     int32x4_t s22 = vaddq_s32(x122, x222);
117     s11 = vqrdmulhq_n_s32(s11, params.output_multiplier);
118     s12 = vqrdmulhq_n_s32(s12, params.output_multiplier);
119     s21 = vqrdmulhq_n_s32(s21, params.output_multiplier);
120     s22 = vqrdmulhq_n_s32(s22, params.output_multiplier);
121     using gemmlowp::RoundingDivideByPOT;
122     s11 = RoundingDivideByPOT(s11, -params.output_shift);
123     s12 = RoundingDivideByPOT(s12, -params.output_shift);
124     s21 = RoundingDivideByPOT(s21, -params.output_shift);
125     s22 = RoundingDivideByPOT(s22, -params.output_shift);
126     const int16x4_t s11_narrowed = vmovn_s32(s11);
127     const int16x4_t s12_narrowed = vmovn_s32(s12);
128     const int16x4_t s21_narrowed = vmovn_s32(s21);
129     const int16x4_t s22_narrowed = vmovn_s32(s22);
130     const int16x8_t s1 = vaddq_s16(vcombine_s16(s11_narrowed, s12_narrowed),
131                                    vdupq_n_s16(params.output_offset));
132     const int16x8_t s2 = vaddq_s16(vcombine_s16(s21_narrowed, s22_narrowed),
133                                    vdupq_n_s16(params.output_offset));
134     const int8x16_t s = vcombine_s8(vqmovn_s16(s1), vqmovn_s16(s2));
135 
136     const int8x16_t clamped =
137         vmaxq_s8(output_activation_min_vector,
138                  vminq_s8(output_activation_max_vector, s));
139     vst1q_s8(output_data + i, clamped);
140   }
141 #endif  // NEON
142 
143   for (; i < size; ++i) {
144     const int32 input1_val = params.input1_offset + input1_data[i];
145     const int32 input2_val = params.input2_offset + input2_data[i];
146     const int32 shifted_input1_val = input1_val * (1 << params.left_shift);
147     const int32 shifted_input2_val = input2_val * (1 << params.left_shift);
148     const int32 scaled_input1_val =
149         MultiplyByQuantizedMultiplierSmallerThanOneExp(
150             shifted_input1_val, params.input1_multiplier, params.input1_shift);
151     const int32 scaled_input2_val =
152         MultiplyByQuantizedMultiplierSmallerThanOneExp(
153             shifted_input2_val, params.input2_multiplier, params.input2_shift);
154     const int32 raw_sum = scaled_input1_val + scaled_input2_val;
155     const int32 raw_output =
156         MultiplyByQuantizedMultiplierSmallerThanOneExp(
157             raw_sum, params.output_multiplier, params.output_shift) +
158         params.output_offset;
159     const int32 clamped_output =
160         std::min(params.quantized_activation_max,
161                  std::max(params.quantized_activation_min, raw_output));
162     output_data[i] = static_cast<int8>(clamped_output);
163   }
164 }
165 
166 // Scalar-broadcast add that can be used for inner loop of more general
167 // broadcast add, so that, for example, scalar-broadcast with batch will still
168 // be fast.
AddScalarBroadcast(int size,const ArithmeticParams & params,int8 input1_data,const int8 * input2_data,int8 * output_data)169 inline void AddScalarBroadcast(int size, const ArithmeticParams& params,
170                                int8 input1_data, const int8* input2_data,
171                                int8* output_data) {
172   using gemmlowp::RoundingDivideByPOT;
173 
174   ruy::profiler::ScopeLabel label("AddScalarBroadcastInt8/8bit");
175   TFLITE_DCHECK_GT(params.input1_offset, -256);
176   TFLITE_DCHECK_GT(params.input2_offset, -256);
177   TFLITE_DCHECK_LT(params.input1_offset, 256);
178   TFLITE_DCHECK_LT(params.input2_offset, 256);
179 
180   int i = 0;
181 
182 #ifdef USE_NEON
183   const int32x4_t left_shift_dup = vdupq_n_s32(params.left_shift);
184   const int8x8_t output_activation_min_vector =
185       vdup_n_s8(params.quantized_activation_min);
186   const int8x8_t output_activation_max_vector =
187       vdup_n_s8(params.quantized_activation_max);
188 
189   // Process broadcast scalar.
190   const int8x8_t input1_val_original = vdup_n_s8(input1_data);
191   const int16x8_t input1_val_s16 = vmovl_s8(input1_val_original);
192   const int16x8_t input1_val =
193       vaddq_s16(input1_val_s16, vdupq_n_s16(params.input1_offset));
194   const int16x4_t input1_val_high = vget_high_s16(input1_val);
195   const int16x4_t input1_val_low = vget_low_s16(input1_val);
196   int32x4_t x11 = vmovl_s16(input1_val_low);
197   int32x4_t x12 = vmovl_s16(input1_val_high);
198   x11 = vshlq_s32(x11, left_shift_dup);
199   x12 = vshlq_s32(x12, left_shift_dup);
200   x11 = vqrdmulhq_n_s32(x11, params.input1_multiplier);
201   x12 = vqrdmulhq_n_s32(x12, params.input1_multiplier);
202   const int32x4_t input1_shift_dup = vdupq_n_s32(params.input1_shift);
203   x11 = vshlq_s32(x11, input1_shift_dup);
204   x12 = vshlq_s32(x12, input1_shift_dup);
205 
206   for (; i <= size - 8; i += 8) {
207     const int8x8_t input2_val_original = vld1_s8(input2_data + i);
208     const int16x8_t input2_val_s16 = vmovl_s8(input2_val_original);
209     const int16x8_t input2_val =
210         vaddq_s16(input2_val_s16, vdupq_n_s16(params.input2_offset));
211     const int16x4_t input2_val_high = vget_high_s16(input2_val);
212     const int16x4_t input2_val_low = vget_low_s16(input2_val);
213     int32x4_t x21 = vmovl_s16(input2_val_low);
214     int32x4_t x22 = vmovl_s16(input2_val_high);
215     x21 = vshlq_s32(x21, left_shift_dup);
216     x22 = vshlq_s32(x22, left_shift_dup);
217     x21 = vqrdmulhq_n_s32(x21, params.input2_multiplier);
218     x22 = vqrdmulhq_n_s32(x22, params.input2_multiplier);
219     const int32x4_t input2_shift_dup = vdupq_n_s32(params.input2_shift);
220     x21 = vshlq_s32(x21, input2_shift_dup);
221     x22 = vshlq_s32(x22, input2_shift_dup);
222     int32x4_t s1 = vaddq_s32(x11, x21);
223     int32x4_t s2 = vaddq_s32(x12, x22);
224     s1 = vqrdmulhq_n_s32(s1, params.output_multiplier);
225     s2 = vqrdmulhq_n_s32(s2, params.output_multiplier);
226     s1 = RoundingDivideByPOT(s1, -params.output_shift);
227     s2 = RoundingDivideByPOT(s2, -params.output_shift);
228     const int16x4_t s1_narrowed = vmovn_s32(s1);
229     const int16x4_t s2_narrowed = vmovn_s32(s2);
230     const int16x8_t s = vaddq_s16(vcombine_s16(s1_narrowed, s2_narrowed),
231                                   vdupq_n_s16(params.output_offset));
232     const int8x8_t clamped =
233         vmax_s8(output_activation_min_vector,
234                 vmin_s8(output_activation_max_vector, vqmovn_s16(s)));
235     vst1_s8(output_data + i, clamped);
236   }
237 #endif  // NEON
238 
239   if (i < size) {
240     // Process broadcast scalar.
241     const int32 input1_val = params.input1_offset + input1_data;
242     const int32 shifted_input1_val = input1_val * (1 << params.left_shift);
243     const int32 scaled_input1_val =
244         MultiplyByQuantizedMultiplierSmallerThanOneExp(
245             shifted_input1_val, params.input1_multiplier, params.input1_shift);
246 
247     for (; i < size; ++i) {
248       const int32 input2_val = params.input2_offset + input2_data[i];
249       const int32 shifted_input2_val = input2_val * (1 << params.left_shift);
250       const int32 scaled_input2_val =
251           MultiplyByQuantizedMultiplierSmallerThanOneExp(
252               shifted_input2_val, params.input2_multiplier,
253               params.input2_shift);
254       const int32 raw_sum = scaled_input1_val + scaled_input2_val;
255       const int32 raw_output =
256           MultiplyByQuantizedMultiplierSmallerThanOneExp(
257               raw_sum, params.output_multiplier, params.output_shift) +
258           params.output_offset;
259       const int32 clamped_output =
260           std::min(params.quantized_activation_max,
261                    std::max(params.quantized_activation_min, raw_output));
262       output_data[i] = static_cast<int8>(clamped_output);
263     }
264   }
265 }
266 
Add(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int8 * input1_data,const RuntimeShape & input2_shape,const int8 * input2_data,const RuntimeShape & output_shape,int8 * output_data)267 inline void Add(const ArithmeticParams& params,
268                 const RuntimeShape& input1_shape, const int8* input1_data,
269                 const RuntimeShape& input2_shape, const int8* input2_data,
270                 const RuntimeShape& output_shape, int8* output_data) {
271   TFLITE_DCHECK_LE(params.quantized_activation_min,
272                    params.quantized_activation_max);
273   ruy::profiler::ScopeLabel label("AddInt8/8bit");
274   const int flat_size =
275       MatchingElementsSize(input1_shape, input2_shape, output_shape);
276 
277   TFLITE_DCHECK_GT(params.input1_offset, -256);
278   TFLITE_DCHECK_GT(params.input2_offset, -256);
279   TFLITE_DCHECK_LT(params.input1_offset, 256);
280   TFLITE_DCHECK_LT(params.input2_offset, 256);
281   AddElementwise(flat_size, params, input1_data, input2_data, output_data);
282 }
283 
BroadcastAddDispatch(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int8 * input1_data,const RuntimeShape & input2_shape,const int8 * input2_data,const RuntimeShape & output_shape,int8 * output_data)284 inline void BroadcastAddDispatch(const ArithmeticParams& params,
285                                  const RuntimeShape& input1_shape,
286                                  const int8* input1_data,
287                                  const RuntimeShape& input2_shape,
288                                  const int8* input2_data,
289                                  const RuntimeShape& output_shape,
290                                  int8* output_data) {
291   if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast) {
292     return reference_integer_ops::BroadcastAdd4DSlow(
293         params, input1_shape, input1_data, input2_shape, input2_data,
294         output_shape, output_data);
295   }
296 
297   optimized_ops::BinaryBroadcastFiveFold(
298       params, input1_shape, input1_data, input2_shape, input2_data,
299       output_shape, output_data, AddElementwise, AddScalarBroadcast);
300 }
301 
302 }  // namespace optimized_integer_ops
303 }  // namespace tflite
304 
305 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_ADD_H_
306