1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Model definitions for simple speech recognition. 16 17""" 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import hashlib 23import math 24import os.path 25import random 26import re 27import sys 28import tarfile 29 30import numpy as np 31from six.moves import urllib 32from six.moves import xrange # pylint: disable=redefined-builtin 33import tensorflow as tf 34 35from tensorflow.python.ops import gen_audio_ops as audio_ops 36from tensorflow.python.ops import io_ops 37from tensorflow.python.platform import gfile 38from tensorflow.python.util import compat 39 40tf.compat.v1.disable_eager_execution() 41 42# If it's available, load the specialized feature generator. If this doesn't 43# work, try building with bazel instead of running the Python script directly. 44try: 45 from tensorflow.lite.experimental.microfrontend.python.ops import audio_microfrontend_op as frontend_op # pylint:disable=g-import-not-at-top 46except ImportError: 47 frontend_op = None 48 49MAX_NUM_WAVS_PER_CLASS = 2**27 - 1 # ~134M 50SILENCE_LABEL = '_silence_' 51SILENCE_INDEX = 0 52UNKNOWN_WORD_LABEL = '_unknown_' 53UNKNOWN_WORD_INDEX = 1 54BACKGROUND_NOISE_DIR_NAME = '_background_noise_' 55RANDOM_SEED = 59185 56 57 58def prepare_words_list(wanted_words): 59 """Prepends common tokens to the custom word list. 60 61 Args: 62 wanted_words: List of strings containing the custom words. 63 64 Returns: 65 List with the standard silence and unknown tokens added. 66 """ 67 return [SILENCE_LABEL, UNKNOWN_WORD_LABEL] + wanted_words 68 69 70def which_set(filename, validation_percentage, testing_percentage): 71 """Determines which data partition the file should belong to. 72 73 We want to keep files in the same training, validation, or testing sets even 74 if new ones are added over time. This makes it less likely that testing 75 samples will accidentally be reused in training when long runs are restarted 76 for example. To keep this stability, a hash of the filename is taken and used 77 to determine which set it should belong to. This determination only depends on 78 the name and the set proportions, so it won't change as other files are added. 79 80 It's also useful to associate particular files as related (for example words 81 spoken by the same person), so anything after '_nohash_' in a filename is 82 ignored for set determination. This ensures that 'bobby_nohash_0.wav' and 83 'bobby_nohash_1.wav' are always in the same set, for example. 84 85 Args: 86 filename: File path of the data sample. 87 validation_percentage: How much of the data set to use for validation. 88 testing_percentage: How much of the data set to use for testing. 89 90 Returns: 91 String, one of 'training', 'validation', or 'testing'. 92 """ 93 base_name = os.path.basename(filename) 94 # We want to ignore anything after '_nohash_' in the file name when 95 # deciding which set to put a wav in, so the data set creator has a way of 96 # grouping wavs that are close variations of each other. 97 hash_name = re.sub(r'_nohash_.*$', '', base_name) 98 # This looks a bit magical, but we need to decide whether this file should 99 # go into the training, testing, or validation sets, and we want to keep 100 # existing files in the same set even if more files are subsequently 101 # added. 102 # To do that, we need a stable way of deciding based on just the file name 103 # itself, so we do a hash of that and then use that to generate a 104 # probability value that we use to assign it. 105 hash_name_hashed = hashlib.sha1(compat.as_bytes(hash_name)).hexdigest() 106 percentage_hash = ((int(hash_name_hashed, 16) % 107 (MAX_NUM_WAVS_PER_CLASS + 1)) * 108 (100.0 / MAX_NUM_WAVS_PER_CLASS)) 109 if percentage_hash < validation_percentage: 110 result = 'validation' 111 elif percentage_hash < (testing_percentage + validation_percentage): 112 result = 'testing' 113 else: 114 result = 'training' 115 return result 116 117 118def load_wav_file(filename): 119 """Loads an audio file and returns a float PCM-encoded array of samples. 120 121 Args: 122 filename: Path to the .wav file to load. 123 124 Returns: 125 Numpy array holding the sample data as floats between -1.0 and 1.0. 126 """ 127 with tf.compat.v1.Session(graph=tf.Graph()) as sess: 128 wav_filename_placeholder = tf.compat.v1.placeholder(tf.string, []) 129 wav_loader = io_ops.read_file(wav_filename_placeholder) 130 wav_decoder = tf.audio.decode_wav(wav_loader, desired_channels=1) 131 return sess.run( 132 wav_decoder, 133 feed_dict={wav_filename_placeholder: filename}).audio.flatten() 134 135 136def save_wav_file(filename, wav_data, sample_rate): 137 """Saves audio sample data to a .wav audio file. 138 139 Args: 140 filename: Path to save the file to. 141 wav_data: 2D array of float PCM-encoded audio data. 142 sample_rate: Samples per second to encode in the file. 143 """ 144 with tf.compat.v1.Session(graph=tf.Graph()) as sess: 145 wav_filename_placeholder = tf.compat.v1.placeholder(tf.string, []) 146 sample_rate_placeholder = tf.compat.v1.placeholder(tf.int32, []) 147 wav_data_placeholder = tf.compat.v1.placeholder(tf.float32, [None, 1]) 148 wav_encoder = tf.audio.encode_wav(wav_data_placeholder, 149 sample_rate_placeholder) 150 wav_saver = io_ops.write_file(wav_filename_placeholder, wav_encoder) 151 sess.run( 152 wav_saver, 153 feed_dict={ 154 wav_filename_placeholder: filename, 155 sample_rate_placeholder: sample_rate, 156 wav_data_placeholder: np.reshape(wav_data, (-1, 1)) 157 }) 158 159 160def get_features_range(model_settings): 161 """Returns the expected min/max for generated features. 162 163 Args: 164 model_settings: Information about the current model being trained. 165 166 Returns: 167 Min/max float pair holding the range of features. 168 169 Raises: 170 Exception: If preprocessing mode isn't recognized. 171 """ 172 # TODO(petewarden): These values have been derived from the observed ranges 173 # of spectrogram and MFCC inputs. If the preprocessing pipeline changes, 174 # they may need to be updated. 175 if model_settings['preprocess'] == 'average': 176 features_min = 0.0 177 features_max = 127.5 178 elif model_settings['preprocess'] == 'mfcc': 179 features_min = -247.0 180 features_max = 30.0 181 elif model_settings['preprocess'] == 'micro': 182 features_min = 0.0 183 features_max = 26.0 184 else: 185 raise Exception('Unknown preprocess mode "%s" (should be "mfcc",' 186 ' "average", or "micro")' % (model_settings['preprocess'])) 187 return features_min, features_max 188 189 190class AudioProcessor(object): 191 """Handles loading, partitioning, and preparing audio training data.""" 192 193 def __init__(self, data_url, data_dir, silence_percentage, unknown_percentage, 194 wanted_words, validation_percentage, testing_percentage, 195 model_settings, summaries_dir): 196 if data_dir: 197 self.data_dir = data_dir 198 self.maybe_download_and_extract_dataset(data_url, data_dir) 199 self.prepare_data_index(silence_percentage, unknown_percentage, 200 wanted_words, validation_percentage, 201 testing_percentage) 202 self.prepare_background_data() 203 self.prepare_processing_graph(model_settings, summaries_dir) 204 205 def maybe_download_and_extract_dataset(self, data_url, dest_directory): 206 """Download and extract data set tar file. 207 208 If the data set we're using doesn't already exist, this function 209 downloads it from the TensorFlow.org website and unpacks it into a 210 directory. 211 If the data_url is none, don't download anything and expect the data 212 directory to contain the correct files already. 213 214 Args: 215 data_url: Web location of the tar file containing the data set. 216 dest_directory: File path to extract data to. 217 """ 218 if not data_url: 219 return 220 if not gfile.Exists(dest_directory): 221 os.makedirs(dest_directory) 222 filename = data_url.split('/')[-1] 223 filepath = os.path.join(dest_directory, filename) 224 if not gfile.Exists(filepath): 225 226 def _progress(count, block_size, total_size): 227 sys.stdout.write( 228 '\r>> Downloading %s %.1f%%' % 229 (filename, float(count * block_size) / float(total_size) * 100.0)) 230 sys.stdout.flush() 231 232 try: 233 filepath, _ = urllib.request.urlretrieve(data_url, filepath, _progress) 234 except: 235 tf.compat.v1.logging.error( 236 'Failed to download URL: {0} to folder: {1}. Please make sure you ' 237 'have enough free space and an internet connection'.format( 238 data_url, filepath)) 239 raise 240 print() 241 statinfo = os.stat(filepath) 242 tf.compat.v1.logging.info( 243 'Successfully downloaded {0} ({1} bytes)'.format( 244 filename, statinfo.st_size)) 245 tarfile.open(filepath, 'r:gz').extractall(dest_directory) 246 247 def prepare_data_index(self, silence_percentage, unknown_percentage, 248 wanted_words, validation_percentage, 249 testing_percentage): 250 """Prepares a list of the samples organized by set and label. 251 252 The training loop needs a list of all the available data, organized by 253 which partition it should belong to, and with ground truth labels attached. 254 This function analyzes the folders below the `data_dir`, figures out the 255 right 256 labels for each file based on the name of the subdirectory it belongs to, 257 and uses a stable hash to assign it to a data set partition. 258 259 Args: 260 silence_percentage: How much of the resulting data should be background. 261 unknown_percentage: How much should be audio outside the wanted classes. 262 wanted_words: Labels of the classes we want to be able to recognize. 263 validation_percentage: How much of the data set to use for validation. 264 testing_percentage: How much of the data set to use for testing. 265 266 Returns: 267 Dictionary containing a list of file information for each set partition, 268 and a lookup map for each class to determine its numeric index. 269 270 Raises: 271 Exception: If expected files are not found. 272 """ 273 # Make sure the shuffling and picking of unknowns is deterministic. 274 random.seed(RANDOM_SEED) 275 wanted_words_index = {} 276 for index, wanted_word in enumerate(wanted_words): 277 wanted_words_index[wanted_word] = index + 2 278 self.data_index = {'validation': [], 'testing': [], 'training': []} 279 unknown_index = {'validation': [], 'testing': [], 'training': []} 280 all_words = {} 281 # Look through all the subfolders to find audio samples 282 search_path = os.path.join(self.data_dir, '*', '*.wav') 283 for wav_path in gfile.Glob(search_path): 284 _, word = os.path.split(os.path.dirname(wav_path)) 285 word = word.lower() 286 # Treat the '_background_noise_' folder as a special case, since we expect 287 # it to contain long audio samples we mix in to improve training. 288 if word == BACKGROUND_NOISE_DIR_NAME: 289 continue 290 all_words[word] = True 291 set_index = which_set(wav_path, validation_percentage, testing_percentage) 292 # If it's a known class, store its detail, otherwise add it to the list 293 # we'll use to train the unknown label. 294 if word in wanted_words_index: 295 self.data_index[set_index].append({'label': word, 'file': wav_path}) 296 else: 297 unknown_index[set_index].append({'label': word, 'file': wav_path}) 298 if not all_words: 299 raise Exception('No .wavs found at ' + search_path) 300 for index, wanted_word in enumerate(wanted_words): 301 if wanted_word not in all_words: 302 raise Exception('Expected to find ' + wanted_word + 303 ' in labels but only found ' + 304 ', '.join(all_words.keys())) 305 # We need an arbitrary file to load as the input for the silence samples. 306 # It's multiplied by zero later, so the content doesn't matter. 307 silence_wav_path = self.data_index['training'][0]['file'] 308 for set_index in ['validation', 'testing', 'training']: 309 set_size = len(self.data_index[set_index]) 310 silence_size = int(math.ceil(set_size * silence_percentage / 100)) 311 for _ in range(silence_size): 312 self.data_index[set_index].append({ 313 'label': SILENCE_LABEL, 314 'file': silence_wav_path 315 }) 316 # Pick some unknowns to add to each partition of the data set. 317 random.shuffle(unknown_index[set_index]) 318 unknown_size = int(math.ceil(set_size * unknown_percentage / 100)) 319 self.data_index[set_index].extend(unknown_index[set_index][:unknown_size]) 320 # Make sure the ordering is random. 321 for set_index in ['validation', 'testing', 'training']: 322 random.shuffle(self.data_index[set_index]) 323 # Prepare the rest of the result data structure. 324 self.words_list = prepare_words_list(wanted_words) 325 self.word_to_index = {} 326 for word in all_words: 327 if word in wanted_words_index: 328 self.word_to_index[word] = wanted_words_index[word] 329 else: 330 self.word_to_index[word] = UNKNOWN_WORD_INDEX 331 self.word_to_index[SILENCE_LABEL] = SILENCE_INDEX 332 333 def prepare_background_data(self): 334 """Searches a folder for background noise audio, and loads it into memory. 335 336 It's expected that the background audio samples will be in a subdirectory 337 named '_background_noise_' inside the 'data_dir' folder, as .wavs that match 338 the sample rate of the training data, but can be much longer in duration. 339 340 If the '_background_noise_' folder doesn't exist at all, this isn't an 341 error, it's just taken to mean that no background noise augmentation should 342 be used. If the folder does exist, but it's empty, that's treated as an 343 error. 344 345 Returns: 346 List of raw PCM-encoded audio samples of background noise. 347 348 Raises: 349 Exception: If files aren't found in the folder. 350 """ 351 self.background_data = [] 352 background_dir = os.path.join(self.data_dir, BACKGROUND_NOISE_DIR_NAME) 353 if not gfile.Exists(background_dir): 354 return self.background_data 355 with tf.compat.v1.Session(graph=tf.Graph()) as sess: 356 wav_filename_placeholder = tf.compat.v1.placeholder(tf.string, []) 357 wav_loader = io_ops.read_file(wav_filename_placeholder) 358 wav_decoder = tf.audio.decode_wav(wav_loader, desired_channels=1) 359 search_path = os.path.join(self.data_dir, BACKGROUND_NOISE_DIR_NAME, 360 '*.wav') 361 for wav_path in gfile.Glob(search_path): 362 wav_data = sess.run( 363 wav_decoder, 364 feed_dict={wav_filename_placeholder: wav_path}).audio.flatten() 365 self.background_data.append(wav_data) 366 if not self.background_data: 367 raise Exception('No background wav files were found in ' + search_path) 368 369 def prepare_processing_graph(self, model_settings, summaries_dir): 370 """Builds a TensorFlow graph to apply the input distortions. 371 372 Creates a graph that loads a WAVE file, decodes it, scales the volume, 373 shifts it in time, adds in background noise, calculates a spectrogram, and 374 then builds an MFCC fingerprint from that. 375 376 This must be called with an active TensorFlow session running, and it 377 creates multiple placeholder inputs, and one output: 378 379 - wav_filename_placeholder_: Filename of the WAV to load. 380 - foreground_volume_placeholder_: How loud the main clip should be. 381 - time_shift_padding_placeholder_: Where to pad the clip. 382 - time_shift_offset_placeholder_: How much to move the clip in time. 383 - background_data_placeholder_: PCM sample data for background noise. 384 - background_volume_placeholder_: Loudness of mixed-in background. 385 - output_: Output 2D fingerprint of processed audio. 386 387 Args: 388 model_settings: Information about the current model being trained. 389 summaries_dir: Path to save training summary information to. 390 391 Raises: 392 ValueError: If the preprocessing mode isn't recognized. 393 Exception: If the preprocessor wasn't compiled in. 394 """ 395 with tf.compat.v1.get_default_graph().name_scope('data'): 396 desired_samples = model_settings['desired_samples'] 397 self.wav_filename_placeholder_ = tf.compat.v1.placeholder( 398 tf.string, [], name='wav_filename') 399 wav_loader = io_ops.read_file(self.wav_filename_placeholder_) 400 wav_decoder = tf.audio.decode_wav( 401 wav_loader, desired_channels=1, desired_samples=desired_samples) 402 # Allow the audio sample's volume to be adjusted. 403 self.foreground_volume_placeholder_ = tf.compat.v1.placeholder( 404 tf.float32, [], name='foreground_volume') 405 scaled_foreground = tf.multiply(wav_decoder.audio, 406 self.foreground_volume_placeholder_) 407 # Shift the sample's start position, and pad any gaps with zeros. 408 self.time_shift_padding_placeholder_ = tf.compat.v1.placeholder( 409 tf.int32, [2, 2], name='time_shift_padding') 410 self.time_shift_offset_placeholder_ = tf.compat.v1.placeholder( 411 tf.int32, [2], name='time_shift_offset') 412 padded_foreground = tf.pad( 413 tensor=scaled_foreground, 414 paddings=self.time_shift_padding_placeholder_, 415 mode='CONSTANT') 416 sliced_foreground = tf.slice(padded_foreground, 417 self.time_shift_offset_placeholder_, 418 [desired_samples, -1]) 419 # Mix in background noise. 420 self.background_data_placeholder_ = tf.compat.v1.placeholder( 421 tf.float32, [desired_samples, 1], name='background_data') 422 self.background_volume_placeholder_ = tf.compat.v1.placeholder( 423 tf.float32, [], name='background_volume') 424 background_mul = tf.multiply(self.background_data_placeholder_, 425 self.background_volume_placeholder_) 426 background_add = tf.add(background_mul, sliced_foreground) 427 background_clamp = tf.clip_by_value(background_add, -1.0, 1.0) 428 # Run the spectrogram and MFCC ops to get a 2D 'fingerprint' of the audio. 429 spectrogram = audio_ops.audio_spectrogram( 430 background_clamp, 431 window_size=model_settings['window_size_samples'], 432 stride=model_settings['window_stride_samples'], 433 magnitude_squared=True) 434 tf.compat.v1.summary.image( 435 'spectrogram', tf.expand_dims(spectrogram, -1), max_outputs=1) 436 # The number of buckets in each FFT row in the spectrogram will depend on 437 # how many input samples there are in each window. This can be quite 438 # large, with a 160 sample window producing 127 buckets for example. We 439 # don't need this level of detail for classification, so we often want to 440 # shrink them down to produce a smaller result. That's what this section 441 # implements. One method is to use average pooling to merge adjacent 442 # buckets, but a more sophisticated approach is to apply the MFCC 443 # algorithm to shrink the representation. 444 if model_settings['preprocess'] == 'average': 445 self.output_ = tf.nn.pool( 446 input=tf.expand_dims(spectrogram, -1), 447 window_shape=[1, model_settings['average_window_width']], 448 strides=[1, model_settings['average_window_width']], 449 pooling_type='AVG', 450 padding='SAME') 451 tf.compat.v1.summary.image('shrunk_spectrogram', 452 self.output_, 453 max_outputs=1) 454 elif model_settings['preprocess'] == 'mfcc': 455 self.output_ = audio_ops.mfcc( 456 spectrogram, 457 wav_decoder.sample_rate, 458 dct_coefficient_count=model_settings['fingerprint_width']) 459 tf.compat.v1.summary.image( 460 'mfcc', tf.expand_dims(self.output_, -1), max_outputs=1) 461 elif model_settings['preprocess'] == 'micro': 462 if not frontend_op: 463 raise Exception( 464 'Micro frontend op is currently not available when running' 465 ' TensorFlow directly from Python, you need to build and run' 466 ' through Bazel') 467 sample_rate = model_settings['sample_rate'] 468 window_size_ms = (model_settings['window_size_samples'] * 469 1000) / sample_rate 470 window_step_ms = (model_settings['window_stride_samples'] * 471 1000) / sample_rate 472 int16_input = tf.cast(tf.multiply(background_clamp, 32768), tf.int16) 473 micro_frontend = frontend_op.audio_microfrontend( 474 int16_input, 475 sample_rate=sample_rate, 476 window_size=window_size_ms, 477 window_step=window_step_ms, 478 num_channels=model_settings['fingerprint_width'], 479 out_scale=1, 480 out_type=tf.float32) 481 self.output_ = tf.multiply(micro_frontend, (10.0 / 256.0)) 482 tf.compat.v1.summary.image( 483 'micro', 484 tf.expand_dims(tf.expand_dims(self.output_, -1), 0), 485 max_outputs=1) 486 else: 487 raise ValueError('Unknown preprocess mode "%s" (should be "mfcc", ' 488 ' "average", or "micro")' % 489 (model_settings['preprocess'])) 490 491 # Merge all the summaries and write them out to /tmp/retrain_logs (by 492 # default) 493 self.merged_summaries_ = tf.compat.v1.summary.merge_all(scope='data') 494 if summaries_dir: 495 self.summary_writer_ = tf.compat.v1.summary.FileWriter( 496 summaries_dir + '/data', tf.compat.v1.get_default_graph()) 497 498 def set_size(self, mode): 499 """Calculates the number of samples in the dataset partition. 500 501 Args: 502 mode: Which partition, must be 'training', 'validation', or 'testing'. 503 504 Returns: 505 Number of samples in the partition. 506 """ 507 return len(self.data_index[mode]) 508 509 def get_data(self, how_many, offset, model_settings, background_frequency, 510 background_volume_range, time_shift, mode, sess): 511 """Gather samples from the data set, applying transformations as needed. 512 513 When the mode is 'training', a random selection of samples will be returned, 514 otherwise the first N clips in the partition will be used. This ensures that 515 validation always uses the same samples, reducing noise in the metrics. 516 517 Args: 518 how_many: Desired number of samples to return. -1 means the entire 519 contents of this partition. 520 offset: Where to start when fetching deterministically. 521 model_settings: Information about the current model being trained. 522 background_frequency: How many clips will have background noise, 0.0 to 523 1.0. 524 background_volume_range: How loud the background noise will be. 525 time_shift: How much to randomly shift the clips by in time. 526 mode: Which partition to use, must be 'training', 'validation', or 527 'testing'. 528 sess: TensorFlow session that was active when processor was created. 529 530 Returns: 531 List of sample data for the transformed samples, and list of label indexes 532 533 Raises: 534 ValueError: If background samples are too short. 535 """ 536 # Pick one of the partitions to choose samples from. 537 candidates = self.data_index[mode] 538 if how_many == -1: 539 sample_count = len(candidates) 540 else: 541 sample_count = max(0, min(how_many, len(candidates) - offset)) 542 # Data and labels will be populated and returned. 543 data = np.zeros((sample_count, model_settings['fingerprint_size'])) 544 labels = np.zeros(sample_count) 545 desired_samples = model_settings['desired_samples'] 546 use_background = self.background_data and (mode == 'training') 547 pick_deterministically = (mode != 'training') 548 # Use the processing graph we created earlier to repeatedly to generate the 549 # final output sample data we'll use in training. 550 for i in xrange(offset, offset + sample_count): 551 # Pick which audio sample to use. 552 if how_many == -1 or pick_deterministically: 553 sample_index = i 554 else: 555 sample_index = np.random.randint(len(candidates)) 556 sample = candidates[sample_index] 557 # If we're time shifting, set up the offset for this sample. 558 if time_shift > 0: 559 time_shift_amount = np.random.randint(-time_shift, time_shift) 560 else: 561 time_shift_amount = 0 562 if time_shift_amount > 0: 563 time_shift_padding = [[time_shift_amount, 0], [0, 0]] 564 time_shift_offset = [0, 0] 565 else: 566 time_shift_padding = [[0, -time_shift_amount], [0, 0]] 567 time_shift_offset = [-time_shift_amount, 0] 568 input_dict = { 569 self.wav_filename_placeholder_: sample['file'], 570 self.time_shift_padding_placeholder_: time_shift_padding, 571 self.time_shift_offset_placeholder_: time_shift_offset, 572 } 573 # Choose a section of background noise to mix in. 574 if use_background or sample['label'] == SILENCE_LABEL: 575 background_index = np.random.randint(len(self.background_data)) 576 background_samples = self.background_data[background_index] 577 if len(background_samples) <= model_settings['desired_samples']: 578 raise ValueError( 579 'Background sample is too short! Need more than %d' 580 ' samples but only %d were found' % 581 (model_settings['desired_samples'], len(background_samples))) 582 background_offset = np.random.randint( 583 0, len(background_samples) - model_settings['desired_samples']) 584 background_clipped = background_samples[background_offset:( 585 background_offset + desired_samples)] 586 background_reshaped = background_clipped.reshape([desired_samples, 1]) 587 if sample['label'] == SILENCE_LABEL: 588 background_volume = np.random.uniform(0, 1) 589 elif np.random.uniform(0, 1) < background_frequency: 590 background_volume = np.random.uniform(0, background_volume_range) 591 else: 592 background_volume = 0 593 else: 594 background_reshaped = np.zeros([desired_samples, 1]) 595 background_volume = 0 596 input_dict[self.background_data_placeholder_] = background_reshaped 597 input_dict[self.background_volume_placeholder_] = background_volume 598 # If we want silence, mute out the main sample but leave the background. 599 if sample['label'] == SILENCE_LABEL: 600 input_dict[self.foreground_volume_placeholder_] = 0 601 else: 602 input_dict[self.foreground_volume_placeholder_] = 1 603 # Run the graph to produce the output audio. 604 summary, data_tensor = sess.run( 605 [self.merged_summaries_, self.output_], feed_dict=input_dict) 606 self.summary_writer_.add_summary(summary) 607 data[i - offset, :] = data_tensor.flatten() 608 label_index = self.word_to_index[sample['label']] 609 labels[i - offset] = label_index 610 return data, labels 611 612 def get_features_for_wav(self, wav_filename, model_settings, sess): 613 """Applies the feature transformation process to the input_wav. 614 615 Runs the feature generation process (generally producing a spectrogram from 616 the input samples) on the WAV file. This can be useful for testing and 617 verifying implementations being run on other platforms. 618 619 Args: 620 wav_filename: The path to the input audio file. 621 model_settings: Information about the current model being trained. 622 sess: TensorFlow session that was active when processor was created. 623 624 Returns: 625 Numpy data array containing the generated features. 626 """ 627 desired_samples = model_settings['desired_samples'] 628 input_dict = { 629 self.wav_filename_placeholder_: wav_filename, 630 self.time_shift_padding_placeholder_: [[0, 0], [0, 0]], 631 self.time_shift_offset_placeholder_: [0, 0], 632 self.background_data_placeholder_: np.zeros([desired_samples, 1]), 633 self.background_volume_placeholder_: 0, 634 self.foreground_volume_placeholder_: 1, 635 } 636 # Run the graph to produce the output audio. 637 data_tensor = sess.run([self.output_], feed_dict=input_dict) 638 return data_tensor 639 640 def get_unprocessed_data(self, how_many, model_settings, mode): 641 """Retrieve sample data for the given partition, with no transformations. 642 643 Args: 644 how_many: Desired number of samples to return. -1 means the entire 645 contents of this partition. 646 model_settings: Information about the current model being trained. 647 mode: Which partition to use, must be 'training', 'validation', or 648 'testing'. 649 650 Returns: 651 List of sample data for the samples, and list of labels in one-hot form. 652 """ 653 candidates = self.data_index[mode] 654 if how_many == -1: 655 sample_count = len(candidates) 656 else: 657 sample_count = how_many 658 desired_samples = model_settings['desired_samples'] 659 words_list = self.words_list 660 data = np.zeros((sample_count, desired_samples)) 661 labels = [] 662 with tf.compat.v1.Session(graph=tf.Graph()) as sess: 663 wav_filename_placeholder = tf.compat.v1.placeholder(tf.string, []) 664 wav_loader = io_ops.read_file(wav_filename_placeholder) 665 wav_decoder = tf.audio.decode_wav( 666 wav_loader, desired_channels=1, desired_samples=desired_samples) 667 foreground_volume_placeholder = tf.compat.v1.placeholder(tf.float32, []) 668 scaled_foreground = tf.multiply(wav_decoder.audio, 669 foreground_volume_placeholder) 670 for i in range(sample_count): 671 if how_many == -1: 672 sample_index = i 673 else: 674 sample_index = np.random.randint(len(candidates)) 675 sample = candidates[sample_index] 676 input_dict = {wav_filename_placeholder: sample['file']} 677 if sample['label'] == SILENCE_LABEL: 678 input_dict[foreground_volume_placeholder] = 0 679 else: 680 input_dict[foreground_volume_placeholder] = 1 681 data[i, :] = sess.run(scaled_foreground, feed_dict=input_dict).flatten() 682 label_index = self.word_to_index[sample['label']] 683 labels.append(words_list[label_index]) 684 return data, labels 685