1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/experimental/microfrontend/lib/filterbank_util.h"
16 
17 #include <assert.h>
18 #include <math.h>
19 #include <stdio.h>
20 
21 #define kFilterbankIndexAlignment 4
22 #define kFilterbankChannelBlockSize 4
23 
FilterbankFillConfigWithDefaults(struct FilterbankConfig * config)24 void FilterbankFillConfigWithDefaults(struct FilterbankConfig* config) {
25   config->num_channels = 32;
26   config->lower_band_limit = 125.0f;
27   config->upper_band_limit = 7500.0f;
28   config->output_scale_shift = 7;
29 }
30 
FreqToMel(float freq)31 static float FreqToMel(float freq) { return 1127.0 * log1p(freq / 700.0); }
32 
CalculateCenterFrequencies(const int num_channels,const float lower_frequency_limit,const float upper_frequency_limit,float * center_frequencies)33 static void CalculateCenterFrequencies(const int num_channels,
34                                        const float lower_frequency_limit,
35                                        const float upper_frequency_limit,
36                                        float* center_frequencies) {
37   assert(lower_frequency_limit >= 0.0f);
38   assert(upper_frequency_limit > lower_frequency_limit);
39 
40   const float mel_low = FreqToMel(lower_frequency_limit);
41   const float mel_hi = FreqToMel(upper_frequency_limit);
42   const float mel_span = mel_hi - mel_low;
43   const float mel_spacing = mel_span / ((float)num_channels);
44   int i;
45   for (i = 0; i < num_channels; ++i) {
46     center_frequencies[i] = mel_low + (mel_spacing * (i + 1));
47   }
48 }
49 
QuantizeFilterbankWeights(const float float_weight,int16_t * weight,int16_t * unweight)50 static void QuantizeFilterbankWeights(const float float_weight, int16_t* weight,
51                                       int16_t* unweight) {
52   *weight = floor(float_weight * (1 << kFilterbankBits) + 0.5);
53   *unweight = floor((1.0 - float_weight) * (1 << kFilterbankBits) + 0.5);
54 }
55 
FilterbankPopulateState(const struct FilterbankConfig * config,struct FilterbankState * state,int sample_rate,int spectrum_size)56 int FilterbankPopulateState(const struct FilterbankConfig* config,
57                             struct FilterbankState* state, int sample_rate,
58                             int spectrum_size) {
59   state->num_channels = config->num_channels;
60   const int num_channels_plus_1 = config->num_channels + 1;
61 
62   // How should we align things to index counts given the byte alignment?
63   const int index_alignment =
64       (kFilterbankIndexAlignment < sizeof(int16_t)
65            ? 1
66            : kFilterbankIndexAlignment / sizeof(int16_t));
67 
68   state->channel_frequency_starts =
69       malloc(num_channels_plus_1 * sizeof(*state->channel_frequency_starts));
70   state->channel_weight_starts =
71       malloc(num_channels_plus_1 * sizeof(*state->channel_weight_starts));
72   state->channel_widths =
73       malloc(num_channels_plus_1 * sizeof(*state->channel_widths));
74   state->work = malloc(num_channels_plus_1 * sizeof(*state->work));
75 
76   float* center_mel_freqs =
77       malloc(num_channels_plus_1 * sizeof(*center_mel_freqs));
78   int16_t* actual_channel_starts =
79       malloc(num_channels_plus_1 * sizeof(*actual_channel_starts));
80   int16_t* actual_channel_widths =
81       malloc(num_channels_plus_1 * sizeof(*actual_channel_widths));
82 
83   if (state->channel_frequency_starts == NULL ||
84       state->channel_weight_starts == NULL || state->channel_widths == NULL ||
85       center_mel_freqs == NULL || actual_channel_starts == NULL ||
86       actual_channel_widths == NULL) {
87     free(center_mel_freqs);
88     free(actual_channel_starts);
89     free(actual_channel_widths);
90     fprintf(stderr, "Failed to allocate channel buffers\n");
91     return 0;
92   }
93 
94   CalculateCenterFrequencies(num_channels_plus_1, config->lower_band_limit,
95                              config->upper_band_limit, center_mel_freqs);
96 
97   // Always exclude DC.
98   const float hz_per_sbin = 0.5 * sample_rate / ((float)spectrum_size - 1);
99   state->start_index = 1.5 + config->lower_band_limit / hz_per_sbin;
100   state->end_index = 0;  // Initialized to zero here, but actually set below.
101 
102   // For each channel, we need to figure out what frequencies belong to it, and
103   // how much padding we need to add so that we can efficiently multiply the
104   // weights and unweights for accumulation. To simplify the multiplication
105   // logic, all channels will have some multiplication to do (even if there are
106   // no frequencies that accumulate to that channel) - they will be directed to
107   // a set of zero weights.
108   int chan_freq_index_start = state->start_index;
109   int weight_index_start = 0;
110   int needs_zeros = 0;
111 
112   int chan;
113   for (chan = 0; chan < num_channels_plus_1; ++chan) {
114     // Keep jumping frequencies until we overshoot the bound on this channel.
115     int freq_index = chan_freq_index_start;
116     while (FreqToMel((freq_index)*hz_per_sbin) <= center_mel_freqs[chan]) {
117       ++freq_index;
118     }
119 
120     const int width = freq_index - chan_freq_index_start;
121     actual_channel_starts[chan] = chan_freq_index_start;
122     actual_channel_widths[chan] = width;
123 
124     if (width == 0) {
125       // This channel doesn't actually get anything from the frequencies, it's
126       // always zero. We need then to insert some 'zero' weights into the
127       // output, and just redirect this channel to do a single multiplication at
128       // this point. For simplicity, the zeros are placed at the beginning of
129       // the weights arrays, so we have to go and update all the other
130       // weight_starts to reflect this shift (but only once).
131       state->channel_frequency_starts[chan] = 0;
132       state->channel_weight_starts[chan] = 0;
133       state->channel_widths[chan] = kFilterbankChannelBlockSize;
134       if (!needs_zeros) {
135         needs_zeros = 1;
136         int j;
137         for (j = 0; j < chan; ++j) {
138           state->channel_weight_starts[j] += kFilterbankChannelBlockSize;
139         }
140         weight_index_start += kFilterbankChannelBlockSize;
141       }
142     } else {
143       // How far back do we need to go to ensure that we have the proper
144       // alignment?
145       const int aligned_start =
146           (chan_freq_index_start / index_alignment) * index_alignment;
147       const int aligned_width = (chan_freq_index_start - aligned_start + width);
148       const int padded_width =
149           (((aligned_width - 1) / kFilterbankChannelBlockSize) + 1) *
150           kFilterbankChannelBlockSize;
151 
152       state->channel_frequency_starts[chan] = aligned_start;
153       state->channel_weight_starts[chan] = weight_index_start;
154       state->channel_widths[chan] = padded_width;
155       weight_index_start += padded_width;
156     }
157     chan_freq_index_start = freq_index;
158   }
159 
160   // Allocate the two arrays to store the weights - weight_index_start contains
161   // the index of what would be the next set of weights that we would need to
162   // add, so that's how many weights we need to allocate.
163   state->weights = calloc(weight_index_start, sizeof(*state->weights));
164   state->unweights = calloc(weight_index_start, sizeof(*state->unweights));
165 
166   // If the alloc failed, we also need to nuke the arrays.
167   if (state->weights == NULL || state->unweights == NULL) {
168     free(center_mel_freqs);
169     free(actual_channel_starts);
170     free(actual_channel_widths);
171     fprintf(stderr, "Failed to allocate weights or unweights\n");
172     return 0;
173   }
174 
175   // Next pass, compute all the weights. Since everything has been memset to
176   // zero, we only need to fill in the weights that correspond to some frequency
177   // for a channel.
178   const float mel_low = FreqToMel(config->lower_band_limit);
179   for (chan = 0; chan < num_channels_plus_1; ++chan) {
180     int frequency = actual_channel_starts[chan];
181     const int num_frequencies = actual_channel_widths[chan];
182     const int frequency_offset =
183         frequency - state->channel_frequency_starts[chan];
184     const int weight_start = state->channel_weight_starts[chan];
185     const float denom_val = (chan == 0) ? mel_low : center_mel_freqs[chan - 1];
186 
187     int j;
188     for (j = 0; j < num_frequencies; ++j, ++frequency) {
189       const float weight =
190           (center_mel_freqs[chan] - FreqToMel(frequency * hz_per_sbin)) /
191           (center_mel_freqs[chan] - denom_val);
192 
193       // Make the float into an integer for the weights (and unweights).
194       const int weight_index = weight_start + frequency_offset + j;
195       QuantizeFilterbankWeights(weight, state->weights + weight_index,
196                                 state->unweights + weight_index);
197     }
198     if (frequency > state->end_index) {
199       state->end_index = frequency;
200     }
201   }
202 
203   free(center_mel_freqs);
204   free(actual_channel_starts);
205   free(actual_channel_widths);
206   if (state->end_index >= spectrum_size) {
207     fprintf(stderr, "Filterbank end_index is above spectrum size.\n");
208     return 0;
209   }
210   return 1;
211 }
212 
FilterbankFreeStateContents(struct FilterbankState * state)213 void FilterbankFreeStateContents(struct FilterbankState* state) {
214   free(state->channel_frequency_starts);
215   free(state->channel_weight_starts);
216   free(state->channel_widths);
217   free(state->weights);
218   free(state->unweights);
219   free(state->work);
220 }
221