1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 using System; 16 using System.Runtime.InteropServices; 17 using System.Linq; 18 19 using TfLiteInterpreter = System.IntPtr; 20 using TfLiteInterpreterOptions = System.IntPtr; 21 using TfLiteModel = System.IntPtr; 22 using TfLiteTensor = System.IntPtr; 23 24 namespace TensorFlowLite 25 { 26 /// <summary> 27 /// Simple C# bindings for the experimental TensorFlowLite C API. 28 /// </summary> 29 public class Interpreter : IDisposable 30 { 31 public struct Options: IEquatable<Options> { 32 /// <summary> 33 /// The number of CPU threads to use for the interpreter. 34 /// </summary> 35 public int threads; 36 EqualsTensorFlowLite.Interpreter.Options37 public bool Equals(Options other) { 38 return threads == other.threads; 39 } 40 } 41 42 public struct TensorInfo { 43 public string name { get; internal set; } 44 public DataType type { get; internal set; } 45 public int[] dimensions { get; internal set; } 46 public QuantizationParams quantizationParams { get; internal set; } 47 ToStringTensorFlowLite.Interpreter.TensorInfo48 public override string ToString() { 49 return string.Format("name: {0}, type: {1}, dimensions: {2}, quantizationParams: {3}", 50 name, 51 type, 52 "[" + string.Join(",", dimensions.Select(d => d.ToString()).ToArray()) + "]", 53 "{" + quantizationParams + "}"); 54 } 55 } 56 57 private TfLiteModel model = IntPtr.Zero; 58 private TfLiteInterpreter interpreter = IntPtr.Zero; 59 private TfLiteInterpreterOptions options = IntPtr.Zero; 60 Interpreter(byte[] modelData)61 public Interpreter(byte[] modelData): this(modelData, default(Options)) {} 62 Interpreter(byte[] modelData, Options options)63 public Interpreter(byte[] modelData, Options options) { 64 GCHandle modelDataHandle = GCHandle.Alloc(modelData, GCHandleType.Pinned); 65 IntPtr modelDataPtr = modelDataHandle.AddrOfPinnedObject(); 66 model = TfLiteModelCreate(modelDataPtr, modelData.Length); 67 if (model == IntPtr.Zero) throw new Exception("Failed to create TensorFlowLite Model"); 68 69 if (!options.Equals(default(Options))) { 70 this.options = TfLiteInterpreterOptionsCreate(); 71 TfLiteInterpreterOptionsSetNumThreads(this.options, options.threads); 72 } 73 74 interpreter = TfLiteInterpreterCreate(model, this.options); 75 if (interpreter == IntPtr.Zero) throw new Exception("Failed to create TensorFlowLite Interpreter"); 76 } 77 Dispose()78 public void Dispose() { 79 if (interpreter != IntPtr.Zero) TfLiteInterpreterDelete(interpreter); 80 interpreter = IntPtr.Zero; 81 if (model != IntPtr.Zero) TfLiteModelDelete(model); 82 model = IntPtr.Zero; 83 if (options != IntPtr.Zero) TfLiteInterpreterOptionsDelete(options); 84 options = IntPtr.Zero; 85 } 86 Invoke()87 public void Invoke() { 88 ThrowIfError(TfLiteInterpreterInvoke(interpreter)); 89 } 90 GetInputTensorCount()91 public int GetInputTensorCount() { 92 return TfLiteInterpreterGetInputTensorCount(interpreter); 93 } 94 SetInputTensorData(int inputTensorIndex, Array inputTensorData)95 public void SetInputTensorData(int inputTensorIndex, Array inputTensorData) { 96 GCHandle tensorDataHandle = GCHandle.Alloc(inputTensorData, GCHandleType.Pinned); 97 IntPtr tensorDataPtr = tensorDataHandle.AddrOfPinnedObject(); 98 TfLiteTensor tensor = TfLiteInterpreterGetInputTensor(interpreter, inputTensorIndex); 99 ThrowIfError(TfLiteTensorCopyFromBuffer( 100 tensor, tensorDataPtr, Buffer.ByteLength(inputTensorData))); 101 } 102 ResizeInputTensor(int inputTensorIndex, int[] inputTensorShape)103 public void ResizeInputTensor(int inputTensorIndex, int[] inputTensorShape) { 104 ThrowIfError(TfLiteInterpreterResizeInputTensor( 105 interpreter, inputTensorIndex, inputTensorShape, inputTensorShape.Length)); 106 } 107 AllocateTensors()108 public void AllocateTensors() { 109 ThrowIfError(TfLiteInterpreterAllocateTensors(interpreter)); 110 } 111 GetOutputTensorCount()112 public int GetOutputTensorCount() { 113 return TfLiteInterpreterGetOutputTensorCount(interpreter); 114 } 115 GetOutputTensorData(int outputTensorIndex, Array outputTensorData)116 public void GetOutputTensorData(int outputTensorIndex, Array outputTensorData) { 117 GCHandle tensorDataHandle = GCHandle.Alloc(outputTensorData, GCHandleType.Pinned); 118 IntPtr tensorDataPtr = tensorDataHandle.AddrOfPinnedObject(); 119 TfLiteTensor tensor = TfLiteInterpreterGetOutputTensor(interpreter, outputTensorIndex); 120 ThrowIfError(TfLiteTensorCopyToBuffer( 121 tensor, tensorDataPtr, Buffer.ByteLength(outputTensorData))); 122 } 123 GetInputTensorInfo(int index)124 public TensorInfo GetInputTensorInfo(int index) { 125 TfLiteTensor tensor = TfLiteInterpreterGetInputTensor(interpreter, index); 126 return GetTensorInfo(tensor); 127 } 128 GetOutputTensorInfo(int index)129 public TensorInfo GetOutputTensorInfo(int index) { 130 TfLiteTensor tensor = TfLiteInterpreterGetOutputTensor(interpreter, index); 131 return GetTensorInfo(tensor); 132 } 133 134 /// <summary> 135 /// Returns a string describing version information of the TensorFlow Lite library. 136 /// TensorFlow Lite uses semantic versioning. 137 /// </summary> 138 /// <returns>A string describing version information</returns> GetVersion()139 public static string GetVersion() { 140 return Marshal.PtrToStringAnsi(TfLiteVersion()); 141 } 142 GetTensorName(TfLiteTensor tensor)143 private static string GetTensorName(TfLiteTensor tensor) { 144 return Marshal.PtrToStringAnsi(TfLiteTensorName(tensor)); 145 } 146 GetTensorInfo(TfLiteTensor tensor)147 private static TensorInfo GetTensorInfo(TfLiteTensor tensor) { 148 int[] dimensions = new int[TfLiteTensorNumDims(tensor)]; 149 for (int i = 0; i < dimensions.Length; i++) { 150 dimensions[i] = TfLiteTensorDim(tensor, i); 151 } 152 return new TensorInfo() { 153 name = GetTensorName(tensor), 154 type = TfLiteTensorType(tensor), 155 dimensions = dimensions, 156 quantizationParams = TfLiteTensorQuantizationParams(tensor), 157 }; 158 } 159 ThrowIfError(int resultCode)160 private static void ThrowIfError(int resultCode) { 161 if (resultCode != 0) throw new Exception("TensorFlowLite operation failed."); 162 } 163 164 #region Externs 165 166 #if UNITY_IPHONE && !UNITY_EDITOR 167 private const string TensorFlowLibrary = "__Internal"; 168 #else 169 private const string TensorFlowLibrary = "tensorflowlite_c"; 170 #endif 171 172 public enum DataType { 173 NoType = 0, 174 Float32 = 1, 175 Int32 = 2, 176 UInt8 = 3, 177 Int64 = 4, 178 String = 5, 179 Bool = 6, 180 Int16 = 7, 181 Complex64 = 8, 182 Int8 = 9, 183 Float16 = 10, 184 } 185 186 public struct QuantizationParams { 187 public float scale; 188 public int zeroPoint; 189 ToStringTensorFlowLite.Interpreter.QuantizationParams190 public override string ToString() { 191 return string.Format("scale: {0} zeroPoint: {1}", scale, zeroPoint); 192 } 193 } 194 195 [DllImport (TensorFlowLibrary)] TfLiteVersion()196 private static extern unsafe IntPtr TfLiteVersion(); 197 198 [DllImport (TensorFlowLibrary)] TfLiteModelCreate(IntPtr model_data, int model_size)199 private static extern unsafe TfLiteInterpreter TfLiteModelCreate(IntPtr model_data, int model_size); 200 201 [DllImport (TensorFlowLibrary)] TfLiteModelDelete(TfLiteModel model)202 private static extern unsafe TfLiteInterpreter TfLiteModelDelete(TfLiteModel model); 203 204 [DllImport (TensorFlowLibrary)] TfLiteInterpreterOptionsCreate()205 private static extern unsafe TfLiteInterpreterOptions TfLiteInterpreterOptionsCreate(); 206 207 [DllImport (TensorFlowLibrary)] TfLiteInterpreterOptionsDelete(TfLiteInterpreterOptions options)208 private static extern unsafe void TfLiteInterpreterOptionsDelete(TfLiteInterpreterOptions options); 209 210 [DllImport (TensorFlowLibrary)] TfLiteInterpreterOptionsSetNumThreads( TfLiteInterpreterOptions options, int num_threads )211 private static extern unsafe void TfLiteInterpreterOptionsSetNumThreads( 212 TfLiteInterpreterOptions options, 213 int num_threads 214 ); 215 216 [DllImport (TensorFlowLibrary)] TfLiteInterpreterCreate( TfLiteModel model, TfLiteInterpreterOptions optional_options)217 private static extern unsafe TfLiteInterpreter TfLiteInterpreterCreate( 218 TfLiteModel model, 219 TfLiteInterpreterOptions optional_options); 220 221 [DllImport (TensorFlowLibrary)] TfLiteInterpreterDelete(TfLiteInterpreter interpreter)222 private static extern unsafe void TfLiteInterpreterDelete(TfLiteInterpreter interpreter); 223 224 [DllImport (TensorFlowLibrary)] TfLiteInterpreterGetInputTensorCount( TfLiteInterpreter interpreter)225 private static extern unsafe int TfLiteInterpreterGetInputTensorCount( 226 TfLiteInterpreter interpreter); 227 228 [DllImport (TensorFlowLibrary)] TfLiteInterpreterGetInputTensor( TfLiteInterpreter interpreter, int input_index)229 private static extern unsafe TfLiteTensor TfLiteInterpreterGetInputTensor( 230 TfLiteInterpreter interpreter, 231 int input_index); 232 233 [DllImport (TensorFlowLibrary)] TfLiteInterpreterResizeInputTensor( TfLiteInterpreter interpreter, int input_index, int[] input_dims, int input_dims_size)234 private static extern unsafe int TfLiteInterpreterResizeInputTensor( 235 TfLiteInterpreter interpreter, 236 int input_index, 237 int[] input_dims, 238 int input_dims_size); 239 240 [DllImport (TensorFlowLibrary)] TfLiteInterpreterAllocateTensors( TfLiteInterpreter interpreter)241 private static extern unsafe int TfLiteInterpreterAllocateTensors( 242 TfLiteInterpreter interpreter); 243 244 [DllImport (TensorFlowLibrary)] TfLiteInterpreterInvoke(TfLiteInterpreter interpreter)245 private static extern unsafe int TfLiteInterpreterInvoke(TfLiteInterpreter interpreter); 246 247 [DllImport (TensorFlowLibrary)] TfLiteInterpreterGetOutputTensorCount( TfLiteInterpreter interpreter)248 private static extern unsafe int TfLiteInterpreterGetOutputTensorCount( 249 TfLiteInterpreter interpreter); 250 251 [DllImport (TensorFlowLibrary)] TfLiteInterpreterGetOutputTensor( TfLiteInterpreter interpreter, int output_index)252 private static extern unsafe TfLiteTensor TfLiteInterpreterGetOutputTensor( 253 TfLiteInterpreter interpreter, 254 int output_index); 255 256 [DllImport (TensorFlowLibrary)] TfLiteTensorType(TfLiteTensor tensor)257 private static extern unsafe DataType TfLiteTensorType(TfLiteTensor tensor); 258 259 [DllImport (TensorFlowLibrary)] TfLiteTensorNumDims(TfLiteTensor tensor)260 private static extern unsafe int TfLiteTensorNumDims(TfLiteTensor tensor); 261 262 [DllImport (TensorFlowLibrary)] TfLiteTensorDim(TfLiteTensor tensor, int dim_index)263 private static extern int TfLiteTensorDim(TfLiteTensor tensor, int dim_index); 264 265 [DllImport (TensorFlowLibrary)] TfLiteTensorByteSize(TfLiteTensor tensor)266 private static extern uint TfLiteTensorByteSize(TfLiteTensor tensor); 267 268 [DllImport (TensorFlowLibrary)] TfLiteTensorName(TfLiteTensor tensor)269 private static extern unsafe IntPtr TfLiteTensorName(TfLiteTensor tensor); 270 271 [DllImport (TensorFlowLibrary)] TfLiteTensorQuantizationParams(TfLiteTensor tensor)272 private static extern unsafe QuantizationParams TfLiteTensorQuantizationParams(TfLiteTensor tensor); 273 274 [DllImport (TensorFlowLibrary)] TfLiteTensorCopyFromBuffer( TfLiteTensor tensor, IntPtr input_data, int input_data_size)275 private static extern unsafe int TfLiteTensorCopyFromBuffer( 276 TfLiteTensor tensor, 277 IntPtr input_data, 278 int input_data_size); 279 280 [DllImport (TensorFlowLibrary)] TfLiteTensorCopyToBuffer( TfLiteTensor tensor, IntPtr output_data, int output_data_size)281 private static extern unsafe int TfLiteTensorCopyToBuffer( 282 TfLiteTensor tensor, 283 IntPtr output_data, 284 int output_data_size); 285 286 #endregion 287 } 288 } 289