1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for speech commands models.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import tensorflow as tf 22 23from tensorflow.examples.speech_commands import models 24from tensorflow.python.framework import test_util 25from tensorflow.python.platform import test 26 27 28class ModelsTest(test.TestCase): 29 30 def _modelSettings(self): 31 return models.prepare_model_settings( 32 label_count=10, 33 sample_rate=16000, 34 clip_duration_ms=1000, 35 window_size_ms=20, 36 window_stride_ms=10, 37 feature_bin_count=40, 38 preprocess="mfcc") 39 40 def testPrepareModelSettings(self): 41 self.assertIsNotNone( 42 models.prepare_model_settings( 43 label_count=10, 44 sample_rate=16000, 45 clip_duration_ms=1000, 46 window_size_ms=20, 47 window_stride_ms=10, 48 feature_bin_count=40, 49 preprocess="mfcc")) 50 51 @test_util.run_deprecated_v1 52 def testCreateModelConvTraining(self): 53 model_settings = self._modelSettings() 54 with self.cached_session() as sess: 55 fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]]) 56 logits, dropout_rate = models.create_model( 57 fingerprint_input, model_settings, "conv", True) 58 self.assertIsNotNone(logits) 59 self.assertIsNotNone(dropout_rate) 60 self.assertIsNotNone(sess.graph.get_tensor_by_name(logits.name)) 61 self.assertIsNotNone(sess.graph.get_tensor_by_name(dropout_rate.name)) 62 63 @test_util.run_deprecated_v1 64 def testCreateModelConvInference(self): 65 model_settings = self._modelSettings() 66 with self.cached_session() as sess: 67 fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]]) 68 logits = models.create_model(fingerprint_input, model_settings, "conv", 69 False) 70 self.assertIsNotNone(logits) 71 self.assertIsNotNone(sess.graph.get_tensor_by_name(logits.name)) 72 73 @test_util.run_deprecated_v1 74 def testCreateModelLowLatencyConvTraining(self): 75 model_settings = self._modelSettings() 76 with self.cached_session() as sess: 77 fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]]) 78 logits, dropout_rate = models.create_model( 79 fingerprint_input, model_settings, "low_latency_conv", True) 80 self.assertIsNotNone(logits) 81 self.assertIsNotNone(dropout_rate) 82 self.assertIsNotNone(sess.graph.get_tensor_by_name(logits.name)) 83 self.assertIsNotNone(sess.graph.get_tensor_by_name(dropout_rate.name)) 84 85 @test_util.run_deprecated_v1 86 def testCreateModelFullyConnectedTraining(self): 87 model_settings = self._modelSettings() 88 with self.cached_session() as sess: 89 fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]]) 90 logits, dropout_rate = models.create_model( 91 fingerprint_input, model_settings, "single_fc", True) 92 self.assertIsNotNone(logits) 93 self.assertIsNotNone(dropout_rate) 94 self.assertIsNotNone(sess.graph.get_tensor_by_name(logits.name)) 95 self.assertIsNotNone(sess.graph.get_tensor_by_name(dropout_rate.name)) 96 97 def testCreateModelBadArchitecture(self): 98 model_settings = self._modelSettings() 99 with self.cached_session(): 100 fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]]) 101 with self.assertRaises(Exception) as e: 102 models.create_model(fingerprint_input, model_settings, 103 "bad_architecture", True) 104 self.assertIn("not recognized", str(e.exception)) 105 106 @test_util.run_deprecated_v1 107 def testCreateModelTinyConvTraining(self): 108 model_settings = self._modelSettings() 109 with self.cached_session() as sess: 110 fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]]) 111 logits, dropout_rate = models.create_model( 112 fingerprint_input, model_settings, "tiny_conv", True) 113 self.assertIsNotNone(logits) 114 self.assertIsNotNone(dropout_rate) 115 self.assertIsNotNone(sess.graph.get_tensor_by_name(logits.name)) 116 self.assertIsNotNone(sess.graph.get_tensor_by_name(dropout_rate.name)) 117 118 119if __name__ == "__main__": 120 test.main() 121