1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15r"""Tool to create accuracy statistics on a continuous stream of samples.
16
17This is designed to be an environment for running experiments on new models and
18settings to understand the effects they will have in a real application. You
19need to supply it with a long audio file containing sounds you want to recognize
20and a text file listing the labels of each sound along with the time they occur.
21With this information, and a frozen model, the tool will process the audio
22stream, apply the model, and keep track of how many mistakes and successes the
23model achieved.
24
25The matched percentage is the number of sounds that were correctly classified,
26as a percentage of the total number of sounds listed in the ground truth file.
27A correct classification is when the right label is chosen within a short time
28of the expected ground truth, where the time tolerance is controlled by the
29'time_tolerance_ms' command line flag.
30
31The wrong percentage is how many sounds triggered a detection (the classifier
32figured out it wasn't silence or background noise), but the detected class was
33wrong. This is also a percentage of the total number of ground truth sounds.
34
35The false positive percentage is how many sounds were detected when there was
36only silence or background noise. This is also expressed as a percentage of the
37total number of ground truth sounds, though since it can be large it may go
38above 100%.
39
40The easiest way to get an audio file and labels to test with is by using the
41'generate_streaming_test_wav' script. This will synthesize a test file with
42randomly placed sounds and background noise, and output a text file with the
43ground truth.
44
45If you want to test natural data, you need to use a .wav with the same sample
46rate as your model (often 16,000 samples per second), and note down where the
47sounds occur in time. Save this information out as a comma-separated text file,
48where the first column is the label and the second is the time in seconds from
49the start of the file that it occurs.
50
51Here's an example of how to run the tool:
52
53bazel run tensorflow/examples/speech_commands:test_streaming_accuracy_py -- \
54--wav=/tmp/streaming_test_bg.wav \
55--ground-truth=/tmp/streaming_test_labels.txt --verbose \
56--model=/tmp/conv_frozen.pb \
57--labels=/tmp/speech_commands_train/conv_labels.txt \
58--clip_duration_ms=1000 --detection_threshold=0.70 --average_window_ms=500 \
59--suppression_ms=500 --time_tolerance_ms=1500
60"""
61
62from __future__ import absolute_import
63from __future__ import division
64from __future__ import print_function
65
66import argparse
67import sys
68
69import numpy
70import tensorflow as tf
71
72from accuracy_utils import StreamingAccuracyStats
73from recognize_commands import RecognizeCommands
74from recognize_commands import RecognizeResult
75from tensorflow.python.ops import io_ops
76
77FLAGS = None
78
79
80def load_graph(mode_file):
81  """Read a tensorflow model, and creates a default graph object."""
82  graph = tf.Graph()
83  with graph.as_default():
84    od_graph_def = tf.compat.v1.GraphDef()
85    with tf.io.gfile.GFile(mode_file, 'rb') as fid:
86      serialized_graph = fid.read()
87      od_graph_def.ParseFromString(serialized_graph)
88      tf.import_graph_def(od_graph_def, name='')
89  return graph
90
91
92def read_label_file(file_name):
93  """Load a list of label."""
94  label_list = []
95  with open(file_name, 'r') as f:
96    for line in f:
97      label_list.append(line.strip())
98  return label_list
99
100
101def read_wav_file(filename):
102  """Load a wav file and return sample_rate and numpy data of float64 type."""
103  with tf.compat.v1.Session(graph=tf.Graph()) as sess:
104    wav_filename_placeholder = tf.compat.v1.placeholder(tf.string, [])
105    wav_loader = io_ops.read_file(wav_filename_placeholder)
106    wav_decoder = tf.audio.decode_wav(wav_loader, desired_channels=1)
107    res = sess.run(wav_decoder, feed_dict={wav_filename_placeholder: filename})
108  return res.sample_rate, res.audio.flatten()
109
110
111def main(_):
112  label_list = read_label_file(FLAGS.labels)
113  sample_rate, data = read_wav_file(FLAGS.wav)
114  # Init instance of RecognizeCommands with given parameters.
115  recognize_commands = RecognizeCommands(
116      labels=label_list,
117      average_window_duration_ms=FLAGS.average_window_duration_ms,
118      detection_threshold=FLAGS.detection_threshold,
119      suppression_ms=FLAGS.suppression_ms,
120      minimum_count=4)
121
122  # Init instance of StreamingAccuracyStats and load ground truth.
123  stats = StreamingAccuracyStats()
124  stats.read_ground_truth_file(FLAGS.ground_truth)
125  recognize_element = RecognizeResult()
126  all_found_words = []
127  data_samples = data.shape[0]
128  clip_duration_samples = int(FLAGS.clip_duration_ms * sample_rate / 1000)
129  clip_stride_samples = int(FLAGS.clip_stride_ms * sample_rate / 1000)
130  audio_data_end = data_samples - clip_duration_samples
131
132  # Load model and create a tf session to process audio pieces
133  recognize_graph = load_graph(FLAGS.model)
134  with recognize_graph.as_default():
135    with tf.compat.v1.Session() as sess:
136
137      # Get input and output tensor
138      data_tensor = sess.graph.get_tensor_by_name(FLAGS.input_names[0])
139      sample_rate_tensor = sess.graph.get_tensor_by_name(FLAGS.input_names[1])
140      output_softmax_tensor = sess.graph.get_tensor_by_name(FLAGS.output_name)
141
142      # Inference along audio stream.
143      for audio_data_offset in range(0, audio_data_end, clip_stride_samples):
144        input_start = audio_data_offset
145        input_end = audio_data_offset + clip_duration_samples
146        outputs = sess.run(
147            output_softmax_tensor,
148            feed_dict={
149                data_tensor:
150                    numpy.expand_dims(data[input_start:input_end], axis=-1),
151                sample_rate_tensor:
152                    sample_rate
153            })
154        outputs = numpy.squeeze(outputs)
155        current_time_ms = int(audio_data_offset * 1000 / sample_rate)
156        try:
157          recognize_commands.process_latest_result(outputs, current_time_ms,
158                                                   recognize_element)
159        except ValueError as e:
160          tf.compat.v1.logging.error('Recognition processing failed: {}' % e)
161          return
162        if (recognize_element.is_new_command and
163            recognize_element.founded_command != '_silence_'):
164          all_found_words.append(
165              [recognize_element.founded_command, current_time_ms])
166          if FLAGS.verbose:
167            stats.calculate_accuracy_stats(all_found_words, current_time_ms,
168                                           FLAGS.time_tolerance_ms)
169            try:
170              recognition_state = stats.delta()
171            except ValueError as e:
172              tf.compat.v1.logging.error(
173                  'Statistics delta computing failed: {}'.format(e))
174            else:
175              tf.compat.v1.logging.info('{}ms {}:{}{}'.format(
176                  current_time_ms, recognize_element.founded_command,
177                  recognize_element.score, recognition_state))
178              stats.print_accuracy_stats()
179  stats.calculate_accuracy_stats(all_found_words, -1, FLAGS.time_tolerance_ms)
180  stats.print_accuracy_stats()
181
182
183if __name__ == '__main__':
184  parser = argparse.ArgumentParser(description='test_streaming_accuracy')
185  parser.add_argument(
186      '--wav', type=str, default='', help='The wave file path to evaluate.')
187  parser.add_argument(
188      '--ground-truth',
189      type=str,
190      default='',
191      help='The ground truth file path corresponding to wav file.')
192  parser.add_argument(
193      '--labels',
194      type=str,
195      default='',
196      help='The label file path containing all possible classes.')
197  parser.add_argument(
198      '--model', type=str, default='', help='The model used for inference')
199  parser.add_argument(
200      '--input-names',
201      type=str,
202      nargs='+',
203      default=['decoded_sample_data:0', 'decoded_sample_data:1'],
204      help='Input name list involved in model graph.')
205  parser.add_argument(
206      '--output-name',
207      type=str,
208      default='labels_softmax:0',
209      help='Output name involved in model graph.')
210  parser.add_argument(
211      '--clip-duration-ms',
212      type=int,
213      default=1000,
214      help='Length of each audio clip fed into model.')
215  parser.add_argument(
216      '--clip-stride-ms',
217      type=int,
218      default=30,
219      help='Length of audio clip stride over main trap.')
220  parser.add_argument(
221      '--average_window_duration_ms',
222      type=int,
223      default=500,
224      help='Length of average window used for smoothing results.')
225  parser.add_argument(
226      '--detection-threshold',
227      type=float,
228      default=0.7,
229      help='The confidence for filtering unreliable commands')
230  parser.add_argument(
231      '--suppression_ms',
232      type=int,
233      default=500,
234      help='The time interval between every two adjacent commands')
235  parser.add_argument(
236      '--time-tolerance-ms',
237      type=int,
238      default=1500,
239      help='Time tolerance before and after the timestamp of this audio clip '
240      'to match ground truth')
241  parser.add_argument(
242      '--verbose',
243      action='store_true',
244      default=False,
245      help='Whether to print streaming accuracy on stdout.')
246
247  FLAGS, unparsed = parser.parse_known_args()
248  tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
249  tf.compat.v1.app.run(main=main, argv=[sys.argv[0]] + unparsed)
250