1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for data input for speech commands.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os 22 23import numpy as np 24import tensorflow as tf 25 26 27from tensorflow.examples.speech_commands import input_data 28from tensorflow.examples.speech_commands import models 29from tensorflow.python.framework import test_util 30from tensorflow.python.platform import test 31 32 33class InputDataTest(test.TestCase): 34 35 def _getWavData(self): 36 with self.cached_session(): 37 sample_data = tf.zeros([32000, 2]) 38 wav_encoder = tf.audio.encode_wav(sample_data, 16000) 39 wav_data = self.evaluate(wav_encoder) 40 return wav_data 41 42 def _saveTestWavFile(self, filename, wav_data): 43 with open(filename, "wb") as f: 44 f.write(wav_data) 45 46 def _saveWavFolders(self, root_dir, labels, how_many): 47 wav_data = self._getWavData() 48 for label in labels: 49 dir_name = os.path.join(root_dir, label) 50 os.mkdir(dir_name) 51 for i in range(how_many): 52 file_path = os.path.join(dir_name, "some_audio_%d.wav" % i) 53 self._saveTestWavFile(file_path, wav_data) 54 55 def _model_settings(self): 56 return { 57 "desired_samples": 160, 58 "fingerprint_size": 40, 59 "label_count": 4, 60 "window_size_samples": 100, 61 "window_stride_samples": 100, 62 "fingerprint_width": 40, 63 "preprocess": "mfcc", 64 } 65 66 def _runGetDataTest(self, preprocess, window_length_ms): 67 tmp_dir = self.get_temp_dir() 68 wav_dir = os.path.join(tmp_dir, "wavs") 69 os.mkdir(wav_dir) 70 self._saveWavFolders(wav_dir, ["a", "b", "c"], 100) 71 background_dir = os.path.join(wav_dir, "_background_noise_") 72 os.mkdir(background_dir) 73 wav_data = self._getWavData() 74 for i in range(10): 75 file_path = os.path.join(background_dir, "background_audio_%d.wav" % i) 76 self._saveTestWavFile(file_path, wav_data) 77 model_settings = models.prepare_model_settings( 78 4, 16000, 1000, window_length_ms, 20, 40, preprocess) 79 with self.cached_session() as sess: 80 audio_processor = input_data.AudioProcessor( 81 "", wav_dir, 10, 10, ["a", "b"], 10, 10, model_settings, tmp_dir) 82 result_data, result_labels = audio_processor.get_data( 83 10, 0, model_settings, 0.3, 0.1, 100, "training", sess) 84 self.assertEqual(10, len(result_data)) 85 self.assertEqual(10, len(result_labels)) 86 87 def testPrepareWordsList(self): 88 words_list = ["a", "b"] 89 self.assertGreater( 90 len(input_data.prepare_words_list(words_list)), len(words_list)) 91 92 def testWhichSet(self): 93 self.assertEqual( 94 input_data.which_set("foo.wav", 10, 10), 95 input_data.which_set("foo.wav", 10, 10)) 96 self.assertEqual( 97 input_data.which_set("foo_nohash_0.wav", 10, 10), 98 input_data.which_set("foo_nohash_1.wav", 10, 10)) 99 100 @test_util.run_deprecated_v1 101 def testPrepareDataIndex(self): 102 tmp_dir = self.get_temp_dir() 103 self._saveWavFolders(tmp_dir, ["a", "b", "c"], 100) 104 audio_processor = input_data.AudioProcessor("", tmp_dir, 10, 10, 105 ["a", "b"], 10, 10, 106 self._model_settings(), tmp_dir) 107 self.assertLess(0, audio_processor.set_size("training")) 108 self.assertIn("training", audio_processor.data_index) 109 self.assertIn("validation", audio_processor.data_index) 110 self.assertIn("testing", audio_processor.data_index) 111 self.assertEqual(input_data.UNKNOWN_WORD_INDEX, 112 audio_processor.word_to_index["c"]) 113 114 def testPrepareDataIndexEmpty(self): 115 tmp_dir = self.get_temp_dir() 116 self._saveWavFolders(tmp_dir, ["a", "b", "c"], 0) 117 with self.assertRaises(Exception) as e: 118 _ = input_data.AudioProcessor("", tmp_dir, 10, 10, ["a", "b"], 10, 10, 119 self._model_settings(), tmp_dir) 120 self.assertIn("No .wavs found", str(e.exception)) 121 122 def testPrepareDataIndexMissing(self): 123 tmp_dir = self.get_temp_dir() 124 self._saveWavFolders(tmp_dir, ["a", "b", "c"], 100) 125 with self.assertRaises(Exception) as e: 126 _ = input_data.AudioProcessor("", tmp_dir, 10, 10, ["a", "b", "d"], 10, 127 10, self._model_settings(), tmp_dir) 128 self.assertIn("Expected to find", str(e.exception)) 129 130 @test_util.run_deprecated_v1 131 def testPrepareBackgroundData(self): 132 tmp_dir = self.get_temp_dir() 133 background_dir = os.path.join(tmp_dir, "_background_noise_") 134 os.mkdir(background_dir) 135 wav_data = self._getWavData() 136 for i in range(10): 137 file_path = os.path.join(background_dir, "background_audio_%d.wav" % i) 138 self._saveTestWavFile(file_path, wav_data) 139 self._saveWavFolders(tmp_dir, ["a", "b", "c"], 100) 140 audio_processor = input_data.AudioProcessor("", tmp_dir, 10, 10, 141 ["a", "b"], 10, 10, 142 self._model_settings(), tmp_dir) 143 self.assertEqual(10, len(audio_processor.background_data)) 144 145 def testLoadWavFile(self): 146 tmp_dir = self.get_temp_dir() 147 file_path = os.path.join(tmp_dir, "load_test.wav") 148 wav_data = self._getWavData() 149 self._saveTestWavFile(file_path, wav_data) 150 sample_data = input_data.load_wav_file(file_path) 151 self.assertIsNotNone(sample_data) 152 153 def testSaveWavFile(self): 154 tmp_dir = self.get_temp_dir() 155 file_path = os.path.join(tmp_dir, "load_test.wav") 156 save_data = np.zeros([16000, 1]) 157 input_data.save_wav_file(file_path, save_data, 16000) 158 loaded_data = input_data.load_wav_file(file_path) 159 self.assertIsNotNone(loaded_data) 160 self.assertEqual(16000, len(loaded_data)) 161 162 @test_util.run_deprecated_v1 163 def testPrepareProcessingGraph(self): 164 tmp_dir = self.get_temp_dir() 165 wav_dir = os.path.join(tmp_dir, "wavs") 166 os.mkdir(wav_dir) 167 self._saveWavFolders(wav_dir, ["a", "b", "c"], 100) 168 background_dir = os.path.join(wav_dir, "_background_noise_") 169 os.mkdir(background_dir) 170 wav_data = self._getWavData() 171 for i in range(10): 172 file_path = os.path.join(background_dir, "background_audio_%d.wav" % i) 173 self._saveTestWavFile(file_path, wav_data) 174 model_settings = { 175 "desired_samples": 160, 176 "fingerprint_size": 40, 177 "label_count": 4, 178 "window_size_samples": 100, 179 "window_stride_samples": 100, 180 "fingerprint_width": 40, 181 "preprocess": "mfcc", 182 } 183 audio_processor = input_data.AudioProcessor("", wav_dir, 10, 10, ["a", "b"], 184 10, 10, model_settings, tmp_dir) 185 self.assertIsNotNone(audio_processor.wav_filename_placeholder_) 186 self.assertIsNotNone(audio_processor.foreground_volume_placeholder_) 187 self.assertIsNotNone(audio_processor.time_shift_padding_placeholder_) 188 self.assertIsNotNone(audio_processor.time_shift_offset_placeholder_) 189 self.assertIsNotNone(audio_processor.background_data_placeholder_) 190 self.assertIsNotNone(audio_processor.background_volume_placeholder_) 191 self.assertIsNotNone(audio_processor.output_) 192 193 @test_util.run_deprecated_v1 194 def testGetDataAverage(self): 195 self._runGetDataTest("average", 10) 196 197 @test_util.run_deprecated_v1 198 def testGetDataAverageLongWindow(self): 199 self._runGetDataTest("average", 30) 200 201 @test_util.run_deprecated_v1 202 def testGetDataMfcc(self): 203 self._runGetDataTest("mfcc", 30) 204 205 @test_util.run_deprecated_v1 206 def testGetDataMicro(self): 207 self._runGetDataTest("micro", 20) 208 209 @test_util.run_deprecated_v1 210 def testGetUnprocessedData(self): 211 tmp_dir = self.get_temp_dir() 212 wav_dir = os.path.join(tmp_dir, "wavs") 213 os.mkdir(wav_dir) 214 self._saveWavFolders(wav_dir, ["a", "b", "c"], 100) 215 model_settings = { 216 "desired_samples": 160, 217 "fingerprint_size": 40, 218 "label_count": 4, 219 "window_size_samples": 100, 220 "window_stride_samples": 100, 221 "fingerprint_width": 40, 222 "preprocess": "mfcc", 223 } 224 audio_processor = input_data.AudioProcessor("", wav_dir, 10, 10, ["a", "b"], 225 10, 10, model_settings, tmp_dir) 226 result_data, result_labels = audio_processor.get_unprocessed_data( 227 10, model_settings, "training") 228 self.assertEqual(10, len(result_data)) 229 self.assertEqual(10, len(result_labels)) 230 231 @test_util.run_deprecated_v1 232 def testGetFeaturesForWav(self): 233 tmp_dir = self.get_temp_dir() 234 wav_dir = os.path.join(tmp_dir, "wavs") 235 os.mkdir(wav_dir) 236 self._saveWavFolders(wav_dir, ["a", "b", "c"], 1) 237 desired_samples = 1600 238 model_settings = { 239 "desired_samples": desired_samples, 240 "fingerprint_size": 40, 241 "label_count": 4, 242 "window_size_samples": 100, 243 "window_stride_samples": 100, 244 "fingerprint_width": 40, 245 "average_window_width": 6, 246 "preprocess": "average", 247 } 248 with self.cached_session() as sess: 249 audio_processor = input_data.AudioProcessor( 250 "", wav_dir, 10, 10, ["a", "b"], 10, 10, model_settings, tmp_dir) 251 sample_data = np.zeros([desired_samples, 1]) 252 for i in range(desired_samples): 253 phase = i % 4 254 if phase == 0: 255 sample_data[i, 0] = 0 256 elif phase == 1: 257 sample_data[i, 0] = -1 258 elif phase == 2: 259 sample_data[i, 0] = 0 260 elif phase == 3: 261 sample_data[i, 0] = 1 262 test_wav_path = os.path.join(tmp_dir, "test_wav.wav") 263 input_data.save_wav_file(test_wav_path, sample_data, 16000) 264 265 results = audio_processor.get_features_for_wav(test_wav_path, 266 model_settings, sess) 267 spectrogram = results[0] 268 self.assertEqual(1, spectrogram.shape[0]) 269 self.assertEqual(16, spectrogram.shape[1]) 270 self.assertEqual(11, spectrogram.shape[2]) 271 self.assertNear(0, spectrogram[0, 0, 0], 0.1) 272 self.assertNear(200, spectrogram[0, 0, 5], 0.1) 273 274 def testGetFeaturesRange(self): 275 model_settings = { 276 "preprocess": "average", 277 } 278 features_min, _ = input_data.get_features_range(model_settings) 279 self.assertNear(0.0, features_min, 1e-5) 280 281 def testGetMfccFeaturesRange(self): 282 model_settings = { 283 "preprocess": "mfcc", 284 } 285 features_min, features_max = input_data.get_features_range(model_settings) 286 self.assertLess(features_min, features_max) 287 288 289if __name__ == "__main__": 290 test.main() 291