1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15r"""Converts WAV audio files into input features for neural networks.
16
17The models used in this example take in two-dimensional spectrograms as the
18input to their neural network portions. For testing and porting purposes it's
19useful to be able to generate these spectrograms outside of the full model, so
20that on-device implementations using their own FFT and streaming code can be
21tested against the version used in training for example. The output is as a
22C source file, so it can be easily linked into an embedded test application.
23
24To use this, run:
25
26bazel run tensorflow/examples/speech_commands:wav_to_features -- \
27--input_wav=my.wav --output_c_file=my_wav_data.c
28
29"""
30from __future__ import absolute_import
31from __future__ import division
32from __future__ import print_function
33
34import argparse
35import os.path
36import sys
37
38import tensorflow as tf
39
40import input_data
41import models
42from tensorflow.python.platform import gfile
43
44FLAGS = None
45
46
47def wav_to_features(sample_rate, clip_duration_ms, window_size_ms,
48                    window_stride_ms, feature_bin_count, quantize, preprocess,
49                    input_wav, output_c_file):
50  """Converts an audio file into its corresponding feature map.
51
52  Args:
53    sample_rate: Expected sample rate of the wavs.
54    clip_duration_ms: Expected duration in milliseconds of the wavs.
55    window_size_ms: How long each spectrogram timeslice is.
56    window_stride_ms: How far to move in time between spectrogram timeslices.
57    feature_bin_count: How many bins to use for the feature fingerprint.
58    quantize: Whether to train the model for eight-bit deployment.
59    preprocess: Spectrogram processing mode; "mfcc", "average" or "micro".
60    input_wav: Path to the audio WAV file to read.
61    output_c_file: Where to save the generated C source file.
62  """
63
64  # Start a new TensorFlow session.
65  sess = tf.compat.v1.InteractiveSession()
66
67  model_settings = models.prepare_model_settings(
68      0, sample_rate, clip_duration_ms, window_size_ms, window_stride_ms,
69      feature_bin_count, preprocess)
70  audio_processor = input_data.AudioProcessor(None, None, 0, 0, '', 0, 0,
71                                              model_settings, None)
72
73  results = audio_processor.get_features_for_wav(input_wav, model_settings,
74                                                 sess)
75  features = results[0]
76
77  variable_base = os.path.splitext(os.path.basename(input_wav).lower())[0]
78
79  # Save a C source file containing the feature data as an array.
80  with gfile.GFile(output_c_file, 'w') as f:
81    f.write('/* File automatically created by\n')
82    f.write(' * tensorflow/examples/speech_commands/wav_to_features.py \\\n')
83    f.write(' * --sample_rate=%d \\\n' % sample_rate)
84    f.write(' * --clip_duration_ms=%d \\\n' % clip_duration_ms)
85    f.write(' * --window_size_ms=%d \\\n' % window_size_ms)
86    f.write(' * --window_stride_ms=%d \\\n' % window_stride_ms)
87    f.write(' * --feature_bin_count=%d \\\n' % feature_bin_count)
88    if quantize:
89      f.write(' * --quantize=1 \\\n')
90    f.write(' * --preprocess="%s" \\\n' % preprocess)
91    f.write(' * --input_wav="%s" \\\n' % input_wav)
92    f.write(' * --output_c_file="%s" \\\n' % output_c_file)
93    f.write(' */\n\n')
94    f.write('const int g_%s_width = %d;\n' %
95            (variable_base, model_settings['fingerprint_width']))
96    f.write('const int g_%s_height = %d;\n' %
97            (variable_base, model_settings['spectrogram_length']))
98    if quantize:
99      features_min, features_max = input_data.get_features_range(model_settings)
100      f.write('const unsigned char g_%s_data[] = {' % variable_base)
101      i = 0
102      for value in features.flatten():
103        quantized_value = int(
104            round(
105                (255 * (value - features_min)) / (features_max - features_min)))
106        if quantized_value < 0:
107          quantized_value = 0
108        if quantized_value > 255:
109          quantized_value = 255
110        if i == 0:
111          f.write('\n  ')
112        f.write('%d, ' % (quantized_value))
113        i = (i + 1) % 10
114    else:
115      f.write('const float g_%s_data[] = {\n' % variable_base)
116      i = 0
117      for value in features.flatten():
118        if i == 0:
119          f.write('\n  ')
120        f.write('%f, ' % value)
121        i = (i + 1) % 10
122    f.write('\n};\n')
123
124
125def main(_):
126  # We want to see all the logging messages.
127  tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
128  wav_to_features(FLAGS.sample_rate, FLAGS.clip_duration_ms,
129                  FLAGS.window_size_ms, FLAGS.window_stride_ms,
130                  FLAGS.feature_bin_count, FLAGS.quantize, FLAGS.preprocess,
131                  FLAGS.input_wav, FLAGS.output_c_file)
132  tf.compat.v1.logging.info('Wrote to "%s"' % (FLAGS.output_c_file))
133
134
135if __name__ == '__main__':
136  parser = argparse.ArgumentParser()
137  parser.add_argument(
138      '--sample_rate',
139      type=int,
140      default=16000,
141      help='Expected sample rate of the wavs',)
142  parser.add_argument(
143      '--clip_duration_ms',
144      type=int,
145      default=1000,
146      help='Expected duration in milliseconds of the wavs',)
147  parser.add_argument(
148      '--window_size_ms',
149      type=float,
150      default=30.0,
151      help='How long each spectrogram timeslice is.',)
152  parser.add_argument(
153      '--window_stride_ms',
154      type=float,
155      default=10.0,
156      help='How far to move in time between spectrogram timeslices.',
157  )
158  parser.add_argument(
159      '--feature_bin_count',
160      type=int,
161      default=40,
162      help='How many bins to use for the MFCC fingerprint',
163  )
164  parser.add_argument(
165      '--quantize',
166      type=bool,
167      default=False,
168      help='Whether to train the model for eight-bit deployment')
169  parser.add_argument(
170      '--preprocess',
171      type=str,
172      default='mfcc',
173      help='Spectrogram processing mode. Can be "mfcc", "average", or "micro"')
174  parser.add_argument(
175      '--input_wav',
176      type=str,
177      default=None,
178      help='Path to the audio WAV file to read')
179  parser.add_argument(
180      '--output_c_file',
181      type=str,
182      default=None,
183      help='Where to save the generated C source file containing the features')
184
185  FLAGS, unparsed = parser.parse_known_args()
186  tf.compat.v1.app.run(main=main, argv=[sys.argv[0]] + unparsed)
187