1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for data input for speech commands.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os 22import unittest 23 24import tensorflow as tf 25 26from tensorflow.examples.speech_commands import train 27from tensorflow.python.framework import test_util 28from tensorflow.python.platform import gfile 29from tensorflow.python.platform import test 30 31 32def requires_contrib(test_method): 33 try: 34 _ = tf.contrib 35 except AttributeError: 36 test_method = unittest.skip( 37 'This test requires tf.contrib:\n `pip install tensorflow<=1.15`')( 38 test_method) 39 40 return test_method 41 42 43# Used to convert a dictionary into an object, for mocking parsed flags. 44class DictStruct(object): 45 46 def __init__(self, **entries): 47 self.__dict__.update(entries) 48 49 50class TrainTest(test.TestCase): 51 52 def _getWavData(self): 53 with self.cached_session(): 54 sample_data = tf.zeros([32000, 2]) 55 wav_encoder = tf.audio.encode_wav(sample_data, 16000) 56 wav_data = self.evaluate(wav_encoder) 57 return wav_data 58 59 def _saveTestWavFile(self, filename, wav_data): 60 with open(filename, 'wb') as f: 61 f.write(wav_data) 62 63 def _saveWavFolders(self, root_dir, labels, how_many): 64 wav_data = self._getWavData() 65 for label in labels: 66 dir_name = os.path.join(root_dir, label) 67 os.mkdir(dir_name) 68 for i in range(how_many): 69 file_path = os.path.join(dir_name, 'some_audio_%d.wav' % i) 70 self._saveTestWavFile(file_path, wav_data) 71 72 def _prepareDummyTrainingData(self): 73 tmp_dir = self.get_temp_dir() 74 wav_dir = os.path.join(tmp_dir, 'wavs') 75 os.mkdir(wav_dir) 76 self._saveWavFolders(wav_dir, ['a', 'b', 'c'], 100) 77 background_dir = os.path.join(wav_dir, '_background_noise_') 78 os.mkdir(background_dir) 79 wav_data = self._getWavData() 80 for i in range(10): 81 file_path = os.path.join(background_dir, 'background_audio_%d.wav' % i) 82 self._saveTestWavFile(file_path, wav_data) 83 return wav_dir 84 85 def _getDefaultFlags(self): 86 flags = { 87 'data_url': '', 88 'data_dir': self._prepareDummyTrainingData(), 89 'wanted_words': 'a,b,c', 90 'sample_rate': 16000, 91 'clip_duration_ms': 1000, 92 'window_size_ms': 30, 93 'window_stride_ms': 20, 94 'feature_bin_count': 40, 95 'preprocess': 'mfcc', 96 'silence_percentage': 25, 97 'unknown_percentage': 25, 98 'validation_percentage': 10, 99 'testing_percentage': 10, 100 'summaries_dir': os.path.join(self.get_temp_dir(), 'summaries'), 101 'train_dir': os.path.join(self.get_temp_dir(), 'train'), 102 'time_shift_ms': 100, 103 'how_many_training_steps': '2', 104 'learning_rate': '0.01', 105 'quantize': False, 106 'model_architecture': 'conv', 107 'check_nans': False, 108 'start_checkpoint': '', 109 'batch_size': 1, 110 'background_volume': 0.25, 111 'background_frequency': 0.8, 112 'eval_step_interval': 1, 113 'save_step_interval': 1, 114 'verbosity': tf.compat.v1.logging.INFO, 115 'optimizer': 'gradient_descent' 116 } 117 return DictStruct(**flags) 118 119 @test_util.run_deprecated_v1 120 def testTrain(self): 121 train.FLAGS = self._getDefaultFlags() 122 train.main('') 123 self.assertTrue( 124 gfile.Exists( 125 os.path.join(train.FLAGS.train_dir, 126 train.FLAGS.model_architecture + '.pbtxt'))) 127 self.assertTrue( 128 gfile.Exists( 129 os.path.join(train.FLAGS.train_dir, 130 train.FLAGS.model_architecture + '_labels.txt'))) 131 self.assertTrue( 132 gfile.Exists( 133 os.path.join(train.FLAGS.train_dir, 134 train.FLAGS.model_architecture + '.ckpt-1.meta'))) 135 136 137if __name__ == '__main__': 138 test.main() 139