1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Model definitions for simple speech recognition.
16
17"""
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import math
23
24import tensorflow as tf
25
26
27def _next_power_of_two(x):
28  """Calculates the smallest enclosing power of two for an input.
29
30  Args:
31    x: Positive float or integer number.
32
33  Returns:
34    Next largest power of two integer.
35  """
36  return 1 if x == 0 else 2**(int(x) - 1).bit_length()
37
38
39def prepare_model_settings(label_count, sample_rate, clip_duration_ms,
40                           window_size_ms, window_stride_ms, feature_bin_count,
41                           preprocess):
42  """Calculates common settings needed for all models.
43
44  Args:
45    label_count: How many classes are to be recognized.
46    sample_rate: Number of audio samples per second.
47    clip_duration_ms: Length of each audio clip to be analyzed.
48    window_size_ms: Duration of frequency analysis window.
49    window_stride_ms: How far to move in time between frequency windows.
50    feature_bin_count: Number of frequency bins to use for analysis.
51    preprocess: How the spectrogram is processed to produce features.
52
53  Returns:
54    Dictionary containing common settings.
55
56  Raises:
57    ValueError: If the preprocessing mode isn't recognized.
58  """
59  desired_samples = int(sample_rate * clip_duration_ms / 1000)
60  window_size_samples = int(sample_rate * window_size_ms / 1000)
61  window_stride_samples = int(sample_rate * window_stride_ms / 1000)
62  length_minus_window = (desired_samples - window_size_samples)
63  if length_minus_window < 0:
64    spectrogram_length = 0
65  else:
66    spectrogram_length = 1 + int(length_minus_window / window_stride_samples)
67  if preprocess == 'average':
68    fft_bin_count = 1 + (_next_power_of_two(window_size_samples) / 2)
69    average_window_width = int(math.floor(fft_bin_count / feature_bin_count))
70    fingerprint_width = int(math.ceil(fft_bin_count / average_window_width))
71  elif preprocess == 'mfcc':
72    average_window_width = -1
73    fingerprint_width = feature_bin_count
74  elif preprocess == 'micro':
75    average_window_width = -1
76    fingerprint_width = feature_bin_count
77  else:
78    raise ValueError('Unknown preprocess mode "%s" (should be "mfcc",'
79                     ' "average", or "micro")' % (preprocess))
80  fingerprint_size = fingerprint_width * spectrogram_length
81  return {
82      'desired_samples': desired_samples,
83      'window_size_samples': window_size_samples,
84      'window_stride_samples': window_stride_samples,
85      'spectrogram_length': spectrogram_length,
86      'fingerprint_width': fingerprint_width,
87      'fingerprint_size': fingerprint_size,
88      'label_count': label_count,
89      'sample_rate': sample_rate,
90      'preprocess': preprocess,
91      'average_window_width': average_window_width,
92  }
93
94
95def create_model(fingerprint_input, model_settings, model_architecture,
96                 is_training, runtime_settings=None):
97  """Builds a model of the requested architecture compatible with the settings.
98
99  There are many possible ways of deriving predictions from a spectrogram
100  input, so this function provides an abstract interface for creating different
101  kinds of models in a black-box way. You need to pass in a TensorFlow node as
102  the 'fingerprint' input, and this should output a batch of 1D features that
103  describe the audio. Typically this will be derived from a spectrogram that's
104  been run through an MFCC, but in theory it can be any feature vector of the
105  size specified in model_settings['fingerprint_size'].
106
107  The function will build the graph it needs in the current TensorFlow graph,
108  and return the tensorflow output that will contain the 'logits' input to the
109  softmax prediction process. If training flag is on, it will also return a
110  placeholder node that can be used to control the dropout amount.
111
112  See the implementations below for the possible model architectures that can be
113  requested.
114
115  Args:
116    fingerprint_input: TensorFlow node that will output audio feature vectors.
117    model_settings: Dictionary of information about the model.
118    model_architecture: String specifying which kind of model to create.
119    is_training: Whether the model is going to be used for training.
120    runtime_settings: Dictionary of information about the runtime.
121
122  Returns:
123    TensorFlow node outputting logits results, and optionally a dropout
124    placeholder.
125
126  Raises:
127    Exception: If the architecture type isn't recognized.
128  """
129  if model_architecture == 'single_fc':
130    return create_single_fc_model(fingerprint_input, model_settings,
131                                  is_training)
132  elif model_architecture == 'conv':
133    return create_conv_model(fingerprint_input, model_settings, is_training)
134  elif model_architecture == 'low_latency_conv':
135    return create_low_latency_conv_model(fingerprint_input, model_settings,
136                                         is_training)
137  elif model_architecture == 'low_latency_svdf':
138    return create_low_latency_svdf_model(fingerprint_input, model_settings,
139                                         is_training, runtime_settings)
140  elif model_architecture == 'tiny_conv':
141    return create_tiny_conv_model(fingerprint_input, model_settings,
142                                  is_training)
143  elif model_architecture == 'tiny_embedding_conv':
144    return create_tiny_embedding_conv_model(fingerprint_input, model_settings,
145                                            is_training)
146  else:
147    raise Exception('model_architecture argument "' + model_architecture +
148                    '" not recognized, should be one of "single_fc", "conv",' +
149                    ' "low_latency_conv, "low_latency_svdf",' +
150                    ' "tiny_conv", or "tiny_embedding_conv"')
151
152
153def load_variables_from_checkpoint(sess, start_checkpoint):
154  """Utility function to centralize checkpoint restoration.
155
156  Args:
157    sess: TensorFlow session.
158    start_checkpoint: Path to saved checkpoint on disk.
159  """
160  saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())
161  saver.restore(sess, start_checkpoint)
162
163
164def create_single_fc_model(fingerprint_input, model_settings, is_training):
165  """Builds a model with a single hidden fully-connected layer.
166
167  This is a very simple model with just one matmul and bias layer. As you'd
168  expect, it doesn't produce very accurate results, but it is very fast and
169  simple, so it's useful for sanity testing.
170
171  Here's the layout of the graph:
172
173  (fingerprint_input)
174          v
175      [MatMul]<-(weights)
176          v
177      [BiasAdd]<-(bias)
178          v
179
180  Args:
181    fingerprint_input: TensorFlow node that will output audio feature vectors.
182    model_settings: Dictionary of information about the model.
183    is_training: Whether the model is going to be used for training.
184
185  Returns:
186    TensorFlow node outputting logits results, and optionally a dropout
187    placeholder.
188  """
189  if is_training:
190    dropout_rate = tf.compat.v1.placeholder(tf.float32, name='dropout_rate')
191  fingerprint_size = model_settings['fingerprint_size']
192  label_count = model_settings['label_count']
193  weights = tf.compat.v1.get_variable(
194      name='weights',
195      initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.001),
196      shape=[fingerprint_size, label_count])
197  bias = tf.compat.v1.get_variable(name='bias',
198                                   initializer=tf.compat.v1.zeros_initializer,
199                                   shape=[label_count])
200  logits = tf.matmul(fingerprint_input, weights) + bias
201  if is_training:
202    return logits, dropout_rate
203  else:
204    return logits
205
206
207def create_conv_model(fingerprint_input, model_settings, is_training):
208  """Builds a standard convolutional model.
209
210  This is roughly the network labeled as 'cnn-trad-fpool3' in the
211  'Convolutional Neural Networks for Small-footprint Keyword Spotting' paper:
212  http://www.isca-speech.org/archive/interspeech_2015/papers/i15_1478.pdf
213
214  Here's the layout of the graph:
215
216  (fingerprint_input)
217          v
218      [Conv2D]<-(weights)
219          v
220      [BiasAdd]<-(bias)
221          v
222        [Relu]
223          v
224      [MaxPool]
225          v
226      [Conv2D]<-(weights)
227          v
228      [BiasAdd]<-(bias)
229          v
230        [Relu]
231          v
232      [MaxPool]
233          v
234      [MatMul]<-(weights)
235          v
236      [BiasAdd]<-(bias)
237          v
238
239  This produces fairly good quality results, but can involve a large number of
240  weight parameters and computations. For a cheaper alternative from the same
241  paper with slightly less accuracy, see 'low_latency_conv' below.
242
243  During training, dropout nodes are introduced after each relu, controlled by a
244  placeholder.
245
246  Args:
247    fingerprint_input: TensorFlow node that will output audio feature vectors.
248    model_settings: Dictionary of information about the model.
249    is_training: Whether the model is going to be used for training.
250
251  Returns:
252    TensorFlow node outputting logits results, and optionally a dropout
253    placeholder.
254  """
255  if is_training:
256    dropout_rate = tf.compat.v1.placeholder(tf.float32, name='dropout_rate')
257  input_frequency_size = model_settings['fingerprint_width']
258  input_time_size = model_settings['spectrogram_length']
259  fingerprint_4d = tf.reshape(fingerprint_input,
260                              [-1, input_time_size, input_frequency_size, 1])
261  first_filter_width = 8
262  first_filter_height = 20
263  first_filter_count = 64
264  first_weights = tf.compat.v1.get_variable(
265      name='first_weights',
266      initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.01),
267      shape=[first_filter_height, first_filter_width, 1, first_filter_count])
268  first_bias = tf.compat.v1.get_variable(
269      name='first_bias',
270      initializer=tf.compat.v1.zeros_initializer,
271      shape=[first_filter_count])
272
273  first_conv = tf.nn.conv2d(input=fingerprint_4d,
274                            filters=first_weights,
275                            strides=[1, 1, 1, 1],
276                            padding='SAME') + first_bias
277  first_relu = tf.nn.relu(first_conv)
278  if is_training:
279    first_dropout = tf.nn.dropout(first_relu, rate=dropout_rate)
280  else:
281    first_dropout = first_relu
282  max_pool = tf.nn.max_pool2d(input=first_dropout,
283                              ksize=[1, 2, 2, 1],
284                              strides=[1, 2, 2, 1],
285                              padding='SAME')
286  second_filter_width = 4
287  second_filter_height = 10
288  second_filter_count = 64
289  second_weights = tf.compat.v1.get_variable(
290      name='second_weights',
291      initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.01),
292      shape=[
293          second_filter_height, second_filter_width, first_filter_count,
294          second_filter_count
295      ])
296  second_bias = tf.compat.v1.get_variable(
297      name='second_bias',
298      initializer=tf.compat.v1.zeros_initializer,
299      shape=[second_filter_count])
300  second_conv = tf.nn.conv2d(input=max_pool,
301                             filters=second_weights,
302                             strides=[1, 1, 1, 1],
303                             padding='SAME') + second_bias
304  second_relu = tf.nn.relu(second_conv)
305  if is_training:
306    second_dropout = tf.nn.dropout(second_relu, rate=dropout_rate)
307  else:
308    second_dropout = second_relu
309  second_conv_shape = second_dropout.get_shape()
310  second_conv_output_width = second_conv_shape[2]
311  second_conv_output_height = second_conv_shape[1]
312  second_conv_element_count = int(
313      second_conv_output_width * second_conv_output_height *
314      second_filter_count)
315  flattened_second_conv = tf.reshape(second_dropout,
316                                     [-1, second_conv_element_count])
317  label_count = model_settings['label_count']
318  final_fc_weights = tf.compat.v1.get_variable(
319      name='final_fc_weights',
320      initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.01),
321      shape=[second_conv_element_count, label_count])
322  final_fc_bias = tf.compat.v1.get_variable(
323      name='final_fc_bias',
324      initializer=tf.compat.v1.zeros_initializer,
325      shape=[label_count])
326  final_fc = tf.matmul(flattened_second_conv, final_fc_weights) + final_fc_bias
327  if is_training:
328    return final_fc, dropout_rate
329  else:
330    return final_fc
331
332
333def create_low_latency_conv_model(fingerprint_input, model_settings,
334                                  is_training):
335  """Builds a convolutional model with low compute requirements.
336
337  This is roughly the network labeled as 'cnn-one-fstride4' in the
338  'Convolutional Neural Networks for Small-footprint Keyword Spotting' paper:
339  http://www.isca-speech.org/archive/interspeech_2015/papers/i15_1478.pdf
340
341  Here's the layout of the graph:
342
343  (fingerprint_input)
344          v
345      [Conv2D]<-(weights)
346          v
347      [BiasAdd]<-(bias)
348          v
349        [Relu]
350          v
351      [MatMul]<-(weights)
352          v
353      [BiasAdd]<-(bias)
354          v
355      [MatMul]<-(weights)
356          v
357      [BiasAdd]<-(bias)
358          v
359      [MatMul]<-(weights)
360          v
361      [BiasAdd]<-(bias)
362          v
363
364  This produces slightly lower quality results than the 'conv' model, but needs
365  fewer weight parameters and computations.
366
367  During training, dropout nodes are introduced after the relu, controlled by a
368  placeholder.
369
370  Args:
371    fingerprint_input: TensorFlow node that will output audio feature vectors.
372    model_settings: Dictionary of information about the model.
373    is_training: Whether the model is going to be used for training.
374
375  Returns:
376    TensorFlow node outputting logits results, and optionally a dropout
377    placeholder.
378  """
379  if is_training:
380    dropout_rate = tf.compat.v1.placeholder(tf.float32, name='dropout_rate')
381  input_frequency_size = model_settings['fingerprint_width']
382  input_time_size = model_settings['spectrogram_length']
383  fingerprint_4d = tf.reshape(fingerprint_input,
384                              [-1, input_time_size, input_frequency_size, 1])
385  first_filter_width = 8
386  first_filter_height = input_time_size
387  first_filter_count = 186
388  first_filter_stride_x = 1
389  first_filter_stride_y = 1
390  first_weights = tf.compat.v1.get_variable(
391      name='first_weights',
392      initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.01),
393      shape=[first_filter_height, first_filter_width, 1, first_filter_count])
394  first_bias = tf.compat.v1.get_variable(
395      name='first_bias',
396      initializer=tf.compat.v1.zeros_initializer,
397      shape=[first_filter_count])
398  first_conv = tf.nn.conv2d(
399      input=fingerprint_4d,
400      filters=first_weights,
401      strides=[1, first_filter_stride_y, first_filter_stride_x, 1],
402      padding='VALID') + first_bias
403  first_relu = tf.nn.relu(first_conv)
404  if is_training:
405    first_dropout = tf.nn.dropout(first_relu, rate=dropout_rate)
406  else:
407    first_dropout = first_relu
408  first_conv_output_width = math.floor(
409      (input_frequency_size - first_filter_width + first_filter_stride_x) /
410      first_filter_stride_x)
411  first_conv_output_height = math.floor(
412      (input_time_size - first_filter_height + first_filter_stride_y) /
413      first_filter_stride_y)
414  first_conv_element_count = int(
415      first_conv_output_width * first_conv_output_height * first_filter_count)
416  flattened_first_conv = tf.reshape(first_dropout,
417                                    [-1, first_conv_element_count])
418  first_fc_output_channels = 128
419  first_fc_weights = tf.compat.v1.get_variable(
420      name='first_fc_weights',
421      initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.01),
422      shape=[first_conv_element_count, first_fc_output_channels])
423  first_fc_bias = tf.compat.v1.get_variable(
424      name='first_fc_bias',
425      initializer=tf.compat.v1.zeros_initializer,
426      shape=[first_fc_output_channels])
427  first_fc = tf.matmul(flattened_first_conv, first_fc_weights) + first_fc_bias
428  if is_training:
429    second_fc_input = tf.nn.dropout(first_fc, rate=dropout_rate)
430  else:
431    second_fc_input = first_fc
432  second_fc_output_channels = 128
433  second_fc_weights = tf.compat.v1.get_variable(
434      name='second_fc_weights',
435      initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.01),
436      shape=[first_fc_output_channels, second_fc_output_channels])
437  second_fc_bias = tf.compat.v1.get_variable(
438      name='second_fc_bias',
439      initializer=tf.compat.v1.zeros_initializer,
440      shape=[second_fc_output_channels])
441  second_fc = tf.matmul(second_fc_input, second_fc_weights) + second_fc_bias
442  if is_training:
443    final_fc_input = tf.nn.dropout(second_fc, rate=dropout_rate)
444  else:
445    final_fc_input = second_fc
446  label_count = model_settings['label_count']
447  final_fc_weights = tf.compat.v1.get_variable(
448      name='final_fc_weights',
449      initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.01),
450      shape=[second_fc_output_channels, label_count])
451  final_fc_bias = tf.compat.v1.get_variable(
452      name='final_fc_bias',
453      initializer=tf.compat.v1.zeros_initializer,
454      shape=[label_count])
455  final_fc = tf.matmul(final_fc_input, final_fc_weights) + final_fc_bias
456  if is_training:
457    return final_fc, dropout_rate
458  else:
459    return final_fc
460
461
462def create_low_latency_svdf_model(fingerprint_input, model_settings,
463                                  is_training, runtime_settings):
464  """Builds an SVDF model with low compute requirements.
465
466  This is based in the topology presented in the 'Compressing Deep Neural
467  Networks using a Rank-Constrained Topology' paper:
468  https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/43813.pdf
469
470  Here's the layout of the graph:
471
472  (fingerprint_input)
473          v
474        [SVDF]<-(weights)
475          v
476      [BiasAdd]<-(bias)
477          v
478        [Relu]
479          v
480      [MatMul]<-(weights)
481          v
482      [BiasAdd]<-(bias)
483          v
484      [MatMul]<-(weights)
485          v
486      [BiasAdd]<-(bias)
487          v
488      [MatMul]<-(weights)
489          v
490      [BiasAdd]<-(bias)
491          v
492
493  This model produces lower recognition accuracy than the 'conv' model above,
494  but requires fewer weight parameters and, significantly fewer computations.
495
496  During training, dropout nodes are introduced after the relu, controlled by a
497  placeholder.
498
499  Args:
500    fingerprint_input: TensorFlow node that will output audio feature vectors.
501    The node is expected to produce a 2D Tensor of shape:
502      [batch, model_settings['fingerprint_width'] *
503              model_settings['spectrogram_length']]
504    with the features corresponding to the same time slot arranged contiguously,
505    and the oldest slot at index [:, 0], and newest at [:, -1].
506    model_settings: Dictionary of information about the model.
507    is_training: Whether the model is going to be used for training.
508    runtime_settings: Dictionary of information about the runtime.
509
510  Returns:
511    TensorFlow node outputting logits results, and optionally a dropout
512    placeholder.
513
514  Raises:
515      ValueError: If the inputs tensor is incorrectly shaped.
516  """
517  if is_training:
518    dropout_rate = tf.compat.v1.placeholder(tf.float32, name='dropout_rate')
519
520  input_frequency_size = model_settings['fingerprint_width']
521  input_time_size = model_settings['spectrogram_length']
522
523  # Validation.
524  input_shape = fingerprint_input.get_shape()
525  if len(input_shape) != 2:
526    raise ValueError('Inputs to `SVDF` should have rank == 2.')
527  if input_shape[-1].value is None:
528    raise ValueError('The last dimension of the input to `SVDF` '
529                     'should be defined. Found `None`.')
530  if input_shape[-1].value % input_frequency_size != 0:
531    raise ValueError('The last dimension of the input to `SVDF` = {0} must be '
532                     'a multiple of the frame size = {1}'.format(
533                         input_shape.shape[-1].value, input_frequency_size))
534
535  # Set number of units (i.e. nodes) and rank.
536  rank = 2
537  num_units = 1280
538  # Number of filters: pairs of feature and time filters.
539  num_filters = rank * num_units
540  # Create the runtime memory: [num_filters, batch, input_time_size]
541  batch = 1
542  memory = tf.compat.v1.get_variable(
543      initializer=tf.compat.v1.zeros_initializer,
544      shape=[num_filters, batch, input_time_size],
545      trainable=False,
546      name='runtime-memory')
547  first_time_flag = tf.compat.v1.get_variable(
548      name='first_time_flag', dtype=tf.int32, initializer=1)
549  # Determine the number of new frames in the input, such that we only operate
550  # on those. For training we do not use the memory, and thus use all frames
551  # provided in the input.
552  # new_fingerprint_input: [batch, num_new_frames*input_frequency_size]
553  if is_training:
554    num_new_frames = input_time_size
555  else:
556    window_stride_ms = int(model_settings['window_stride_samples'] * 1000 /
557                           model_settings['sample_rate'])
558    num_new_frames = tf.cond(
559        pred=tf.equal(first_time_flag, 1),
560        true_fn=lambda: input_time_size,
561        false_fn=lambda: int(runtime_settings['clip_stride_ms'] / window_stride_ms))  # pylint:disable=line-too-long
562  first_time_flag = 0
563  new_fingerprint_input = fingerprint_input[
564      :, -num_new_frames*input_frequency_size:]
565  # Expand to add input channels dimension.
566  new_fingerprint_input = tf.expand_dims(new_fingerprint_input, 2)
567
568  # Create the frequency filters.
569  weights_frequency = tf.compat.v1.get_variable(
570      name='weights_frequency',
571      initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.01),
572      shape=[input_frequency_size, num_filters])
573  # Expand to add input channels dimensions.
574  # weights_frequency: [input_frequency_size, 1, num_filters]
575  weights_frequency = tf.expand_dims(weights_frequency, 1)
576  # Convolve the 1D feature filters sliding over the time dimension.
577  # activations_time: [batch, num_new_frames, num_filters]
578  activations_time = tf.nn.conv1d(input=new_fingerprint_input,
579                                  filters=weights_frequency,
580                                  stride=input_frequency_size,
581                                  padding='VALID')
582  # Rearrange such that we can perform the batched matmul.
583  # activations_time: [num_filters, batch, num_new_frames]
584  activations_time = tf.transpose(a=activations_time, perm=[2, 0, 1])
585
586  # Runtime memory optimization.
587  if not is_training:
588    # We need to drop the activations corresponding to the oldest frames, and
589    # then add those corresponding to the new frames.
590    new_memory = memory[:, :, num_new_frames:]
591    new_memory = tf.concat([new_memory, activations_time], 2)
592    tf.compat.v1.assign(memory, new_memory)
593    activations_time = new_memory
594
595  # Create the time filters.
596  weights_time = tf.compat.v1.get_variable(
597      name='weights_time',
598      initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.01),
599      shape=[num_filters, input_time_size])
600  # Apply the time filter on the outputs of the feature filters.
601  # weights_time: [num_filters, input_time_size, 1]
602  # outputs: [num_filters, batch, 1]
603  weights_time = tf.expand_dims(weights_time, 2)
604  outputs = tf.matmul(activations_time, weights_time)
605  # Split num_units and rank into separate dimensions (the remaining
606  # dimension is the input_shape[0] -i.e. batch size). This also squeezes
607  # the last dimension, since it's not used.
608  # [num_filters, batch, 1] => [num_units, rank, batch]
609  outputs = tf.reshape(outputs, [num_units, rank, -1])
610  # Sum the rank outputs per unit => [num_units, batch].
611  units_output = tf.reduce_sum(input_tensor=outputs, axis=1)
612  # Transpose to shape [batch, num_units]
613  units_output = tf.transpose(a=units_output)
614
615  # Appy bias.
616  bias = tf.compat.v1.get_variable(name='bias',
617                                   initializer=tf.compat.v1.zeros_initializer,
618                                   shape=[num_units])
619  first_bias = tf.nn.bias_add(units_output, bias)
620
621  # Relu.
622  first_relu = tf.nn.relu(first_bias)
623
624  if is_training:
625    first_dropout = tf.nn.dropout(first_relu, rate=dropout_rate)
626  else:
627    first_dropout = first_relu
628
629  first_fc_output_channels = 256
630  first_fc_weights = tf.compat.v1.get_variable(
631      name='first_fc_weights',
632      initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.01),
633      shape=[num_units, first_fc_output_channels])
634  first_fc_bias = tf.compat.v1.get_variable(
635      name='first_fc_bias',
636      initializer=tf.compat.v1.zeros_initializer,
637      shape=[first_fc_output_channels])
638  first_fc = tf.matmul(first_dropout, first_fc_weights) + first_fc_bias
639  if is_training:
640    second_fc_input = tf.nn.dropout(first_fc, rate=dropout_rate)
641  else:
642    second_fc_input = first_fc
643  second_fc_output_channels = 256
644  second_fc_weights = tf.compat.v1.get_variable(
645      name='second_fc_weights',
646      initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.01),
647      shape=[first_fc_output_channels, second_fc_output_channels])
648  second_fc_bias = tf.compat.v1.get_variable(
649      name='second_fc_bias',
650      initializer=tf.compat.v1.zeros_initializer,
651      shape=[second_fc_output_channels])
652  second_fc = tf.matmul(second_fc_input, second_fc_weights) + second_fc_bias
653  if is_training:
654    final_fc_input = tf.nn.dropout(second_fc, rate=dropout_rate)
655  else:
656    final_fc_input = second_fc
657  label_count = model_settings['label_count']
658  final_fc_weights = tf.compat.v1.get_variable(
659      name='final_fc_weights',
660      initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.01),
661      shape=[second_fc_output_channels, label_count])
662  final_fc_bias = tf.compat.v1.get_variable(
663      name='final_fc_bias',
664      initializer=tf.compat.v1.zeros_initializer,
665      shape=[label_count])
666  final_fc = tf.matmul(final_fc_input, final_fc_weights) + final_fc_bias
667  if is_training:
668    return final_fc, dropout_rate
669  else:
670    return final_fc
671
672
673def create_tiny_conv_model(fingerprint_input, model_settings, is_training):
674  """Builds a convolutional model aimed at microcontrollers.
675
676  Devices like DSPs and microcontrollers can have very small amounts of
677  memory and limited processing power. This model is designed to use less
678  than 20KB of working RAM, and fit within 32KB of read-only (flash) memory.
679
680  Here's the layout of the graph:
681
682  (fingerprint_input)
683          v
684      [Conv2D]<-(weights)
685          v
686      [BiasAdd]<-(bias)
687          v
688        [Relu]
689          v
690      [MatMul]<-(weights)
691          v
692      [BiasAdd]<-(bias)
693          v
694
695  This doesn't produce particularly accurate results, but it's designed to be
696  used as the first stage of a pipeline, running on a low-energy piece of
697  hardware that can always be on, and then wake higher-power chips when a
698  possible utterance has been found, so that more accurate analysis can be done.
699
700  During training, a dropout node is introduced after the relu, controlled by a
701  placeholder.
702
703  Args:
704    fingerprint_input: TensorFlow node that will output audio feature vectors.
705    model_settings: Dictionary of information about the model.
706    is_training: Whether the model is going to be used for training.
707
708  Returns:
709    TensorFlow node outputting logits results, and optionally a dropout
710    placeholder.
711  """
712  if is_training:
713    dropout_rate = tf.compat.v1.placeholder(tf.float32, name='dropout_rate')
714  input_frequency_size = model_settings['fingerprint_width']
715  input_time_size = model_settings['spectrogram_length']
716  fingerprint_4d = tf.reshape(fingerprint_input,
717                              [-1, input_time_size, input_frequency_size, 1])
718  first_filter_width = 8
719  first_filter_height = 10
720  first_filter_count = 8
721  first_weights = tf.compat.v1.get_variable(
722      name='first_weights',
723      initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.01),
724      shape=[first_filter_height, first_filter_width, 1, first_filter_count])
725  first_bias = tf.compat.v1.get_variable(
726      name='first_bias',
727      initializer=tf.compat.v1.zeros_initializer,
728      shape=[first_filter_count])
729  first_conv_stride_x = 2
730  first_conv_stride_y = 2
731  first_conv = tf.nn.conv2d(
732      input=fingerprint_4d, filters=first_weights,
733      strides=[1, first_conv_stride_y, first_conv_stride_x, 1],
734      padding='SAME') + first_bias
735  first_relu = tf.nn.relu(first_conv)
736  if is_training:
737    first_dropout = tf.nn.dropout(first_relu, rate=dropout_rate)
738  else:
739    first_dropout = first_relu
740  first_dropout_shape = first_dropout.get_shape()
741  first_dropout_output_width = first_dropout_shape[2]
742  first_dropout_output_height = first_dropout_shape[1]
743  first_dropout_element_count = int(
744      first_dropout_output_width * first_dropout_output_height *
745      first_filter_count)
746  flattened_first_dropout = tf.reshape(first_dropout,
747                                       [-1, first_dropout_element_count])
748  label_count = model_settings['label_count']
749  final_fc_weights = tf.compat.v1.get_variable(
750      name='final_fc_weights',
751      initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.01),
752      shape=[first_dropout_element_count, label_count])
753  final_fc_bias = tf.compat.v1.get_variable(
754      name='final_fc_bias',
755      initializer=tf.compat.v1.zeros_initializer,
756      shape=[label_count])
757  final_fc = (
758      tf.matmul(flattened_first_dropout, final_fc_weights) + final_fc_bias)
759  if is_training:
760    return final_fc, dropout_rate
761  else:
762    return final_fc
763
764
765def create_tiny_embedding_conv_model(fingerprint_input, model_settings,
766                                     is_training):
767  """Builds a convolutional model aimed at microcontrollers.
768
769  Devices like DSPs and microcontrollers can have very small amounts of
770  memory and limited processing power. This model is designed to use less
771  than 20KB of working RAM, and fit within 32KB of read-only (flash) memory.
772
773  Here's the layout of the graph:
774
775  (fingerprint_input)
776          v
777      [Conv2D]<-(weights)
778          v
779      [BiasAdd]<-(bias)
780          v
781        [Relu]
782          v
783      [Conv2D]<-(weights)
784          v
785      [BiasAdd]<-(bias)
786          v
787        [Relu]
788          v
789      [Conv2D]<-(weights)
790          v
791      [BiasAdd]<-(bias)
792          v
793        [Relu]
794          v
795      [MatMul]<-(weights)
796          v
797      [BiasAdd]<-(bias)
798          v
799
800  This doesn't produce particularly accurate results, but it's designed to be
801  used as the first stage of a pipeline, running on a low-energy piece of
802  hardware that can always be on, and then wake higher-power chips when a
803  possible utterance has been found, so that more accurate analysis can be done.
804
805  During training, a dropout node is introduced after the relu, controlled by a
806  placeholder.
807
808  Args:
809    fingerprint_input: TensorFlow node that will output audio feature vectors.
810    model_settings: Dictionary of information about the model.
811    is_training: Whether the model is going to be used for training.
812
813  Returns:
814    TensorFlow node outputting logits results, and optionally a dropout
815    placeholder.
816  """
817  if is_training:
818    dropout_rate = tf.compat.v1.placeholder(tf.float32, name='dropout_rate')
819  input_frequency_size = model_settings['fingerprint_width']
820  input_time_size = model_settings['spectrogram_length']
821  fingerprint_4d = tf.reshape(fingerprint_input,
822                              [-1, input_time_size, input_frequency_size, 1])
823
824  first_filter_width = 8
825  first_filter_height = 10
826  first_filter_count = 8
827  first_weights = tf.compat.v1.get_variable(
828      name='first_weights',
829      initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.01),
830      shape=[first_filter_height, first_filter_width, 1, first_filter_count])
831  first_bias = tf.compat.v1.get_variable(
832      name='first_bias',
833      initializer=tf.compat.v1.zeros_initializer,
834      shape=[first_filter_count])
835  first_conv_stride_x = 2
836  first_conv_stride_y = 2
837
838  first_conv = tf.nn.conv2d(
839      input=fingerprint_4d, filters=first_weights,
840      strides=[1, first_conv_stride_y, first_conv_stride_x, 1],
841      padding='SAME') + first_bias
842  first_relu = tf.nn.relu(first_conv)
843  if is_training:
844    first_dropout = tf.nn.dropout(first_relu, rate=dropout_rate)
845
846  else:
847    first_dropout = first_relu
848
849  second_filter_width = 8
850  second_filter_height = 10
851  second_filter_count = 8
852  second_weights = tf.compat.v1.get_variable(
853      name='second_weights',
854      initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.01),
855      shape=[
856          second_filter_height, second_filter_width, first_filter_count,
857          second_filter_count
858      ])
859  second_bias = tf.compat.v1.get_variable(
860      name='second_bias',
861      initializer=tf.compat.v1.zeros_initializer,
862      shape=[second_filter_count])
863  second_conv_stride_x = 8
864  second_conv_stride_y = 8
865  second_conv = tf.nn.conv2d(
866      input=first_dropout, filters=second_weights,
867      strides=[1, second_conv_stride_y, second_conv_stride_x, 1],
868      padding='SAME') + second_bias
869  second_relu = tf.nn.relu(second_conv)
870  if is_training:
871    second_dropout = tf.nn.dropout(second_relu, rate=dropout_rate)
872  else:
873    second_dropout = second_relu
874
875  second_dropout_shape = second_dropout.get_shape()
876  second_dropout_output_width = second_dropout_shape[2]
877  second_dropout_output_height = second_dropout_shape[1]
878  second_dropout_element_count = int(second_dropout_output_width *
879                                     second_dropout_output_height *
880                                     second_filter_count)
881  flattened_second_dropout = tf.reshape(second_dropout,
882                                        [-1, second_dropout_element_count])
883  label_count = model_settings['label_count']
884  final_fc_weights = tf.compat.v1.get_variable(
885      name='final_fc_weights',
886      initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.01),
887      shape=[second_dropout_element_count, label_count])
888  final_fc_bias = tf.compat.v1.get_variable(
889      name='final_fc_bias',
890      initializer=tf.compat.v1.zeros_initializer,
891      shape=[label_count])
892  final_fc = (
893      tf.matmul(flattened_second_dropout, final_fc_weights) + final_fc_bias)
894  if is_training:
895    return final_fc, dropout_rate
896  else:
897    return final_fc
898