1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for data input for speech commands.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os.path 22 23from tensorflow.examples.speech_commands import freeze 24from tensorflow.python.framework import graph_util 25from tensorflow.python.framework import test_util 26from tensorflow.python.ops.variables import global_variables_initializer 27from tensorflow.python.platform import test 28 29 30class FreezeTest(test.TestCase): 31 32 @test_util.run_deprecated_v1 33 def testCreateInferenceGraphWithMfcc(self): 34 with self.cached_session() as sess: 35 freeze.create_inference_graph( 36 wanted_words='a,b,c,d', 37 sample_rate=16000, 38 clip_duration_ms=1000.0, 39 clip_stride_ms=30.0, 40 window_size_ms=30.0, 41 window_stride_ms=10.0, 42 feature_bin_count=40, 43 model_architecture='conv', 44 preprocess='mfcc') 45 self.assertIsNotNone(sess.graph.get_tensor_by_name('wav_data:0')) 46 self.assertIsNotNone( 47 sess.graph.get_tensor_by_name('decoded_sample_data:0')) 48 self.assertIsNotNone(sess.graph.get_tensor_by_name('labels_softmax:0')) 49 ops = [node.op for node in sess.graph_def.node] 50 self.assertEqual(1, ops.count('Mfcc')) 51 52 @test_util.run_deprecated_v1 53 def testCreateInferenceGraphWithoutMfcc(self): 54 with self.cached_session() as sess: 55 freeze.create_inference_graph( 56 wanted_words='a,b,c,d', 57 sample_rate=16000, 58 clip_duration_ms=1000.0, 59 clip_stride_ms=30.0, 60 window_size_ms=30.0, 61 window_stride_ms=10.0, 62 feature_bin_count=40, 63 model_architecture='conv', 64 preprocess='average') 65 self.assertIsNotNone(sess.graph.get_tensor_by_name('wav_data:0')) 66 self.assertIsNotNone( 67 sess.graph.get_tensor_by_name('decoded_sample_data:0')) 68 self.assertIsNotNone(sess.graph.get_tensor_by_name('labels_softmax:0')) 69 ops = [node.op for node in sess.graph_def.node] 70 self.assertEqual(0, ops.count('Mfcc')) 71 72 @test_util.run_deprecated_v1 73 def testCreateInferenceGraphWithMicro(self): 74 with self.cached_session() as sess: 75 freeze.create_inference_graph( 76 wanted_words='a,b,c,d', 77 sample_rate=16000, 78 clip_duration_ms=1000.0, 79 clip_stride_ms=30.0, 80 window_size_ms=30.0, 81 window_stride_ms=10.0, 82 feature_bin_count=40, 83 model_architecture='conv', 84 preprocess='micro') 85 self.assertIsNotNone(sess.graph.get_tensor_by_name('wav_data:0')) 86 self.assertIsNotNone( 87 sess.graph.get_tensor_by_name('decoded_sample_data:0')) 88 self.assertIsNotNone(sess.graph.get_tensor_by_name('labels_softmax:0')) 89 90 @test_util.run_deprecated_v1 91 def testFeatureBinCount(self): 92 with self.cached_session() as sess: 93 freeze.create_inference_graph( 94 wanted_words='a,b,c,d', 95 sample_rate=16000, 96 clip_duration_ms=1000.0, 97 clip_stride_ms=30.0, 98 window_size_ms=30.0, 99 window_stride_ms=10.0, 100 feature_bin_count=80, 101 model_architecture='conv', 102 preprocess='average') 103 self.assertIsNotNone(sess.graph.get_tensor_by_name('wav_data:0')) 104 self.assertIsNotNone( 105 sess.graph.get_tensor_by_name('decoded_sample_data:0')) 106 self.assertIsNotNone(sess.graph.get_tensor_by_name('labels_softmax:0')) 107 ops = [node.op for node in sess.graph_def.node] 108 self.assertEqual(0, ops.count('Mfcc')) 109 110 @test_util.run_deprecated_v1 111 def testCreateSavedModel(self): 112 tmp_dir = self.get_temp_dir() 113 saved_model_path = os.path.join(tmp_dir, 'saved_model') 114 with self.cached_session() as sess: 115 input_tensor, output_tensor = freeze.create_inference_graph( 116 wanted_words='a,b,c,d', 117 sample_rate=16000, 118 clip_duration_ms=1000.0, 119 clip_stride_ms=30.0, 120 window_size_ms=30.0, 121 window_stride_ms=10.0, 122 feature_bin_count=40, 123 model_architecture='conv', 124 preprocess='micro') 125 global_variables_initializer().run() 126 graph_util.convert_variables_to_constants( 127 sess, sess.graph_def, ['labels_softmax']) 128 freeze.save_saved_model(saved_model_path, sess, input_tensor, 129 output_tensor) 130 131 132if __name__ == '__main__': 133 test.main() 134