1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15r"""Simple speech recognition to spot a limited number of keywords. 16 17This is a self-contained example script that will train a very basic audio 18recognition model in TensorFlow. It downloads the necessary training data and 19runs with reasonable defaults to train within a few hours even only using a CPU. 20For more information, please see 21https://www.tensorflow.org/tutorials/audio/simple_audio. 22 23It is intended as an introduction to using neural networks for audio 24recognition, and is not a full speech recognition system. For more advanced 25speech systems, I recommend looking into Kaldi. This network uses a keyword 26detection style to spot discrete words from a small vocabulary, consisting of 27"yes", "no", "up", "down", "left", "right", "on", "off", "stop", and "go". 28 29To run the training process, use: 30 31bazel run tensorflow/examples/speech_commands:train 32 33This will write out checkpoints to /tmp/speech_commands_train/, and will 34download over 1GB of open source training data, so you'll need enough free space 35and a good internet connection. The default data is a collection of thousands of 36one-second .wav files, each containing one spoken word. This data set is 37collected from https://aiyprojects.withgoogle.com/open_speech_recording, please 38consider contributing to help improve this and other models! 39 40As training progresses, it will print out its accuracy metrics, which should 41rise above 90% by the end. Once it's complete, you can run the freeze script to 42get a binary GraphDef that you can easily deploy on mobile applications. 43 44If you want to train on your own data, you'll need to create .wavs with your 45recordings, all at a consistent length, and then arrange them into subfolders 46organized by label. For example, here's a possible file structure: 47 48my_wavs > 49 up > 50 audio_0.wav 51 audio_1.wav 52 down > 53 audio_2.wav 54 audio_3.wav 55 other> 56 audio_4.wav 57 audio_5.wav 58 59You'll also need to tell the script what labels to look for, using the 60`--wanted_words` argument. In this case, 'up,down' might be what you want, and 61the audio in the 'other' folder would be used to train an 'unknown' category. 62 63To pull this all together, you'd run: 64 65bazel run tensorflow/examples/speech_commands:train -- \ 66--data_dir=my_wavs --wanted_words=up,down 67 68""" 69from __future__ import absolute_import 70from __future__ import division 71from __future__ import print_function 72 73import argparse 74import os.path 75import sys 76 77import numpy as np 78from six.moves import xrange # pylint: disable=redefined-builtin 79import tensorflow as tf 80 81import input_data 82import models 83from tensorflow.python.platform import gfile 84 85FLAGS = None 86 87 88def main(_): 89 # Set the verbosity based on flags (default is INFO, so we see all messages) 90 tf.compat.v1.logging.set_verbosity(FLAGS.verbosity) 91 92 # Start a new TensorFlow session. 93 sess = tf.compat.v1.InteractiveSession() 94 95 # Begin by making sure we have the training data we need. If you already have 96 # training data of your own, use `--data_url= ` on the command line to avoid 97 # downloading. 98 model_settings = models.prepare_model_settings( 99 len(input_data.prepare_words_list(FLAGS.wanted_words.split(','))), 100 FLAGS.sample_rate, FLAGS.clip_duration_ms, FLAGS.window_size_ms, 101 FLAGS.window_stride_ms, FLAGS.feature_bin_count, FLAGS.preprocess) 102 audio_processor = input_data.AudioProcessor( 103 FLAGS.data_url, FLAGS.data_dir, 104 FLAGS.silence_percentage, FLAGS.unknown_percentage, 105 FLAGS.wanted_words.split(','), FLAGS.validation_percentage, 106 FLAGS.testing_percentage, model_settings, FLAGS.summaries_dir) 107 fingerprint_size = model_settings['fingerprint_size'] 108 label_count = model_settings['label_count'] 109 time_shift_samples = int((FLAGS.time_shift_ms * FLAGS.sample_rate) / 1000) 110 # Figure out the learning rates for each training phase. Since it's often 111 # effective to have high learning rates at the start of training, followed by 112 # lower levels towards the end, the number of steps and learning rates can be 113 # specified as comma-separated lists to define the rate at each stage. For 114 # example --how_many_training_steps=10000,3000 --learning_rate=0.001,0.0001 115 # will run 13,000 training loops in total, with a rate of 0.001 for the first 116 # 10,000, and 0.0001 for the final 3,000. 117 training_steps_list = list(map(int, FLAGS.how_many_training_steps.split(','))) 118 learning_rates_list = list(map(float, FLAGS.learning_rate.split(','))) 119 if len(training_steps_list) != len(learning_rates_list): 120 raise Exception( 121 '--how_many_training_steps and --learning_rate must be equal length ' 122 'lists, but are %d and %d long instead' % (len(training_steps_list), 123 len(learning_rates_list))) 124 125 input_placeholder = tf.compat.v1.placeholder( 126 tf.float32, [None, fingerprint_size], name='fingerprint_input') 127 if FLAGS.quantize: 128 fingerprint_min, fingerprint_max = input_data.get_features_range( 129 model_settings) 130 fingerprint_input = tf.quantization.fake_quant_with_min_max_args( 131 input_placeholder, fingerprint_min, fingerprint_max) 132 else: 133 fingerprint_input = input_placeholder 134 135 logits, dropout_rate = models.create_model( 136 fingerprint_input, 137 model_settings, 138 FLAGS.model_architecture, 139 is_training=True) 140 141 # Define loss and optimizer 142 ground_truth_input = tf.compat.v1.placeholder( 143 tf.int64, [None], name='groundtruth_input') 144 145 # Optionally we can add runtime checks to spot when NaNs or other symptoms of 146 # numerical errors start occurring during training. 147 control_dependencies = [] 148 if FLAGS.check_nans: 149 checks = tf.compat.v1.add_check_numerics_ops() 150 control_dependencies = [checks] 151 152 # Create the back propagation and training evaluation machinery in the graph. 153 with tf.compat.v1.name_scope('cross_entropy'): 154 cross_entropy_mean = tf.compat.v1.losses.sparse_softmax_cross_entropy( 155 labels=ground_truth_input, logits=logits) 156 157 if FLAGS.quantize: 158 try: 159 tf.contrib.quantize.create_training_graph(quant_delay=0) 160 except AttributeError as e: 161 msg = e.args[0] 162 msg += ('\n\n The --quantize option still requires contrib, which is not ' 163 'part of TensorFlow 2.0. Please install a previous version:' 164 '\n `pip install tensorflow<=1.15`') 165 e.args = (msg,) 166 raise e 167 168 with tf.compat.v1.name_scope('train'), tf.control_dependencies( 169 control_dependencies): 170 learning_rate_input = tf.compat.v1.placeholder( 171 tf.float32, [], name='learning_rate_input') 172 if FLAGS.optimizer == 'gradient_descent': 173 train_step = tf.compat.v1.train.GradientDescentOptimizer( 174 learning_rate_input).minimize(cross_entropy_mean) 175 elif FLAGS.optimizer == 'momentum': 176 train_step = tf.compat.v1.train.MomentumOptimizer( 177 learning_rate_input, .9, 178 use_nesterov=True).minimize(cross_entropy_mean) 179 else: 180 raise Exception('Invalid Optimizer') 181 predicted_indices = tf.argmax(input=logits, axis=1) 182 correct_prediction = tf.equal(predicted_indices, ground_truth_input) 183 confusion_matrix = tf.math.confusion_matrix(labels=ground_truth_input, 184 predictions=predicted_indices, 185 num_classes=label_count) 186 evaluation_step = tf.reduce_mean(input_tensor=tf.cast(correct_prediction, 187 tf.float32)) 188 with tf.compat.v1.get_default_graph().name_scope('eval'): 189 tf.compat.v1.summary.scalar('cross_entropy', cross_entropy_mean) 190 tf.compat.v1.summary.scalar('accuracy', evaluation_step) 191 192 global_step = tf.compat.v1.train.get_or_create_global_step() 193 increment_global_step = tf.compat.v1.assign(global_step, global_step + 1) 194 195 saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables()) 196 197 # Merge all the summaries and write them out to /tmp/retrain_logs (by default) 198 merged_summaries = tf.compat.v1.summary.merge_all(scope='eval') 199 train_writer = tf.compat.v1.summary.FileWriter(FLAGS.summaries_dir + '/train', 200 sess.graph) 201 validation_writer = tf.compat.v1.summary.FileWriter( 202 FLAGS.summaries_dir + '/validation') 203 204 tf.compat.v1.global_variables_initializer().run() 205 206 start_step = 1 207 208 if FLAGS.start_checkpoint: 209 models.load_variables_from_checkpoint(sess, FLAGS.start_checkpoint) 210 start_step = global_step.eval(session=sess) 211 212 tf.compat.v1.logging.info('Training from step: %d ', start_step) 213 214 # Save graph.pbtxt. 215 tf.io.write_graph(sess.graph_def, FLAGS.train_dir, 216 FLAGS.model_architecture + '.pbtxt') 217 218 # Save list of words. 219 with gfile.GFile( 220 os.path.join(FLAGS.train_dir, FLAGS.model_architecture + '_labels.txt'), 221 'w') as f: 222 f.write('\n'.join(audio_processor.words_list)) 223 224 # Training loop. 225 training_steps_max = np.sum(training_steps_list) 226 for training_step in xrange(start_step, training_steps_max + 1): 227 # Figure out what the current learning rate is. 228 training_steps_sum = 0 229 for i in range(len(training_steps_list)): 230 training_steps_sum += training_steps_list[i] 231 if training_step <= training_steps_sum: 232 learning_rate_value = learning_rates_list[i] 233 break 234 # Pull the audio samples we'll use for training. 235 train_fingerprints, train_ground_truth = audio_processor.get_data( 236 FLAGS.batch_size, 0, model_settings, FLAGS.background_frequency, 237 FLAGS.background_volume, time_shift_samples, 'training', sess) 238 # Run the graph with this batch of training data. 239 train_summary, train_accuracy, cross_entropy_value, _, _ = sess.run( 240 [ 241 merged_summaries, 242 evaluation_step, 243 cross_entropy_mean, 244 train_step, 245 increment_global_step, 246 ], 247 feed_dict={ 248 fingerprint_input: train_fingerprints, 249 ground_truth_input: train_ground_truth, 250 learning_rate_input: learning_rate_value, 251 dropout_rate: 0.5 252 }) 253 train_writer.add_summary(train_summary, training_step) 254 tf.compat.v1.logging.debug( 255 'Step #%d: rate %f, accuracy %.1f%%, cross entropy %f' % 256 (training_step, learning_rate_value, train_accuracy * 100, 257 cross_entropy_value)) 258 is_last_step = (training_step == training_steps_max) 259 if (training_step % FLAGS.eval_step_interval) == 0 or is_last_step: 260 tf.compat.v1.logging.info( 261 'Step #%d: rate %f, accuracy %.1f%%, cross entropy %f' % 262 (training_step, learning_rate_value, train_accuracy * 100, 263 cross_entropy_value)) 264 set_size = audio_processor.set_size('validation') 265 total_accuracy = 0 266 total_conf_matrix = None 267 for i in xrange(0, set_size, FLAGS.batch_size): 268 validation_fingerprints, validation_ground_truth = ( 269 audio_processor.get_data(FLAGS.batch_size, i, model_settings, 0.0, 270 0.0, 0, 'validation', sess)) 271 # Run a validation step and capture training summaries for TensorBoard 272 # with the `merged` op. 273 validation_summary, validation_accuracy, conf_matrix = sess.run( 274 [merged_summaries, evaluation_step, confusion_matrix], 275 feed_dict={ 276 fingerprint_input: validation_fingerprints, 277 ground_truth_input: validation_ground_truth, 278 dropout_rate: 0.0 279 }) 280 validation_writer.add_summary(validation_summary, training_step) 281 batch_size = min(FLAGS.batch_size, set_size - i) 282 total_accuracy += (validation_accuracy * batch_size) / set_size 283 if total_conf_matrix is None: 284 total_conf_matrix = conf_matrix 285 else: 286 total_conf_matrix += conf_matrix 287 tf.compat.v1.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix)) 288 tf.compat.v1.logging.info('Step %d: Validation accuracy = %.1f%% (N=%d)' % 289 (training_step, total_accuracy * 100, set_size)) 290 291 # Save the model checkpoint periodically. 292 if (training_step % FLAGS.save_step_interval == 0 or 293 training_step == training_steps_max): 294 checkpoint_path = os.path.join(FLAGS.train_dir, 295 FLAGS.model_architecture + '.ckpt') 296 tf.compat.v1.logging.info('Saving to "%s-%d"', checkpoint_path, 297 training_step) 298 saver.save(sess, checkpoint_path, global_step=training_step) 299 300 set_size = audio_processor.set_size('testing') 301 tf.compat.v1.logging.info('set_size=%d', set_size) 302 total_accuracy = 0 303 total_conf_matrix = None 304 for i in xrange(0, set_size, FLAGS.batch_size): 305 test_fingerprints, test_ground_truth = audio_processor.get_data( 306 FLAGS.batch_size, i, model_settings, 0.0, 0.0, 0, 'testing', sess) 307 test_accuracy, conf_matrix = sess.run( 308 [evaluation_step, confusion_matrix], 309 feed_dict={ 310 fingerprint_input: test_fingerprints, 311 ground_truth_input: test_ground_truth, 312 dropout_rate: 0.0 313 }) 314 batch_size = min(FLAGS.batch_size, set_size - i) 315 total_accuracy += (test_accuracy * batch_size) / set_size 316 if total_conf_matrix is None: 317 total_conf_matrix = conf_matrix 318 else: 319 total_conf_matrix += conf_matrix 320 tf.compat.v1.logging.warn('Confusion Matrix:\n %s' % (total_conf_matrix)) 321 tf.compat.v1.logging.warn('Final test accuracy = %.1f%% (N=%d)' % 322 (total_accuracy * 100, set_size)) 323 324 325if __name__ == '__main__': 326 parser = argparse.ArgumentParser() 327 parser.add_argument( 328 '--data_url', 329 type=str, 330 # pylint: disable=line-too-long 331 default='https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.02.tar.gz', 332 # pylint: enable=line-too-long 333 help='Location of speech training data archive on the web.') 334 parser.add_argument( 335 '--data_dir', 336 type=str, 337 default='/tmp/speech_dataset/', 338 help="""\ 339 Where to download the speech training data to. 340 """) 341 parser.add_argument( 342 '--background_volume', 343 type=float, 344 default=0.1, 345 help="""\ 346 How loud the background noise should be, between 0 and 1. 347 """) 348 parser.add_argument( 349 '--background_frequency', 350 type=float, 351 default=0.8, 352 help="""\ 353 How many of the training samples have background noise mixed in. 354 """) 355 parser.add_argument( 356 '--silence_percentage', 357 type=float, 358 default=10.0, 359 help="""\ 360 How much of the training data should be silence. 361 """) 362 parser.add_argument( 363 '--unknown_percentage', 364 type=float, 365 default=10.0, 366 help="""\ 367 How much of the training data should be unknown words. 368 """) 369 parser.add_argument( 370 '--time_shift_ms', 371 type=float, 372 default=100.0, 373 help="""\ 374 Range to randomly shift the training audio by in time. 375 """) 376 parser.add_argument( 377 '--testing_percentage', 378 type=int, 379 default=10, 380 help='What percentage of wavs to use as a test set.') 381 parser.add_argument( 382 '--validation_percentage', 383 type=int, 384 default=10, 385 help='What percentage of wavs to use as a validation set.') 386 parser.add_argument( 387 '--sample_rate', 388 type=int, 389 default=16000, 390 help='Expected sample rate of the wavs',) 391 parser.add_argument( 392 '--clip_duration_ms', 393 type=int, 394 default=1000, 395 help='Expected duration in milliseconds of the wavs',) 396 parser.add_argument( 397 '--window_size_ms', 398 type=float, 399 default=30.0, 400 help='How long each spectrogram timeslice is.',) 401 parser.add_argument( 402 '--window_stride_ms', 403 type=float, 404 default=10.0, 405 help='How far to move in time between spectrogram timeslices.', 406 ) 407 parser.add_argument( 408 '--feature_bin_count', 409 type=int, 410 default=40, 411 help='How many bins to use for the MFCC fingerprint', 412 ) 413 parser.add_argument( 414 '--how_many_training_steps', 415 type=str, 416 default='15000,3000', 417 help='How many training loops to run',) 418 parser.add_argument( 419 '--eval_step_interval', 420 type=int, 421 default=400, 422 help='How often to evaluate the training results.') 423 parser.add_argument( 424 '--learning_rate', 425 type=str, 426 default='0.001,0.0001', 427 help='How large a learning rate to use when training.') 428 parser.add_argument( 429 '--batch_size', 430 type=int, 431 default=100, 432 help='How many items to train with at once',) 433 parser.add_argument( 434 '--summaries_dir', 435 type=str, 436 default='/tmp/retrain_logs', 437 help='Where to save summary logs for TensorBoard.') 438 parser.add_argument( 439 '--wanted_words', 440 type=str, 441 default='yes,no,up,down,left,right,on,off,stop,go', 442 help='Words to use (others will be added to an unknown label)',) 443 parser.add_argument( 444 '--train_dir', 445 type=str, 446 default='/tmp/speech_commands_train', 447 help='Directory to write event logs and checkpoint.') 448 parser.add_argument( 449 '--save_step_interval', 450 type=int, 451 default=100, 452 help='Save model checkpoint every save_steps.') 453 parser.add_argument( 454 '--start_checkpoint', 455 type=str, 456 default='', 457 help='If specified, restore this pretrained model before any training.') 458 parser.add_argument( 459 '--model_architecture', 460 type=str, 461 default='conv', 462 help='What model architecture to use') 463 parser.add_argument( 464 '--check_nans', 465 type=bool, 466 default=False, 467 help='Whether to check for invalid numbers during processing') 468 parser.add_argument( 469 '--quantize', 470 type=bool, 471 default=False, 472 help='Whether to train the model for eight-bit deployment') 473 parser.add_argument( 474 '--preprocess', 475 type=str, 476 default='mfcc', 477 help='Spectrogram processing mode. Can be "mfcc", "average", or "micro"') 478 479 # Function used to parse --verbosity argument 480 def verbosity_arg(value): 481 """Parses verbosity argument. 482 483 Args: 484 value: A member of tf.logging. 485 Raises: 486 ArgumentTypeError: Not an expected value. 487 """ 488 value = value.upper() 489 if value == 'DEBUG': 490 return tf.compat.v1.logging.DEBUG 491 elif value == 'INFO': 492 return tf.compat.v1.logging.INFO 493 elif value == 'WARN': 494 return tf.compat.v1.logging.WARN 495 elif value == 'ERROR': 496 return tf.compat.v1.logging.ERROR 497 elif value == 'FATAL': 498 return tf.compat.v1.logging.FATAL 499 else: 500 raise argparse.ArgumentTypeError('Not an expected value') 501 parser.add_argument( 502 '--verbosity', 503 type=verbosity_arg, 504 default=tf.compat.v1.logging.INFO, 505 help='Log verbosity. Can be "DEBUG", "INFO", "WARN", "ERROR", or "FATAL"') 506 parser.add_argument( 507 '--optimizer', 508 type=str, 509 default='gradient_descent', 510 help='Optimizer (gradient_descent or momentum)') 511 512 FLAGS, unparsed = parser.parse_known_args() 513 tf.compat.v1.app.run(main=main, argv=[sys.argv[0]] + unparsed) 514