1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7   http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 using System;
16 using System.Runtime.InteropServices;
17 using System.Linq;
18 
19 using TfLiteInterpreter = System.IntPtr;
20 using TfLiteInterpreterOptions = System.IntPtr;
21 using TfLiteModel = System.IntPtr;
22 using TfLiteTensor = System.IntPtr;
23 
24 namespace TensorFlowLite
25 {
26   /// <summary>
27   /// Simple C# bindings for the experimental TensorFlowLite C API.
28   /// </summary>
29   public class Interpreter : IDisposable
30   {
31     public struct Options: IEquatable<Options> {
32       /// <summary>
33       /// The number of CPU threads to use for the interpreter.
34       /// </summary>
35       public int threads;
36 
EqualsTensorFlowLite.Interpreter.Options37       public bool Equals(Options other) {
38         return threads == other.threads;
39       }
40     }
41 
42     public struct TensorInfo {
43       public string name { get; internal set; }
44       public DataType type { get; internal set; }
45       public int[] dimensions { get; internal set; }
46       public QuantizationParams quantizationParams { get; internal set; }
47 
ToStringTensorFlowLite.Interpreter.TensorInfo48       public override string ToString() {
49         return string.Format("name: {0}, type: {1}, dimensions: {2}, quantizationParams: {3}",
50           name,
51           type,
52           "[" + string.Join(",", dimensions.Select(d => d.ToString()).ToArray()) + "]",
53           "{" + quantizationParams + "}");
54       }
55     }
56 
57     private TfLiteModel model = IntPtr.Zero;
58     private TfLiteInterpreter interpreter = IntPtr.Zero;
59     private TfLiteInterpreterOptions options = IntPtr.Zero;
60 
Interpreter(byte[] modelData)61     public Interpreter(byte[] modelData): this(modelData, default(Options)) {}
62 
Interpreter(byte[] modelData, Options options)63     public Interpreter(byte[] modelData, Options options) {
64       GCHandle modelDataHandle = GCHandle.Alloc(modelData, GCHandleType.Pinned);
65       IntPtr modelDataPtr = modelDataHandle.AddrOfPinnedObject();
66       model = TfLiteModelCreate(modelDataPtr, modelData.Length);
67       if (model == IntPtr.Zero) throw new Exception("Failed to create TensorFlowLite Model");
68 
69       if (!options.Equals(default(Options))) {
70         this.options = TfLiteInterpreterOptionsCreate();
71         TfLiteInterpreterOptionsSetNumThreads(this.options, options.threads);
72       }
73 
74       interpreter = TfLiteInterpreterCreate(model, this.options);
75       if (interpreter == IntPtr.Zero) throw new Exception("Failed to create TensorFlowLite Interpreter");
76     }
77 
Dispose()78     public void Dispose() {
79       if (interpreter != IntPtr.Zero) TfLiteInterpreterDelete(interpreter);
80       interpreter = IntPtr.Zero;
81       if (model != IntPtr.Zero) TfLiteModelDelete(model);
82       model = IntPtr.Zero;
83       if (options != IntPtr.Zero) TfLiteInterpreterOptionsDelete(options);
84       options = IntPtr.Zero;
85     }
86 
Invoke()87     public void Invoke() {
88       ThrowIfError(TfLiteInterpreterInvoke(interpreter));
89     }
90 
GetInputTensorCount()91     public int GetInputTensorCount() {
92       return TfLiteInterpreterGetInputTensorCount(interpreter);
93     }
94 
SetInputTensorData(int inputTensorIndex, Array inputTensorData)95     public void SetInputTensorData(int inputTensorIndex, Array inputTensorData) {
96       GCHandle tensorDataHandle = GCHandle.Alloc(inputTensorData, GCHandleType.Pinned);
97       IntPtr tensorDataPtr = tensorDataHandle.AddrOfPinnedObject();
98       TfLiteTensor tensor = TfLiteInterpreterGetInputTensor(interpreter, inputTensorIndex);
99       ThrowIfError(TfLiteTensorCopyFromBuffer(
100           tensor, tensorDataPtr, Buffer.ByteLength(inputTensorData)));
101     }
102 
ResizeInputTensor(int inputTensorIndex, int[] inputTensorShape)103     public void ResizeInputTensor(int inputTensorIndex, int[] inputTensorShape) {
104       ThrowIfError(TfLiteInterpreterResizeInputTensor(
105           interpreter, inputTensorIndex, inputTensorShape, inputTensorShape.Length));
106     }
107 
AllocateTensors()108     public void AllocateTensors() {
109       ThrowIfError(TfLiteInterpreterAllocateTensors(interpreter));
110     }
111 
GetOutputTensorCount()112     public int GetOutputTensorCount() {
113       return TfLiteInterpreterGetOutputTensorCount(interpreter);
114     }
115 
GetOutputTensorData(int outputTensorIndex, Array outputTensorData)116     public void GetOutputTensorData(int outputTensorIndex, Array outputTensorData) {
117       GCHandle tensorDataHandle = GCHandle.Alloc(outputTensorData, GCHandleType.Pinned);
118       IntPtr tensorDataPtr = tensorDataHandle.AddrOfPinnedObject();
119       TfLiteTensor tensor = TfLiteInterpreterGetOutputTensor(interpreter, outputTensorIndex);
120       ThrowIfError(TfLiteTensorCopyToBuffer(
121           tensor, tensorDataPtr, Buffer.ByteLength(outputTensorData)));
122     }
123 
GetInputTensorInfo(int index)124     public TensorInfo GetInputTensorInfo(int index) {
125       TfLiteTensor tensor = TfLiteInterpreterGetInputTensor(interpreter, index);
126       return GetTensorInfo(tensor);
127     }
128 
GetOutputTensorInfo(int index)129     public TensorInfo GetOutputTensorInfo(int index) {
130       TfLiteTensor tensor = TfLiteInterpreterGetOutputTensor(interpreter, index);
131       return GetTensorInfo(tensor);
132     }
133 
134     /// <summary>
135     /// Returns a string describing version information of the TensorFlow Lite library.
136     /// TensorFlow Lite uses semantic versioning.
137     /// </summary>
138     /// <returns>A string describing version information</returns>
GetVersion()139     public static string GetVersion() {
140       return Marshal.PtrToStringAnsi(TfLiteVersion());
141     }
142 
GetTensorName(TfLiteTensor tensor)143     private static string GetTensorName(TfLiteTensor tensor) {
144       return Marshal.PtrToStringAnsi(TfLiteTensorName(tensor));
145     }
146 
GetTensorInfo(TfLiteTensor tensor)147     private static TensorInfo GetTensorInfo(TfLiteTensor tensor) {
148       int[] dimensions = new int[TfLiteTensorNumDims(tensor)];
149       for (int i = 0; i < dimensions.Length; i++) {
150         dimensions[i] = TfLiteTensorDim(tensor, i);
151       }
152       return new TensorInfo() {
153         name = GetTensorName(tensor),
154         type = TfLiteTensorType(tensor),
155         dimensions = dimensions,
156         quantizationParams = TfLiteTensorQuantizationParams(tensor),
157       };
158     }
159 
ThrowIfError(int resultCode)160     private static void ThrowIfError(int resultCode) {
161       if (resultCode != 0) throw new Exception("TensorFlowLite operation failed.");
162     }
163 
164     #region Externs
165 
166     #if UNITY_IPHONE && !UNITY_EDITOR
167     private const string TensorFlowLibrary = "__Internal";
168 #else
169     private const string TensorFlowLibrary = "tensorflowlite_c";
170 #endif
171 
172     public enum DataType {
173       NoType = 0,
174       Float32 = 1,
175       Int32 = 2,
176       UInt8 = 3,
177       Int64 = 4,
178       String = 5,
179       Bool = 6,
180       Int16 = 7,
181       Complex64 = 8,
182       Int8 = 9,
183       Float16 = 10,
184     }
185 
186     public struct QuantizationParams {
187       public float scale;
188       public int zeroPoint;
189 
ToStringTensorFlowLite.Interpreter.QuantizationParams190       public override string ToString() {
191         return string.Format("scale: {0} zeroPoint: {1}", scale, zeroPoint);
192       }
193     }
194 
195     [DllImport (TensorFlowLibrary)]
TfLiteVersion()196     private static extern unsafe IntPtr TfLiteVersion();
197 
198     [DllImport (TensorFlowLibrary)]
TfLiteModelCreate(IntPtr model_data, int model_size)199     private static extern unsafe TfLiteInterpreter TfLiteModelCreate(IntPtr model_data, int model_size);
200 
201     [DllImport (TensorFlowLibrary)]
TfLiteModelDelete(TfLiteModel model)202     private static extern unsafe TfLiteInterpreter TfLiteModelDelete(TfLiteModel model);
203 
204     [DllImport (TensorFlowLibrary)]
TfLiteInterpreterOptionsCreate()205     private static extern unsafe TfLiteInterpreterOptions TfLiteInterpreterOptionsCreate();
206 
207     [DllImport (TensorFlowLibrary)]
TfLiteInterpreterOptionsDelete(TfLiteInterpreterOptions options)208     private static extern unsafe void TfLiteInterpreterOptionsDelete(TfLiteInterpreterOptions options);
209 
210     [DllImport (TensorFlowLibrary)]
TfLiteInterpreterOptionsSetNumThreads( TfLiteInterpreterOptions options, int num_threads )211     private static extern unsafe void TfLiteInterpreterOptionsSetNumThreads(
212         TfLiteInterpreterOptions options,
213         int num_threads
214     );
215 
216     [DllImport (TensorFlowLibrary)]
TfLiteInterpreterCreate( TfLiteModel model, TfLiteInterpreterOptions optional_options)217     private static extern unsafe TfLiteInterpreter TfLiteInterpreterCreate(
218         TfLiteModel model,
219         TfLiteInterpreterOptions optional_options);
220 
221     [DllImport (TensorFlowLibrary)]
TfLiteInterpreterDelete(TfLiteInterpreter interpreter)222     private static extern unsafe void TfLiteInterpreterDelete(TfLiteInterpreter interpreter);
223 
224     [DllImport (TensorFlowLibrary)]
TfLiteInterpreterGetInputTensorCount( TfLiteInterpreter interpreter)225     private static extern unsafe int TfLiteInterpreterGetInputTensorCount(
226         TfLiteInterpreter interpreter);
227 
228     [DllImport (TensorFlowLibrary)]
TfLiteInterpreterGetInputTensor( TfLiteInterpreter interpreter, int input_index)229     private static extern unsafe TfLiteTensor TfLiteInterpreterGetInputTensor(
230         TfLiteInterpreter interpreter,
231         int input_index);
232 
233     [DllImport (TensorFlowLibrary)]
TfLiteInterpreterResizeInputTensor( TfLiteInterpreter interpreter, int input_index, int[] input_dims, int input_dims_size)234     private static extern unsafe int TfLiteInterpreterResizeInputTensor(
235         TfLiteInterpreter interpreter,
236         int input_index,
237         int[] input_dims,
238         int input_dims_size);
239 
240     [DllImport (TensorFlowLibrary)]
TfLiteInterpreterAllocateTensors( TfLiteInterpreter interpreter)241     private static extern unsafe int TfLiteInterpreterAllocateTensors(
242         TfLiteInterpreter interpreter);
243 
244     [DllImport (TensorFlowLibrary)]
TfLiteInterpreterInvoke(TfLiteInterpreter interpreter)245     private static extern unsafe int TfLiteInterpreterInvoke(TfLiteInterpreter interpreter);
246 
247     [DllImport (TensorFlowLibrary)]
TfLiteInterpreterGetOutputTensorCount( TfLiteInterpreter interpreter)248     private static extern unsafe int TfLiteInterpreterGetOutputTensorCount(
249         TfLiteInterpreter interpreter);
250 
251     [DllImport (TensorFlowLibrary)]
TfLiteInterpreterGetOutputTensor( TfLiteInterpreter interpreter, int output_index)252     private static extern unsafe TfLiteTensor TfLiteInterpreterGetOutputTensor(
253         TfLiteInterpreter interpreter,
254         int output_index);
255 
256     [DllImport (TensorFlowLibrary)]
TfLiteTensorType(TfLiteTensor tensor)257     private static extern unsafe DataType TfLiteTensorType(TfLiteTensor tensor);
258 
259     [DllImport (TensorFlowLibrary)]
TfLiteTensorNumDims(TfLiteTensor tensor)260     private static extern unsafe int TfLiteTensorNumDims(TfLiteTensor tensor);
261 
262     [DllImport (TensorFlowLibrary)]
TfLiteTensorDim(TfLiteTensor tensor, int dim_index)263     private static extern int TfLiteTensorDim(TfLiteTensor tensor, int dim_index);
264 
265     [DllImport (TensorFlowLibrary)]
TfLiteTensorByteSize(TfLiteTensor tensor)266     private static extern uint TfLiteTensorByteSize(TfLiteTensor tensor);
267 
268     [DllImport (TensorFlowLibrary)]
TfLiteTensorName(TfLiteTensor tensor)269     private static extern unsafe IntPtr TfLiteTensorName(TfLiteTensor tensor);
270 
271     [DllImport (TensorFlowLibrary)]
TfLiteTensorQuantizationParams(TfLiteTensor tensor)272     private static extern unsafe QuantizationParams TfLiteTensorQuantizationParams(TfLiteTensor tensor);
273 
274     [DllImport (TensorFlowLibrary)]
TfLiteTensorCopyFromBuffer( TfLiteTensor tensor, IntPtr input_data, int input_data_size)275     private static extern unsafe int TfLiteTensorCopyFromBuffer(
276         TfLiteTensor tensor,
277         IntPtr input_data,
278         int input_data_size);
279 
280     [DllImport (TensorFlowLibrary)]
TfLiteTensorCopyToBuffer( TfLiteTensor tensor, IntPtr output_data, int output_data_size)281     private static extern unsafe int TfLiteTensorCopyToBuffer(
282         TfLiteTensor tensor,
283         IntPtr output_data,
284         int output_data_size);
285 
286     #endregion
287   }
288 }
289