1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 /// \file 16 /// Deserialization infrastructure for tflite. Provides functionality 17 /// to go from a serialized tflite model in flatbuffer format to an 18 /// in-memory representation of the model. 19 /// 20 #ifndef TENSORFLOW_LITE_MODEL_BUILDER_H_ 21 #define TENSORFLOW_LITE_MODEL_BUILDER_H_ 22 23 #include <stddef.h> 24 25 #include <memory> 26 #include <string> 27 28 #include "tensorflow/lite/allocation.h" 29 #include "tensorflow/lite/c/common.h" 30 #include "tensorflow/lite/core/api/error_reporter.h" 31 #include "tensorflow/lite/core/api/op_resolver.h" 32 #include "tensorflow/lite/core/api/verifier.h" 33 #include "tensorflow/lite/mutable_op_resolver.h" 34 #include "tensorflow/lite/schema/schema_generated.h" 35 #include "tensorflow/lite/stderr_reporter.h" 36 #include "tensorflow/lite/string_type.h" 37 38 namespace tflite { 39 40 /// An RAII object that represents a read-only tflite model, copied from disk, 41 /// or mmapped. This uses flatbuffers as the serialization format. 42 /// 43 /// NOTE: The current API requires that a FlatBufferModel instance be kept alive 44 /// by the client as long as it is in use by any dependent Interpreter 45 /// instances. As the FlatBufferModel instance is effectively immutable after 46 /// creation, the client may safely use a single model with multiple dependent 47 /// Interpreter instances, even across multiple threads (though note that each 48 /// Interpreter instance is *not* thread-safe). 49 /// 50 /// <pre><code> 51 /// using namespace tflite; 52 /// StderrReporter error_reporter; 53 /// auto model = FlatBufferModel::BuildFromFile("interesting_model.tflite", 54 /// &error_reporter); 55 /// MyOpResolver resolver; // You need to subclass OpResolver to provide 56 /// // implementations. 57 /// InterpreterBuilder builder(*model, resolver); 58 /// std::unique_ptr<Interpreter> interpreter; 59 /// if(builder(&interpreter) == kTfLiteOk) { 60 /// .. run model inference with interpreter 61 /// } 62 /// </code></pre> 63 /// 64 /// OpResolver must be defined to provide your kernel implementations to the 65 /// interpreter. This is environment specific and may consist of just the 66 /// builtin ops, or some custom operators you defined to extend tflite. 67 class FlatBufferModel { 68 public: 69 /// Builds a model based on a file. 70 /// Caller retains ownership of `error_reporter` and must ensure its lifetime 71 /// is longer than the FlatBufferModel instance. 72 /// Returns a nullptr in case of failure. 73 static std::unique_ptr<FlatBufferModel> BuildFromFile( 74 const char* filename, 75 ErrorReporter* error_reporter = DefaultErrorReporter()); 76 77 /// Verifies whether the content of the file is legit, then builds a model 78 /// based on the file. 79 /// The extra_verifier argument is an additional optional verifier for the 80 /// file contents. By default, we always check with tflite::VerifyModelBuffer. 81 /// If extra_verifier is supplied, the file contents is also checked against 82 /// the extra_verifier after the check against tflite::VerifyModelBuilder. 83 /// Caller retains ownership of `error_reporter` and must ensure its lifetime 84 /// is longer than the FlatBufferModel instance. 85 /// Returns a nullptr in case of failure. 86 static std::unique_ptr<FlatBufferModel> VerifyAndBuildFromFile( 87 const char* filename, TfLiteVerifier* extra_verifier = nullptr, 88 ErrorReporter* error_reporter = DefaultErrorReporter()); 89 90 /// Builds a model based on a pre-loaded flatbuffer. 91 /// Caller retains ownership of the buffer and should keep it alive until 92 /// the returned object is destroyed. Caller also retains ownership of 93 /// `error_reporter` and must ensure its lifetime is longer than the 94 /// FlatBufferModel instance. 95 /// Returns a nullptr in case of failure. 96 /// NOTE: this does NOT validate the buffer so it should NOT be called on 97 /// invalid/untrusted input. Use VerifyAndBuildFromBuffer in that case 98 static std::unique_ptr<FlatBufferModel> BuildFromBuffer( 99 const char* caller_owned_buffer, size_t buffer_size, 100 ErrorReporter* error_reporter = DefaultErrorReporter()); 101 102 /// Verifies whether the content of the buffer is legit, then builds a model 103 /// based on the pre-loaded flatbuffer. 104 /// The extra_verifier argument is an additional optional verifier for the 105 /// buffer. By default, we always check with tflite::VerifyModelBuffer. If 106 /// extra_verifier is supplied, the buffer is checked against the 107 /// extra_verifier after the check against tflite::VerifyModelBuilder. The 108 /// caller retains ownership of the buffer and should keep it alive until the 109 /// returned object is destroyed. Caller retains ownership of `error_reporter` 110 /// and must ensure its lifetime is longer than the FlatBufferModel instance. 111 /// Returns a nullptr in case of failure. 112 static std::unique_ptr<FlatBufferModel> VerifyAndBuildFromBuffer( 113 const char* caller_owned_buffer, size_t buffer_size, 114 TfLiteVerifier* extra_verifier = nullptr, 115 ErrorReporter* error_reporter = DefaultErrorReporter()); 116 117 /// Builds a model directly from an allocation. 118 /// Ownership of the allocation is passed to the model, but the caller 119 /// retains ownership of `error_reporter` and must ensure its lifetime is 120 /// longer than the FlatBufferModel instance. 121 /// Returns a nullptr in case of failure (e.g., the allocation is invalid). 122 static std::unique_ptr<FlatBufferModel> BuildFromAllocation( 123 std::unique_ptr<Allocation> allocation, 124 ErrorReporter* error_reporter = DefaultErrorReporter()); 125 126 /// Verifies whether the content of the allocation is legit, then builds a 127 /// model based on the provided allocation. 128 /// The extra_verifier argument is an additional optional verifier for the 129 /// buffer. By default, we always check with tflite::VerifyModelBuffer. If 130 /// extra_verifier is supplied, the buffer is checked against the 131 /// extra_verifier after the check against tflite::VerifyModelBuilder. 132 /// Ownership of the allocation is passed to the model, but the caller 133 /// retains ownership of `error_reporter` and must ensure its lifetime is 134 /// longer than the FlatBufferModel instance. 135 /// Returns a nullptr in case of failure. 136 static std::unique_ptr<FlatBufferModel> VerifyAndBuildFromAllocation( 137 std::unique_ptr<Allocation> allocation, 138 TfLiteVerifier* extra_verifier = nullptr, 139 ErrorReporter* error_reporter = DefaultErrorReporter()); 140 141 /// Builds a model directly from a flatbuffer pointer 142 /// Caller retains ownership of the buffer and should keep it alive until the 143 /// returned object is destroyed. Caller retains ownership of `error_reporter` 144 /// and must ensure its lifetime is longer than the FlatBufferModel instance. 145 /// Returns a nullptr in case of failure. 146 static std::unique_ptr<FlatBufferModel> BuildFromModel( 147 const tflite::Model* caller_owned_model_spec, 148 ErrorReporter* error_reporter = DefaultErrorReporter()); 149 150 // Releases memory or unmaps mmaped memory. 151 ~FlatBufferModel(); 152 153 // Copying or assignment is disallowed to simplify ownership semantics. 154 FlatBufferModel(const FlatBufferModel&) = delete; 155 FlatBufferModel& operator=(const FlatBufferModel&) = delete; 156 initialized()157 bool initialized() const { return model_ != nullptr; } 158 const tflite::Model* operator->() const { return model_; } GetModel()159 const tflite::Model* GetModel() const { return model_; } error_reporter()160 ErrorReporter* error_reporter() const { return error_reporter_; } allocation()161 const Allocation* allocation() const { return allocation_.get(); } 162 163 // Returns the minimum runtime version from the flatbuffer. This runtime 164 // version encodes the minimum required interpreter version to run the 165 // flatbuffer model. If the minimum version can't be determined, an empty 166 // string will be returned. 167 // Note that the returned minimum version is a lower-bound but not a strict 168 // lower-bound; ops in the graph may not have an associated runtime version, 169 // in which case the actual required runtime might be greater than the 170 // reported minimum. 171 std::string GetMinimumRuntime() const; 172 173 /// Returns true if the model identifier is correct (otherwise false and 174 /// reports an error). 175 bool CheckModelIdentifier() const; 176 177 private: 178 /// Loads a model from a given allocation. FlatBufferModel will take over the 179 /// ownership of `allocation`, and delete it in destructor. The ownership of 180 /// `error_reporter`remains with the caller and must have lifetime at least 181 /// as much as FlatBufferModel. This is to allow multiple models to use the 182 /// same ErrorReporter instance. 183 FlatBufferModel(std::unique_ptr<Allocation> allocation, 184 ErrorReporter* error_reporter = DefaultErrorReporter()); 185 186 /// Loads a model from Model flatbuffer. The `model` has to remain alive and 187 /// unchanged until the end of this flatbuffermodel's lifetime. 188 FlatBufferModel(const Model* model, ErrorReporter* error_reporter); 189 190 /// Flatbuffer traverser pointer. (Model* is a pointer that is within the 191 /// allocated memory of the data allocated by allocation's internals. 192 const tflite::Model* model_ = nullptr; 193 /// The error reporter to use for model errors and subsequent errors when 194 /// the interpreter is created 195 ErrorReporter* error_reporter_; 196 /// The allocator used for holding memory of the model. Note that this will 197 /// be null if the client provides a tflite::Model directly. 198 std::unique_ptr<Allocation> allocation_; 199 }; 200 201 } // namespace tflite 202 203 #endif // TENSORFLOW_LITE_MODEL_BUILDER_H_ 204