1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <fstream>
17 #include <vector>
18
19 #include "tensorflow/core/framework/tensor.h"
20 #include "tensorflow/core/lib/io/path.h"
21 #include "tensorflow/core/platform/init_main.h"
22 #include "tensorflow/core/platform/logging.h"
23 #include "tensorflow/core/platform/types.h"
24 #include "tensorflow/core/public/session.h"
25 #include "tensorflow/core/util/command_line_flags.h"
26
27 // These are all common classes it's handy to reference with no namespace.
28 using tensorflow::Flag;
29 using tensorflow::int32;
30 using tensorflow::Status;
31 using tensorflow::string;
32 using tensorflow::Tensor;
33 using tensorflow::tstring;
34
35 namespace {
36
37 // Reads a model graph definition from disk, and creates a session object you
38 // can use to run it.
LoadGraph(const string & graph_file_name,std::unique_ptr<tensorflow::Session> * session)39 Status LoadGraph(const string& graph_file_name,
40 std::unique_ptr<tensorflow::Session>* session) {
41 tensorflow::GraphDef graph_def;
42 Status load_graph_status =
43 ReadBinaryProto(tensorflow::Env::Default(), graph_file_name, &graph_def);
44 if (!load_graph_status.ok()) {
45 return tensorflow::errors::NotFound("Failed to load compute graph at '",
46 graph_file_name, "'");
47 }
48 session->reset(tensorflow::NewSession(tensorflow::SessionOptions()));
49 Status session_create_status = (*session)->Create(graph_def);
50 if (!session_create_status.ok()) {
51 return session_create_status;
52 }
53 return Status::OK();
54 }
55
56 // Takes a file name, and loads a list of labels from it, one per line, and
57 // returns a vector of the strings.
ReadLabelsFile(const string & file_name,std::vector<string> * result)58 Status ReadLabelsFile(const string& file_name, std::vector<string>* result) {
59 std::ifstream file(file_name);
60 if (!file) {
61 return tensorflow::errors::NotFound("Labels file ", file_name,
62 " not found.");
63 }
64 result->clear();
65 string line;
66 while (std::getline(file, line)) {
67 result->push_back(line);
68 }
69 return Status::OK();
70 }
71
72 // Analyzes the output of the graph to retrieve the highest scores and
73 // their positions in the tensor.
GetTopLabels(const std::vector<Tensor> & outputs,int how_many_labels,Tensor * out_indices,Tensor * out_scores)74 void GetTopLabels(const std::vector<Tensor>& outputs, int how_many_labels,
75 Tensor* out_indices, Tensor* out_scores) {
76 const Tensor& unsorted_scores_tensor = outputs[0];
77 auto unsorted_scores_flat = unsorted_scores_tensor.flat<float>();
78 std::vector<std::pair<int, float>> scores;
79 scores.reserve(unsorted_scores_flat.size());
80 for (int i = 0; i < unsorted_scores_flat.size(); ++i) {
81 scores.push_back(std::pair<int, float>({i, unsorted_scores_flat(i)}));
82 }
83 std::sort(scores.begin(), scores.end(),
84 [](const std::pair<int, float>& left,
85 const std::pair<int, float>& right) {
86 return left.second > right.second;
87 });
88 scores.resize(how_many_labels);
89 Tensor sorted_indices(tensorflow::DT_INT32, {how_many_labels});
90 Tensor sorted_scores(tensorflow::DT_FLOAT, {how_many_labels});
91 for (int i = 0; i < scores.size(); ++i) {
92 sorted_indices.flat<int>()(i) = scores[i].first;
93 sorted_scores.flat<float>()(i) = scores[i].second;
94 }
95 *out_indices = sorted_indices;
96 *out_scores = sorted_scores;
97 }
98
99 } // namespace
100
main(int argc,char * argv[])101 int main(int argc, char* argv[]) {
102 string wav = "";
103 string graph = "";
104 string labels = "";
105 string input_name = "wav_data";
106 string output_name = "labels_softmax";
107 int32 how_many_labels = 3;
108 std::vector<Flag> flag_list = {
109 Flag("wav", &wav, "audio file to be identified"),
110 Flag("graph", &graph, "model to be executed"),
111 Flag("labels", &labels, "path to file containing labels"),
112 Flag("input_name", &input_name, "name of input node in model"),
113 Flag("output_name", &output_name, "name of output node in model"),
114 Flag("how_many_labels", &how_many_labels, "number of results to show"),
115 };
116 string usage = tensorflow::Flags::Usage(argv[0], flag_list);
117 const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
118 if (!parse_result) {
119 LOG(ERROR) << usage;
120 return -1;
121 }
122
123 // We need to call this to set up global state for TensorFlow.
124 tensorflow::port::InitMain(argv[0], &argc, &argv);
125 if (argc > 1) {
126 LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
127 return -1;
128 }
129
130 // First we load and initialize the model.
131 std::unique_ptr<tensorflow::Session> session;
132 Status load_graph_status = LoadGraph(graph, &session);
133 if (!load_graph_status.ok()) {
134 LOG(ERROR) << load_graph_status;
135 return -1;
136 }
137
138 std::vector<string> labels_list;
139 Status read_labels_status = ReadLabelsFile(labels, &labels_list);
140 if (!read_labels_status.ok()) {
141 LOG(ERROR) << read_labels_status;
142 return -1;
143 }
144
145 string wav_string;
146 Status read_wav_status = tensorflow::ReadFileToString(
147 tensorflow::Env::Default(), wav, &wav_string);
148 if (!read_wav_status.ok()) {
149 LOG(ERROR) << read_wav_status;
150 return -1;
151 }
152 Tensor wav_tensor(tensorflow::DT_STRING, tensorflow::TensorShape({}));
153 wav_tensor.scalar<tstring>()() = wav_string;
154
155 // Actually run the audio through the model.
156 std::vector<Tensor> outputs;
157 Status run_status =
158 session->Run({{input_name, wav_tensor}}, {output_name}, {}, &outputs);
159 if (!run_status.ok()) {
160 LOG(ERROR) << "Running model failed: " << run_status;
161 return -1;
162 }
163
164 Tensor indices;
165 Tensor scores;
166 GetTopLabels(outputs, how_many_labels, &indices, &scores);
167 tensorflow::TTypes<float>::Flat scores_flat = scores.flat<float>();
168 tensorflow::TTypes<int32>::Flat indices_flat = indices.flat<int32>();
169 for (int pos = 0; pos < how_many_labels; ++pos) {
170 const int label_index = indices_flat(pos);
171 const float score = scores_flat(pos);
172 LOG(INFO) << labels_list[label_index] << " (" << label_index
173 << "): " << score;
174 }
175
176 return 0;
177 }
178