1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for speech commands models."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import tensorflow as tf
22
23from tensorflow.examples.speech_commands import models
24from tensorflow.python.framework import test_util
25from tensorflow.python.platform import test
26
27
28class ModelsTest(test.TestCase):
29
30  def _modelSettings(self):
31    return models.prepare_model_settings(
32        label_count=10,
33        sample_rate=16000,
34        clip_duration_ms=1000,
35        window_size_ms=20,
36        window_stride_ms=10,
37        feature_bin_count=40,
38        preprocess="mfcc")
39
40  def testPrepareModelSettings(self):
41    self.assertIsNotNone(
42        models.prepare_model_settings(
43            label_count=10,
44            sample_rate=16000,
45            clip_duration_ms=1000,
46            window_size_ms=20,
47            window_stride_ms=10,
48            feature_bin_count=40,
49            preprocess="mfcc"))
50
51  @test_util.run_deprecated_v1
52  def testCreateModelConvTraining(self):
53    model_settings = self._modelSettings()
54    with self.cached_session() as sess:
55      fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
56      logits, dropout_rate = models.create_model(
57          fingerprint_input, model_settings, "conv", True)
58      self.assertIsNotNone(logits)
59      self.assertIsNotNone(dropout_rate)
60      self.assertIsNotNone(sess.graph.get_tensor_by_name(logits.name))
61      self.assertIsNotNone(sess.graph.get_tensor_by_name(dropout_rate.name))
62
63  @test_util.run_deprecated_v1
64  def testCreateModelConvInference(self):
65    model_settings = self._modelSettings()
66    with self.cached_session() as sess:
67      fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
68      logits = models.create_model(fingerprint_input, model_settings, "conv",
69                                   False)
70      self.assertIsNotNone(logits)
71      self.assertIsNotNone(sess.graph.get_tensor_by_name(logits.name))
72
73  @test_util.run_deprecated_v1
74  def testCreateModelLowLatencyConvTraining(self):
75    model_settings = self._modelSettings()
76    with self.cached_session() as sess:
77      fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
78      logits, dropout_rate = models.create_model(
79          fingerprint_input, model_settings, "low_latency_conv", True)
80      self.assertIsNotNone(logits)
81      self.assertIsNotNone(dropout_rate)
82      self.assertIsNotNone(sess.graph.get_tensor_by_name(logits.name))
83      self.assertIsNotNone(sess.graph.get_tensor_by_name(dropout_rate.name))
84
85  @test_util.run_deprecated_v1
86  def testCreateModelFullyConnectedTraining(self):
87    model_settings = self._modelSettings()
88    with self.cached_session() as sess:
89      fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
90      logits, dropout_rate = models.create_model(
91          fingerprint_input, model_settings, "single_fc", True)
92      self.assertIsNotNone(logits)
93      self.assertIsNotNone(dropout_rate)
94      self.assertIsNotNone(sess.graph.get_tensor_by_name(logits.name))
95      self.assertIsNotNone(sess.graph.get_tensor_by_name(dropout_rate.name))
96
97  def testCreateModelBadArchitecture(self):
98    model_settings = self._modelSettings()
99    with self.cached_session():
100      fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
101      with self.assertRaises(Exception) as e:
102        models.create_model(fingerprint_input, model_settings,
103                            "bad_architecture", True)
104      self.assertIn("not recognized", str(e.exception))
105
106  @test_util.run_deprecated_v1
107  def testCreateModelTinyConvTraining(self):
108    model_settings = self._modelSettings()
109    with self.cached_session() as sess:
110      fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
111      logits, dropout_rate = models.create_model(
112          fingerprint_input, model_settings, "tiny_conv", True)
113      self.assertIsNotNone(logits)
114      self.assertIsNotNone(dropout_rate)
115      self.assertIsNotNone(sess.graph.get_tensor_by_name(logits.name))
116      self.assertIsNotNone(sess.graph.get_tensor_by_name(dropout_rate.name))
117
118
119if __name__ == "__main__":
120  test.main()
121