1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15r"""Converts a trained checkpoint into a frozen model for mobile inference.
16
17Once you've trained a model using the `train.py` script, you can use this tool
18to convert it into a binary GraphDef file that can be loaded into the Android,
19iOS, or Raspberry Pi example code. Here's an example of how to run it:
20
21bazel run tensorflow/examples/speech_commands/freeze -- \
22--sample_rate=16000 --dct_coefficient_count=40 --window_size_ms=20 \
23--window_stride_ms=10 --clip_duration_ms=1000 \
24--model_architecture=conv \
25--start_checkpoint=/tmp/speech_commands_train/conv.ckpt-1300 \
26--output_file=/tmp/my_frozen_graph.pb
27
28One thing to watch out for is that you need to pass in the same arguments for
29`sample_rate` and other command line variables here as you did for the training
30script.
31
32The resulting graph has an input for WAV-encoded data named 'wav_data', one for
33raw PCM data (as floats in the range -1.0 to 1.0) called 'decoded_sample_data',
34and the output is called 'labels_softmax'.
35
36"""
37from __future__ import absolute_import
38from __future__ import division
39from __future__ import print_function
40
41import argparse
42import os.path
43import sys
44
45import tensorflow as tf
46
47import input_data
48import models
49from tensorflow.python.framework import graph_util
50from tensorflow.python.ops import gen_audio_ops as audio_ops
51
52# If it's available, load the specialized feature generator. If this doesn't
53# work, try building with bazel instead of running the Python script directly.
54# bazel run tensorflow/examples/speech_commands:freeze_graph
55try:
56  from tensorflow.lite.experimental.microfrontend.python.ops import audio_microfrontend_op as frontend_op  # pylint:disable=g-import-not-at-top
57except ImportError:
58  frontend_op = None
59
60FLAGS = None
61
62
63def create_inference_graph(wanted_words, sample_rate, clip_duration_ms,
64                           clip_stride_ms, window_size_ms, window_stride_ms,
65                           feature_bin_count, model_architecture, preprocess):
66  """Creates an audio model with the nodes needed for inference.
67
68  Uses the supplied arguments to create a model, and inserts the input and
69  output nodes that are needed to use the graph for inference.
70
71  Args:
72    wanted_words: Comma-separated list of the words we're trying to recognize.
73    sample_rate: How many samples per second are in the input audio files.
74    clip_duration_ms: How many samples to analyze for the audio pattern.
75    clip_stride_ms: How often to run recognition. Useful for models with cache.
76    window_size_ms: Time slice duration to estimate frequencies from.
77    window_stride_ms: How far apart time slices should be.
78    feature_bin_count: Number of frequency bands to analyze.
79    model_architecture: Name of the kind of model to generate.
80    preprocess: How the spectrogram is processed to produce features, for
81      example 'mfcc', 'average', or 'micro'.
82
83  Returns:
84    Input and output tensor objects.
85
86  Raises:
87    Exception: If the preprocessing mode isn't recognized.
88  """
89
90  words_list = input_data.prepare_words_list(wanted_words.split(','))
91  model_settings = models.prepare_model_settings(
92      len(words_list), sample_rate, clip_duration_ms, window_size_ms,
93      window_stride_ms, feature_bin_count, preprocess)
94  runtime_settings = {'clip_stride_ms': clip_stride_ms}
95
96  wav_data_placeholder = tf.compat.v1.placeholder(tf.string, [],
97                                                  name='wav_data')
98  decoded_sample_data = tf.audio.decode_wav(
99      wav_data_placeholder,
100      desired_channels=1,
101      desired_samples=model_settings['desired_samples'],
102      name='decoded_sample_data')
103  spectrogram = audio_ops.audio_spectrogram(
104      decoded_sample_data.audio,
105      window_size=model_settings['window_size_samples'],
106      stride=model_settings['window_stride_samples'],
107      magnitude_squared=True)
108
109  if preprocess == 'average':
110    fingerprint_input = tf.nn.pool(
111        input=tf.expand_dims(spectrogram, -1),
112        window_shape=[1, model_settings['average_window_width']],
113        strides=[1, model_settings['average_window_width']],
114        pooling_type='AVG',
115        padding='SAME')
116  elif preprocess == 'mfcc':
117    fingerprint_input = audio_ops.mfcc(
118        spectrogram,
119        sample_rate,
120        dct_coefficient_count=model_settings['fingerprint_width'])
121  elif preprocess == 'micro':
122    if not frontend_op:
123      raise Exception(
124          'Micro frontend op is currently not available when running TensorFlow'
125          ' directly from Python, you need to build and run through Bazel, for'
126          ' example'
127          ' `bazel run tensorflow/examples/speech_commands:freeze_graph`')
128    sample_rate = model_settings['sample_rate']
129    window_size_ms = (model_settings['window_size_samples'] *
130                      1000) / sample_rate
131    window_step_ms = (model_settings['window_stride_samples'] *
132                      1000) / sample_rate
133    int16_input = tf.cast(
134        tf.multiply(decoded_sample_data.audio, 32767), tf.int16)
135    micro_frontend = frontend_op.audio_microfrontend(
136        int16_input,
137        sample_rate=sample_rate,
138        window_size=window_size_ms,
139        window_step=window_step_ms,
140        num_channels=model_settings['fingerprint_width'],
141        out_scale=1,
142        out_type=tf.float32)
143    fingerprint_input = tf.multiply(micro_frontend, (10.0 / 256.0))
144  else:
145    raise Exception('Unknown preprocess mode "%s" (should be "mfcc",'
146                    ' "average", or "micro")' % (preprocess))
147
148  fingerprint_size = model_settings['fingerprint_size']
149  reshaped_input = tf.reshape(fingerprint_input, [-1, fingerprint_size])
150
151  logits = models.create_model(
152      reshaped_input, model_settings, model_architecture, is_training=False,
153      runtime_settings=runtime_settings)
154
155  # Create an output to use for inference.
156  softmax = tf.nn.softmax(logits, name='labels_softmax')
157
158  return reshaped_input, softmax
159
160
161def save_graph_def(file_name, frozen_graph_def):
162  """Writes a graph def file out to disk.
163
164  Args:
165    file_name: Where to save the file.
166    frozen_graph_def: GraphDef proto object to save.
167  """
168  tf.io.write_graph(
169      frozen_graph_def,
170      os.path.dirname(file_name),
171      os.path.basename(file_name),
172      as_text=False)
173  tf.compat.v1.logging.info('Saved frozen graph to %s', file_name)
174
175
176def save_saved_model(file_name, sess, input_tensor, output_tensor):
177  """Writes a SavedModel out to disk.
178
179  Args:
180    file_name: Where to save the file.
181    sess: TensorFlow session containing the graph.
182    input_tensor: Tensor object defining the input's properties.
183    output_tensor: Tensor object defining the output's properties.
184  """
185  # Store the frozen graph as a SavedModel for v2 compatibility.
186  builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(file_name)
187  tensor_info_inputs = {
188      'input': tf.compat.v1.saved_model.utils.build_tensor_info(input_tensor)
189  }
190  tensor_info_outputs = {
191      'output': tf.compat.v1.saved_model.utils.build_tensor_info(output_tensor)
192  }
193  signature = (
194      tf.compat.v1.saved_model.signature_def_utils.build_signature_def(
195          inputs=tensor_info_inputs,
196          outputs=tensor_info_outputs,
197          method_name=tf.compat.v1.saved_model.signature_constants
198          .PREDICT_METHOD_NAME))
199  builder.add_meta_graph_and_variables(
200      sess,
201      [tf.compat.v1.saved_model.tag_constants.SERVING],
202      signature_def_map={
203          tf.compat.v1.saved_model.signature_constants
204          .DEFAULT_SERVING_SIGNATURE_DEF_KEY:
205              signature,
206      },
207  )
208  builder.save()
209
210
211def main(_):
212  if FLAGS.quantize:
213    try:
214      _ = tf.contrib
215    except AttributeError as e:
216      msg = e.args[0]
217      msg += ('\n\n The --quantize option still requires contrib, which is not '
218              'part of TensorFlow 2.0. Please install a previous version:'
219              '\n    `pip install tensorflow<=1.15`')
220      e.args = (msg,)
221      raise e
222
223  # Create the model and load its weights.
224  sess = tf.compat.v1.InteractiveSession()
225  input_tensor, output_tensor = create_inference_graph(
226      FLAGS.wanted_words, FLAGS.sample_rate, FLAGS.clip_duration_ms,
227      FLAGS.clip_stride_ms, FLAGS.window_size_ms, FLAGS.window_stride_ms,
228      FLAGS.feature_bin_count, FLAGS.model_architecture, FLAGS.preprocess)
229  if FLAGS.quantize:
230    tf.contrib.quantize.create_eval_graph()
231  models.load_variables_from_checkpoint(sess, FLAGS.start_checkpoint)
232
233  # Turn all the variables into inline constants inside the graph and save it.
234  frozen_graph_def = graph_util.convert_variables_to_constants(
235      sess, sess.graph_def, ['labels_softmax'])
236
237  if FLAGS.save_format == 'graph_def':
238    save_graph_def(FLAGS.output_file, frozen_graph_def)
239  elif FLAGS.save_format == 'saved_model':
240    save_saved_model(FLAGS.output_file, sess, input_tensor, output_tensor)
241  else:
242    raise Exception('Unknown save format "%s" (should be "graph_def" or'
243                    ' "saved_model")' % (FLAGS.save_format))
244
245
246if __name__ == '__main__':
247  parser = argparse.ArgumentParser()
248  parser.add_argument(
249      '--sample_rate',
250      type=int,
251      default=16000,
252      help='Expected sample rate of the wavs',)
253  parser.add_argument(
254      '--clip_duration_ms',
255      type=int,
256      default=1000,
257      help='Expected duration in milliseconds of the wavs',)
258  parser.add_argument(
259      '--clip_stride_ms',
260      type=int,
261      default=30,
262      help='How often to run recognition. Useful for models with cache.',)
263  parser.add_argument(
264      '--window_size_ms',
265      type=float,
266      default=30.0,
267      help='How long each spectrogram timeslice is',)
268  parser.add_argument(
269      '--window_stride_ms',
270      type=float,
271      default=10.0,
272      help='How long the stride is between spectrogram timeslices',)
273  parser.add_argument(
274      '--feature_bin_count',
275      type=int,
276      default=40,
277      help='How many bins to use for the MFCC fingerprint',
278  )
279  parser.add_argument(
280      '--start_checkpoint',
281      type=str,
282      default='',
283      help='If specified, restore this pretrained model before any training.')
284  parser.add_argument(
285      '--model_architecture',
286      type=str,
287      default='conv',
288      help='What model architecture to use')
289  parser.add_argument(
290      '--wanted_words',
291      type=str,
292      default='yes,no,up,down,left,right,on,off,stop,go',
293      help='Words to use (others will be added to an unknown label)',)
294  parser.add_argument(
295      '--output_file', type=str, help='Where to save the frozen graph.')
296  parser.add_argument(
297      '--quantize',
298      type=bool,
299      default=False,
300      help='Whether to train the model for eight-bit deployment')
301  parser.add_argument(
302      '--preprocess',
303      type=str,
304      default='mfcc',
305      help='Spectrogram processing mode. Can be "mfcc" or "average"')
306  parser.add_argument(
307      '--save_format',
308      type=str,
309      default='graph_def',
310      help='How to save the result. Can be "graph_def" or "saved_model"')
311  FLAGS, unparsed = parser.parse_known_args()
312  tf.compat.v1.app.run(main=main, argv=[sys.argv[0]] + unparsed)
313