15r"""Simple speech recognition to spot a limited number of keywords.
17This is a self-contained example script that will train a very basic audio
18recognition model in TensorFlow. It downloads the necessary training data and
19runs with reasonable defaults to train within a few hours even only using a CPU.
20For more information, please see
23It is intended as an introduction to using neural networks for audio
24recognition, and is not a full speech recognition system. For more advanced
25speech systems, I recommend looking into Kaldi. This network uses a keyword
26detection style to spot discrete words from a small vocabulary, consisting of
27"yes", "no", "up", "down", "left", "right", "on", "off", "stop", and "go".
29To run the training process, use:
31bazel run tensorflow/examples/speech_commands:train
33This will write out checkpoints to /tmp/speech_commands_train/, and will
34download over 1GB of open source training data, so you'll need enough free space
35and a good internet connection. The default data is a collection of thousands of
36one-second .wav files, each containing one spoken word. This data set is
37collected from https://aiyprojects.withgoogle.com/open_speech_recording, please
38consider contributing to help improve this and other models!
40As training progresses, it will print out its accuracy metrics, which should
41rise above 90% by the end. Once it's complete, you can run the freeze script to
42get a binary GraphDef that you can easily deploy on mobile applications.
44If you want to train on your own data, you'll need to create .wavs with your
45recordings, all at a consistent length, and then arrange them into subfolders
46organized by label. For example, here's a possible file structure:
48my_wavs >
49  up >
50    audio_0.wav
51    audio_1.wav
52  down >
53    audio_2.wav
54    audio_3.wav
55  other>
56    audio_4.wav
57    audio_5.wav
59You'll also need to tell the script what labels to look for, using the
60`--wanted_words` argument. In this case, 'up,down' might be what you want, and
61the audio in the 'other' folder would be used to train an 'unknown' category.
63To pull this all together, you'd run:
65bazel run tensorflow/examples/speech_commands:train -- \
66--data_dir=my_wavs --wanted_words=up,down
69from __future__ import absolute_import
70from __future__ import division
71from __future__ import print_function
73import argparse
74import os.path
75import sys
77import numpy as np
78from six.moves import xrange  # pylint: disable=redefined-builtin
79import tensorflow as tf
81import input_data
82import models
83from tensorflow.python.platform import gfile
85FLAGS = None
88def main(_):
89  # Set the verbosity based on flags (default is INFO, so we see all messages)
90  tf.compat.v1.logging.set_verbosity(FLAGS.verbosity)
92  # Start a new TensorFlow session.
93  sess = tf.compat.v1.InteractiveSession()
95  # Begin by making sure we have the training data we need. If you already have
96  # training data of your own, use `--data_url= ` on the command line to avoid
97  # downloading.
98  model_settings = models.prepare_model_settings(
99      len(input_data.prepare_words_list(FLAGS.wanted_words.split(','))),
100      FLAGS.sample_rate, FLAGS.clip_duration_ms, FLAGS.window_size_ms,
101      FLAGS.window_stride_ms, FLAGS.feature_bin_count, FLAGS.preprocess)
102  audio_processor = input_data.AudioProcessor(
103      FLAGS.data_url, FLAGS.data_dir,
104      FLAGS.silence_percentage, FLAGS.unknown_percentage,
105      FLAGS.wanted_words.split(','), FLAGS.validation_percentage,
106      FLAGS.testing_percentage, model_settings, FLAGS.summaries_dir)
107  fingerprint_size = model_settings['fingerprint_size']
108  label_count = model_settings['label_count']
109  time_shift_samples = int((FLAGS.time_shift_ms * FLAGS.sample_rate) / 1000)
110  # Figure out the learning rates for each training phase. Since it's often
111  # effective to have high learning rates at the start of training, followed by
112  # lower levels towards the end, the number of steps and learning rates can be
113  # specified as comma-separated lists to define the rate at each stage. For
114  # example --how_many_training_steps=10000,3000 --learning_rate=0.001,0.0001
115  # will run 13,000 training loops in total, with a rate of 0.001 for the first
116  # 10,000, and 0.0001 for the final 3,000.
117  training_steps_list = list(map(int, FLAGS.how_many_training_steps.split(',')))
118  learning_rates_list = list(map(float, FLAGS.learning_rate.split(',')))
119  if len(training_steps_list) != len(learning_rates_list):
120    raise Exception(
121        '--how_many_training_steps and --learning_rate must be equal length '
122        'lists, but are %d and %d long instead' % (len(training_steps_list),
123                                                   len(learning_rates_list)))
125  input_placeholder = tf.compat.v1.placeholder(
126      tf.float32, [None, fingerprint_size], name='fingerprint_input')
127  if FLAGS.quantize:
128    fingerprint_min, fingerprint_max = input_data.get_features_range(
129        model_settings)
130    fingerprint_input = tf.quantization.fake_quant_with_min_max_args(
131        input_placeholder, fingerprint_min, fingerprint_max)
132  else:
133    fingerprint_input = input_placeholder
135  logits, dropout_rate = models.create_model(
136      fingerprint_input,
137      model_settings,
138      FLAGS.model_architecture,
139      is_training=True)
141  # Define loss and optimizer
142  ground_truth_input = tf.compat.v1.placeholder(
143      tf.int64, [None], name='groundtruth_input')
145  # Optionally we can add runtime checks to spot when NaNs or other symptoms of
146  # numerical errors start occurring during training.
147  control_dependencies = []
148  if FLAGS.check_nans:
149    checks = tf.compat.v1.add_check_numerics_ops()
150    control_dependencies = [checks]
152  # Create the back propagation and training evaluation machinery in the graph.
153  with tf.compat.v1.name_scope('cross_entropy'):
154    cross_entropy_mean = tf.compat.v1.losses.sparse_softmax_cross_entropy(
155        labels=ground_truth_input, logits=logits)
157  if FLAGS.quantize:
158    try:
159      tf.contrib.quantize.create_training_graph(quant_delay=0)
160    except AttributeError as e:
161      msg = e.args[0]
162      msg += ('\n\n The --quantize option still requires contrib, which is not '
163              'part of TensorFlow 2.0. Please install a previous version:'
164              '\n    `pip install tensorflow<=1.15`')
165      e.args = (msg,)
166      raise e
168  with tf.compat.v1.name_scope('train'), tf.control_dependencies(
169      control_dependencies):
170    learning_rate_input = tf.compat.v1.placeholder(
171        tf.float32, [], name='learning_rate_input')
172    if FLAGS.optimizer == 'gradient_descent':
173      train_step = tf.compat.v1.train.GradientDescentOptimizer(
174          learning_rate_input).minimize(cross_entropy_mean)
175    elif FLAGS.optimizer == 'momentum':
176      train_step = tf.compat.v1.train.MomentumOptimizer(
177          learning_rate_input, .9,
178          use_nesterov=True).minimize(cross_entropy_mean)
179    else:
180      raise Exception('Invalid Optimizer')
181  predicted_indices = tf.argmax(input=logits, axis=1)
182  correct_prediction = tf.equal(predicted_indices, ground_truth_input)
183  confusion_matrix = tf.math.confusion_matrix(labels=ground_truth_input,
184                                              predictions=predicted_indices,
185                                              num_classes=label_count)
186  evaluation_step = tf.reduce_mean(input_tensor=tf.cast(correct_prediction,
187                                                        tf.float32))
188  with tf.compat.v1.get_default_graph().name_scope('eval'):
189    tf.compat.v1.summary.scalar('cross_entropy', cross_entropy_mean)
190    tf.compat.v1.summary.scalar('accuracy', evaluation_step)
192  global_step = tf.compat.v1.train.get_or_create_global_step()
193  increment_global_step = tf.compat.v1.assign(global_step, global_step + 1)
195  saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())
197  # Merge all the summaries and write them out to /tmp/retrain_logs (by default)
198  merged_summaries = tf.compat.v1.summary.merge_all(scope='eval')
199  train_writer = tf.compat.v1.summary.FileWriter(FLAGS.summaries_dir + '/train',
200                                                 sess.graph)
201  validation_writer = tf.compat.v1.summary.FileWriter(
202      FLAGS.summaries_dir + '/validation')
204  tf.compat.v1.global_variables_initializer().run()
206  start_step = 1
208  if FLAGS.start_checkpoint:
209    models.load_variables_from_checkpoint(sess, FLAGS.start_checkpoint)
210    start_step = global_step.eval(session=sess)
212  tf.compat.v1.logging.info('Training from step: %d ', start_step)
214  # Save graph.pbtxt.
215  tf.io.write_graph(sess.graph_def, FLAGS.train_dir,
216                    FLAGS.model_architecture + '.pbtxt')
218  # Save list of words.
219  with gfile.GFile(
220      os.path.join(FLAGS.train_dir, FLAGS.model_architecture + '_labels.txt'),
221      'w') as f:
222    f.write('\n'.join(audio_processor.words_list))
224  # Training loop.
225  training_steps_max = np.sum(training_steps_list)
226  for training_step in xrange(start_step, training_steps_max + 1):
227    # Figure out what the current learning rate is.
228    training_steps_sum = 0
229    for i in range(len(training_steps_list)):
230      training_steps_sum += training_steps_list[i]
231      if training_step <= training_steps_sum:
232        learning_rate_value = learning_rates_list[i]
233        break
234    # Pull the audio samples we'll use for training.
235    train_fingerprints, train_ground_truth = audio_processor.get_data(
236        FLAGS.batch_size, 0, model_settings, FLAGS.background_frequency,
237        FLAGS.background_volume, time_shift_samples, 'training', sess)
238    # Run the graph with this batch of training data.
239    train_summary, train_accuracy, cross_entropy_value, _, _ = sess.run(
240        [
241            merged_summaries,
242            evaluation_step,
243            cross_entropy_mean,
244            train_step,
245            increment_global_step,
246        ],
247        feed_dict={
248            fingerprint_input: train_fingerprints,
249            ground_truth_input: train_ground_truth,
250            learning_rate_input: learning_rate_value,
251            dropout_rate: 0.5
252        })
253    train_writer.add_summary(train_summary, training_step)
254    tf.compat.v1.logging.debug(
255        'Step #%d: rate %f, accuracy %.1f%%, cross entropy %f' %
256        (training_step, learning_rate_value, train_accuracy * 100,
257         cross_entropy_value))
258    is_last_step = (training_step == training_steps_max)
259    if (training_step % FLAGS.eval_step_interval) == 0 or is_last_step:
260      tf.compat.v1.logging.info(
261          'Step #%d: rate %f, accuracy %.1f%%, cross entropy %f' %
262          (training_step, learning_rate_value, train_accuracy * 100,
263           cross_entropy_value))
264      set_size = audio_processor.set_size('validation')
265      total_accuracy = 0
266      total_conf_matrix = None
267      for i in xrange(0, set_size, FLAGS.batch_size):
268        validation_fingerprints, validation_ground_truth = (
269            audio_processor.get_data(FLAGS.batch_size, i, model_settings, 0.0,
270                                     0.0, 0, 'validation', sess))
271        # Run a validation step and capture training summaries for TensorBoard
272        # with the `merged` op.
273        validation_summary, validation_accuracy, conf_matrix = sess.run(
274            [merged_summaries, evaluation_step, confusion_matrix],
275            feed_dict={
276                fingerprint_input: validation_fingerprints,
277                ground_truth_input: validation_ground_truth,
278                dropout_rate: 0.0
279            })
280        validation_writer.add_summary(validation_summary, training_step)
281        batch_size = min(FLAGS.batch_size, set_size - i)
282        total_accuracy += (validation_accuracy * batch_size) / set_size
283        if total_conf_matrix is None:
284          total_conf_matrix = conf_matrix
285        else:
286          total_conf_matrix += conf_matrix
287      tf.compat.v1.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix))
288      tf.compat.v1.logging.info('Step %d: Validation accuracy = %.1f%% (N=%d)' %
289                                (training_step, total_accuracy * 100, set_size))
291    # Save the model checkpoint periodically.
292    if (training_step % FLAGS.save_step_interval == 0 or
293        training_step == training_steps_max):
294      checkpoint_path = os.path.join(FLAGS.train_dir,
295                                     FLAGS.model_architecture + '.ckpt')
296      tf.compat.v1.logging.info('Saving to "%s-%d"', checkpoint_path,
297                                training_step)
298      saver.save(sess, checkpoint_path, global_step=training_step)
300  set_size = audio_processor.set_size('testing')
301  tf.compat.v1.logging.info('set_size=%d', set_size)
302  total_accuracy = 0
303  total_conf_matrix = None
304  for i in xrange(0, set_size, FLAGS.batch_size):
305    test_fingerprints, test_ground_truth = audio_processor.get_data(
306        FLAGS.batch_size, i, model_settings, 0.0, 0.0, 0, 'testing', sess)
307    test_accuracy, conf_matrix = sess.run(
308        [evaluation_step, confusion_matrix],
309        feed_dict={
310            fingerprint_input: test_fingerprints,
311            ground_truth_input: test_ground_truth,
312            dropout_rate: 0.0
313        })
314    batch_size = min(FLAGS.batch_size, set_size - i)
315    total_accuracy += (test_accuracy * batch_size) / set_size
316    if total_conf_matrix is None:
317      total_conf_matrix = conf_matrix
318    else:
319      total_conf_matrix += conf_matrix
320  tf.compat.v1.logging.warn('Confusion Matrix:\n %s' % (total_conf_matrix))
321  tf.compat.v1.logging.warn('Final test accuracy = %.1f%% (N=%d)' %
322                            (total_accuracy * 100, set_size))
325if __name__ == '__main__':
326  parser = argparse.ArgumentParser()
327  parser.add_argument(
328      '--data_url',
329      type=str,
330      # pylint: disable=line-too-long
331      default='https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.02.tar.gz',
332      # pylint: enable=line-too-long
333      help='Location of speech training data archive on the web.')
334  parser.add_argument(
335      '--data_dir',
336      type=str,
337      default='/tmp/speech_dataset/',
338      help="""\
339      Where to download the speech training data to.
340      """)
341  parser.add_argument(
342      '--background_volume',
343      type=float,
344      default=0.1,
345      help="""\
346      How loud the background noise should be, between 0 and 1.
347      """)
348  parser.add_argument(
349      '--background_frequency',
350      type=float,
351      default=0.8,
352      help="""\
353      How many of the training samples have background noise mixed in.
354      """)
355  parser.add_argument(
356      '--silence_percentage',
357      type=float,
358      default=10.0,
359      help="""\
360      How much of the training data should be silence.
361      """)
362  parser.add_argument(
363      '--unknown_percentage',
364      type=float,
365      default=10.0,
366      help="""\
367      How much of the training data should be unknown words.
368      """)
369  parser.add_argument(
370      '--time_shift_ms',
371      type=float,
372      default=100.0,
373      help="""\
374      Range to randomly shift the training audio by in time.
375      """)
376  parser.add_argument(
377      '--testing_percentage',
378      type=int,
379      default=10,
380      help='What percentage of wavs to use as a test set.')
381  parser.add_argument(
382      '--validation_percentage',
383      type=int,
384      default=10,
385      help='What percentage of wavs to use as a validation set.')
386  parser.add_argument(
387      '--sample_rate',
388      type=int,
389      default=16000,
390      help='Expected sample rate of the wavs',)
391  parser.add_argument(
392      '--clip_duration_ms',
393      type=int,
394      default=1000,
395      help='Expected duration in milliseconds of the wavs',)
396  parser.add_argument(
397      '--window_size_ms',
398      type=float,
399      default=30.0,
400      help='How long each spectrogram timeslice is.',)
401  parser.add_argument(
402      '--window_stride_ms',
403      type=float,
404      default=10.0,
405      help='How far to move in time between spectrogram timeslices.',
406  )
407  parser.add_argument(
408      '--feature_bin_count',
409      type=int,
410      default=40,
411      help='How many bins to use for the MFCC fingerprint',
412  )
413  parser.add_argument(
414      '--how_many_training_steps',
415      type=str,
416      default='15000,3000',
417      help='How many training loops to run',)
418  parser.add_argument(
419      '--eval_step_interval',
420      type=int,
421      default=400,
422      help='How often to evaluate the training results.')
423  parser.add_argument(
424      '--learning_rate',
425      type=str,
426      default='0.001,0.0001',
427      help='How large a learning rate to use when training.')
428  parser.add_argument(
429      '--batch_size',
430      type=int,
431      default=100,
432      help='How many items to train with at once',)
433  parser.add_argument(
434      '--summaries_dir',
435      type=str,
436      default='/tmp/retrain_logs',
437      help='Where to save summary logs for TensorBoard.')
438  parser.add_argument(
439      '--wanted_words',
440      type=str,
441      default='yes,no,up,down,left,right,on,off,stop,go',
442      help='Words to use (others will be added to an unknown label)',)
443  parser.add_argument(
444      '--train_dir',
445      type=str,
446      default='/tmp/speech_commands_train',
447      help='Directory to write event logs and checkpoint.')
448  parser.add_argument(
449      '--save_step_interval',
450      type=int,
451      default=100,
452      help='Save model checkpoint every save_steps.')
453  parser.add_argument(
454      '--start_checkpoint',
455      type=str,
456      default='',
457      help='If specified, restore this pretrained model before any training.')
458  parser.add_argument(
459      '--model_architecture',
460      type=str,
461      default='conv',
462      help='What model architecture to use')
463  parser.add_argument(
464      '--check_nans',
465      type=bool,
466      default=False,
467      help='Whether to check for invalid numbers during processing')
468  parser.add_argument(
469      '--quantize',
470      type=bool,
471      default=False,
472      help='Whether to train the model for eight-bit deployment')
473  parser.add_argument(
474      '--preprocess',
475      type=str,
476      default='mfcc',
477      help='Spectrogram processing mode. Can be "mfcc", "average", or "micro"')
479  # Function used to parse --verbosity argument
480  def verbosity_arg(value):
481    """Parses verbosity argument.
483    Args:
484      value: A member of tf.logging.
485    Raises:
486      ArgumentTypeError: Not an expected value.
487    """
488    value = value.upper()
489    if value == 'DEBUG':
490      return tf.compat.v1.logging.DEBUG
491    elif value == 'INFO':
492      return tf.compat.v1.logging.INFO
493    elif value == 'WARN':
494      return tf.compat.v1.logging.WARN
495    elif value == 'ERROR':
496      return tf.compat.v1.logging.ERROR
497    elif value == 'FATAL':
498      return tf.compat.v1.logging.FATAL
499    else:
500      raise argparse.ArgumentTypeError('Not an expected value')
501  parser.add_argument(
502      '--verbosity',
503      type=verbosity_arg,
504      default=tf.compat.v1.logging.INFO,
505      help='Log verbosity. Can be "DEBUG", "INFO", "WARN", "ERROR", or "FATAL"')
506  parser.add_argument(
507      '--optimizer',
508      type=str,
509      default='gradient_descent',
510      help='Optimizer (gradient_descent or momentum)')
512  FLAGS, unparsed = parser.parse_known_args()
513  tf.compat.v1.app.run(main=main, argv=[sys.argv[0]] + unparsed)