1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15r"""Runs a trained audio graph against a WAVE file and reports the results. 16 17The model, labels and .wav file specified in the arguments will be loaded, and 18then the predictions from running the model against the audio data will be 19printed to the console. This is a useful script for sanity checking trained 20models, and as an example of how to use an audio model from Python. 21 22Here's an example of running it: 23 24python tensorflow/examples/speech_commands/label_wav.py \ 25--graph=/tmp/my_frozen_graph.pb \ 26--labels=/tmp/speech_commands_train/conv_labels.txt \ 27--wav=/tmp/speech_dataset/left/a5d485dc_nohash_0.wav 28 29""" 30from __future__ import absolute_import 31from __future__ import division 32from __future__ import print_function 33 34import argparse 35import sys 36 37import tensorflow as tf 38 39 40FLAGS = None 41 42 43def load_graph(filename): 44 """Unpersists graph from file as default graph.""" 45 with tf.io.gfile.GFile(filename, 'rb') as f: 46 graph_def = tf.compat.v1.GraphDef() 47 graph_def.ParseFromString(f.read()) 48 tf.import_graph_def(graph_def, name='') 49 50 51def load_labels(filename): 52 """Read in labels, one label per line.""" 53 return [line.rstrip() for line in tf.io.gfile.GFile(filename)] 54 55 56def run_graph(wav_data, labels, input_layer_name, output_layer_name, 57 num_top_predictions): 58 """Runs the audio data through the graph and prints predictions.""" 59 with tf.compat.v1.Session() as sess: 60 # Feed the audio data as input to the graph. 61 # predictions will contain a two-dimensional array, where one 62 # dimension represents the input image count, and the other has 63 # predictions per class 64 softmax_tensor = sess.graph.get_tensor_by_name(output_layer_name) 65 predictions, = sess.run(softmax_tensor, {input_layer_name: wav_data}) 66 67 # Sort to show labels in order of confidence 68 top_k = predictions.argsort()[-num_top_predictions:][::-1] 69 for node_id in top_k: 70 human_string = labels[node_id] 71 score = predictions[node_id] 72 print('%s (score = %.5f)' % (human_string, score)) 73 74 return 0 75 76 77def label_wav(wav, labels, graph, input_name, output_name, how_many_labels): 78 """Loads the model and labels, and runs the inference to print predictions.""" 79 if not wav or not tf.io.gfile.exists(wav): 80 raise ValueError('Audio file does not exist at {0}'.format(wav)) 81 if not labels or not tf.io.gfile.exists(labels): 82 raise ValueError('Labels file does not exist at {0}'.format(labels)) 83 84 if not graph or not tf.io.gfile.exists(graph): 85 raise ValueError('Graph file does not exist at {0}'.format(graph)) 86 87 labels_list = load_labels(labels) 88 89 # load graph, which is stored in the default session 90 load_graph(graph) 91 92 with open(wav, 'rb') as wav_file: 93 wav_data = wav_file.read() 94 95 run_graph(wav_data, labels_list, input_name, output_name, how_many_labels) 96 97 98def main(_): 99 """Entry point for script, converts flags to arguments.""" 100 label_wav(FLAGS.wav, FLAGS.labels, FLAGS.graph, FLAGS.input_name, 101 FLAGS.output_name, FLAGS.how_many_labels) 102 103 104if __name__ == '__main__': 105 parser = argparse.ArgumentParser() 106 parser.add_argument( 107 '--wav', type=str, default='', help='Audio file to be identified.') 108 parser.add_argument( 109 '--graph', type=str, default='', help='Model to use for identification.') 110 parser.add_argument( 111 '--labels', type=str, default='', help='Path to file containing labels.') 112 parser.add_argument( 113 '--input_name', 114 type=str, 115 default='wav_data:0', 116 help='Name of WAVE data input node in model.') 117 parser.add_argument( 118 '--output_name', 119 type=str, 120 default='labels_softmax:0', 121 help='Name of node outputting a prediction in the model.') 122 parser.add_argument( 123 '--how_many_labels', 124 type=int, 125 default=3, 126 help='Number of results to show.') 127 128 FLAGS, unparsed = parser.parse_known_args() 129 tf.compat.v1.app.run(main=main, argv=[sys.argv[0]] + unparsed) 130