1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for WAVE file labeling tool.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os 22 23import tensorflow as tf 24 25from tensorflow.examples.speech_commands import label_wav 26from tensorflow.python.platform import test 27 28 29class LabelWavTest(test.TestCase): 30 31 def _getWavData(self): 32 with self.cached_session(): 33 sample_data = tf.zeros([1000, 2]) 34 wav_encoder = tf.audio.encode_wav(sample_data, 16000) 35 wav_data = self.evaluate(wav_encoder) 36 return wav_data 37 38 def _saveTestWavFile(self, filename, wav_data): 39 with open(filename, "wb") as f: 40 f.write(wav_data) 41 42 def testLabelWav(self): 43 tmp_dir = self.get_temp_dir() 44 wav_data = self._getWavData() 45 wav_filename = os.path.join(tmp_dir, "wav_file.wav") 46 self._saveTestWavFile(wav_filename, wav_data) 47 input_name = "test_input" 48 output_name = "test_output" 49 graph_filename = os.path.join(tmp_dir, "test_graph.pb") 50 with tf.compat.v1.Session() as sess: 51 tf.compat.v1.placeholder(tf.string, name=input_name) 52 tf.zeros([1, 3], name=output_name) 53 with open(graph_filename, "wb") as f: 54 f.write(sess.graph.as_graph_def().SerializeToString()) 55 labels_filename = os.path.join(tmp_dir, "test_labels.txt") 56 with open(labels_filename, "w") as f: 57 f.write("a\nb\nc\n") 58 label_wav.label_wav(wav_filename, labels_filename, graph_filename, 59 input_name + ":0", output_name + ":0", 3) 60 61 62if __name__ == "__main__": 63 test.main() 64