1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15r"""Saves out a .wav file with synthesized conversational data and labels.
16
17The best way to estimate the real-world performance of an audio recognition
18model is by running it against a continuous stream of data, the way that it
19would be used in an application. Training evaluations are only run against
20discrete individual samples, so the results aren't as realistic.
21
22To make it easy to run evaluations against audio streams, this script uses
23samples from the testing partition of the data set, mixes them in at random
24positions together with background noise, and saves out the result as one long
25audio file.
26
27Here's an example of generating a test file:
28
29bazel run tensorflow/examples/speech_commands:generate_streaming_test_wav -- \
30--data_dir=/tmp/my_wavs --background_dir=/tmp/my_backgrounds \
31--background_volume=0.1 --test_duration_seconds=600 \
32--output_audio_file=/tmp/streaming_test.wav \
33--output_labels_file=/tmp/streaming_test_labels.txt
34
35Once you've created a streaming audio file, you can then use the
36test_streaming_accuracy tool to calculate accuracy metrics for a model.
37"""
38from __future__ import absolute_import
39from __future__ import division
40from __future__ import print_function
41
42import argparse
43import math
44import sys
45
46import numpy as np
47import tensorflow as tf
48
49import input_data
50import models
51
52FLAGS = None
53
54
55def mix_in_audio_sample(track_data, track_offset, sample_data, sample_offset,
56                        clip_duration, sample_volume, ramp_in, ramp_out):
57  """Mixes the sample data into the main track at the specified offset.
58
59  Args:
60    track_data: Numpy array holding main audio data. Modified in-place.
61    track_offset: Where to mix the sample into the main track.
62    sample_data: Numpy array of audio data to mix into the main track.
63    sample_offset: Where to start in the audio sample.
64    clip_duration: How long the sample segment is.
65    sample_volume: Loudness to mix the sample in at.
66    ramp_in: Length in samples of volume increase stage.
67    ramp_out: Length in samples of volume decrease stage.
68  """
69  ramp_out_index = clip_duration - ramp_out
70  track_end = min(track_offset + clip_duration, track_data.shape[0])
71  track_end = min(track_end,
72                  track_offset + (sample_data.shape[0] - sample_offset))
73  sample_range = track_end - track_offset
74  for i in range(sample_range):
75    if i < ramp_in:
76      envelope_scale = i / ramp_in
77    elif i > ramp_out_index:
78      envelope_scale = (clip_duration - i) / ramp_out
79    else:
80      envelope_scale = 1
81    sample_input = sample_data[sample_offset + i]
82    track_data[track_offset
83               + i] += sample_input * envelope_scale * sample_volume
84
85
86def main(_):
87  words_list = input_data.prepare_words_list(FLAGS.wanted_words.split(','))
88  model_settings = models.prepare_model_settings(
89      len(words_list), FLAGS.sample_rate, FLAGS.clip_duration_ms,
90      FLAGS.window_size_ms, FLAGS.window_stride_ms, FLAGS.feature_bin_count,
91      'mfcc')
92  audio_processor = input_data.AudioProcessor(
93      '', FLAGS.data_dir, FLAGS.silence_percentage, 10,
94      FLAGS.wanted_words.split(','), FLAGS.validation_percentage,
95      FLAGS.testing_percentage, model_settings, FLAGS.data_dir)
96
97  output_audio_sample_count = FLAGS.sample_rate * FLAGS.test_duration_seconds
98  output_audio = np.zeros((output_audio_sample_count,), dtype=np.float32)
99
100  # Set up background audio.
101  background_crossover_ms = 500
102  background_segment_duration_ms = (
103      FLAGS.clip_duration_ms + background_crossover_ms)
104  background_segment_duration_samples = int(
105      (background_segment_duration_ms * FLAGS.sample_rate) / 1000)
106  background_segment_stride_samples = int(
107      (FLAGS.clip_duration_ms * FLAGS.sample_rate) / 1000)
108  background_ramp_samples = int(
109      ((background_crossover_ms / 2) * FLAGS.sample_rate) / 1000)
110
111  # Mix the background audio into the main track.
112  how_many_backgrounds = int(
113      math.ceil(output_audio_sample_count / background_segment_stride_samples))
114  for i in range(how_many_backgrounds):
115    output_offset = int(i * background_segment_stride_samples)
116    background_index = np.random.randint(len(audio_processor.background_data))
117    background_samples = audio_processor.background_data[background_index]
118    background_offset = np.random.randint(
119        0, len(background_samples) - model_settings['desired_samples'])
120    background_volume = np.random.uniform(0, FLAGS.background_volume)
121    mix_in_audio_sample(output_audio, output_offset, background_samples,
122                        background_offset, background_segment_duration_samples,
123                        background_volume, background_ramp_samples,
124                        background_ramp_samples)
125
126  # Mix the words into the main track, noting their labels and positions.
127  output_labels = []
128  word_stride_ms = FLAGS.clip_duration_ms + FLAGS.word_gap_ms
129  word_stride_samples = int((word_stride_ms * FLAGS.sample_rate) / 1000)
130  clip_duration_samples = int(
131      (FLAGS.clip_duration_ms * FLAGS.sample_rate) / 1000)
132  word_gap_samples = int((FLAGS.word_gap_ms * FLAGS.sample_rate) / 1000)
133  how_many_words = int(
134      math.floor(output_audio_sample_count / word_stride_samples))
135  all_test_data, all_test_labels = audio_processor.get_unprocessed_data(
136      -1, model_settings, 'testing')
137  for i in range(how_many_words):
138    output_offset = (
139        int(i * word_stride_samples) + np.random.randint(word_gap_samples))
140    output_offset_ms = (output_offset * 1000) / FLAGS.sample_rate
141    is_unknown = np.random.randint(100) < FLAGS.unknown_percentage
142    if is_unknown:
143      wanted_label = input_data.UNKNOWN_WORD_LABEL
144    else:
145      wanted_label = words_list[2 + np.random.randint(len(words_list) - 2)]
146    test_data_start = np.random.randint(len(all_test_data))
147    found_sample_data = None
148    index_lookup = np.arange(len(all_test_data), dtype=np.int32)
149    np.random.shuffle(index_lookup)
150    for test_data_offset in range(len(all_test_data)):
151      test_data_index = index_lookup[(
152          test_data_start + test_data_offset) % len(all_test_data)]
153      current_label = all_test_labels[test_data_index]
154      if current_label == wanted_label:
155        found_sample_data = all_test_data[test_data_index]
156        break
157    mix_in_audio_sample(output_audio, output_offset, found_sample_data, 0,
158                        clip_duration_samples, 1.0, 500, 500)
159    output_labels.append({'label': wanted_label, 'time': output_offset_ms})
160
161  input_data.save_wav_file(FLAGS.output_audio_file, output_audio,
162                           FLAGS.sample_rate)
163  tf.compat.v1.logging.info('Saved streaming test wav to %s',
164                            FLAGS.output_audio_file)
165
166  with open(FLAGS.output_labels_file, 'w') as f:
167    for output_label in output_labels:
168      f.write('%s, %f\n' % (output_label['label'], output_label['time']))
169  tf.compat.v1.logging.info('Saved streaming test labels to %s',
170                            FLAGS.output_labels_file)
171
172
173if __name__ == '__main__':
174  parser = argparse.ArgumentParser()
175  parser.add_argument(
176      '--data_url',
177      type=str,
178      # pylint: disable=line-too-long
179      default='https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.01.tar.gz',
180      # pylint: enable=line-too-long
181      help='Location of speech training data')
182  parser.add_argument(
183      '--data_dir',
184      type=str,
185      default='/tmp/speech_dataset',
186      help="""\
187      Where to download the speech training data to.
188      """)
189  parser.add_argument(
190      '--background_dir',
191      type=str,
192      default='',
193      help="""\
194      Path to a directory of .wav files to mix in as background noise during training.
195      """)
196  parser.add_argument(
197      '--background_volume',
198      type=float,
199      default=0.1,
200      help="""\
201      How loud the background noise should be, between 0 and 1.
202      """)
203  parser.add_argument(
204      '--background_frequency',
205      type=float,
206      default=0.8,
207      help="""\
208      How many of the training samples have background noise mixed in.
209      """)
210  parser.add_argument(
211      '--silence_percentage',
212      type=float,
213      default=10.0,
214      help="""\
215      How much of the training data should be silence.
216      """)
217  parser.add_argument(
218      '--testing_percentage',
219      type=int,
220      default=10,
221      help='What percentage of wavs to use as a test set.')
222  parser.add_argument(
223      '--validation_percentage',
224      type=int,
225      default=10,
226      help='What percentage of wavs to use as a validation set.')
227  parser.add_argument(
228      '--sample_rate',
229      type=int,
230      default=16000,
231      help='Expected sample rate of the wavs.',)
232  parser.add_argument(
233      '--clip_duration_ms',
234      type=int,
235      default=1000,
236      help='Expected duration in milliseconds of the wavs.',)
237  parser.add_argument(
238      '--window_size_ms',
239      type=float,
240      default=30.0,
241      help='How long each spectrogram timeslice is',)
242  parser.add_argument(
243      '--window_stride_ms',
244      type=float,
245      default=10.0,
246      help='How long the stride is between spectrogram timeslices',)
247  parser.add_argument(
248      '--feature_bin_count',
249      type=int,
250      default=40,
251      help='How many bins to use for the MFCC fingerprint',
252  )
253  parser.add_argument(
254      '--wanted_words',
255      type=str,
256      default='yes,no,up,down,left,right,on,off,stop,go',
257      help='Words to use (others will be added to an unknown label)',)
258  parser.add_argument(
259      '--output_audio_file',
260      type=str,
261      default='/tmp/speech_commands_train/streaming_test.wav',
262      help='File to save the generated test audio to.')
263  parser.add_argument(
264      '--output_labels_file',
265      type=str,
266      default='/tmp/speech_commands_train/streaming_test_labels.txt',
267      help='File to save the generated test labels to.')
268  parser.add_argument(
269      '--test_duration_seconds',
270      type=int,
271      default=600,
272      help='How long the generated test audio file should be.',)
273  parser.add_argument(
274      '--word_gap_ms',
275      type=int,
276      default=2000,
277      help='How long the average gap should be between words.',)
278  parser.add_argument(
279      '--unknown_percentage',
280      type=int,
281      default=30,
282      help='What percentage of words should be unknown.')
283
284  FLAGS, unparsed = parser.parse_known_args()
285  tf.compat.v1.app.run(main=main, argv=[sys.argv[0]] + unparsed)
286