1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15r"""Tool to create accuracy statistics on a continuous stream of samples. 16 17This is designed to be an environment for running experiments on new models and 18settings to understand the effects they will have in a real application. You 19need to supply it with a long audio file containing sounds you want to recognize 20and a text file listing the labels of each sound along with the time they occur. 21With this information, and a frozen model, the tool will process the audio 22stream, apply the model, and keep track of how many mistakes and successes the 23model achieved. 24 25The matched percentage is the number of sounds that were correctly classified, 26as a percentage of the total number of sounds listed in the ground truth file. 27A correct classification is when the right label is chosen within a short time 28of the expected ground truth, where the time tolerance is controlled by the 29'time_tolerance_ms' command line flag. 30 31The wrong percentage is how many sounds triggered a detection (the classifier 32figured out it wasn't silence or background noise), but the detected class was 33wrong. This is also a percentage of the total number of ground truth sounds. 34 35The false positive percentage is how many sounds were detected when there was 36only silence or background noise. This is also expressed as a percentage of the 37total number of ground truth sounds, though since it can be large it may go 38above 100%. 39 40The easiest way to get an audio file and labels to test with is by using the 41'generate_streaming_test_wav' script. This will synthesize a test file with 42randomly placed sounds and background noise, and output a text file with the 43ground truth. 44 45If you want to test natural data, you need to use a .wav with the same sample 46rate as your model (often 16,000 samples per second), and note down where the 47sounds occur in time. Save this information out as a comma-separated text file, 48where the first column is the label and the second is the time in seconds from 49the start of the file that it occurs. 50 51Here's an example of how to run the tool: 52 53bazel run tensorflow/examples/speech_commands:test_streaming_accuracy_py -- \ 54--wav=/tmp/streaming_test_bg.wav \ 55--ground-truth=/tmp/streaming_test_labels.txt --verbose \ 56--model=/tmp/conv_frozen.pb \ 57--labels=/tmp/speech_commands_train/conv_labels.txt \ 58--clip_duration_ms=1000 --detection_threshold=0.70 --average_window_ms=500 \ 59--suppression_ms=500 --time_tolerance_ms=1500 60""" 61 62from __future__ import absolute_import 63from __future__ import division 64from __future__ import print_function 65 66import argparse 67import sys 68 69import numpy 70import tensorflow as tf 71 72from accuracy_utils import StreamingAccuracyStats 73from recognize_commands import RecognizeCommands 74from recognize_commands import RecognizeResult 75from tensorflow.python.ops import io_ops 76 77FLAGS = None 78 79 80def load_graph(mode_file): 81 """Read a tensorflow model, and creates a default graph object.""" 82 graph = tf.Graph() 83 with graph.as_default(): 84 od_graph_def = tf.compat.v1.GraphDef() 85 with tf.io.gfile.GFile(mode_file, 'rb') as fid: 86 serialized_graph = fid.read() 87 od_graph_def.ParseFromString(serialized_graph) 88 tf.import_graph_def(od_graph_def, name='') 89 return graph 90 91 92def read_label_file(file_name): 93 """Load a list of label.""" 94 label_list = [] 95 with open(file_name, 'r') as f: 96 for line in f: 97 label_list.append(line.strip()) 98 return label_list 99 100 101def read_wav_file(filename): 102 """Load a wav file and return sample_rate and numpy data of float64 type.""" 103 with tf.compat.v1.Session(graph=tf.Graph()) as sess: 104 wav_filename_placeholder = tf.compat.v1.placeholder(tf.string, []) 105 wav_loader = io_ops.read_file(wav_filename_placeholder) 106 wav_decoder = tf.audio.decode_wav(wav_loader, desired_channels=1) 107 res = sess.run(wav_decoder, feed_dict={wav_filename_placeholder: filename}) 108 return res.sample_rate, res.audio.flatten() 109 110 111def main(_): 112 label_list = read_label_file(FLAGS.labels) 113 sample_rate, data = read_wav_file(FLAGS.wav) 114 # Init instance of RecognizeCommands with given parameters. 115 recognize_commands = RecognizeCommands( 116 labels=label_list, 117 average_window_duration_ms=FLAGS.average_window_duration_ms, 118 detection_threshold=FLAGS.detection_threshold, 119 suppression_ms=FLAGS.suppression_ms, 120 minimum_count=4) 121 122 # Init instance of StreamingAccuracyStats and load ground truth. 123 stats = StreamingAccuracyStats() 124 stats.read_ground_truth_file(FLAGS.ground_truth) 125 recognize_element = RecognizeResult() 126 all_found_words = [] 127 data_samples = data.shape[0] 128 clip_duration_samples = int(FLAGS.clip_duration_ms * sample_rate / 1000) 129 clip_stride_samples = int(FLAGS.clip_stride_ms * sample_rate / 1000) 130 audio_data_end = data_samples - clip_duration_samples 131 132 # Load model and create a tf session to process audio pieces 133 recognize_graph = load_graph(FLAGS.model) 134 with recognize_graph.as_default(): 135 with tf.compat.v1.Session() as sess: 136 137 # Get input and output tensor 138 data_tensor = sess.graph.get_tensor_by_name(FLAGS.input_names[0]) 139 sample_rate_tensor = sess.graph.get_tensor_by_name(FLAGS.input_names[1]) 140 output_softmax_tensor = sess.graph.get_tensor_by_name(FLAGS.output_name) 141 142 # Inference along audio stream. 143 for audio_data_offset in range(0, audio_data_end, clip_stride_samples): 144 input_start = audio_data_offset 145 input_end = audio_data_offset + clip_duration_samples 146 outputs = sess.run( 147 output_softmax_tensor, 148 feed_dict={ 149 data_tensor: 150 numpy.expand_dims(data[input_start:input_end], axis=-1), 151 sample_rate_tensor: 152 sample_rate 153 }) 154 outputs = numpy.squeeze(outputs) 155 current_time_ms = int(audio_data_offset * 1000 / sample_rate) 156 try: 157 recognize_commands.process_latest_result(outputs, current_time_ms, 158 recognize_element) 159 except ValueError as e: 160 tf.compat.v1.logging.error('Recognition processing failed: {}' % e) 161 return 162 if (recognize_element.is_new_command and 163 recognize_element.founded_command != '_silence_'): 164 all_found_words.append( 165 [recognize_element.founded_command, current_time_ms]) 166 if FLAGS.verbose: 167 stats.calculate_accuracy_stats(all_found_words, current_time_ms, 168 FLAGS.time_tolerance_ms) 169 try: 170 recognition_state = stats.delta() 171 except ValueError as e: 172 tf.compat.v1.logging.error( 173 'Statistics delta computing failed: {}'.format(e)) 174 else: 175 tf.compat.v1.logging.info('{}ms {}:{}{}'.format( 176 current_time_ms, recognize_element.founded_command, 177 recognize_element.score, recognition_state)) 178 stats.print_accuracy_stats() 179 stats.calculate_accuracy_stats(all_found_words, -1, FLAGS.time_tolerance_ms) 180 stats.print_accuracy_stats() 181 182 183if __name__ == '__main__': 184 parser = argparse.ArgumentParser(description='test_streaming_accuracy') 185 parser.add_argument( 186 '--wav', type=str, default='', help='The wave file path to evaluate.') 187 parser.add_argument( 188 '--ground-truth', 189 type=str, 190 default='', 191 help='The ground truth file path corresponding to wav file.') 192 parser.add_argument( 193 '--labels', 194 type=str, 195 default='', 196 help='The label file path containing all possible classes.') 197 parser.add_argument( 198 '--model', type=str, default='', help='The model used for inference') 199 parser.add_argument( 200 '--input-names', 201 type=str, 202 nargs='+', 203 default=['decoded_sample_data:0', 'decoded_sample_data:1'], 204 help='Input name list involved in model graph.') 205 parser.add_argument( 206 '--output-name', 207 type=str, 208 default='labels_softmax:0', 209 help='Output name involved in model graph.') 210 parser.add_argument( 211 '--clip-duration-ms', 212 type=int, 213 default=1000, 214 help='Length of each audio clip fed into model.') 215 parser.add_argument( 216 '--clip-stride-ms', 217 type=int, 218 default=30, 219 help='Length of audio clip stride over main trap.') 220 parser.add_argument( 221 '--average_window_duration_ms', 222 type=int, 223 default=500, 224 help='Length of average window used for smoothing results.') 225 parser.add_argument( 226 '--detection-threshold', 227 type=float, 228 default=0.7, 229 help='The confidence for filtering unreliable commands') 230 parser.add_argument( 231 '--suppression_ms', 232 type=int, 233 default=500, 234 help='The time interval between every two adjacent commands') 235 parser.add_argument( 236 '--time-tolerance-ms', 237 type=int, 238 default=1500, 239 help='Time tolerance before and after the timestamp of this audio clip ' 240 'to match ground truth') 241 parser.add_argument( 242 '--verbose', 243 action='store_true', 244 default=False, 245 help='Whether to print streaming accuracy on stdout.') 246 247 FLAGS, unparsed = parser.parse_known_args() 248 tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) 249 tf.compat.v1.app.run(main=main, argv=[sys.argv[0]] + unparsed) 250