1 /*
2 * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *
16 * Copyright (c) 2025 Aerlync Labs Inc.
17 */
18
19 #include "main_functions.h"
20
21 #include <tensorflow/lite/micro/micro_mutable_op_resolver.h>
22 #include "constants.h"
23 #include "model.hpp"
24 #include "output_handler.hpp"
25 #include <tensorflow/lite/micro/micro_log.h>
26 #include <tensorflow/lite/micro/micro_interpreter.h>
27 #include <tensorflow/lite/micro/system_setup.h>
28 #include <tensorflow/lite/schema/schema_generated.h>
29
30 /* Globals, used for compatibility with Arduino-style sketches. */
31 namespace {
32 const tflite::Model *model = nullptr;
33 tflite::MicroInterpreter *interpreter = nullptr;
34 TfLiteTensor *input = nullptr;
35 TfLiteTensor *output = nullptr;
36 int inference_count = 0;
37
38 constexpr int kTensorArenaSize = 2000;
39 uint8_t tensor_arena[kTensorArenaSize];
40 } /* namespace */
41
42 /* The name of this function is important for Arduino compatibility. */
setup(void)43 void setup(void)
44 {
45 /* Map the model into a usable data structure. This doesn't involve any
46 * copying or parsing, it's a very lightweight operation.
47 */
48 model = tflite::GetModel(g_model);
49 if (model->version() != TFLITE_SCHEMA_VERSION) {
50 MicroPrintf("Model provided is schema version %d not equal "
51 "to supported version %d.",
52 model->version(), TFLITE_SCHEMA_VERSION);
53 return;
54 }
55
56 /* This pulls in the operation implementations we need.
57 * NOLINTNEXTLINE(runtime-global-variables)
58 */
59 static tflite::MicroMutableOpResolver <1> resolver;
60 resolver.AddFullyConnected();
61
62 /* Build an interpreter to run the model with. */
63 static tflite::MicroInterpreter static_interpreter(
64 model, resolver, tensor_arena, kTensorArenaSize);
65 interpreter = &static_interpreter;
66
67 /* Allocate memory from the tensor_arena for the model's tensors. */
68 TfLiteStatus allocate_status = interpreter->AllocateTensors();
69 if (allocate_status != kTfLiteOk) {
70 MicroPrintf("AllocateTensors() failed");
71 return;
72 }
73
74 /* Obtain pointers to the model's input and output tensors. */
75 input = interpreter->input(0);
76 output = interpreter->output(0);
77
78 /* Keep track of how many inferences we have performed. */
79 inference_count = 0;
80 }
81
82 /* The name of this function is important for Arduino compatibility. */
loop(void)83 void loop(void)
84 {
85 /* Calculate an x value to feed into the model. We compare the current
86 * inference_count to the number of inferences per cycle to determine
87 * our position within the range of possible x values the model was
88 * trained on, and use this to calculate a value.
89 */
90 float position = static_cast < float > (inference_count) /
91 static_cast < float > (kInferencesPerCycle);
92 float x = position * kXrange;
93
94 /* Quantize the input from floating-point to integer */
95 int8_t x_quantized = (int8_t)round(x / input->params.scale)
96 + input->params.zero_point;
97 /* Place the quantized input in the model's input tensor */
98 input->data.int8[0] = x_quantized;
99
100 /* Run inference, and report any error */
101 TfLiteStatus invoke_status = interpreter->Invoke();
102 if (invoke_status != kTfLiteOk) {
103 MicroPrintf("Invoke failed on x: %f\n", static_cast < double > (x));
104 return;
105 }
106
107 /* Obtain the quantized output from model's output tensor */
108 int8_t y_quantized = output->data.int8[0];
109 /* Dequantize the output from integer to floating-point */
110 float y = (y_quantized - output->params.zero_point) * output->params.scale;
111
112 /* Output the results. A custom HandleOutput function can be implemented
113 * for each supported hardware target.
114 */
115 HandleOutput(x, y);
116
117 /* Increment the inference_counter, and reset it if we have reached
118 * the total number per cycle
119 */
120 inference_count += 1;
121 if (inference_count >= kInferencesPerCycle) inference_count = 0;
122 }
123