1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <fstream>
17 #include <vector>
18 
19 #include "tensorflow/core/framework/tensor.h"
20 #include "tensorflow/core/lib/io/path.h"
21 #include "tensorflow/core/platform/init_main.h"
22 #include "tensorflow/core/platform/logging.h"
23 #include "tensorflow/core/platform/types.h"
24 #include "tensorflow/core/public/session.h"
25 #include "tensorflow/core/util/command_line_flags.h"
26 
27 // These are all common classes it's handy to reference with no namespace.
28 using tensorflow::Flag;
29 using tensorflow::int32;
30 using tensorflow::Status;
31 using tensorflow::string;
32 using tensorflow::Tensor;
33 using tensorflow::tstring;
34 
35 namespace {
36 
37 // Reads a model graph definition from disk, and creates a session object you
38 // can use to run it.
LoadGraph(const string & graph_file_name,std::unique_ptr<tensorflow::Session> * session)39 Status LoadGraph(const string& graph_file_name,
40                  std::unique_ptr<tensorflow::Session>* session) {
41   tensorflow::GraphDef graph_def;
42   Status load_graph_status =
43       ReadBinaryProto(tensorflow::Env::Default(), graph_file_name, &graph_def);
44   if (!load_graph_status.ok()) {
45     return tensorflow::errors::NotFound("Failed to load compute graph at '",
46                                         graph_file_name, "'");
47   }
48   session->reset(tensorflow::NewSession(tensorflow::SessionOptions()));
49   Status session_create_status = (*session)->Create(graph_def);
50   if (!session_create_status.ok()) {
51     return session_create_status;
52   }
53   return Status::OK();
54 }
55 
56 // Takes a file name, and loads a list of labels from it, one per line, and
57 // returns a vector of the strings.
ReadLabelsFile(const string & file_name,std::vector<string> * result)58 Status ReadLabelsFile(const string& file_name, std::vector<string>* result) {
59   std::ifstream file(file_name);
60   if (!file) {
61     return tensorflow::errors::NotFound("Labels file ", file_name,
62                                         " not found.");
63   }
64   result->clear();
65   string line;
66   while (std::getline(file, line)) {
67     result->push_back(line);
68   }
69   return Status::OK();
70 }
71 
72 // Analyzes the output of the graph to retrieve the highest scores and
73 // their positions in the tensor.
GetTopLabels(const std::vector<Tensor> & outputs,int how_many_labels,Tensor * out_indices,Tensor * out_scores)74 void GetTopLabels(const std::vector<Tensor>& outputs, int how_many_labels,
75                   Tensor* out_indices, Tensor* out_scores) {
76   const Tensor& unsorted_scores_tensor = outputs[0];
77   auto unsorted_scores_flat = unsorted_scores_tensor.flat<float>();
78   std::vector<std::pair<int, float>> scores;
79   scores.reserve(unsorted_scores_flat.size());
80   for (int i = 0; i < unsorted_scores_flat.size(); ++i) {
81     scores.push_back(std::pair<int, float>({i, unsorted_scores_flat(i)}));
82   }
83   std::sort(scores.begin(), scores.end(),
84             [](const std::pair<int, float>& left,
85                const std::pair<int, float>& right) {
86               return left.second > right.second;
87             });
88   scores.resize(how_many_labels);
89   Tensor sorted_indices(tensorflow::DT_INT32, {how_many_labels});
90   Tensor sorted_scores(tensorflow::DT_FLOAT, {how_many_labels});
91   for (int i = 0; i < scores.size(); ++i) {
92     sorted_indices.flat<int>()(i) = scores[i].first;
93     sorted_scores.flat<float>()(i) = scores[i].second;
94   }
95   *out_indices = sorted_indices;
96   *out_scores = sorted_scores;
97 }
98 
99 }  // namespace
100 
main(int argc,char * argv[])101 int main(int argc, char* argv[]) {
102   string wav = "";
103   string graph = "";
104   string labels = "";
105   string input_name = "wav_data";
106   string output_name = "labels_softmax";
107   int32 how_many_labels = 3;
108   std::vector<Flag> flag_list = {
109       Flag("wav", &wav, "audio file to be identified"),
110       Flag("graph", &graph, "model to be executed"),
111       Flag("labels", &labels, "path to file containing labels"),
112       Flag("input_name", &input_name, "name of input node in model"),
113       Flag("output_name", &output_name, "name of output node in model"),
114       Flag("how_many_labels", &how_many_labels, "number of results to show"),
115   };
116   string usage = tensorflow::Flags::Usage(argv[0], flag_list);
117   const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
118   if (!parse_result) {
119     LOG(ERROR) << usage;
120     return -1;
121   }
122 
123   // We need to call this to set up global state for TensorFlow.
124   tensorflow::port::InitMain(argv[0], &argc, &argv);
125   if (argc > 1) {
126     LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
127     return -1;
128   }
129 
130   // First we load and initialize the model.
131   std::unique_ptr<tensorflow::Session> session;
132   Status load_graph_status = LoadGraph(graph, &session);
133   if (!load_graph_status.ok()) {
134     LOG(ERROR) << load_graph_status;
135     return -1;
136   }
137 
138   std::vector<string> labels_list;
139   Status read_labels_status = ReadLabelsFile(labels, &labels_list);
140   if (!read_labels_status.ok()) {
141     LOG(ERROR) << read_labels_status;
142     return -1;
143   }
144 
145   string wav_string;
146   Status read_wav_status = tensorflow::ReadFileToString(
147       tensorflow::Env::Default(), wav, &wav_string);
148   if (!read_wav_status.ok()) {
149     LOG(ERROR) << read_wav_status;
150     return -1;
151   }
152   Tensor wav_tensor(tensorflow::DT_STRING, tensorflow::TensorShape({}));
153   wav_tensor.scalar<tstring>()() = wav_string;
154 
155   // Actually run the audio through the model.
156   std::vector<Tensor> outputs;
157   Status run_status =
158       session->Run({{input_name, wav_tensor}}, {output_name}, {}, &outputs);
159   if (!run_status.ok()) {
160     LOG(ERROR) << "Running model failed: " << run_status;
161     return -1;
162   }
163 
164   Tensor indices;
165   Tensor scores;
166   GetTopLabels(outputs, how_many_labels, &indices, &scores);
167   tensorflow::TTypes<float>::Flat scores_flat = scores.flat<float>();
168   tensorflow::TTypes<int32>::Flat indices_flat = indices.flat<int32>();
169   for (int pos = 0; pos < how_many_labels; ++pos) {
170     const int label_index = indices_flat(pos);
171     const float score = scores_flat(pos);
172     LOG(INFO) << labels_list[label_index] << " (" << label_index
173               << "): " << score;
174   }
175 
176   return 0;
177 }
178