1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15r"""Runs a trained audio graph against a WAVE file and reports the results.
16
17The model, labels and .wav file specified in the arguments will be loaded, and
18then the predictions from running the model against the audio data will be
19printed to the console. This is a useful script for sanity checking trained
20models, and as an example of how to use an audio model from Python.
21
22Here's an example of running it:
23
24python tensorflow/examples/speech_commands/label_wav.py \
25--graph=/tmp/my_frozen_graph.pb \
26--labels=/tmp/speech_commands_train/conv_labels.txt \
27--wav=/tmp/speech_dataset/left/a5d485dc_nohash_0.wav
28
29"""
30from __future__ import absolute_import
31from __future__ import division
32from __future__ import print_function
33
34import argparse
35import sys
36
37import tensorflow as tf
38
39
40FLAGS = None
41
42
43def load_graph(filename):
44  """Unpersists graph from file as default graph."""
45  with tf.io.gfile.GFile(filename, 'rb') as f:
46    graph_def = tf.compat.v1.GraphDef()
47    graph_def.ParseFromString(f.read())
48    tf.import_graph_def(graph_def, name='')
49
50
51def load_labels(filename):
52  """Read in labels, one label per line."""
53  return [line.rstrip() for line in tf.io.gfile.GFile(filename)]
54
55
56def run_graph(wav_data, labels, input_layer_name, output_layer_name,
57              num_top_predictions):
58  """Runs the audio data through the graph and prints predictions."""
59  with tf.compat.v1.Session() as sess:
60    # Feed the audio data as input to the graph.
61    #   predictions  will contain a two-dimensional array, where one
62    #   dimension represents the input image count, and the other has
63    #   predictions per class
64    softmax_tensor = sess.graph.get_tensor_by_name(output_layer_name)
65    predictions, = sess.run(softmax_tensor, {input_layer_name: wav_data})
66
67    # Sort to show labels in order of confidence
68    top_k = predictions.argsort()[-num_top_predictions:][::-1]
69    for node_id in top_k:
70      human_string = labels[node_id]
71      score = predictions[node_id]
72      print('%s (score = %.5f)' % (human_string, score))
73
74    return 0
75
76
77def label_wav(wav, labels, graph, input_name, output_name, how_many_labels):
78  """Loads the model and labels, and runs the inference to print predictions."""
79  if not wav or not tf.io.gfile.exists(wav):
80    raise ValueError('Audio file does not exist at {0}'.format(wav))
81  if not labels or not tf.io.gfile.exists(labels):
82    raise ValueError('Labels file does not exist at {0}'.format(labels))
83
84  if not graph or not tf.io.gfile.exists(graph):
85    raise ValueError('Graph file does not exist at {0}'.format(graph))
86
87  labels_list = load_labels(labels)
88
89  # load graph, which is stored in the default session
90  load_graph(graph)
91
92  with open(wav, 'rb') as wav_file:
93    wav_data = wav_file.read()
94
95  run_graph(wav_data, labels_list, input_name, output_name, how_many_labels)
96
97
98def main(_):
99  """Entry point for script, converts flags to arguments."""
100  label_wav(FLAGS.wav, FLAGS.labels, FLAGS.graph, FLAGS.input_name,
101            FLAGS.output_name, FLAGS.how_many_labels)
102
103
104if __name__ == '__main__':
105  parser = argparse.ArgumentParser()
106  parser.add_argument(
107      '--wav', type=str, default='', help='Audio file to be identified.')
108  parser.add_argument(
109      '--graph', type=str, default='', help='Model to use for identification.')
110  parser.add_argument(
111      '--labels', type=str, default='', help='Path to file containing labels.')
112  parser.add_argument(
113      '--input_name',
114      type=str,
115      default='wav_data:0',
116      help='Name of WAVE data input node in model.')
117  parser.add_argument(
118      '--output_name',
119      type=str,
120      default='labels_softmax:0',
121      help='Name of node outputting a prediction in the model.')
122  parser.add_argument(
123      '--how_many_labels',
124      type=int,
125      default=3,
126      help='Number of results to show.')
127
128  FLAGS, unparsed = parser.parse_known_args()
129  tf.compat.v1.app.run(main=main, argv=[sys.argv[0]] + unparsed)
130