1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/lite/experimental/microfrontend/lib/frontend.h" 16 #include "tensorflow/lite/experimental/microfrontend/lib/frontend_util.h" 17 #include "tensorflow/core/framework/op.h" 18 #include "tensorflow/core/framework/op_kernel.h" 19 #include "tensorflow/core/framework/shape_inference.h" 20 #include "tensorflow/core/framework/tensor.h" 21 #include "tensorflow/core/framework/tensor_shape.h" 22 #include "tensorflow/core/lib/core/errors.h" 23 #include "tensorflow/core/lib/core/status.h" 24 #include "tensorflow/core/platform/macros.h" 25 26 using tensorflow::OpKernel; 27 using tensorflow::OpKernelConstruction; 28 using tensorflow::OpKernelContext; 29 using tensorflow::Status; 30 using tensorflow::Tensor; 31 using tensorflow::TensorShape; 32 using tensorflow::TensorShapeUtils; 33 using tensorflow::errors::Internal; 34 using tensorflow::errors::InvalidArgument; 35 using tensorflow::shape_inference::DimensionHandle; 36 using tensorflow::shape_inference::InferenceContext; 37 using tensorflow::shape_inference::ShapeHandle; 38 39 namespace tensorflow { 40 REGISTER_OP("AudioMicrofrontend") 41 .Input("audio: int16") 42 .Output("filterbanks: out_type") 43 .Attr("sample_rate: int = 16000") 44 .Attr("window_size: int = 25") 45 .Attr("window_step: int = 10") 46 .Attr("num_channels: int = 32") 47 .Attr("upper_band_limit: float = 7500.0") 48 .Attr("lower_band_limit: float = 125.0") 49 .Attr("smoothing_bits: int = 10") 50 .Attr("even_smoothing: float = 0.025") 51 .Attr("odd_smoothing: float = 0.06") 52 .Attr("min_signal_remaining: float = 0.05") 53 .Attr("enable_pcan: bool = false") 54 .Attr("pcan_strength: float = 0.95") 55 .Attr("pcan_offset: float = 80.0") 56 .Attr("gain_bits: int = 21") 57 .Attr("enable_log: bool = true") 58 .Attr("scale_shift: int = 6") 59 .Attr("left_context: int = 0") 60 .Attr("right_context: int = 0") 61 .Attr("frame_stride: int = 1") 62 .Attr("zero_padding: bool = false") 63 .Attr("out_scale: int = 1") 64 .Attr("out_type: {uint16, float} = DT_UINT16") __anon3fe811a90102(InferenceContext* ctx) 65 .SetShapeFn([](InferenceContext* ctx) { 66 ShapeHandle input; 67 TF_RETURN_IF_ERROR(ctx->WithRank(ctx->input(0), 1, &input)); 68 69 int sample_rate; 70 TF_RETURN_IF_ERROR(ctx->GetAttr("sample_rate", &sample_rate)); 71 int window_size; 72 TF_RETURN_IF_ERROR(ctx->GetAttr("window_size", &window_size)); 73 window_size *= sample_rate / 1000; 74 int window_step; 75 TF_RETURN_IF_ERROR(ctx->GetAttr("window_step", &window_step)); 76 window_step *= sample_rate / 1000; 77 78 int num_channels; 79 TF_RETURN_IF_ERROR(ctx->GetAttr("num_channels", &num_channels)); 80 int left_context; 81 TF_RETURN_IF_ERROR(ctx->GetAttr("left_context", &left_context)); 82 int right_context; 83 TF_RETURN_IF_ERROR(ctx->GetAttr("right_context", &right_context)); 84 int frame_stride; 85 TF_RETURN_IF_ERROR(ctx->GetAttr("frame_stride", &frame_stride)); 86 87 DimensionHandle num_frames = ctx->Dim(input, 0); 88 if (ctx->Value(num_frames) < window_size) { 89 num_frames = ctx->MakeDim(0); 90 } else { 91 TF_RETURN_IF_ERROR(ctx->Subtract(num_frames, window_size, &num_frames)); 92 TF_RETURN_IF_ERROR( 93 ctx->Divide(num_frames, window_step, false, &num_frames)); 94 TF_RETURN_IF_ERROR( 95 ctx->Divide(num_frames, frame_stride, false, &num_frames)); 96 TF_RETURN_IF_ERROR(ctx->Add(num_frames, 1, &num_frames)); 97 } 98 99 int stack_size = 1 + left_context + right_context; 100 DimensionHandle num_features = ctx->MakeDim(num_channels); 101 TF_RETURN_IF_ERROR( 102 ctx->Multiply(num_features, stack_size, &num_features)); 103 104 ShapeHandle output = ctx->MakeShape({num_frames, num_features}); 105 ctx->set_output(0, output); 106 return tensorflow::Status::OK(); 107 }) 108 .Doc(R"doc( 109 Audio Microfrontend Op. 110 111 This Op converts a sequence of audio data into one or more 112 feature vectors containing filterbanks of the input. The 113 conversion process uses a lightweight library to perform: 114 115 1. A slicing window function 116 2. Short-time FFTs 117 3. Filterbank calculations 118 4. Noise reduction 119 5. PCAN Auto Gain Control 120 6. Logarithmic scaling 121 122 Arguments 123 audio: 1D Tensor, int16 audio data in temporal ordering. 124 sample_rate: Integer, the sample rate of the audio in Hz. 125 window_size: Integer, length of desired time frames in ms. 126 window_step: Integer, length of step size for the next frame in ms. 127 num_channels: Integer, the number of filterbank channels to use. 128 upper_band_limit: Float, the highest frequency included in the filterbanks. 129 lower_band_limit: Float, the lowest frequency included in the filterbanks. 130 smoothing_bits: Int, scale up signal by 2^(smoothing_bits) before reduction. 131 even_smoothing: Float, smoothing coefficient for even-numbered channels. 132 odd_smoothing: Float, smoothing coefficient for odd-numbered channels. 133 min_signal_remaining: Float, fraction of signal to preserve in smoothing. 134 enable_pcan: Bool, enable PCAN auto gain control. 135 pcan_strength: Float, gain normalization exponent. 136 pcan_offset: Float, positive value added in the normalization denominator. 137 gain_bits: Int, number of fractional bits in the gain. 138 enable_log: Bool, enable logarithmic scaling of filterbanks. 139 scale_shift: Integer, scale filterbanks by 2^(scale_shift). 140 left_context: Integer, number of preceding frames to attach to each frame. 141 right_context: Integer, number of preceding frames to attach to each frame. 142 frame_stride: Integer, M frames to skip over, where output[n] = frame[n*M]. 143 zero_padding: Bool, if left/right context is out-of-bounds, attach frame of 144 zeroes. Otherwise, frame[0] or frame[size-1] will be copied. 145 out_scale: Integer, divide all filterbanks by this number. 146 out_type: DType, type of the output Tensor, defaults to UINT16. 147 148 Returns 149 filterbanks: 2D Tensor, each row is a time frame, each column is a channel. 150 )doc"); 151 152 template <typename T> 153 class AudioMicrofrontendOp : public OpKernel { 154 public: AudioMicrofrontendOp(OpKernelConstruction * ctx)155 explicit AudioMicrofrontendOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 156 OP_REQUIRES_OK(ctx, ctx->GetAttr("sample_rate", &sample_rate_)); 157 158 int window_size; 159 OP_REQUIRES_OK(ctx, ctx->GetAttr("window_size", &window_size)); 160 config_.window.size_ms = window_size; 161 162 int window_step; 163 OP_REQUIRES_OK(ctx, ctx->GetAttr("window_step", &window_step)); 164 config_.window.step_size_ms = window_step; 165 166 OP_REQUIRES_OK( 167 ctx, ctx->GetAttr("num_channels", &config_.filterbank.num_channels)); 168 OP_REQUIRES_OK(ctx, ctx->GetAttr("upper_band_limit", 169 &config_.filterbank.upper_band_limit)); 170 OP_REQUIRES_OK(ctx, ctx->GetAttr("lower_band_limit", 171 &config_.filterbank.lower_band_limit)); 172 OP_REQUIRES_OK(ctx, ctx->GetAttr("smoothing_bits", 173 &config_.noise_reduction.smoothing_bits)); 174 OP_REQUIRES_OK(ctx, ctx->GetAttr("even_smoothing", 175 &config_.noise_reduction.even_smoothing)); 176 OP_REQUIRES_OK(ctx, ctx->GetAttr("odd_smoothing", 177 &config_.noise_reduction.odd_smoothing)); 178 OP_REQUIRES_OK(ctx, 179 ctx->GetAttr("min_signal_remaining", 180 &config_.noise_reduction.min_signal_remaining)); 181 182 bool enable_pcan; 183 OP_REQUIRES_OK(ctx, ctx->GetAttr("enable_pcan", &enable_pcan)); 184 config_.pcan_gain_control.enable_pcan = enable_pcan; 185 186 OP_REQUIRES_OK(ctx, ctx->GetAttr("pcan_strength", 187 &config_.pcan_gain_control.strength)); 188 OP_REQUIRES_OK( 189 ctx, ctx->GetAttr("pcan_offset", &config_.pcan_gain_control.offset)); 190 OP_REQUIRES_OK( 191 ctx, ctx->GetAttr("gain_bits", &config_.pcan_gain_control.gain_bits)); 192 193 bool enable_log; 194 OP_REQUIRES_OK(ctx, ctx->GetAttr("enable_log", &enable_log)); 195 config_.log_scale.enable_log = enable_log; 196 197 OP_REQUIRES_OK(ctx, 198 ctx->GetAttr("scale_shift", &config_.log_scale.scale_shift)); 199 200 OP_REQUIRES_OK(ctx, ctx->GetAttr("left_context", &left_context_)); 201 OP_REQUIRES_OK(ctx, ctx->GetAttr("right_context", &right_context_)); 202 OP_REQUIRES_OK(ctx, ctx->GetAttr("frame_stride", &frame_stride_)); 203 OP_REQUIRES_OK(ctx, ctx->GetAttr("zero_padding", &zero_padding_)); 204 OP_REQUIRES_OK(ctx, ctx->GetAttr("out_scale", &out_scale_)); 205 } 206 Compute(OpKernelContext * ctx)207 void Compute(OpKernelContext* ctx) override { 208 const Tensor* audio; 209 OP_REQUIRES_OK(ctx, ctx->input("audio", &audio)); 210 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(audio->shape()), 211 InvalidArgument("audio is not a vector")); 212 213 auto audio_data = 214 reinterpret_cast<const int16_t*>(audio->tensor_data().data()); 215 int audio_size = audio->NumElements(); 216 217 Tensor* filterbanks = nullptr; 218 int window_size = config_.window.size_ms * sample_rate_ / 1000; 219 int window_step = config_.window.step_size_ms * sample_rate_ / 1000; 220 int num_frames = 0; 221 int sampled_frames = 0; 222 if (audio_size >= window_size) { 223 num_frames = (audio_size - window_size) / window_step + 1; 224 sampled_frames = (num_frames - 1) / frame_stride_ + 1; 225 } 226 TensorShape filterbanks_shape{ 227 sampled_frames, 228 config_.filterbank.num_channels * (1 + left_context_ + right_context_)}; 229 OP_REQUIRES_OK(ctx, 230 ctx->allocate_output(0, filterbanks_shape, &filterbanks)); 231 auto filterbanks_flat = filterbanks->flat<T>(); 232 233 struct FrontendState state; 234 if (!TF_PREDICT_TRUE( 235 FrontendPopulateState(&config_, &state, sample_rate_))) { 236 ctx->CtxFailure(__FILE__, __LINE__, 237 Internal("failed to populate frontend state")); 238 FrontendFreeStateContents(&state); 239 return; 240 } 241 242 std::vector<std::vector<T>> frame_buffer(num_frames); 243 int frame_index = 0; 244 while (audio_size > 0) { 245 size_t num_samples_read; 246 struct FrontendOutput output = FrontendProcessSamples( 247 &state, audio_data, audio_size, &num_samples_read); 248 audio_data += num_samples_read; 249 audio_size -= num_samples_read; 250 251 if (output.values != nullptr) { 252 frame_buffer[frame_index].reserve(output.size); 253 int i; 254 for (i = 0; i < output.size; ++i) { 255 frame_buffer[frame_index].push_back(static_cast<T>(output.values[i]) / 256 out_scale_); 257 } 258 ++frame_index; 259 } 260 } 261 FrontendFreeStateContents(&state); 262 263 int index = 0; 264 std::vector<T> pad(config_.filterbank.num_channels, 0); 265 int anchor; 266 for (anchor = 0; anchor < frame_buffer.size(); anchor += frame_stride_) { 267 int frame; 268 for (frame = anchor - left_context_; frame <= anchor + right_context_; 269 ++frame) { 270 std::vector<T>* feature; 271 if (zero_padding_ && (frame < 0 || frame >= frame_buffer.size())) { 272 feature = &pad; 273 } else if (frame < 0) { 274 feature = &frame_buffer[0]; 275 } else if (frame >= frame_buffer.size()) { 276 feature = &frame_buffer[frame_buffer.size() - 1]; 277 } else { 278 feature = &frame_buffer[frame]; 279 } 280 for (auto f : *feature) { 281 filterbanks_flat(index++) = f; 282 } 283 } 284 } 285 } 286 287 protected: 288 int sample_rate_; 289 struct FrontendConfig config_; 290 int left_context_; 291 int right_context_; 292 int frame_stride_; 293 bool zero_padding_; 294 int out_scale_; 295 296 TF_DISALLOW_COPY_AND_ASSIGN(AudioMicrofrontendOp); 297 }; 298 299 REGISTER_KERNEL_BUILDER(Name("AudioMicrofrontend") 300 .Device(tensorflow::DEVICE_CPU) 301 .TypeConstraint<uint16>("out_type"), 302 AudioMicrofrontendOp<uint16>); 303 REGISTER_KERNEL_BUILDER(Name("AudioMicrofrontend") 304 .Device(tensorflow::DEVICE_CPU) 305 .TypeConstraint<float>("out_type"), 306 AudioMicrofrontendOp<float>); 307 } // namespace tensorflow 308