1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15r"""Converts a trained checkpoint into a frozen model for mobile inference. 16 17Once you've trained a model using the `train.py` script, you can use this tool 18to convert it into a binary GraphDef file that can be loaded into the Android, 19iOS, or Raspberry Pi example code. Here's an example of how to run it: 20 21bazel run tensorflow/examples/speech_commands/freeze -- \ 22--sample_rate=16000 --dct_coefficient_count=40 --window_size_ms=20 \ 23--window_stride_ms=10 --clip_duration_ms=1000 \ 24--model_architecture=conv \ 25--start_checkpoint=/tmp/speech_commands_train/conv.ckpt-1300 \ 26--output_file=/tmp/my_frozen_graph.pb 27 28One thing to watch out for is that you need to pass in the same arguments for 29`sample_rate` and other command line variables here as you did for the training 30script. 31 32The resulting graph has an input for WAV-encoded data named 'wav_data', one for 33raw PCM data (as floats in the range -1.0 to 1.0) called 'decoded_sample_data', 34and the output is called 'labels_softmax'. 35 36""" 37from __future__ import absolute_import 38from __future__ import division 39from __future__ import print_function 40 41import argparse 42import os.path 43import sys 44 45import tensorflow as tf 46 47import input_data 48import models 49from tensorflow.python.framework import graph_util 50from tensorflow.python.ops import gen_audio_ops as audio_ops 51 52# If it's available, load the specialized feature generator. If this doesn't 53# work, try building with bazel instead of running the Python script directly. 54# bazel run tensorflow/examples/speech_commands:freeze_graph 55try: 56 from tensorflow.lite.experimental.microfrontend.python.ops import audio_microfrontend_op as frontend_op # pylint:disable=g-import-not-at-top 57except ImportError: 58 frontend_op = None 59 60FLAGS = None 61 62 63def create_inference_graph(wanted_words, sample_rate, clip_duration_ms, 64 clip_stride_ms, window_size_ms, window_stride_ms, 65 feature_bin_count, model_architecture, preprocess): 66 """Creates an audio model with the nodes needed for inference. 67 68 Uses the supplied arguments to create a model, and inserts the input and 69 output nodes that are needed to use the graph for inference. 70 71 Args: 72 wanted_words: Comma-separated list of the words we're trying to recognize. 73 sample_rate: How many samples per second are in the input audio files. 74 clip_duration_ms: How many samples to analyze for the audio pattern. 75 clip_stride_ms: How often to run recognition. Useful for models with cache. 76 window_size_ms: Time slice duration to estimate frequencies from. 77 window_stride_ms: How far apart time slices should be. 78 feature_bin_count: Number of frequency bands to analyze. 79 model_architecture: Name of the kind of model to generate. 80 preprocess: How the spectrogram is processed to produce features, for 81 example 'mfcc', 'average', or 'micro'. 82 83 Returns: 84 Input and output tensor objects. 85 86 Raises: 87 Exception: If the preprocessing mode isn't recognized. 88 """ 89 90 words_list = input_data.prepare_words_list(wanted_words.split(',')) 91 model_settings = models.prepare_model_settings( 92 len(words_list), sample_rate, clip_duration_ms, window_size_ms, 93 window_stride_ms, feature_bin_count, preprocess) 94 runtime_settings = {'clip_stride_ms': clip_stride_ms} 95 96 wav_data_placeholder = tf.compat.v1.placeholder(tf.string, [], 97 name='wav_data') 98 decoded_sample_data = tf.audio.decode_wav( 99 wav_data_placeholder, 100 desired_channels=1, 101 desired_samples=model_settings['desired_samples'], 102 name='decoded_sample_data') 103 spectrogram = audio_ops.audio_spectrogram( 104 decoded_sample_data.audio, 105 window_size=model_settings['window_size_samples'], 106 stride=model_settings['window_stride_samples'], 107 magnitude_squared=True) 108 109 if preprocess == 'average': 110 fingerprint_input = tf.nn.pool( 111 input=tf.expand_dims(spectrogram, -1), 112 window_shape=[1, model_settings['average_window_width']], 113 strides=[1, model_settings['average_window_width']], 114 pooling_type='AVG', 115 padding='SAME') 116 elif preprocess == 'mfcc': 117 fingerprint_input = audio_ops.mfcc( 118 spectrogram, 119 sample_rate, 120 dct_coefficient_count=model_settings['fingerprint_width']) 121 elif preprocess == 'micro': 122 if not frontend_op: 123 raise Exception( 124 'Micro frontend op is currently not available when running TensorFlow' 125 ' directly from Python, you need to build and run through Bazel, for' 126 ' example' 127 ' `bazel run tensorflow/examples/speech_commands:freeze_graph`') 128 sample_rate = model_settings['sample_rate'] 129 window_size_ms = (model_settings['window_size_samples'] * 130 1000) / sample_rate 131 window_step_ms = (model_settings['window_stride_samples'] * 132 1000) / sample_rate 133 int16_input = tf.cast( 134 tf.multiply(decoded_sample_data.audio, 32767), tf.int16) 135 micro_frontend = frontend_op.audio_microfrontend( 136 int16_input, 137 sample_rate=sample_rate, 138 window_size=window_size_ms, 139 window_step=window_step_ms, 140 num_channels=model_settings['fingerprint_width'], 141 out_scale=1, 142 out_type=tf.float32) 143 fingerprint_input = tf.multiply(micro_frontend, (10.0 / 256.0)) 144 else: 145 raise Exception('Unknown preprocess mode "%s" (should be "mfcc",' 146 ' "average", or "micro")' % (preprocess)) 147 148 fingerprint_size = model_settings['fingerprint_size'] 149 reshaped_input = tf.reshape(fingerprint_input, [-1, fingerprint_size]) 150 151 logits = models.create_model( 152 reshaped_input, model_settings, model_architecture, is_training=False, 153 runtime_settings=runtime_settings) 154 155 # Create an output to use for inference. 156 softmax = tf.nn.softmax(logits, name='labels_softmax') 157 158 return reshaped_input, softmax 159 160 161def save_graph_def(file_name, frozen_graph_def): 162 """Writes a graph def file out to disk. 163 164 Args: 165 file_name: Where to save the file. 166 frozen_graph_def: GraphDef proto object to save. 167 """ 168 tf.io.write_graph( 169 frozen_graph_def, 170 os.path.dirname(file_name), 171 os.path.basename(file_name), 172 as_text=False) 173 tf.compat.v1.logging.info('Saved frozen graph to %s', file_name) 174 175 176def save_saved_model(file_name, sess, input_tensor, output_tensor): 177 """Writes a SavedModel out to disk. 178 179 Args: 180 file_name: Where to save the file. 181 sess: TensorFlow session containing the graph. 182 input_tensor: Tensor object defining the input's properties. 183 output_tensor: Tensor object defining the output's properties. 184 """ 185 # Store the frozen graph as a SavedModel for v2 compatibility. 186 builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(file_name) 187 tensor_info_inputs = { 188 'input': tf.compat.v1.saved_model.utils.build_tensor_info(input_tensor) 189 } 190 tensor_info_outputs = { 191 'output': tf.compat.v1.saved_model.utils.build_tensor_info(output_tensor) 192 } 193 signature = ( 194 tf.compat.v1.saved_model.signature_def_utils.build_signature_def( 195 inputs=tensor_info_inputs, 196 outputs=tensor_info_outputs, 197 method_name=tf.compat.v1.saved_model.signature_constants 198 .PREDICT_METHOD_NAME)) 199 builder.add_meta_graph_and_variables( 200 sess, 201 [tf.compat.v1.saved_model.tag_constants.SERVING], 202 signature_def_map={ 203 tf.compat.v1.saved_model.signature_constants 204 .DEFAULT_SERVING_SIGNATURE_DEF_KEY: 205 signature, 206 }, 207 ) 208 builder.save() 209 210 211def main(_): 212 if FLAGS.quantize: 213 try: 214 _ = tf.contrib 215 except AttributeError as e: 216 msg = e.args[0] 217 msg += ('\n\n The --quantize option still requires contrib, which is not ' 218 'part of TensorFlow 2.0. Please install a previous version:' 219 '\n `pip install tensorflow<=1.15`') 220 e.args = (msg,) 221 raise e 222 223 # Create the model and load its weights. 224 sess = tf.compat.v1.InteractiveSession() 225 input_tensor, output_tensor = create_inference_graph( 226 FLAGS.wanted_words, FLAGS.sample_rate, FLAGS.clip_duration_ms, 227 FLAGS.clip_stride_ms, FLAGS.window_size_ms, FLAGS.window_stride_ms, 228 FLAGS.feature_bin_count, FLAGS.model_architecture, FLAGS.preprocess) 229 if FLAGS.quantize: 230 tf.contrib.quantize.create_eval_graph() 231 models.load_variables_from_checkpoint(sess, FLAGS.start_checkpoint) 232 233 # Turn all the variables into inline constants inside the graph and save it. 234 frozen_graph_def = graph_util.convert_variables_to_constants( 235 sess, sess.graph_def, ['labels_softmax']) 236 237 if FLAGS.save_format == 'graph_def': 238 save_graph_def(FLAGS.output_file, frozen_graph_def) 239 elif FLAGS.save_format == 'saved_model': 240 save_saved_model(FLAGS.output_file, sess, input_tensor, output_tensor) 241 else: 242 raise Exception('Unknown save format "%s" (should be "graph_def" or' 243 ' "saved_model")' % (FLAGS.save_format)) 244 245 246if __name__ == '__main__': 247 parser = argparse.ArgumentParser() 248 parser.add_argument( 249 '--sample_rate', 250 type=int, 251 default=16000, 252 help='Expected sample rate of the wavs',) 253 parser.add_argument( 254 '--clip_duration_ms', 255 type=int, 256 default=1000, 257 help='Expected duration in milliseconds of the wavs',) 258 parser.add_argument( 259 '--clip_stride_ms', 260 type=int, 261 default=30, 262 help='How often to run recognition. Useful for models with cache.',) 263 parser.add_argument( 264 '--window_size_ms', 265 type=float, 266 default=30.0, 267 help='How long each spectrogram timeslice is',) 268 parser.add_argument( 269 '--window_stride_ms', 270 type=float, 271 default=10.0, 272 help='How long the stride is between spectrogram timeslices',) 273 parser.add_argument( 274 '--feature_bin_count', 275 type=int, 276 default=40, 277 help='How many bins to use for the MFCC fingerprint', 278 ) 279 parser.add_argument( 280 '--start_checkpoint', 281 type=str, 282 default='', 283 help='If specified, restore this pretrained model before any training.') 284 parser.add_argument( 285 '--model_architecture', 286 type=str, 287 default='conv', 288 help='What model architecture to use') 289 parser.add_argument( 290 '--wanted_words', 291 type=str, 292 default='yes,no,up,down,left,right,on,off,stop,go', 293 help='Words to use (others will be added to an unknown label)',) 294 parser.add_argument( 295 '--output_file', type=str, help='Where to save the frozen graph.') 296 parser.add_argument( 297 '--quantize', 298 type=bool, 299 default=False, 300 help='Whether to train the model for eight-bit deployment') 301 parser.add_argument( 302 '--preprocess', 303 type=str, 304 default='mfcc', 305 help='Spectrogram processing mode. Can be "mfcc" or "average"') 306 parser.add_argument( 307 '--save_format', 308 type=str, 309 default='graph_def', 310 help='How to save the result. Can be "graph_def" or "saved_model"') 311 FLAGS, unparsed = parser.parse_known_args() 312 tf.compat.v1.app.run(main=main, argv=[sys.argv[0]] + unparsed) 313