1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for data input for speech commands."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22import unittest
23
24import tensorflow as tf
25
26from tensorflow.examples.speech_commands import train
27from tensorflow.python.framework import test_util
28from tensorflow.python.platform import gfile
29from tensorflow.python.platform import test
30
31
32def requires_contrib(test_method):
33  try:
34    _ = tf.contrib
35  except AttributeError:
36    test_method = unittest.skip(
37        'This test requires tf.contrib:\n    `pip install tensorflow<=1.15`')(
38            test_method)
39
40  return test_method
41
42
43# Used to convert a dictionary into an object, for mocking parsed flags.
44class DictStruct(object):
45
46  def __init__(self, **entries):
47    self.__dict__.update(entries)
48
49
50class TrainTest(test.TestCase):
51
52  def _getWavData(self):
53    with self.cached_session():
54      sample_data = tf.zeros([32000, 2])
55      wav_encoder = tf.audio.encode_wav(sample_data, 16000)
56      wav_data = self.evaluate(wav_encoder)
57    return wav_data
58
59  def _saveTestWavFile(self, filename, wav_data):
60    with open(filename, 'wb') as f:
61      f.write(wav_data)
62
63  def _saveWavFolders(self, root_dir, labels, how_many):
64    wav_data = self._getWavData()
65    for label in labels:
66      dir_name = os.path.join(root_dir, label)
67      os.mkdir(dir_name)
68      for i in range(how_many):
69        file_path = os.path.join(dir_name, 'some_audio_%d.wav' % i)
70        self._saveTestWavFile(file_path, wav_data)
71
72  def _prepareDummyTrainingData(self):
73    tmp_dir = self.get_temp_dir()
74    wav_dir = os.path.join(tmp_dir, 'wavs')
75    os.mkdir(wav_dir)
76    self._saveWavFolders(wav_dir, ['a', 'b', 'c'], 100)
77    background_dir = os.path.join(wav_dir, '_background_noise_')
78    os.mkdir(background_dir)
79    wav_data = self._getWavData()
80    for i in range(10):
81      file_path = os.path.join(background_dir, 'background_audio_%d.wav' % i)
82      self._saveTestWavFile(file_path, wav_data)
83    return wav_dir
84
85  def _getDefaultFlags(self):
86    flags = {
87        'data_url': '',
88        'data_dir': self._prepareDummyTrainingData(),
89        'wanted_words': 'a,b,c',
90        'sample_rate': 16000,
91        'clip_duration_ms': 1000,
92        'window_size_ms': 30,
93        'window_stride_ms': 20,
94        'feature_bin_count': 40,
95        'preprocess': 'mfcc',
96        'silence_percentage': 25,
97        'unknown_percentage': 25,
98        'validation_percentage': 10,
99        'testing_percentage': 10,
100        'summaries_dir': os.path.join(self.get_temp_dir(), 'summaries'),
101        'train_dir': os.path.join(self.get_temp_dir(), 'train'),
102        'time_shift_ms': 100,
103        'how_many_training_steps': '2',
104        'learning_rate': '0.01',
105        'quantize': False,
106        'model_architecture': 'conv',
107        'check_nans': False,
108        'start_checkpoint': '',
109        'batch_size': 1,
110        'background_volume': 0.25,
111        'background_frequency': 0.8,
112        'eval_step_interval': 1,
113        'save_step_interval': 1,
114        'verbosity': tf.compat.v1.logging.INFO,
115        'optimizer': 'gradient_descent'
116    }
117    return DictStruct(**flags)
118
119  @test_util.run_deprecated_v1
120  def testTrain(self):
121    train.FLAGS = self._getDefaultFlags()
122    train.main('')
123    self.assertTrue(
124        gfile.Exists(
125            os.path.join(train.FLAGS.train_dir,
126                         train.FLAGS.model_architecture + '.pbtxt')))
127    self.assertTrue(
128        gfile.Exists(
129            os.path.join(train.FLAGS.train_dir,
130                         train.FLAGS.model_architecture + '_labels.txt')))
131    self.assertTrue(
132        gfile.Exists(
133            os.path.join(train.FLAGS.train_dir,
134                         train.FLAGS.model_architecture + '.ckpt-1.meta')))
135
136
137if __name__ == '__main__':
138  test.main()
139