1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for data input for speech commands."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22
23import numpy as np
24import tensorflow as tf
25
26
27from tensorflow.examples.speech_commands import input_data
28from tensorflow.examples.speech_commands import models
29from tensorflow.python.framework import test_util
30from tensorflow.python.platform import test
31
32
33class InputDataTest(test.TestCase):
34
35  def _getWavData(self):
36    with self.cached_session():
37      sample_data = tf.zeros([32000, 2])
38      wav_encoder = tf.audio.encode_wav(sample_data, 16000)
39      wav_data = self.evaluate(wav_encoder)
40    return wav_data
41
42  def _saveTestWavFile(self, filename, wav_data):
43    with open(filename, "wb") as f:
44      f.write(wav_data)
45
46  def _saveWavFolders(self, root_dir, labels, how_many):
47    wav_data = self._getWavData()
48    for label in labels:
49      dir_name = os.path.join(root_dir, label)
50      os.mkdir(dir_name)
51      for i in range(how_many):
52        file_path = os.path.join(dir_name, "some_audio_%d.wav" % i)
53        self._saveTestWavFile(file_path, wav_data)
54
55  def _model_settings(self):
56    return {
57        "desired_samples": 160,
58        "fingerprint_size": 40,
59        "label_count": 4,
60        "window_size_samples": 100,
61        "window_stride_samples": 100,
62        "fingerprint_width": 40,
63        "preprocess": "mfcc",
64    }
65
66  def _runGetDataTest(self, preprocess, window_length_ms):
67    tmp_dir = self.get_temp_dir()
68    wav_dir = os.path.join(tmp_dir, "wavs")
69    os.mkdir(wav_dir)
70    self._saveWavFolders(wav_dir, ["a", "b", "c"], 100)
71    background_dir = os.path.join(wav_dir, "_background_noise_")
72    os.mkdir(background_dir)
73    wav_data = self._getWavData()
74    for i in range(10):
75      file_path = os.path.join(background_dir, "background_audio_%d.wav" % i)
76      self._saveTestWavFile(file_path, wav_data)
77    model_settings = models.prepare_model_settings(
78        4, 16000, 1000, window_length_ms, 20, 40, preprocess)
79    with self.cached_session() as sess:
80      audio_processor = input_data.AudioProcessor(
81          "", wav_dir, 10, 10, ["a", "b"], 10, 10, model_settings, tmp_dir)
82      result_data, result_labels = audio_processor.get_data(
83          10, 0, model_settings, 0.3, 0.1, 100, "training", sess)
84      self.assertEqual(10, len(result_data))
85      self.assertEqual(10, len(result_labels))
86
87  def testPrepareWordsList(self):
88    words_list = ["a", "b"]
89    self.assertGreater(
90        len(input_data.prepare_words_list(words_list)), len(words_list))
91
92  def testWhichSet(self):
93    self.assertEqual(
94        input_data.which_set("foo.wav", 10, 10),
95        input_data.which_set("foo.wav", 10, 10))
96    self.assertEqual(
97        input_data.which_set("foo_nohash_0.wav", 10, 10),
98        input_data.which_set("foo_nohash_1.wav", 10, 10))
99
100  @test_util.run_deprecated_v1
101  def testPrepareDataIndex(self):
102    tmp_dir = self.get_temp_dir()
103    self._saveWavFolders(tmp_dir, ["a", "b", "c"], 100)
104    audio_processor = input_data.AudioProcessor("", tmp_dir, 10, 10,
105                                                ["a", "b"], 10, 10,
106                                                self._model_settings(), tmp_dir)
107    self.assertLess(0, audio_processor.set_size("training"))
108    self.assertIn("training", audio_processor.data_index)
109    self.assertIn("validation", audio_processor.data_index)
110    self.assertIn("testing", audio_processor.data_index)
111    self.assertEqual(input_data.UNKNOWN_WORD_INDEX,
112                     audio_processor.word_to_index["c"])
113
114  def testPrepareDataIndexEmpty(self):
115    tmp_dir = self.get_temp_dir()
116    self._saveWavFolders(tmp_dir, ["a", "b", "c"], 0)
117    with self.assertRaises(Exception) as e:
118      _ = input_data.AudioProcessor("", tmp_dir, 10, 10, ["a", "b"], 10, 10,
119                                    self._model_settings(), tmp_dir)
120    self.assertIn("No .wavs found", str(e.exception))
121
122  def testPrepareDataIndexMissing(self):
123    tmp_dir = self.get_temp_dir()
124    self._saveWavFolders(tmp_dir, ["a", "b", "c"], 100)
125    with self.assertRaises(Exception) as e:
126      _ = input_data.AudioProcessor("", tmp_dir, 10, 10, ["a", "b", "d"], 10,
127                                    10, self._model_settings(), tmp_dir)
128    self.assertIn("Expected to find", str(e.exception))
129
130  @test_util.run_deprecated_v1
131  def testPrepareBackgroundData(self):
132    tmp_dir = self.get_temp_dir()
133    background_dir = os.path.join(tmp_dir, "_background_noise_")
134    os.mkdir(background_dir)
135    wav_data = self._getWavData()
136    for i in range(10):
137      file_path = os.path.join(background_dir, "background_audio_%d.wav" % i)
138      self._saveTestWavFile(file_path, wav_data)
139    self._saveWavFolders(tmp_dir, ["a", "b", "c"], 100)
140    audio_processor = input_data.AudioProcessor("", tmp_dir, 10, 10,
141                                                ["a", "b"], 10, 10,
142                                                self._model_settings(), tmp_dir)
143    self.assertEqual(10, len(audio_processor.background_data))
144
145  def testLoadWavFile(self):
146    tmp_dir = self.get_temp_dir()
147    file_path = os.path.join(tmp_dir, "load_test.wav")
148    wav_data = self._getWavData()
149    self._saveTestWavFile(file_path, wav_data)
150    sample_data = input_data.load_wav_file(file_path)
151    self.assertIsNotNone(sample_data)
152
153  def testSaveWavFile(self):
154    tmp_dir = self.get_temp_dir()
155    file_path = os.path.join(tmp_dir, "load_test.wav")
156    save_data = np.zeros([16000, 1])
157    input_data.save_wav_file(file_path, save_data, 16000)
158    loaded_data = input_data.load_wav_file(file_path)
159    self.assertIsNotNone(loaded_data)
160    self.assertEqual(16000, len(loaded_data))
161
162  @test_util.run_deprecated_v1
163  def testPrepareProcessingGraph(self):
164    tmp_dir = self.get_temp_dir()
165    wav_dir = os.path.join(tmp_dir, "wavs")
166    os.mkdir(wav_dir)
167    self._saveWavFolders(wav_dir, ["a", "b", "c"], 100)
168    background_dir = os.path.join(wav_dir, "_background_noise_")
169    os.mkdir(background_dir)
170    wav_data = self._getWavData()
171    for i in range(10):
172      file_path = os.path.join(background_dir, "background_audio_%d.wav" % i)
173      self._saveTestWavFile(file_path, wav_data)
174    model_settings = {
175        "desired_samples": 160,
176        "fingerprint_size": 40,
177        "label_count": 4,
178        "window_size_samples": 100,
179        "window_stride_samples": 100,
180        "fingerprint_width": 40,
181        "preprocess": "mfcc",
182    }
183    audio_processor = input_data.AudioProcessor("", wav_dir, 10, 10, ["a", "b"],
184                                                10, 10, model_settings, tmp_dir)
185    self.assertIsNotNone(audio_processor.wav_filename_placeholder_)
186    self.assertIsNotNone(audio_processor.foreground_volume_placeholder_)
187    self.assertIsNotNone(audio_processor.time_shift_padding_placeholder_)
188    self.assertIsNotNone(audio_processor.time_shift_offset_placeholder_)
189    self.assertIsNotNone(audio_processor.background_data_placeholder_)
190    self.assertIsNotNone(audio_processor.background_volume_placeholder_)
191    self.assertIsNotNone(audio_processor.output_)
192
193  @test_util.run_deprecated_v1
194  def testGetDataAverage(self):
195    self._runGetDataTest("average", 10)
196
197  @test_util.run_deprecated_v1
198  def testGetDataAverageLongWindow(self):
199    self._runGetDataTest("average", 30)
200
201  @test_util.run_deprecated_v1
202  def testGetDataMfcc(self):
203    self._runGetDataTest("mfcc", 30)
204
205  @test_util.run_deprecated_v1
206  def testGetDataMicro(self):
207    self._runGetDataTest("micro", 20)
208
209  @test_util.run_deprecated_v1
210  def testGetUnprocessedData(self):
211    tmp_dir = self.get_temp_dir()
212    wav_dir = os.path.join(tmp_dir, "wavs")
213    os.mkdir(wav_dir)
214    self._saveWavFolders(wav_dir, ["a", "b", "c"], 100)
215    model_settings = {
216        "desired_samples": 160,
217        "fingerprint_size": 40,
218        "label_count": 4,
219        "window_size_samples": 100,
220        "window_stride_samples": 100,
221        "fingerprint_width": 40,
222        "preprocess": "mfcc",
223    }
224    audio_processor = input_data.AudioProcessor("", wav_dir, 10, 10, ["a", "b"],
225                                                10, 10, model_settings, tmp_dir)
226    result_data, result_labels = audio_processor.get_unprocessed_data(
227        10, model_settings, "training")
228    self.assertEqual(10, len(result_data))
229    self.assertEqual(10, len(result_labels))
230
231  @test_util.run_deprecated_v1
232  def testGetFeaturesForWav(self):
233    tmp_dir = self.get_temp_dir()
234    wav_dir = os.path.join(tmp_dir, "wavs")
235    os.mkdir(wav_dir)
236    self._saveWavFolders(wav_dir, ["a", "b", "c"], 1)
237    desired_samples = 1600
238    model_settings = {
239        "desired_samples": desired_samples,
240        "fingerprint_size": 40,
241        "label_count": 4,
242        "window_size_samples": 100,
243        "window_stride_samples": 100,
244        "fingerprint_width": 40,
245        "average_window_width": 6,
246        "preprocess": "average",
247    }
248    with self.cached_session() as sess:
249      audio_processor = input_data.AudioProcessor(
250          "", wav_dir, 10, 10, ["a", "b"], 10, 10, model_settings, tmp_dir)
251      sample_data = np.zeros([desired_samples, 1])
252      for i in range(desired_samples):
253        phase = i % 4
254        if phase == 0:
255          sample_data[i, 0] = 0
256        elif phase == 1:
257          sample_data[i, 0] = -1
258        elif phase == 2:
259          sample_data[i, 0] = 0
260        elif phase == 3:
261          sample_data[i, 0] = 1
262      test_wav_path = os.path.join(tmp_dir, "test_wav.wav")
263      input_data.save_wav_file(test_wav_path, sample_data, 16000)
264
265      results = audio_processor.get_features_for_wav(test_wav_path,
266                                                     model_settings, sess)
267      spectrogram = results[0]
268      self.assertEqual(1, spectrogram.shape[0])
269      self.assertEqual(16, spectrogram.shape[1])
270      self.assertEqual(11, spectrogram.shape[2])
271      self.assertNear(0, spectrogram[0, 0, 0], 0.1)
272      self.assertNear(200, spectrogram[0, 0, 5], 0.1)
273
274  def testGetFeaturesRange(self):
275    model_settings = {
276        "preprocess": "average",
277    }
278    features_min, _ = input_data.get_features_range(model_settings)
279    self.assertNear(0.0, features_min, 1e-5)
280
281  def testGetMfccFeaturesRange(self):
282    model_settings = {
283        "preprocess": "mfcc",
284    }
285    features_min, features_max = input_data.get_features_range(model_settings)
286    self.assertLess(features_min, features_max)
287
288
289if __name__ == "__main__":
290  test.main()
291