1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for data input for speech commands."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os.path
22
23from tensorflow.examples.speech_commands import freeze
24from tensorflow.python.framework import graph_util
25from tensorflow.python.framework import test_util
26from tensorflow.python.ops.variables import global_variables_initializer
27from tensorflow.python.platform import test
28
29
30class FreezeTest(test.TestCase):
31
32  @test_util.run_deprecated_v1
33  def testCreateInferenceGraphWithMfcc(self):
34    with self.cached_session() as sess:
35      freeze.create_inference_graph(
36          wanted_words='a,b,c,d',
37          sample_rate=16000,
38          clip_duration_ms=1000.0,
39          clip_stride_ms=30.0,
40          window_size_ms=30.0,
41          window_stride_ms=10.0,
42          feature_bin_count=40,
43          model_architecture='conv',
44          preprocess='mfcc')
45      self.assertIsNotNone(sess.graph.get_tensor_by_name('wav_data:0'))
46      self.assertIsNotNone(
47          sess.graph.get_tensor_by_name('decoded_sample_data:0'))
48      self.assertIsNotNone(sess.graph.get_tensor_by_name('labels_softmax:0'))
49      ops = [node.op for node in sess.graph_def.node]
50      self.assertEqual(1, ops.count('Mfcc'))
51
52  @test_util.run_deprecated_v1
53  def testCreateInferenceGraphWithoutMfcc(self):
54    with self.cached_session() as sess:
55      freeze.create_inference_graph(
56          wanted_words='a,b,c,d',
57          sample_rate=16000,
58          clip_duration_ms=1000.0,
59          clip_stride_ms=30.0,
60          window_size_ms=30.0,
61          window_stride_ms=10.0,
62          feature_bin_count=40,
63          model_architecture='conv',
64          preprocess='average')
65      self.assertIsNotNone(sess.graph.get_tensor_by_name('wav_data:0'))
66      self.assertIsNotNone(
67          sess.graph.get_tensor_by_name('decoded_sample_data:0'))
68      self.assertIsNotNone(sess.graph.get_tensor_by_name('labels_softmax:0'))
69      ops = [node.op for node in sess.graph_def.node]
70      self.assertEqual(0, ops.count('Mfcc'))
71
72  @test_util.run_deprecated_v1
73  def testCreateInferenceGraphWithMicro(self):
74    with self.cached_session() as sess:
75      freeze.create_inference_graph(
76          wanted_words='a,b,c,d',
77          sample_rate=16000,
78          clip_duration_ms=1000.0,
79          clip_stride_ms=30.0,
80          window_size_ms=30.0,
81          window_stride_ms=10.0,
82          feature_bin_count=40,
83          model_architecture='conv',
84          preprocess='micro')
85      self.assertIsNotNone(sess.graph.get_tensor_by_name('wav_data:0'))
86      self.assertIsNotNone(
87          sess.graph.get_tensor_by_name('decoded_sample_data:0'))
88      self.assertIsNotNone(sess.graph.get_tensor_by_name('labels_softmax:0'))
89
90  @test_util.run_deprecated_v1
91  def testFeatureBinCount(self):
92    with self.cached_session() as sess:
93      freeze.create_inference_graph(
94          wanted_words='a,b,c,d',
95          sample_rate=16000,
96          clip_duration_ms=1000.0,
97          clip_stride_ms=30.0,
98          window_size_ms=30.0,
99          window_stride_ms=10.0,
100          feature_bin_count=80,
101          model_architecture='conv',
102          preprocess='average')
103      self.assertIsNotNone(sess.graph.get_tensor_by_name('wav_data:0'))
104      self.assertIsNotNone(
105          sess.graph.get_tensor_by_name('decoded_sample_data:0'))
106      self.assertIsNotNone(sess.graph.get_tensor_by_name('labels_softmax:0'))
107      ops = [node.op for node in sess.graph_def.node]
108      self.assertEqual(0, ops.count('Mfcc'))
109
110  @test_util.run_deprecated_v1
111  def testCreateSavedModel(self):
112    tmp_dir = self.get_temp_dir()
113    saved_model_path = os.path.join(tmp_dir, 'saved_model')
114    with self.cached_session() as sess:
115      input_tensor, output_tensor = freeze.create_inference_graph(
116          wanted_words='a,b,c,d',
117          sample_rate=16000,
118          clip_duration_ms=1000.0,
119          clip_stride_ms=30.0,
120          window_size_ms=30.0,
121          window_stride_ms=10.0,
122          feature_bin_count=40,
123          model_architecture='conv',
124          preprocess='micro')
125      global_variables_initializer().run()
126      graph_util.convert_variables_to_constants(
127          sess, sess.graph_def, ['labels_softmax'])
128      freeze.save_saved_model(saved_model_path, sess, input_tensor,
129                              output_tensor)
130
131
132if __name__ == '__main__':
133  test.main()
134