1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/experimental/microfrontend/lib/frontend.h"
16 #include "tensorflow/lite/experimental/microfrontend/lib/frontend_util.h"
17 #include "tensorflow/core/framework/op.h"
18 #include "tensorflow/core/framework/op_kernel.h"
19 #include "tensorflow/core/framework/shape_inference.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/framework/tensor_shape.h"
22 #include "tensorflow/core/lib/core/errors.h"
23 #include "tensorflow/core/lib/core/status.h"
24 #include "tensorflow/core/platform/macros.h"
25 
26 using tensorflow::OpKernel;
27 using tensorflow::OpKernelConstruction;
28 using tensorflow::OpKernelContext;
29 using tensorflow::Status;
30 using tensorflow::Tensor;
31 using tensorflow::TensorShape;
32 using tensorflow::TensorShapeUtils;
33 using tensorflow::errors::Internal;
34 using tensorflow::errors::InvalidArgument;
35 using tensorflow::shape_inference::DimensionHandle;
36 using tensorflow::shape_inference::InferenceContext;
37 using tensorflow::shape_inference::ShapeHandle;
38 
39 namespace tensorflow {
40 REGISTER_OP("AudioMicrofrontend")
41     .Input("audio: int16")
42     .Output("filterbanks: out_type")
43     .Attr("sample_rate: int = 16000")
44     .Attr("window_size: int = 25")
45     .Attr("window_step: int = 10")
46     .Attr("num_channels: int = 32")
47     .Attr("upper_band_limit: float = 7500.0")
48     .Attr("lower_band_limit: float = 125.0")
49     .Attr("smoothing_bits: int = 10")
50     .Attr("even_smoothing: float = 0.025")
51     .Attr("odd_smoothing: float = 0.06")
52     .Attr("min_signal_remaining: float = 0.05")
53     .Attr("enable_pcan: bool = false")
54     .Attr("pcan_strength: float = 0.95")
55     .Attr("pcan_offset: float = 80.0")
56     .Attr("gain_bits: int = 21")
57     .Attr("enable_log: bool = true")
58     .Attr("scale_shift: int = 6")
59     .Attr("left_context: int = 0")
60     .Attr("right_context: int = 0")
61     .Attr("frame_stride: int = 1")
62     .Attr("zero_padding: bool = false")
63     .Attr("out_scale: int = 1")
64     .Attr("out_type: {uint16, float} = DT_UINT16")
__anon3fe811a90102(InferenceContext* ctx) 65     .SetShapeFn([](InferenceContext* ctx) {
66       ShapeHandle input;
67       TF_RETURN_IF_ERROR(ctx->WithRank(ctx->input(0), 1, &input));
68 
69       int sample_rate;
70       TF_RETURN_IF_ERROR(ctx->GetAttr("sample_rate", &sample_rate));
71       int window_size;
72       TF_RETURN_IF_ERROR(ctx->GetAttr("window_size", &window_size));
73       window_size *= sample_rate / 1000;
74       int window_step;
75       TF_RETURN_IF_ERROR(ctx->GetAttr("window_step", &window_step));
76       window_step *= sample_rate / 1000;
77 
78       int num_channels;
79       TF_RETURN_IF_ERROR(ctx->GetAttr("num_channels", &num_channels));
80       int left_context;
81       TF_RETURN_IF_ERROR(ctx->GetAttr("left_context", &left_context));
82       int right_context;
83       TF_RETURN_IF_ERROR(ctx->GetAttr("right_context", &right_context));
84       int frame_stride;
85       TF_RETURN_IF_ERROR(ctx->GetAttr("frame_stride", &frame_stride));
86 
87       DimensionHandle num_frames = ctx->Dim(input, 0);
88       if (ctx->Value(num_frames) < window_size) {
89         num_frames = ctx->MakeDim(0);
90       } else {
91         TF_RETURN_IF_ERROR(ctx->Subtract(num_frames, window_size, &num_frames));
92         TF_RETURN_IF_ERROR(
93             ctx->Divide(num_frames, window_step, false, &num_frames));
94         TF_RETURN_IF_ERROR(
95             ctx->Divide(num_frames, frame_stride, false, &num_frames));
96         TF_RETURN_IF_ERROR(ctx->Add(num_frames, 1, &num_frames));
97       }
98 
99       int stack_size = 1 + left_context + right_context;
100       DimensionHandle num_features = ctx->MakeDim(num_channels);
101       TF_RETURN_IF_ERROR(
102           ctx->Multiply(num_features, stack_size, &num_features));
103 
104       ShapeHandle output = ctx->MakeShape({num_frames, num_features});
105       ctx->set_output(0, output);
106       return tensorflow::Status::OK();
107     })
108     .Doc(R"doc(
109 Audio Microfrontend Op.
110 
111 This Op converts a sequence of audio data into one or more
112 feature vectors containing filterbanks of the input. The
113 conversion process uses a lightweight library to perform:
114 
115 1. A slicing window function
116 2. Short-time FFTs
117 3. Filterbank calculations
118 4. Noise reduction
119 5. PCAN Auto Gain Control
120 6. Logarithmic scaling
121 
122 Arguments
123   audio: 1D Tensor, int16 audio data in temporal ordering.
124   sample_rate: Integer, the sample rate of the audio in Hz.
125   window_size: Integer, length of desired time frames in ms.
126   window_step: Integer, length of step size for the next frame in ms.
127   num_channels: Integer, the number of filterbank channels to use.
128   upper_band_limit: Float, the highest frequency included in the filterbanks.
129   lower_band_limit: Float, the lowest frequency included in the filterbanks.
130   smoothing_bits: Int, scale up signal by 2^(smoothing_bits) before reduction.
131   even_smoothing: Float, smoothing coefficient for even-numbered channels.
132   odd_smoothing: Float, smoothing coefficient for odd-numbered channels.
133   min_signal_remaining: Float, fraction of signal to preserve in smoothing.
134   enable_pcan: Bool, enable PCAN auto gain control.
135   pcan_strength: Float, gain normalization exponent.
136   pcan_offset: Float, positive value added in the normalization denominator.
137   gain_bits: Int, number of fractional bits in the gain.
138   enable_log: Bool, enable logarithmic scaling of filterbanks.
139   scale_shift: Integer, scale filterbanks by 2^(scale_shift).
140   left_context: Integer, number of preceding frames to attach to each frame.
141   right_context: Integer, number of preceding frames to attach to each frame.
142   frame_stride: Integer, M frames to skip over, where output[n] = frame[n*M].
143   zero_padding: Bool, if left/right context is out-of-bounds, attach frame of
144                 zeroes. Otherwise, frame[0] or frame[size-1] will be copied.
145   out_scale: Integer, divide all filterbanks by this number.
146   out_type: DType, type of the output Tensor, defaults to UINT16.
147 
148 Returns
149   filterbanks: 2D Tensor, each row is a time frame, each column is a channel.
150 )doc");
151 
152 template <typename T>
153 class AudioMicrofrontendOp : public OpKernel {
154  public:
AudioMicrofrontendOp(OpKernelConstruction * ctx)155   explicit AudioMicrofrontendOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
156     OP_REQUIRES_OK(ctx, ctx->GetAttr("sample_rate", &sample_rate_));
157 
158     int window_size;
159     OP_REQUIRES_OK(ctx, ctx->GetAttr("window_size", &window_size));
160     config_.window.size_ms = window_size;
161 
162     int window_step;
163     OP_REQUIRES_OK(ctx, ctx->GetAttr("window_step", &window_step));
164     config_.window.step_size_ms = window_step;
165 
166     OP_REQUIRES_OK(
167         ctx, ctx->GetAttr("num_channels", &config_.filterbank.num_channels));
168     OP_REQUIRES_OK(ctx, ctx->GetAttr("upper_band_limit",
169                                      &config_.filterbank.upper_band_limit));
170     OP_REQUIRES_OK(ctx, ctx->GetAttr("lower_band_limit",
171                                      &config_.filterbank.lower_band_limit));
172     OP_REQUIRES_OK(ctx, ctx->GetAttr("smoothing_bits",
173                                      &config_.noise_reduction.smoothing_bits));
174     OP_REQUIRES_OK(ctx, ctx->GetAttr("even_smoothing",
175                                      &config_.noise_reduction.even_smoothing));
176     OP_REQUIRES_OK(ctx, ctx->GetAttr("odd_smoothing",
177                                      &config_.noise_reduction.odd_smoothing));
178     OP_REQUIRES_OK(ctx,
179                    ctx->GetAttr("min_signal_remaining",
180                                 &config_.noise_reduction.min_signal_remaining));
181 
182     bool enable_pcan;
183     OP_REQUIRES_OK(ctx, ctx->GetAttr("enable_pcan", &enable_pcan));
184     config_.pcan_gain_control.enable_pcan = enable_pcan;
185 
186     OP_REQUIRES_OK(ctx, ctx->GetAttr("pcan_strength",
187                                      &config_.pcan_gain_control.strength));
188     OP_REQUIRES_OK(
189         ctx, ctx->GetAttr("pcan_offset", &config_.pcan_gain_control.offset));
190     OP_REQUIRES_OK(
191         ctx, ctx->GetAttr("gain_bits", &config_.pcan_gain_control.gain_bits));
192 
193     bool enable_log;
194     OP_REQUIRES_OK(ctx, ctx->GetAttr("enable_log", &enable_log));
195     config_.log_scale.enable_log = enable_log;
196 
197     OP_REQUIRES_OK(ctx,
198                    ctx->GetAttr("scale_shift", &config_.log_scale.scale_shift));
199 
200     OP_REQUIRES_OK(ctx, ctx->GetAttr("left_context", &left_context_));
201     OP_REQUIRES_OK(ctx, ctx->GetAttr("right_context", &right_context_));
202     OP_REQUIRES_OK(ctx, ctx->GetAttr("frame_stride", &frame_stride_));
203     OP_REQUIRES_OK(ctx, ctx->GetAttr("zero_padding", &zero_padding_));
204     OP_REQUIRES_OK(ctx, ctx->GetAttr("out_scale", &out_scale_));
205   }
206 
Compute(OpKernelContext * ctx)207   void Compute(OpKernelContext* ctx) override {
208     const Tensor* audio;
209     OP_REQUIRES_OK(ctx, ctx->input("audio", &audio));
210     OP_REQUIRES(ctx, TensorShapeUtils::IsVector(audio->shape()),
211                 InvalidArgument("audio is not a vector"));
212 
213     auto audio_data =
214         reinterpret_cast<const int16_t*>(audio->tensor_data().data());
215     int audio_size = audio->NumElements();
216 
217     Tensor* filterbanks = nullptr;
218     int window_size = config_.window.size_ms * sample_rate_ / 1000;
219     int window_step = config_.window.step_size_ms * sample_rate_ / 1000;
220     int num_frames = 0;
221     int sampled_frames = 0;
222     if (audio_size >= window_size) {
223       num_frames = (audio_size - window_size) / window_step + 1;
224       sampled_frames = (num_frames - 1) / frame_stride_ + 1;
225     }
226     TensorShape filterbanks_shape{
227         sampled_frames,
228         config_.filterbank.num_channels * (1 + left_context_ + right_context_)};
229     OP_REQUIRES_OK(ctx,
230                    ctx->allocate_output(0, filterbanks_shape, &filterbanks));
231     auto filterbanks_flat = filterbanks->flat<T>();
232 
233     struct FrontendState state;
234     if (!TF_PREDICT_TRUE(
235             FrontendPopulateState(&config_, &state, sample_rate_))) {
236       ctx->CtxFailure(__FILE__, __LINE__,
237                       Internal("failed to populate frontend state"));
238       FrontendFreeStateContents(&state);
239       return;
240     }
241 
242     std::vector<std::vector<T>> frame_buffer(num_frames);
243     int frame_index = 0;
244     while (audio_size > 0) {
245       size_t num_samples_read;
246       struct FrontendOutput output = FrontendProcessSamples(
247           &state, audio_data, audio_size, &num_samples_read);
248       audio_data += num_samples_read;
249       audio_size -= num_samples_read;
250 
251       if (output.values != nullptr) {
252         frame_buffer[frame_index].reserve(output.size);
253         int i;
254         for (i = 0; i < output.size; ++i) {
255           frame_buffer[frame_index].push_back(static_cast<T>(output.values[i]) /
256                                               out_scale_);
257         }
258         ++frame_index;
259       }
260     }
261     FrontendFreeStateContents(&state);
262 
263     int index = 0;
264     std::vector<T> pad(config_.filterbank.num_channels, 0);
265     int anchor;
266     for (anchor = 0; anchor < frame_buffer.size(); anchor += frame_stride_) {
267       int frame;
268       for (frame = anchor - left_context_; frame <= anchor + right_context_;
269            ++frame) {
270         std::vector<T>* feature;
271         if (zero_padding_ && (frame < 0 || frame >= frame_buffer.size())) {
272           feature = &pad;
273         } else if (frame < 0) {
274           feature = &frame_buffer[0];
275         } else if (frame >= frame_buffer.size()) {
276           feature = &frame_buffer[frame_buffer.size() - 1];
277         } else {
278           feature = &frame_buffer[frame];
279         }
280         for (auto f : *feature) {
281           filterbanks_flat(index++) = f;
282         }
283       }
284     }
285   }
286 
287  protected:
288   int sample_rate_;
289   struct FrontendConfig config_;
290   int left_context_;
291   int right_context_;
292   int frame_stride_;
293   bool zero_padding_;
294   int out_scale_;
295 
296   TF_DISALLOW_COPY_AND_ASSIGN(AudioMicrofrontendOp);
297 };
298 
299 REGISTER_KERNEL_BUILDER(Name("AudioMicrofrontend")
300                             .Device(tensorflow::DEVICE_CPU)
301                             .TypeConstraint<uint16>("out_type"),
302                         AudioMicrofrontendOp<uint16>);
303 REGISTER_KERNEL_BUILDER(Name("AudioMicrofrontend")
304                             .Device(tensorflow::DEVICE_CPU)
305                             .TypeConstraint<float>("out_type"),
306                         AudioMicrofrontendOp<float>);
307 }  // namespace tensorflow
308