1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Model definitions for simple speech recognition.
16
17"""
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import hashlib
23import math
24import os.path
25import random
26import re
27import sys
28import tarfile
29
30import numpy as np
31from six.moves import urllib
32from six.moves import xrange  # pylint: disable=redefined-builtin
33import tensorflow as tf
34
35from tensorflow.python.ops import gen_audio_ops as audio_ops
36from tensorflow.python.ops import io_ops
37from tensorflow.python.platform import gfile
38from tensorflow.python.util import compat
39
40tf.compat.v1.disable_eager_execution()
41
42# If it's available, load the specialized feature generator. If this doesn't
43# work, try building with bazel instead of running the Python script directly.
44try:
45  from tensorflow.lite.experimental.microfrontend.python.ops import audio_microfrontend_op as frontend_op  # pylint:disable=g-import-not-at-top
46except ImportError:
47  frontend_op = None
48
49MAX_NUM_WAVS_PER_CLASS = 2**27 - 1  # ~134M
50SILENCE_LABEL = '_silence_'
51SILENCE_INDEX = 0
52UNKNOWN_WORD_LABEL = '_unknown_'
53UNKNOWN_WORD_INDEX = 1
54BACKGROUND_NOISE_DIR_NAME = '_background_noise_'
55RANDOM_SEED = 59185
56
57
58def prepare_words_list(wanted_words):
59  """Prepends common tokens to the custom word list.
60
61  Args:
62    wanted_words: List of strings containing the custom words.
63
64  Returns:
65    List with the standard silence and unknown tokens added.
66  """
67  return [SILENCE_LABEL, UNKNOWN_WORD_LABEL] + wanted_words
68
69
70def which_set(filename, validation_percentage, testing_percentage):
71  """Determines which data partition the file should belong to.
72
73  We want to keep files in the same training, validation, or testing sets even
74  if new ones are added over time. This makes it less likely that testing
75  samples will accidentally be reused in training when long runs are restarted
76  for example. To keep this stability, a hash of the filename is taken and used
77  to determine which set it should belong to. This determination only depends on
78  the name and the set proportions, so it won't change as other files are added.
79
80  It's also useful to associate particular files as related (for example words
81  spoken by the same person), so anything after '_nohash_' in a filename is
82  ignored for set determination. This ensures that 'bobby_nohash_0.wav' and
83  'bobby_nohash_1.wav' are always in the same set, for example.
84
85  Args:
86    filename: File path of the data sample.
87    validation_percentage: How much of the data set to use for validation.
88    testing_percentage: How much of the data set to use for testing.
89
90  Returns:
91    String, one of 'training', 'validation', or 'testing'.
92  """
93  base_name = os.path.basename(filename)
94  # We want to ignore anything after '_nohash_' in the file name when
95  # deciding which set to put a wav in, so the data set creator has a way of
96  # grouping wavs that are close variations of each other.
97  hash_name = re.sub(r'_nohash_.*$', '', base_name)
98  # This looks a bit magical, but we need to decide whether this file should
99  # go into the training, testing, or validation sets, and we want to keep
100  # existing files in the same set even if more files are subsequently
101  # added.
102  # To do that, we need a stable way of deciding based on just the file name
103  # itself, so we do a hash of that and then use that to generate a
104  # probability value that we use to assign it.
105  hash_name_hashed = hashlib.sha1(compat.as_bytes(hash_name)).hexdigest()
106  percentage_hash = ((int(hash_name_hashed, 16) %
107                      (MAX_NUM_WAVS_PER_CLASS + 1)) *
108                     (100.0 / MAX_NUM_WAVS_PER_CLASS))
109  if percentage_hash < validation_percentage:
110    result = 'validation'
111  elif percentage_hash < (testing_percentage + validation_percentage):
112    result = 'testing'
113  else:
114    result = 'training'
115  return result
116
117
118def load_wav_file(filename):
119  """Loads an audio file and returns a float PCM-encoded array of samples.
120
121  Args:
122    filename: Path to the .wav file to load.
123
124  Returns:
125    Numpy array holding the sample data as floats between -1.0 and 1.0.
126  """
127  with tf.compat.v1.Session(graph=tf.Graph()) as sess:
128    wav_filename_placeholder = tf.compat.v1.placeholder(tf.string, [])
129    wav_loader = io_ops.read_file(wav_filename_placeholder)
130    wav_decoder = tf.audio.decode_wav(wav_loader, desired_channels=1)
131    return sess.run(
132        wav_decoder,
133        feed_dict={wav_filename_placeholder: filename}).audio.flatten()
134
135
136def save_wav_file(filename, wav_data, sample_rate):
137  """Saves audio sample data to a .wav audio file.
138
139  Args:
140    filename: Path to save the file to.
141    wav_data: 2D array of float PCM-encoded audio data.
142    sample_rate: Samples per second to encode in the file.
143  """
144  with tf.compat.v1.Session(graph=tf.Graph()) as sess:
145    wav_filename_placeholder = tf.compat.v1.placeholder(tf.string, [])
146    sample_rate_placeholder = tf.compat.v1.placeholder(tf.int32, [])
147    wav_data_placeholder = tf.compat.v1.placeholder(tf.float32, [None, 1])
148    wav_encoder = tf.audio.encode_wav(wav_data_placeholder,
149                                      sample_rate_placeholder)
150    wav_saver = io_ops.write_file(wav_filename_placeholder, wav_encoder)
151    sess.run(
152        wav_saver,
153        feed_dict={
154            wav_filename_placeholder: filename,
155            sample_rate_placeholder: sample_rate,
156            wav_data_placeholder: np.reshape(wav_data, (-1, 1))
157        })
158
159
160def get_features_range(model_settings):
161  """Returns the expected min/max for generated features.
162
163  Args:
164    model_settings: Information about the current model being trained.
165
166  Returns:
167    Min/max float pair holding the range of features.
168
169  Raises:
170    Exception: If preprocessing mode isn't recognized.
171  """
172  # TODO(petewarden): These values have been derived from the observed ranges
173  # of spectrogram and MFCC inputs. If the preprocessing pipeline changes,
174  # they may need to be updated.
175  if model_settings['preprocess'] == 'average':
176    features_min = 0.0
177    features_max = 127.5
178  elif model_settings['preprocess'] == 'mfcc':
179    features_min = -247.0
180    features_max = 30.0
181  elif model_settings['preprocess'] == 'micro':
182    features_min = 0.0
183    features_max = 26.0
184  else:
185    raise Exception('Unknown preprocess mode "%s" (should be "mfcc",'
186                    ' "average", or "micro")' % (model_settings['preprocess']))
187  return features_min, features_max
188
189
190class AudioProcessor(object):
191  """Handles loading, partitioning, and preparing audio training data."""
192
193  def __init__(self, data_url, data_dir, silence_percentage, unknown_percentage,
194               wanted_words, validation_percentage, testing_percentage,
195               model_settings, summaries_dir):
196    if data_dir:
197      self.data_dir = data_dir
198      self.maybe_download_and_extract_dataset(data_url, data_dir)
199      self.prepare_data_index(silence_percentage, unknown_percentage,
200                              wanted_words, validation_percentage,
201                              testing_percentage)
202      self.prepare_background_data()
203    self.prepare_processing_graph(model_settings, summaries_dir)
204
205  def maybe_download_and_extract_dataset(self, data_url, dest_directory):
206    """Download and extract data set tar file.
207
208    If the data set we're using doesn't already exist, this function
209    downloads it from the TensorFlow.org website and unpacks it into a
210    directory.
211    If the data_url is none, don't download anything and expect the data
212    directory to contain the correct files already.
213
214    Args:
215      data_url: Web location of the tar file containing the data set.
216      dest_directory: File path to extract data to.
217    """
218    if not data_url:
219      return
220    if not gfile.Exists(dest_directory):
221      os.makedirs(dest_directory)
222    filename = data_url.split('/')[-1]
223    filepath = os.path.join(dest_directory, filename)
224    if not gfile.Exists(filepath):
225
226      def _progress(count, block_size, total_size):
227        sys.stdout.write(
228            '\r>> Downloading %s %.1f%%' %
229            (filename, float(count * block_size) / float(total_size) * 100.0))
230        sys.stdout.flush()
231
232      try:
233        filepath, _ = urllib.request.urlretrieve(data_url, filepath, _progress)
234      except:
235        tf.compat.v1.logging.error(
236            'Failed to download URL: {0} to folder: {1}. Please make sure you '
237            'have enough free space and an internet connection'.format(
238                data_url, filepath))
239        raise
240      print()
241      statinfo = os.stat(filepath)
242      tf.compat.v1.logging.info(
243          'Successfully downloaded {0} ({1} bytes)'.format(
244              filename, statinfo.st_size))
245      tarfile.open(filepath, 'r:gz').extractall(dest_directory)
246
247  def prepare_data_index(self, silence_percentage, unknown_percentage,
248                         wanted_words, validation_percentage,
249                         testing_percentage):
250    """Prepares a list of the samples organized by set and label.
251
252    The training loop needs a list of all the available data, organized by
253    which partition it should belong to, and with ground truth labels attached.
254    This function analyzes the folders below the `data_dir`, figures out the
255    right
256    labels for each file based on the name of the subdirectory it belongs to,
257    and uses a stable hash to assign it to a data set partition.
258
259    Args:
260      silence_percentage: How much of the resulting data should be background.
261      unknown_percentage: How much should be audio outside the wanted classes.
262      wanted_words: Labels of the classes we want to be able to recognize.
263      validation_percentage: How much of the data set to use for validation.
264      testing_percentage: How much of the data set to use for testing.
265
266    Returns:
267      Dictionary containing a list of file information for each set partition,
268      and a lookup map for each class to determine its numeric index.
269
270    Raises:
271      Exception: If expected files are not found.
272    """
273    # Make sure the shuffling and picking of unknowns is deterministic.
274    random.seed(RANDOM_SEED)
275    wanted_words_index = {}
276    for index, wanted_word in enumerate(wanted_words):
277      wanted_words_index[wanted_word] = index + 2
278    self.data_index = {'validation': [], 'testing': [], 'training': []}
279    unknown_index = {'validation': [], 'testing': [], 'training': []}
280    all_words = {}
281    # Look through all the subfolders to find audio samples
282    search_path = os.path.join(self.data_dir, '*', '*.wav')
283    for wav_path in gfile.Glob(search_path):
284      _, word = os.path.split(os.path.dirname(wav_path))
285      word = word.lower()
286      # Treat the '_background_noise_' folder as a special case, since we expect
287      # it to contain long audio samples we mix in to improve training.
288      if word == BACKGROUND_NOISE_DIR_NAME:
289        continue
290      all_words[word] = True
291      set_index = which_set(wav_path, validation_percentage, testing_percentage)
292      # If it's a known class, store its detail, otherwise add it to the list
293      # we'll use to train the unknown label.
294      if word in wanted_words_index:
295        self.data_index[set_index].append({'label': word, 'file': wav_path})
296      else:
297        unknown_index[set_index].append({'label': word, 'file': wav_path})
298    if not all_words:
299      raise Exception('No .wavs found at ' + search_path)
300    for index, wanted_word in enumerate(wanted_words):
301      if wanted_word not in all_words:
302        raise Exception('Expected to find ' + wanted_word +
303                        ' in labels but only found ' +
304                        ', '.join(all_words.keys()))
305    # We need an arbitrary file to load as the input for the silence samples.
306    # It's multiplied by zero later, so the content doesn't matter.
307    silence_wav_path = self.data_index['training'][0]['file']
308    for set_index in ['validation', 'testing', 'training']:
309      set_size = len(self.data_index[set_index])
310      silence_size = int(math.ceil(set_size * silence_percentage / 100))
311      for _ in range(silence_size):
312        self.data_index[set_index].append({
313            'label': SILENCE_LABEL,
314            'file': silence_wav_path
315        })
316      # Pick some unknowns to add to each partition of the data set.
317      random.shuffle(unknown_index[set_index])
318      unknown_size = int(math.ceil(set_size * unknown_percentage / 100))
319      self.data_index[set_index].extend(unknown_index[set_index][:unknown_size])
320    # Make sure the ordering is random.
321    for set_index in ['validation', 'testing', 'training']:
322      random.shuffle(self.data_index[set_index])
323    # Prepare the rest of the result data structure.
324    self.words_list = prepare_words_list(wanted_words)
325    self.word_to_index = {}
326    for word in all_words:
327      if word in wanted_words_index:
328        self.word_to_index[word] = wanted_words_index[word]
329      else:
330        self.word_to_index[word] = UNKNOWN_WORD_INDEX
331    self.word_to_index[SILENCE_LABEL] = SILENCE_INDEX
332
333  def prepare_background_data(self):
334    """Searches a folder for background noise audio, and loads it into memory.
335
336    It's expected that the background audio samples will be in a subdirectory
337    named '_background_noise_' inside the 'data_dir' folder, as .wavs that match
338    the sample rate of the training data, but can be much longer in duration.
339
340    If the '_background_noise_' folder doesn't exist at all, this isn't an
341    error, it's just taken to mean that no background noise augmentation should
342    be used. If the folder does exist, but it's empty, that's treated as an
343    error.
344
345    Returns:
346      List of raw PCM-encoded audio samples of background noise.
347
348    Raises:
349      Exception: If files aren't found in the folder.
350    """
351    self.background_data = []
352    background_dir = os.path.join(self.data_dir, BACKGROUND_NOISE_DIR_NAME)
353    if not gfile.Exists(background_dir):
354      return self.background_data
355    with tf.compat.v1.Session(graph=tf.Graph()) as sess:
356      wav_filename_placeholder = tf.compat.v1.placeholder(tf.string, [])
357      wav_loader = io_ops.read_file(wav_filename_placeholder)
358      wav_decoder = tf.audio.decode_wav(wav_loader, desired_channels=1)
359      search_path = os.path.join(self.data_dir, BACKGROUND_NOISE_DIR_NAME,
360                                 '*.wav')
361      for wav_path in gfile.Glob(search_path):
362        wav_data = sess.run(
363            wav_decoder,
364            feed_dict={wav_filename_placeholder: wav_path}).audio.flatten()
365        self.background_data.append(wav_data)
366      if not self.background_data:
367        raise Exception('No background wav files were found in ' + search_path)
368
369  def prepare_processing_graph(self, model_settings, summaries_dir):
370    """Builds a TensorFlow graph to apply the input distortions.
371
372    Creates a graph that loads a WAVE file, decodes it, scales the volume,
373    shifts it in time, adds in background noise, calculates a spectrogram, and
374    then builds an MFCC fingerprint from that.
375
376    This must be called with an active TensorFlow session running, and it
377    creates multiple placeholder inputs, and one output:
378
379      - wav_filename_placeholder_: Filename of the WAV to load.
380      - foreground_volume_placeholder_: How loud the main clip should be.
381      - time_shift_padding_placeholder_: Where to pad the clip.
382      - time_shift_offset_placeholder_: How much to move the clip in time.
383      - background_data_placeholder_: PCM sample data for background noise.
384      - background_volume_placeholder_: Loudness of mixed-in background.
385      - output_: Output 2D fingerprint of processed audio.
386
387    Args:
388      model_settings: Information about the current model being trained.
389      summaries_dir: Path to save training summary information to.
390
391    Raises:
392      ValueError: If the preprocessing mode isn't recognized.
393      Exception: If the preprocessor wasn't compiled in.
394    """
395    with tf.compat.v1.get_default_graph().name_scope('data'):
396      desired_samples = model_settings['desired_samples']
397      self.wav_filename_placeholder_ = tf.compat.v1.placeholder(
398          tf.string, [], name='wav_filename')
399      wav_loader = io_ops.read_file(self.wav_filename_placeholder_)
400      wav_decoder = tf.audio.decode_wav(
401          wav_loader, desired_channels=1, desired_samples=desired_samples)
402      # Allow the audio sample's volume to be adjusted.
403      self.foreground_volume_placeholder_ = tf.compat.v1.placeholder(
404          tf.float32, [], name='foreground_volume')
405      scaled_foreground = tf.multiply(wav_decoder.audio,
406                                      self.foreground_volume_placeholder_)
407      # Shift the sample's start position, and pad any gaps with zeros.
408      self.time_shift_padding_placeholder_ = tf.compat.v1.placeholder(
409          tf.int32, [2, 2], name='time_shift_padding')
410      self.time_shift_offset_placeholder_ = tf.compat.v1.placeholder(
411          tf.int32, [2], name='time_shift_offset')
412      padded_foreground = tf.pad(
413          tensor=scaled_foreground,
414          paddings=self.time_shift_padding_placeholder_,
415          mode='CONSTANT')
416      sliced_foreground = tf.slice(padded_foreground,
417                                   self.time_shift_offset_placeholder_,
418                                   [desired_samples, -1])
419      # Mix in background noise.
420      self.background_data_placeholder_ = tf.compat.v1.placeholder(
421          tf.float32, [desired_samples, 1], name='background_data')
422      self.background_volume_placeholder_ = tf.compat.v1.placeholder(
423          tf.float32, [], name='background_volume')
424      background_mul = tf.multiply(self.background_data_placeholder_,
425                                   self.background_volume_placeholder_)
426      background_add = tf.add(background_mul, sliced_foreground)
427      background_clamp = tf.clip_by_value(background_add, -1.0, 1.0)
428      # Run the spectrogram and MFCC ops to get a 2D 'fingerprint' of the audio.
429      spectrogram = audio_ops.audio_spectrogram(
430          background_clamp,
431          window_size=model_settings['window_size_samples'],
432          stride=model_settings['window_stride_samples'],
433          magnitude_squared=True)
434      tf.compat.v1.summary.image(
435          'spectrogram', tf.expand_dims(spectrogram, -1), max_outputs=1)
436      # The number of buckets in each FFT row in the spectrogram will depend on
437      # how many input samples there are in each window. This can be quite
438      # large, with a 160 sample window producing 127 buckets for example. We
439      # don't need this level of detail for classification, so we often want to
440      # shrink them down to produce a smaller result. That's what this section
441      # implements. One method is to use average pooling to merge adjacent
442      # buckets, but a more sophisticated approach is to apply the MFCC
443      # algorithm to shrink the representation.
444      if model_settings['preprocess'] == 'average':
445        self.output_ = tf.nn.pool(
446            input=tf.expand_dims(spectrogram, -1),
447            window_shape=[1, model_settings['average_window_width']],
448            strides=[1, model_settings['average_window_width']],
449            pooling_type='AVG',
450            padding='SAME')
451        tf.compat.v1.summary.image('shrunk_spectrogram',
452                                   self.output_,
453                                   max_outputs=1)
454      elif model_settings['preprocess'] == 'mfcc':
455        self.output_ = audio_ops.mfcc(
456            spectrogram,
457            wav_decoder.sample_rate,
458            dct_coefficient_count=model_settings['fingerprint_width'])
459        tf.compat.v1.summary.image(
460            'mfcc', tf.expand_dims(self.output_, -1), max_outputs=1)
461      elif model_settings['preprocess'] == 'micro':
462        if not frontend_op:
463          raise Exception(
464              'Micro frontend op is currently not available when running'
465              ' TensorFlow directly from Python, you need to build and run'
466              ' through Bazel')
467        sample_rate = model_settings['sample_rate']
468        window_size_ms = (model_settings['window_size_samples'] *
469                          1000) / sample_rate
470        window_step_ms = (model_settings['window_stride_samples'] *
471                          1000) / sample_rate
472        int16_input = tf.cast(tf.multiply(background_clamp, 32768), tf.int16)
473        micro_frontend = frontend_op.audio_microfrontend(
474            int16_input,
475            sample_rate=sample_rate,
476            window_size=window_size_ms,
477            window_step=window_step_ms,
478            num_channels=model_settings['fingerprint_width'],
479            out_scale=1,
480            out_type=tf.float32)
481        self.output_ = tf.multiply(micro_frontend, (10.0 / 256.0))
482        tf.compat.v1.summary.image(
483            'micro',
484            tf.expand_dims(tf.expand_dims(self.output_, -1), 0),
485            max_outputs=1)
486      else:
487        raise ValueError('Unknown preprocess mode "%s" (should be "mfcc", '
488                         ' "average", or "micro")' %
489                         (model_settings['preprocess']))
490
491      # Merge all the summaries and write them out to /tmp/retrain_logs (by
492      # default)
493      self.merged_summaries_ = tf.compat.v1.summary.merge_all(scope='data')
494      if summaries_dir:
495        self.summary_writer_ = tf.compat.v1.summary.FileWriter(
496            summaries_dir + '/data', tf.compat.v1.get_default_graph())
497
498  def set_size(self, mode):
499    """Calculates the number of samples in the dataset partition.
500
501    Args:
502      mode: Which partition, must be 'training', 'validation', or 'testing'.
503
504    Returns:
505      Number of samples in the partition.
506    """
507    return len(self.data_index[mode])
508
509  def get_data(self, how_many, offset, model_settings, background_frequency,
510               background_volume_range, time_shift, mode, sess):
511    """Gather samples from the data set, applying transformations as needed.
512
513    When the mode is 'training', a random selection of samples will be returned,
514    otherwise the first N clips in the partition will be used. This ensures that
515    validation always uses the same samples, reducing noise in the metrics.
516
517    Args:
518      how_many: Desired number of samples to return. -1 means the entire
519        contents of this partition.
520      offset: Where to start when fetching deterministically.
521      model_settings: Information about the current model being trained.
522      background_frequency: How many clips will have background noise, 0.0 to
523        1.0.
524      background_volume_range: How loud the background noise will be.
525      time_shift: How much to randomly shift the clips by in time.
526      mode: Which partition to use, must be 'training', 'validation', or
527        'testing'.
528      sess: TensorFlow session that was active when processor was created.
529
530    Returns:
531      List of sample data for the transformed samples, and list of label indexes
532
533    Raises:
534      ValueError: If background samples are too short.
535    """
536    # Pick one of the partitions to choose samples from.
537    candidates = self.data_index[mode]
538    if how_many == -1:
539      sample_count = len(candidates)
540    else:
541      sample_count = max(0, min(how_many, len(candidates) - offset))
542    # Data and labels will be populated and returned.
543    data = np.zeros((sample_count, model_settings['fingerprint_size']))
544    labels = np.zeros(sample_count)
545    desired_samples = model_settings['desired_samples']
546    use_background = self.background_data and (mode == 'training')
547    pick_deterministically = (mode != 'training')
548    # Use the processing graph we created earlier to repeatedly to generate the
549    # final output sample data we'll use in training.
550    for i in xrange(offset, offset + sample_count):
551      # Pick which audio sample to use.
552      if how_many == -1 or pick_deterministically:
553        sample_index = i
554      else:
555        sample_index = np.random.randint(len(candidates))
556      sample = candidates[sample_index]
557      # If we're time shifting, set up the offset for this sample.
558      if time_shift > 0:
559        time_shift_amount = np.random.randint(-time_shift, time_shift)
560      else:
561        time_shift_amount = 0
562      if time_shift_amount > 0:
563        time_shift_padding = [[time_shift_amount, 0], [0, 0]]
564        time_shift_offset = [0, 0]
565      else:
566        time_shift_padding = [[0, -time_shift_amount], [0, 0]]
567        time_shift_offset = [-time_shift_amount, 0]
568      input_dict = {
569          self.wav_filename_placeholder_: sample['file'],
570          self.time_shift_padding_placeholder_: time_shift_padding,
571          self.time_shift_offset_placeholder_: time_shift_offset,
572      }
573      # Choose a section of background noise to mix in.
574      if use_background or sample['label'] == SILENCE_LABEL:
575        background_index = np.random.randint(len(self.background_data))
576        background_samples = self.background_data[background_index]
577        if len(background_samples) <= model_settings['desired_samples']:
578          raise ValueError(
579              'Background sample is too short! Need more than %d'
580              ' samples but only %d were found' %
581              (model_settings['desired_samples'], len(background_samples)))
582        background_offset = np.random.randint(
583            0, len(background_samples) - model_settings['desired_samples'])
584        background_clipped = background_samples[background_offset:(
585            background_offset + desired_samples)]
586        background_reshaped = background_clipped.reshape([desired_samples, 1])
587        if sample['label'] == SILENCE_LABEL:
588          background_volume = np.random.uniform(0, 1)
589        elif np.random.uniform(0, 1) < background_frequency:
590          background_volume = np.random.uniform(0, background_volume_range)
591        else:
592          background_volume = 0
593      else:
594        background_reshaped = np.zeros([desired_samples, 1])
595        background_volume = 0
596      input_dict[self.background_data_placeholder_] = background_reshaped
597      input_dict[self.background_volume_placeholder_] = background_volume
598      # If we want silence, mute out the main sample but leave the background.
599      if sample['label'] == SILENCE_LABEL:
600        input_dict[self.foreground_volume_placeholder_] = 0
601      else:
602        input_dict[self.foreground_volume_placeholder_] = 1
603      # Run the graph to produce the output audio.
604      summary, data_tensor = sess.run(
605          [self.merged_summaries_, self.output_], feed_dict=input_dict)
606      self.summary_writer_.add_summary(summary)
607      data[i - offset, :] = data_tensor.flatten()
608      label_index = self.word_to_index[sample['label']]
609      labels[i - offset] = label_index
610    return data, labels
611
612  def get_features_for_wav(self, wav_filename, model_settings, sess):
613    """Applies the feature transformation process to the input_wav.
614
615    Runs the feature generation process (generally producing a spectrogram from
616    the input samples) on the WAV file. This can be useful for testing and
617    verifying implementations being run on other platforms.
618
619    Args:
620      wav_filename: The path to the input audio file.
621      model_settings: Information about the current model being trained.
622      sess: TensorFlow session that was active when processor was created.
623
624    Returns:
625      Numpy data array containing the generated features.
626    """
627    desired_samples = model_settings['desired_samples']
628    input_dict = {
629        self.wav_filename_placeholder_: wav_filename,
630        self.time_shift_padding_placeholder_: [[0, 0], [0, 0]],
631        self.time_shift_offset_placeholder_: [0, 0],
632        self.background_data_placeholder_: np.zeros([desired_samples, 1]),
633        self.background_volume_placeholder_: 0,
634        self.foreground_volume_placeholder_: 1,
635    }
636    # Run the graph to produce the output audio.
637    data_tensor = sess.run([self.output_], feed_dict=input_dict)
638    return data_tensor
639
640  def get_unprocessed_data(self, how_many, model_settings, mode):
641    """Retrieve sample data for the given partition, with no transformations.
642
643    Args:
644      how_many: Desired number of samples to return. -1 means the entire
645        contents of this partition.
646      model_settings: Information about the current model being trained.
647      mode: Which partition to use, must be 'training', 'validation', or
648        'testing'.
649
650    Returns:
651      List of sample data for the samples, and list of labels in one-hot form.
652    """
653    candidates = self.data_index[mode]
654    if how_many == -1:
655      sample_count = len(candidates)
656    else:
657      sample_count = how_many
658    desired_samples = model_settings['desired_samples']
659    words_list = self.words_list
660    data = np.zeros((sample_count, desired_samples))
661    labels = []
662    with tf.compat.v1.Session(graph=tf.Graph()) as sess:
663      wav_filename_placeholder = tf.compat.v1.placeholder(tf.string, [])
664      wav_loader = io_ops.read_file(wav_filename_placeholder)
665      wav_decoder = tf.audio.decode_wav(
666          wav_loader, desired_channels=1, desired_samples=desired_samples)
667      foreground_volume_placeholder = tf.compat.v1.placeholder(tf.float32, [])
668      scaled_foreground = tf.multiply(wav_decoder.audio,
669                                      foreground_volume_placeholder)
670      for i in range(sample_count):
671        if how_many == -1:
672          sample_index = i
673        else:
674          sample_index = np.random.randint(len(candidates))
675        sample = candidates[sample_index]
676        input_dict = {wav_filename_placeholder: sample['file']}
677        if sample['label'] == SILENCE_LABEL:
678          input_dict[foreground_volume_placeholder] = 0
679        else:
680          input_dict[foreground_volume_placeholder] = 1
681        data[i, :] = sess.run(scaled_foreground, feed_dict=input_dict).flatten()
682        label_index = self.word_to_index[sample['label']]
683        labels.append(words_list[label_index])
684    return data, labels
685