1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/experimental/microfrontend/lib/filterbank_util.h"
16
17 #include <assert.h>
18 #include <math.h>
19 #include <stdio.h>
20
21 #define kFilterbankIndexAlignment 4
22 #define kFilterbankChannelBlockSize 4
23
FilterbankFillConfigWithDefaults(struct FilterbankConfig * config)24 void FilterbankFillConfigWithDefaults(struct FilterbankConfig* config) {
25 config->num_channels = 32;
26 config->lower_band_limit = 125.0f;
27 config->upper_band_limit = 7500.0f;
28 config->output_scale_shift = 7;
29 }
30
FreqToMel(float freq)31 static float FreqToMel(float freq) { return 1127.0 * log1p(freq / 700.0); }
32
CalculateCenterFrequencies(const int num_channels,const float lower_frequency_limit,const float upper_frequency_limit,float * center_frequencies)33 static void CalculateCenterFrequencies(const int num_channels,
34 const float lower_frequency_limit,
35 const float upper_frequency_limit,
36 float* center_frequencies) {
37 assert(lower_frequency_limit >= 0.0f);
38 assert(upper_frequency_limit > lower_frequency_limit);
39
40 const float mel_low = FreqToMel(lower_frequency_limit);
41 const float mel_hi = FreqToMel(upper_frequency_limit);
42 const float mel_span = mel_hi - mel_low;
43 const float mel_spacing = mel_span / ((float)num_channels);
44 int i;
45 for (i = 0; i < num_channels; ++i) {
46 center_frequencies[i] = mel_low + (mel_spacing * (i + 1));
47 }
48 }
49
QuantizeFilterbankWeights(const float float_weight,int16_t * weight,int16_t * unweight)50 static void QuantizeFilterbankWeights(const float float_weight, int16_t* weight,
51 int16_t* unweight) {
52 *weight = floor(float_weight * (1 << kFilterbankBits) + 0.5);
53 *unweight = floor((1.0 - float_weight) * (1 << kFilterbankBits) + 0.5);
54 }
55
FilterbankPopulateState(const struct FilterbankConfig * config,struct FilterbankState * state,int sample_rate,int spectrum_size)56 int FilterbankPopulateState(const struct FilterbankConfig* config,
57 struct FilterbankState* state, int sample_rate,
58 int spectrum_size) {
59 state->num_channels = config->num_channels;
60 const int num_channels_plus_1 = config->num_channels + 1;
61
62 // How should we align things to index counts given the byte alignment?
63 const int index_alignment =
64 (kFilterbankIndexAlignment < sizeof(int16_t)
65 ? 1
66 : kFilterbankIndexAlignment / sizeof(int16_t));
67
68 state->channel_frequency_starts =
69 malloc(num_channels_plus_1 * sizeof(*state->channel_frequency_starts));
70 state->channel_weight_starts =
71 malloc(num_channels_plus_1 * sizeof(*state->channel_weight_starts));
72 state->channel_widths =
73 malloc(num_channels_plus_1 * sizeof(*state->channel_widths));
74 state->work = malloc(num_channels_plus_1 * sizeof(*state->work));
75
76 float* center_mel_freqs =
77 malloc(num_channels_plus_1 * sizeof(*center_mel_freqs));
78 int16_t* actual_channel_starts =
79 malloc(num_channels_plus_1 * sizeof(*actual_channel_starts));
80 int16_t* actual_channel_widths =
81 malloc(num_channels_plus_1 * sizeof(*actual_channel_widths));
82
83 if (state->channel_frequency_starts == NULL ||
84 state->channel_weight_starts == NULL || state->channel_widths == NULL ||
85 center_mel_freqs == NULL || actual_channel_starts == NULL ||
86 actual_channel_widths == NULL) {
87 free(center_mel_freqs);
88 free(actual_channel_starts);
89 free(actual_channel_widths);
90 fprintf(stderr, "Failed to allocate channel buffers\n");
91 return 0;
92 }
93
94 CalculateCenterFrequencies(num_channels_plus_1, config->lower_band_limit,
95 config->upper_band_limit, center_mel_freqs);
96
97 // Always exclude DC.
98 const float hz_per_sbin = 0.5 * sample_rate / ((float)spectrum_size - 1);
99 state->start_index = 1.5 + config->lower_band_limit / hz_per_sbin;
100 state->end_index = 0; // Initialized to zero here, but actually set below.
101
102 // For each channel, we need to figure out what frequencies belong to it, and
103 // how much padding we need to add so that we can efficiently multiply the
104 // weights and unweights for accumulation. To simplify the multiplication
105 // logic, all channels will have some multiplication to do (even if there are
106 // no frequencies that accumulate to that channel) - they will be directed to
107 // a set of zero weights.
108 int chan_freq_index_start = state->start_index;
109 int weight_index_start = 0;
110 int needs_zeros = 0;
111
112 int chan;
113 for (chan = 0; chan < num_channels_plus_1; ++chan) {
114 // Keep jumping frequencies until we overshoot the bound on this channel.
115 int freq_index = chan_freq_index_start;
116 while (FreqToMel((freq_index)*hz_per_sbin) <= center_mel_freqs[chan]) {
117 ++freq_index;
118 }
119
120 const int width = freq_index - chan_freq_index_start;
121 actual_channel_starts[chan] = chan_freq_index_start;
122 actual_channel_widths[chan] = width;
123
124 if (width == 0) {
125 // This channel doesn't actually get anything from the frequencies, it's
126 // always zero. We need then to insert some 'zero' weights into the
127 // output, and just redirect this channel to do a single multiplication at
128 // this point. For simplicity, the zeros are placed at the beginning of
129 // the weights arrays, so we have to go and update all the other
130 // weight_starts to reflect this shift (but only once).
131 state->channel_frequency_starts[chan] = 0;
132 state->channel_weight_starts[chan] = 0;
133 state->channel_widths[chan] = kFilterbankChannelBlockSize;
134 if (!needs_zeros) {
135 needs_zeros = 1;
136 int j;
137 for (j = 0; j < chan; ++j) {
138 state->channel_weight_starts[j] += kFilterbankChannelBlockSize;
139 }
140 weight_index_start += kFilterbankChannelBlockSize;
141 }
142 } else {
143 // How far back do we need to go to ensure that we have the proper
144 // alignment?
145 const int aligned_start =
146 (chan_freq_index_start / index_alignment) * index_alignment;
147 const int aligned_width = (chan_freq_index_start - aligned_start + width);
148 const int padded_width =
149 (((aligned_width - 1) / kFilterbankChannelBlockSize) + 1) *
150 kFilterbankChannelBlockSize;
151
152 state->channel_frequency_starts[chan] = aligned_start;
153 state->channel_weight_starts[chan] = weight_index_start;
154 state->channel_widths[chan] = padded_width;
155 weight_index_start += padded_width;
156 }
157 chan_freq_index_start = freq_index;
158 }
159
160 // Allocate the two arrays to store the weights - weight_index_start contains
161 // the index of what would be the next set of weights that we would need to
162 // add, so that's how many weights we need to allocate.
163 state->weights = calloc(weight_index_start, sizeof(*state->weights));
164 state->unweights = calloc(weight_index_start, sizeof(*state->unweights));
165
166 // If the alloc failed, we also need to nuke the arrays.
167 if (state->weights == NULL || state->unweights == NULL) {
168 free(center_mel_freqs);
169 free(actual_channel_starts);
170 free(actual_channel_widths);
171 fprintf(stderr, "Failed to allocate weights or unweights\n");
172 return 0;
173 }
174
175 // Next pass, compute all the weights. Since everything has been memset to
176 // zero, we only need to fill in the weights that correspond to some frequency
177 // for a channel.
178 const float mel_low = FreqToMel(config->lower_band_limit);
179 for (chan = 0; chan < num_channels_plus_1; ++chan) {
180 int frequency = actual_channel_starts[chan];
181 const int num_frequencies = actual_channel_widths[chan];
182 const int frequency_offset =
183 frequency - state->channel_frequency_starts[chan];
184 const int weight_start = state->channel_weight_starts[chan];
185 const float denom_val = (chan == 0) ? mel_low : center_mel_freqs[chan - 1];
186
187 int j;
188 for (j = 0; j < num_frequencies; ++j, ++frequency) {
189 const float weight =
190 (center_mel_freqs[chan] - FreqToMel(frequency * hz_per_sbin)) /
191 (center_mel_freqs[chan] - denom_val);
192
193 // Make the float into an integer for the weights (and unweights).
194 const int weight_index = weight_start + frequency_offset + j;
195 QuantizeFilterbankWeights(weight, state->weights + weight_index,
196 state->unweights + weight_index);
197 }
198 if (frequency > state->end_index) {
199 state->end_index = frequency;
200 }
201 }
202
203 free(center_mel_freqs);
204 free(actual_channel_starts);
205 free(actual_channel_widths);
206 if (state->end_index >= spectrum_size) {
207 fprintf(stderr, "Filterbank end_index is above spectrum size.\n");
208 return 0;
209 }
210 return 1;
211 }
212
FilterbankFreeStateContents(struct FilterbankState * state)213 void FilterbankFreeStateContents(struct FilterbankState* state) {
214 free(state->channel_frequency_starts);
215 free(state->channel_weight_starts);
216 free(state->channel_widths);
217 free(state->weights);
218 free(state->unweights);
219 free(state->work);
220 }
221