1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/experimental/microfrontend/lib/filterbank.h"
16 
17 #include <string.h>
18 
19 #include "tensorflow/lite/experimental/microfrontend/lib/bits.h"
20 
FilterbankConvertFftComplexToEnergy(struct FilterbankState * state,struct complex_int16_t * fft_output,int32_t * energy)21 void FilterbankConvertFftComplexToEnergy(struct FilterbankState* state,
22                                          struct complex_int16_t* fft_output,
23                                          int32_t* energy) {
24   const int end_index = state->end_index;
25   int i;
26   energy += state->start_index;
27   fft_output += state->start_index;
28   for (i = state->start_index; i < end_index; ++i) {
29     const int32_t real = fft_output->real;
30     const int32_t imag = fft_output->imag;
31     fft_output++;
32     const uint32_t mag_squared = (real * real) + (imag * imag);
33     *energy++ = mag_squared;
34   }
35 }
36 
FilterbankAccumulateChannels(struct FilterbankState * state,const int32_t * energy)37 void FilterbankAccumulateChannels(struct FilterbankState* state,
38                                   const int32_t* energy) {
39   uint64_t* work = state->work;
40   uint64_t weight_accumulator = 0;
41   uint64_t unweight_accumulator = 0;
42 
43   const int16_t* channel_frequency_starts = state->channel_frequency_starts;
44   const int16_t* channel_weight_starts = state->channel_weight_starts;
45   const int16_t* channel_widths = state->channel_widths;
46 
47   int num_channels_plus_1 = state->num_channels + 1;
48   int i;
49   for (i = 0; i < num_channels_plus_1; ++i) {
50     const int32_t* magnitudes = energy + *channel_frequency_starts++;
51     const int16_t* weights = state->weights + *channel_weight_starts;
52     const int16_t* unweights = state->unweights + *channel_weight_starts++;
53     const int width = *channel_widths++;
54     int j;
55     for (j = 0; j < width; ++j) {
56       weight_accumulator += *weights++ * ((uint64_t)*magnitudes);
57       unweight_accumulator += *unweights++ * ((uint64_t)*magnitudes);
58       ++magnitudes;
59     }
60     *work++ = weight_accumulator;
61     weight_accumulator = unweight_accumulator;
62     unweight_accumulator = 0;
63   }
64 }
65 
Sqrt32(uint32_t num)66 static uint16_t Sqrt32(uint32_t num) {
67   if (num == 0) {
68     return 0;
69   }
70   uint32_t res = 0;
71   int max_bit_number = 32 - MostSignificantBit32(num);
72   max_bit_number |= 1;
73   uint32_t bit = 1U << (31 - max_bit_number);
74   int iterations = (31 - max_bit_number) / 2 + 1;
75   while (iterations--) {
76     if (num >= res + bit) {
77       num -= res + bit;
78       res = (res >> 1U) + bit;
79     } else {
80       res >>= 1U;
81     }
82     bit >>= 2U;
83   }
84   // Do rounding - if we have the bits.
85   if (num > res && res != 0xFFFF) {
86     ++res;
87   }
88   return res;
89 }
90 
Sqrt64(uint64_t num)91 static uint32_t Sqrt64(uint64_t num) {
92   // Take a shortcut and just use 32 bit operations if the upper word is all
93   // clear. This will cause a slight off by one issue for numbers close to 2^32,
94   // but it probably isn't going to matter (and gives us a big performance win).
95   if ((num >> 32) == 0) {
96     return Sqrt32((uint32_t)num);
97   }
98   uint64_t res = 0;
99   int max_bit_number = 64 - MostSignificantBit64(num);
100   max_bit_number |= 1;
101   uint64_t bit = 1ULL << (63 - max_bit_number);
102   int iterations = (63 - max_bit_number) / 2 + 1;
103   while (iterations--) {
104     if (num >= res + bit) {
105       num -= res + bit;
106       res = (res >> 1U) + bit;
107     } else {
108       res >>= 1U;
109     }
110     bit >>= 2U;
111   }
112   // Do rounding - if we have the bits.
113   if (num > res && res != 0xFFFFFFFFLL) {
114     ++res;
115   }
116   return res;
117 }
118 
FilterbankSqrt(struct FilterbankState * state,int scale_down_shift)119 uint32_t* FilterbankSqrt(struct FilterbankState* state, int scale_down_shift) {
120   const int num_channels = state->num_channels;
121   const uint64_t* work = state->work + 1;
122   // Reuse the work buffer since we're fine clobbering it at this point to hold
123   // the output.
124   uint32_t* output = (uint32_t*)state->work;
125   int i;
126   for (i = 0; i < num_channels; ++i) {
127     *output++ = Sqrt64(*work++) >> scale_down_shift;
128   }
129   return (uint32_t*)state->work;
130 }
131 
FilterbankReset(struct FilterbankState * state)132 void FilterbankReset(struct FilterbankState* state) {
133   memset(state->work, 0, (state->num_channels + 1) * sizeof(*state->work));
134 }
135