1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for data input for speech commands."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22
23import tensorflow as tf
24
25from tensorflow.examples.speech_commands import wav_to_features
26from tensorflow.python.framework import test_util
27from tensorflow.python.platform import test
28
29
30class WavToFeaturesTest(test.TestCase):
31
32  def _getWavData(self):
33    with self.cached_session():
34      sample_data = tf.zeros([32000, 2])
35      wav_encoder = tf.audio.encode_wav(sample_data, 16000)
36      wav_data = self.evaluate(wav_encoder)
37    return wav_data
38
39  def _saveTestWavFile(self, filename, wav_data):
40    with open(filename, "wb") as f:
41      f.write(wav_data)
42
43  def _saveWavFolders(self, root_dir, labels, how_many):
44    wav_data = self._getWavData()
45    for label in labels:
46      dir_name = os.path.join(root_dir, label)
47      os.mkdir(dir_name)
48      for i in range(how_many):
49        file_path = os.path.join(dir_name, "some_audio_%d.wav" % i)
50        self._saveTestWavFile(file_path, wav_data)
51
52  @test_util.run_deprecated_v1
53  def testWavToFeatures(self):
54    tmp_dir = self.get_temp_dir()
55    wav_dir = os.path.join(tmp_dir, "wavs")
56    os.mkdir(wav_dir)
57    self._saveWavFolders(wav_dir, ["a", "b", "c"], 100)
58    input_file_path = os.path.join(tmp_dir, "input.wav")
59    output_file_path = os.path.join(tmp_dir, "output.c")
60    wav_data = self._getWavData()
61    self._saveTestWavFile(input_file_path, wav_data)
62    wav_to_features.wav_to_features(16000, 1000, 10, 10, 40, True, "average",
63                                    input_file_path, output_file_path)
64    with open(output_file_path, "rb") as f:
65      content = f.read()
66      self.assertIn(b"const unsigned char g_input_data", content)
67
68  @test_util.run_deprecated_v1
69  def testWavToFeaturesMicro(self):
70    tmp_dir = self.get_temp_dir()
71    wav_dir = os.path.join(tmp_dir, "wavs")
72    os.mkdir(wav_dir)
73    self._saveWavFolders(wav_dir, ["a", "b", "c"], 100)
74    input_file_path = os.path.join(tmp_dir, "input.wav")
75    output_file_path = os.path.join(tmp_dir, "output.c")
76    wav_data = self._getWavData()
77    self._saveTestWavFile(input_file_path, wav_data)
78    wav_to_features.wav_to_features(16000, 1000, 10, 10, 40, True, "micro",
79                                    input_file_path, output_file_path)
80    with open(output_file_path, "rb") as f:
81      content = f.read()
82      self.assertIn(b"const unsigned char g_input_data", content)
83
84
85if __name__ == "__main__":
86  test.main()
87