1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15r"""Simple speech recognition to spot a limited number of keywords.
16
17This is a self-contained example script that will train a very basic audio
18recognition model in TensorFlow. It downloads the necessary training data and
19runs with reasonable defaults to train within a few hours even only using a CPU.
20For more information, please see
21https://www.tensorflow.org/tutorials/audio/simple_audio.
22
23It is intended as an introduction to using neural networks for audio
24recognition, and is not a full speech recognition system. For more advanced
25speech systems, I recommend looking into Kaldi. This network uses a keyword
26detection style to spot discrete words from a small vocabulary, consisting of
27"yes", "no", "up", "down", "left", "right", "on", "off", "stop", and "go".
28
29To run the training process, use:
30
31bazel run tensorflow/examples/speech_commands:train
32
33This will write out checkpoints to /tmp/speech_commands_train/, and will
34download over 1GB of open source training data, so you'll need enough free space
35and a good internet connection. The default data is a collection of thousands of
36one-second .wav files, each containing one spoken word. This data set is
37collected from https://aiyprojects.withgoogle.com/open_speech_recording, please
38consider contributing to help improve this and other models!
39
40As training progresses, it will print out its accuracy metrics, which should
41rise above 90% by the end. Once it's complete, you can run the freeze script to
42get a binary GraphDef that you can easily deploy on mobile applications.
43
44If you want to train on your own data, you'll need to create .wavs with your
45recordings, all at a consistent length, and then arrange them into subfolders
46organized by label. For example, here's a possible file structure:
47
48my_wavs >
49  up >
50    audio_0.wav
51    audio_1.wav
52  down >
53    audio_2.wav
54    audio_3.wav
55  other>
56    audio_4.wav
57    audio_5.wav
58
59You'll also need to tell the script what labels to look for, using the
60`--wanted_words` argument. In this case, 'up,down' might be what you want, and
61the audio in the 'other' folder would be used to train an 'unknown' category.
62
63To pull this all together, you'd run:
64
65bazel run tensorflow/examples/speech_commands:train -- \
66--data_dir=my_wavs --wanted_words=up,down
67
68"""
69from __future__ import absolute_import
70from __future__ import division
71from __future__ import print_function
72
73import argparse
74import os.path
75import sys
76
77import numpy as np
78from six.moves import xrange  # pylint: disable=redefined-builtin
79import tensorflow as tf
80
81import input_data
82import models
83from tensorflow.python.platform import gfile
84
85FLAGS = None
86
87
88def main(_):
89  # Set the verbosity based on flags (default is INFO, so we see all messages)
90  tf.compat.v1.logging.set_verbosity(FLAGS.verbosity)
91
92  # Start a new TensorFlow session.
93  sess = tf.compat.v1.InteractiveSession()
94
95  # Begin by making sure we have the training data we need. If you already have
96  # training data of your own, use `--data_url= ` on the command line to avoid
97  # downloading.
98  model_settings = models.prepare_model_settings(
99      len(input_data.prepare_words_list(FLAGS.wanted_words.split(','))),
100      FLAGS.sample_rate, FLAGS.clip_duration_ms, FLAGS.window_size_ms,
101      FLAGS.window_stride_ms, FLAGS.feature_bin_count, FLAGS.preprocess)
102  audio_processor = input_data.AudioProcessor(
103      FLAGS.data_url, FLAGS.data_dir,
104      FLAGS.silence_percentage, FLAGS.unknown_percentage,
105      FLAGS.wanted_words.split(','), FLAGS.validation_percentage,
106      FLAGS.testing_percentage, model_settings, FLAGS.summaries_dir)
107  fingerprint_size = model_settings['fingerprint_size']
108  label_count = model_settings['label_count']
109  time_shift_samples = int((FLAGS.time_shift_ms * FLAGS.sample_rate) / 1000)
110  # Figure out the learning rates for each training phase. Since it's often
111  # effective to have high learning rates at the start of training, followed by
112  # lower levels towards the end, the number of steps and learning rates can be
113  # specified as comma-separated lists to define the rate at each stage. For
114  # example --how_many_training_steps=10000,3000 --learning_rate=0.001,0.0001
115  # will run 13,000 training loops in total, with a rate of 0.001 for the first
116  # 10,000, and 0.0001 for the final 3,000.
117  training_steps_list = list(map(int, FLAGS.how_many_training_steps.split(',')))
118  learning_rates_list = list(map(float, FLAGS.learning_rate.split(',')))
119  if len(training_steps_list) != len(learning_rates_list):
120    raise Exception(
121        '--how_many_training_steps and --learning_rate must be equal length '
122        'lists, but are %d and %d long instead' % (len(training_steps_list),
123                                                   len(learning_rates_list)))
124
125  input_placeholder = tf.compat.v1.placeholder(
126      tf.float32, [None, fingerprint_size], name='fingerprint_input')
127  if FLAGS.quantize:
128    fingerprint_min, fingerprint_max = input_data.get_features_range(
129        model_settings)
130    fingerprint_input = tf.quantization.fake_quant_with_min_max_args(
131        input_placeholder, fingerprint_min, fingerprint_max)
132  else:
133    fingerprint_input = input_placeholder
134
135  logits, dropout_rate = models.create_model(
136      fingerprint_input,
137      model_settings,
138      FLAGS.model_architecture,
139      is_training=True)
140
141  # Define loss and optimizer
142  ground_truth_input = tf.compat.v1.placeholder(
143      tf.int64, [None], name='groundtruth_input')
144
145  # Optionally we can add runtime checks to spot when NaNs or other symptoms of
146  # numerical errors start occurring during training.
147  control_dependencies = []
148  if FLAGS.check_nans:
149    checks = tf.compat.v1.add_check_numerics_ops()
150    control_dependencies = [checks]
151
152  # Create the back propagation and training evaluation machinery in the graph.
153  with tf.compat.v1.name_scope('cross_entropy'):
154    cross_entropy_mean = tf.compat.v1.losses.sparse_softmax_cross_entropy(
155        labels=ground_truth_input, logits=logits)
156
157  if FLAGS.quantize:
158    try:
159      tf.contrib.quantize.create_training_graph(quant_delay=0)
160    except AttributeError as e:
161      msg = e.args[0]
162      msg += ('\n\n The --quantize option still requires contrib, which is not '
163              'part of TensorFlow 2.0. Please install a previous version:'
164              '\n    `pip install tensorflow<=1.15`')
165      e.args = (msg,)
166      raise e
167
168  with tf.compat.v1.name_scope('train'), tf.control_dependencies(
169      control_dependencies):
170    learning_rate_input = tf.compat.v1.placeholder(
171        tf.float32, [], name='learning_rate_input')
172    if FLAGS.optimizer == 'gradient_descent':
173      train_step = tf.compat.v1.train.GradientDescentOptimizer(
174          learning_rate_input).minimize(cross_entropy_mean)
175    elif FLAGS.optimizer == 'momentum':
176      train_step = tf.compat.v1.train.MomentumOptimizer(
177          learning_rate_input, .9,
178          use_nesterov=True).minimize(cross_entropy_mean)
179    else:
180      raise Exception('Invalid Optimizer')
181  predicted_indices = tf.argmax(input=logits, axis=1)
182  correct_prediction = tf.equal(predicted_indices, ground_truth_input)
183  confusion_matrix = tf.math.confusion_matrix(labels=ground_truth_input,
184                                              predictions=predicted_indices,
185                                              num_classes=label_count)
186  evaluation_step = tf.reduce_mean(input_tensor=tf.cast(correct_prediction,
187                                                        tf.float32))
188  with tf.compat.v1.get_default_graph().name_scope('eval'):
189    tf.compat.v1.summary.scalar('cross_entropy', cross_entropy_mean)
190    tf.compat.v1.summary.scalar('accuracy', evaluation_step)
191
192  global_step = tf.compat.v1.train.get_or_create_global_step()
193  increment_global_step = tf.compat.v1.assign(global_step, global_step + 1)
194
195  saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())
196
197  # Merge all the summaries and write them out to /tmp/retrain_logs (by default)
198  merged_summaries = tf.compat.v1.summary.merge_all(scope='eval')
199  train_writer = tf.compat.v1.summary.FileWriter(FLAGS.summaries_dir + '/train',
200                                                 sess.graph)
201  validation_writer = tf.compat.v1.summary.FileWriter(
202      FLAGS.summaries_dir + '/validation')
203
204  tf.compat.v1.global_variables_initializer().run()
205
206  start_step = 1
207
208  if FLAGS.start_checkpoint:
209    models.load_variables_from_checkpoint(sess, FLAGS.start_checkpoint)
210    start_step = global_step.eval(session=sess)
211
212  tf.compat.v1.logging.info('Training from step: %d ', start_step)
213
214  # Save graph.pbtxt.
215  tf.io.write_graph(sess.graph_def, FLAGS.train_dir,
216                    FLAGS.model_architecture + '.pbtxt')
217
218  # Save list of words.
219  with gfile.GFile(
220      os.path.join(FLAGS.train_dir, FLAGS.model_architecture + '_labels.txt'),
221      'w') as f:
222    f.write('\n'.join(audio_processor.words_list))
223
224  # Training loop.
225  training_steps_max = np.sum(training_steps_list)
226  for training_step in xrange(start_step, training_steps_max + 1):
227    # Figure out what the current learning rate is.
228    training_steps_sum = 0
229    for i in range(len(training_steps_list)):
230      training_steps_sum += training_steps_list[i]
231      if training_step <= training_steps_sum:
232        learning_rate_value = learning_rates_list[i]
233        break
234    # Pull the audio samples we'll use for training.
235    train_fingerprints, train_ground_truth = audio_processor.get_data(
236        FLAGS.batch_size, 0, model_settings, FLAGS.background_frequency,
237        FLAGS.background_volume, time_shift_samples, 'training', sess)
238    # Run the graph with this batch of training data.
239    train_summary, train_accuracy, cross_entropy_value, _, _ = sess.run(
240        [
241            merged_summaries,
242            evaluation_step,
243            cross_entropy_mean,
244            train_step,
245            increment_global_step,
246        ],
247        feed_dict={
248            fingerprint_input: train_fingerprints,
249            ground_truth_input: train_ground_truth,
250            learning_rate_input: learning_rate_value,
251            dropout_rate: 0.5
252        })
253    train_writer.add_summary(train_summary, training_step)
254    tf.compat.v1.logging.debug(
255        'Step #%d: rate %f, accuracy %.1f%%, cross entropy %f' %
256        (training_step, learning_rate_value, train_accuracy * 100,
257         cross_entropy_value))
258    is_last_step = (training_step == training_steps_max)
259    if (training_step % FLAGS.eval_step_interval) == 0 or is_last_step:
260      tf.compat.v1.logging.info(
261          'Step #%d: rate %f, accuracy %.1f%%, cross entropy %f' %
262          (training_step, learning_rate_value, train_accuracy * 100,
263           cross_entropy_value))
264      set_size = audio_processor.set_size('validation')
265      total_accuracy = 0
266      total_conf_matrix = None
267      for i in xrange(0, set_size, FLAGS.batch_size):
268        validation_fingerprints, validation_ground_truth = (
269            audio_processor.get_data(FLAGS.batch_size, i, model_settings, 0.0,
270                                     0.0, 0, 'validation', sess))
271        # Run a validation step and capture training summaries for TensorBoard
272        # with the `merged` op.
273        validation_summary, validation_accuracy, conf_matrix = sess.run(
274            [merged_summaries, evaluation_step, confusion_matrix],
275            feed_dict={
276                fingerprint_input: validation_fingerprints,
277                ground_truth_input: validation_ground_truth,
278                dropout_rate: 0.0
279            })
280        validation_writer.add_summary(validation_summary, training_step)
281        batch_size = min(FLAGS.batch_size, set_size - i)
282        total_accuracy += (validation_accuracy * batch_size) / set_size
283        if total_conf_matrix is None:
284          total_conf_matrix = conf_matrix
285        else:
286          total_conf_matrix += conf_matrix
287      tf.compat.v1.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix))
288      tf.compat.v1.logging.info('Step %d: Validation accuracy = %.1f%% (N=%d)' %
289                                (training_step, total_accuracy * 100, set_size))
290
291    # Save the model checkpoint periodically.
292    if (training_step % FLAGS.save_step_interval == 0 or
293        training_step == training_steps_max):
294      checkpoint_path = os.path.join(FLAGS.train_dir,
295                                     FLAGS.model_architecture + '.ckpt')
296      tf.compat.v1.logging.info('Saving to "%s-%d"', checkpoint_path,
297                                training_step)
298      saver.save(sess, checkpoint_path, global_step=training_step)
299
300  set_size = audio_processor.set_size('testing')
301  tf.compat.v1.logging.info('set_size=%d', set_size)
302  total_accuracy = 0
303  total_conf_matrix = None
304  for i in xrange(0, set_size, FLAGS.batch_size):
305    test_fingerprints, test_ground_truth = audio_processor.get_data(
306        FLAGS.batch_size, i, model_settings, 0.0, 0.0, 0, 'testing', sess)
307    test_accuracy, conf_matrix = sess.run(
308        [evaluation_step, confusion_matrix],
309        feed_dict={
310            fingerprint_input: test_fingerprints,
311            ground_truth_input: test_ground_truth,
312            dropout_rate: 0.0
313        })
314    batch_size = min(FLAGS.batch_size, set_size - i)
315    total_accuracy += (test_accuracy * batch_size) / set_size
316    if total_conf_matrix is None:
317      total_conf_matrix = conf_matrix
318    else:
319      total_conf_matrix += conf_matrix
320  tf.compat.v1.logging.warn('Confusion Matrix:\n %s' % (total_conf_matrix))
321  tf.compat.v1.logging.warn('Final test accuracy = %.1f%% (N=%d)' %
322                            (total_accuracy * 100, set_size))
323
324
325if __name__ == '__main__':
326  parser = argparse.ArgumentParser()
327  parser.add_argument(
328      '--data_url',
329      type=str,
330      # pylint: disable=line-too-long
331      default='https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.02.tar.gz',
332      # pylint: enable=line-too-long
333      help='Location of speech training data archive on the web.')
334  parser.add_argument(
335      '--data_dir',
336      type=str,
337      default='/tmp/speech_dataset/',
338      help="""\
339      Where to download the speech training data to.
340      """)
341  parser.add_argument(
342      '--background_volume',
343      type=float,
344      default=0.1,
345      help="""\
346      How loud the background noise should be, between 0 and 1.
347      """)
348  parser.add_argument(
349      '--background_frequency',
350      type=float,
351      default=0.8,
352      help="""\
353      How many of the training samples have background noise mixed in.
354      """)
355  parser.add_argument(
356      '--silence_percentage',
357      type=float,
358      default=10.0,
359      help="""\
360      How much of the training data should be silence.
361      """)
362  parser.add_argument(
363      '--unknown_percentage',
364      type=float,
365      default=10.0,
366      help="""\
367      How much of the training data should be unknown words.
368      """)
369  parser.add_argument(
370      '--time_shift_ms',
371      type=float,
372      default=100.0,
373      help="""\
374      Range to randomly shift the training audio by in time.
375      """)
376  parser.add_argument(
377      '--testing_percentage',
378      type=int,
379      default=10,
380      help='What percentage of wavs to use as a test set.')
381  parser.add_argument(
382      '--validation_percentage',
383      type=int,
384      default=10,
385      help='What percentage of wavs to use as a validation set.')
386  parser.add_argument(
387      '--sample_rate',
388      type=int,
389      default=16000,
390      help='Expected sample rate of the wavs',)
391  parser.add_argument(
392      '--clip_duration_ms',
393      type=int,
394      default=1000,
395      help='Expected duration in milliseconds of the wavs',)
396  parser.add_argument(
397      '--window_size_ms',
398      type=float,
399      default=30.0,
400      help='How long each spectrogram timeslice is.',)
401  parser.add_argument(
402      '--window_stride_ms',
403      type=float,
404      default=10.0,
405      help='How far to move in time between spectrogram timeslices.',
406  )
407  parser.add_argument(
408      '--feature_bin_count',
409      type=int,
410      default=40,
411      help='How many bins to use for the MFCC fingerprint',
412  )
413  parser.add_argument(
414      '--how_many_training_steps',
415      type=str,
416      default='15000,3000',
417      help='How many training loops to run',)
418  parser.add_argument(
419      '--eval_step_interval',
420      type=int,
421      default=400,
422      help='How often to evaluate the training results.')
423  parser.add_argument(
424      '--learning_rate',
425      type=str,
426      default='0.001,0.0001',
427      help='How large a learning rate to use when training.')
428  parser.add_argument(
429      '--batch_size',
430      type=int,
431      default=100,
432      help='How many items to train with at once',)
433  parser.add_argument(
434      '--summaries_dir',
435      type=str,
436      default='/tmp/retrain_logs',
437      help='Where to save summary logs for TensorBoard.')
438  parser.add_argument(
439      '--wanted_words',
440      type=str,
441      default='yes,no,up,down,left,right,on,off,stop,go',
442      help='Words to use (others will be added to an unknown label)',)
443  parser.add_argument(
444      '--train_dir',
445      type=str,
446      default='/tmp/speech_commands_train',
447      help='Directory to write event logs and checkpoint.')
448  parser.add_argument(
449      '--save_step_interval',
450      type=int,
451      default=100,
452      help='Save model checkpoint every save_steps.')
453  parser.add_argument(
454      '--start_checkpoint',
455      type=str,
456      default='',
457      help='If specified, restore this pretrained model before any training.')
458  parser.add_argument(
459      '--model_architecture',
460      type=str,
461      default='conv',
462      help='What model architecture to use')
463  parser.add_argument(
464      '--check_nans',
465      type=bool,
466      default=False,
467      help='Whether to check for invalid numbers during processing')
468  parser.add_argument(
469      '--quantize',
470      type=bool,
471      default=False,
472      help='Whether to train the model for eight-bit deployment')
473  parser.add_argument(
474      '--preprocess',
475      type=str,
476      default='mfcc',
477      help='Spectrogram processing mode. Can be "mfcc", "average", or "micro"')
478
479  # Function used to parse --verbosity argument
480  def verbosity_arg(value):
481    """Parses verbosity argument.
482
483    Args:
484      value: A member of tf.logging.
485    Raises:
486      ArgumentTypeError: Not an expected value.
487    """
488    value = value.upper()
489    if value == 'DEBUG':
490      return tf.compat.v1.logging.DEBUG
491    elif value == 'INFO':
492      return tf.compat.v1.logging.INFO
493    elif value == 'WARN':
494      return tf.compat.v1.logging.WARN
495    elif value == 'ERROR':
496      return tf.compat.v1.logging.ERROR
497    elif value == 'FATAL':
498      return tf.compat.v1.logging.FATAL
499    else:
500      raise argparse.ArgumentTypeError('Not an expected value')
501  parser.add_argument(
502      '--verbosity',
503      type=verbosity_arg,
504      default=tf.compat.v1.logging.INFO,
505      help='Log verbosity. Can be "DEBUG", "INFO", "WARN", "ERROR", or "FATAL"')
506  parser.add_argument(
507      '--optimizer',
508      type=str,
509      default='gradient_descent',
510      help='Optimizer (gradient_descent or momentum)')
511
512  FLAGS, unparsed = parser.parse_known_args()
513  tf.compat.v1.app.run(main=main, argv=[sys.argv[0]] + unparsed)
514