1 /*
2  * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  *
16  * Copyright (c) 2025 Aerlync Labs Inc.
17  */
18 
19 #include "main_functions.h"
20 
21 #include <tensorflow/lite/micro/micro_mutable_op_resolver.h>
22 #include "constants.h"
23 #include "model.hpp"
24 #include "output_handler.hpp"
25 #include <tensorflow/lite/micro/micro_log.h>
26 #include <tensorflow/lite/micro/micro_interpreter.h>
27 #include <tensorflow/lite/micro/system_setup.h>
28 #include <tensorflow/lite/schema/schema_generated.h>
29 
30 /* Globals, used for compatibility with Arduino-style sketches. */
31 namespace {
32 	const tflite::Model *model = nullptr;
33 	tflite::MicroInterpreter *interpreter = nullptr;
34 	TfLiteTensor *input = nullptr;
35 	TfLiteTensor *output = nullptr;
36 	int inference_count = 0;
37 
38 	constexpr int kTensorArenaSize = 2000;
39 	uint8_t tensor_arena[kTensorArenaSize];
40 }  /* namespace */
41 
42 /* The name of this function is important for Arduino compatibility. */
setup(void)43 void setup(void)
44 {
45 	/* Map the model into a usable data structure. This doesn't involve any
46 	 * copying or parsing, it's a very lightweight operation.
47 	 */
48 	model = tflite::GetModel(g_model);
49 	if (model->version() != TFLITE_SCHEMA_VERSION) {
50 		MicroPrintf("Model provided is schema version %d not equal "
51 					"to supported version %d.",
52 					model->version(), TFLITE_SCHEMA_VERSION);
53 		return;
54 	}
55 
56 	/* This pulls in the operation implementations we need.
57 	 * NOLINTNEXTLINE(runtime-global-variables)
58 	 */
59 	static tflite::MicroMutableOpResolver <1> resolver;
60 	resolver.AddFullyConnected();
61 
62 	/* Build an interpreter to run the model with. */
63 	static tflite::MicroInterpreter static_interpreter(
64 		model, resolver, tensor_arena, kTensorArenaSize);
65 	interpreter = &static_interpreter;
66 
67 	/* Allocate memory from the tensor_arena for the model's tensors. */
68 	TfLiteStatus allocate_status = interpreter->AllocateTensors();
69 	if (allocate_status != kTfLiteOk) {
70 		MicroPrintf("AllocateTensors() failed");
71 		return;
72 	}
73 
74 	/* Obtain pointers to the model's input and output tensors. */
75 	input = interpreter->input(0);
76 	output = interpreter->output(0);
77 
78 	/* Keep track of how many inferences we have performed. */
79 	inference_count = 0;
80 }
81 
82 /* The name of this function is important for Arduino compatibility. */
loop(void)83 void loop(void)
84 {
85 	/* Calculate an x value to feed into the model. We compare the current
86 	 * inference_count to the number of inferences per cycle to determine
87 	 * our position within the range of possible x values the model was
88 	 * trained on, and use this to calculate a value.
89 	 */
90 	float position = static_cast < float > (inference_count) /
91 			 static_cast < float > (kInferencesPerCycle);
92 	float x = position * kXrange;
93 
94 	/* Quantize the input from floating-point to integer */
95 	int8_t x_quantized = (int8_t)round(x / input->params.scale)
96 		             + input->params.zero_point;
97 	/* Place the quantized input in the model's input tensor */
98 	input->data.int8[0] = x_quantized;
99 
100 	/* Run inference, and report any error */
101 	TfLiteStatus invoke_status = interpreter->Invoke();
102 	if (invoke_status != kTfLiteOk) {
103 		MicroPrintf("Invoke failed on x: %f\n", static_cast < double > (x));
104 		return;
105 	}
106 
107 	/* Obtain the quantized output from model's output tensor */
108 	int8_t y_quantized = output->data.int8[0];
109 	/* Dequantize the output from integer to floating-point */
110 	float y = (y_quantized - output->params.zero_point) * output->params.scale;
111 
112 	/* Output the results. A custom HandleOutput function can be implemented
113 	 * for each supported hardware target.
114 	 */
115 	HandleOutput(x, y);
116 
117 	/* Increment the inference_counter, and reset it if we have reached
118 	 * the total number per cycle
119 	 */
120 	inference_count += 1;
121 	if (inference_count >= kInferencesPerCycle) inference_count = 0;
122 }
123