1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15r"""Converts WAV audio files into input features for neural networks. 16 17The models used in this example take in two-dimensional spectrograms as the 18input to their neural network portions. For testing and porting purposes it's 19useful to be able to generate these spectrograms outside of the full model, so 20that on-device implementations using their own FFT and streaming code can be 21tested against the version used in training for example. The output is as a 22C source file, so it can be easily linked into an embedded test application. 23 24To use this, run: 25 26bazel run tensorflow/examples/speech_commands:wav_to_features -- \ 27--input_wav=my.wav --output_c_file=my_wav_data.c 28 29""" 30from __future__ import absolute_import 31from __future__ import division 32from __future__ import print_function 33 34import argparse 35import os.path 36import sys 37 38import tensorflow as tf 39 40import input_data 41import models 42from tensorflow.python.platform import gfile 43 44FLAGS = None 45 46 47def wav_to_features(sample_rate, clip_duration_ms, window_size_ms, 48 window_stride_ms, feature_bin_count, quantize, preprocess, 49 input_wav, output_c_file): 50 """Converts an audio file into its corresponding feature map. 51 52 Args: 53 sample_rate: Expected sample rate of the wavs. 54 clip_duration_ms: Expected duration in milliseconds of the wavs. 55 window_size_ms: How long each spectrogram timeslice is. 56 window_stride_ms: How far to move in time between spectrogram timeslices. 57 feature_bin_count: How many bins to use for the feature fingerprint. 58 quantize: Whether to train the model for eight-bit deployment. 59 preprocess: Spectrogram processing mode; "mfcc", "average" or "micro". 60 input_wav: Path to the audio WAV file to read. 61 output_c_file: Where to save the generated C source file. 62 """ 63 64 # Start a new TensorFlow session. 65 sess = tf.compat.v1.InteractiveSession() 66 67 model_settings = models.prepare_model_settings( 68 0, sample_rate, clip_duration_ms, window_size_ms, window_stride_ms, 69 feature_bin_count, preprocess) 70 audio_processor = input_data.AudioProcessor(None, None, 0, 0, '', 0, 0, 71 model_settings, None) 72 73 results = audio_processor.get_features_for_wav(input_wav, model_settings, 74 sess) 75 features = results[0] 76 77 variable_base = os.path.splitext(os.path.basename(input_wav).lower())[0] 78 79 # Save a C source file containing the feature data as an array. 80 with gfile.GFile(output_c_file, 'w') as f: 81 f.write('/* File automatically created by\n') 82 f.write(' * tensorflow/examples/speech_commands/wav_to_features.py \\\n') 83 f.write(' * --sample_rate=%d \\\n' % sample_rate) 84 f.write(' * --clip_duration_ms=%d \\\n' % clip_duration_ms) 85 f.write(' * --window_size_ms=%d \\\n' % window_size_ms) 86 f.write(' * --window_stride_ms=%d \\\n' % window_stride_ms) 87 f.write(' * --feature_bin_count=%d \\\n' % feature_bin_count) 88 if quantize: 89 f.write(' * --quantize=1 \\\n') 90 f.write(' * --preprocess="%s" \\\n' % preprocess) 91 f.write(' * --input_wav="%s" \\\n' % input_wav) 92 f.write(' * --output_c_file="%s" \\\n' % output_c_file) 93 f.write(' */\n\n') 94 f.write('const int g_%s_width = %d;\n' % 95 (variable_base, model_settings['fingerprint_width'])) 96 f.write('const int g_%s_height = %d;\n' % 97 (variable_base, model_settings['spectrogram_length'])) 98 if quantize: 99 features_min, features_max = input_data.get_features_range(model_settings) 100 f.write('const unsigned char g_%s_data[] = {' % variable_base) 101 i = 0 102 for value in features.flatten(): 103 quantized_value = int( 104 round( 105 (255 * (value - features_min)) / (features_max - features_min))) 106 if quantized_value < 0: 107 quantized_value = 0 108 if quantized_value > 255: 109 quantized_value = 255 110 if i == 0: 111 f.write('\n ') 112 f.write('%d, ' % (quantized_value)) 113 i = (i + 1) % 10 114 else: 115 f.write('const float g_%s_data[] = {\n' % variable_base) 116 i = 0 117 for value in features.flatten(): 118 if i == 0: 119 f.write('\n ') 120 f.write('%f, ' % value) 121 i = (i + 1) % 10 122 f.write('\n};\n') 123 124 125def main(_): 126 # We want to see all the logging messages. 127 tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) 128 wav_to_features(FLAGS.sample_rate, FLAGS.clip_duration_ms, 129 FLAGS.window_size_ms, FLAGS.window_stride_ms, 130 FLAGS.feature_bin_count, FLAGS.quantize, FLAGS.preprocess, 131 FLAGS.input_wav, FLAGS.output_c_file) 132 tf.compat.v1.logging.info('Wrote to "%s"' % (FLAGS.output_c_file)) 133 134 135if __name__ == '__main__': 136 parser = argparse.ArgumentParser() 137 parser.add_argument( 138 '--sample_rate', 139 type=int, 140 default=16000, 141 help='Expected sample rate of the wavs',) 142 parser.add_argument( 143 '--clip_duration_ms', 144 type=int, 145 default=1000, 146 help='Expected duration in milliseconds of the wavs',) 147 parser.add_argument( 148 '--window_size_ms', 149 type=float, 150 default=30.0, 151 help='How long each spectrogram timeslice is.',) 152 parser.add_argument( 153 '--window_stride_ms', 154 type=float, 155 default=10.0, 156 help='How far to move in time between spectrogram timeslices.', 157 ) 158 parser.add_argument( 159 '--feature_bin_count', 160 type=int, 161 default=40, 162 help='How many bins to use for the MFCC fingerprint', 163 ) 164 parser.add_argument( 165 '--quantize', 166 type=bool, 167 default=False, 168 help='Whether to train the model for eight-bit deployment') 169 parser.add_argument( 170 '--preprocess', 171 type=str, 172 default='mfcc', 173 help='Spectrogram processing mode. Can be "mfcc", "average", or "micro"') 174 parser.add_argument( 175 '--input_wav', 176 type=str, 177 default=None, 178 help='Path to the audio WAV file to read') 179 parser.add_argument( 180 '--output_c_file', 181 type=str, 182 default=None, 183 help='Where to save the generated C source file containing the features') 184 185 FLAGS, unparsed = parser.parse_known_args() 186 tf.compat.v1.app.run(main=main, argv=[sys.argv[0]] + unparsed) 187