1 /*
2  * Copyright (c) 2023 Lucas Dietrich <ld.adecy@gmail.com>
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  */
6 
7 #include "creds/creds.h"
8 
9 #include <errno.h>
10 #include <stdio.h>
11 #include <stdlib.h>
12 
13 #include <zephyr/net/socket.h>
14 #include <zephyr/net/dns_resolve.h>
15 #include <zephyr/net/mqtt.h>
16 #include <zephyr/net/tls_credentials.h>
17 #include <zephyr/data/json.h>
18 #include <zephyr/random/random.h>
19 #include <zephyr/logging/log.h>
20 #include "net_sample_common.h"
21 
22 
23 #if defined(CONFIG_MBEDTLS_MEMORY_DEBUG)
24 #include <mbedtls/memory_buffer_alloc.h>
25 #endif
26 
27 LOG_MODULE_REGISTER(aws, LOG_LEVEL_DBG);
28 
29 #define SNTP_SERVER "0.pool.ntp.org"
30 
31 #define AWS_BROKER_PORT CONFIG_AWS_MQTT_PORT
32 
33 #define MQTT_BUFFER_SIZE 256u
34 #define APP_BUFFER_SIZE	 4096u
35 
36 #define MAX_RETRIES	    10u
37 #define BACKOFF_EXP_BASE_MS 1000u
38 #define BACKOFF_EXP_MAX_MS  60000u
39 #define BACKOFF_CONST_MS    5000u
40 
41 static struct sockaddr_in aws_broker;
42 
43 static uint8_t rx_buffer[MQTT_BUFFER_SIZE];
44 static uint8_t tx_buffer[MQTT_BUFFER_SIZE];
45 static uint8_t buffer[APP_BUFFER_SIZE]; /* Shared between published and received messages */
46 
47 static struct mqtt_client client_ctx;
48 
49 static const char mqtt_client_name[] = CONFIG_AWS_THING_NAME;
50 
51 static uint32_t messages_received_counter;
52 static bool do_publish;	  /* Trigger client to publish */
53 static bool do_subscribe; /* Trigger client to subscribe */
54 
55 #if (CONFIG_AWS_MQTT_PORT == 443 && !defined(CONFIG_MQTT_LIB_WEBSOCKET))
56 static const char * const alpn_list[] = {"x-amzn-mqtt-ca"};
57 #endif
58 
59 #define TLS_TAG_DEVICE_CERTIFICATE 1
60 #define TLS_TAG_DEVICE_PRIVATE_KEY 1
61 #define TLS_TAG_AWS_CA_CERTIFICATE 2
62 
63 static const sec_tag_t sec_tls_tags[] = {
64 	TLS_TAG_DEVICE_CERTIFICATE,
65 	TLS_TAG_AWS_CA_CERTIFICATE,
66 };
67 
setup_credentials(void)68 static int setup_credentials(void)
69 {
70 	int ret;
71 
72 	ret = tls_credential_add(TLS_TAG_DEVICE_CERTIFICATE, TLS_CREDENTIAL_SERVER_CERTIFICATE,
73 				 public_cert, public_cert_len);
74 	if (ret < 0) {
75 		LOG_ERR("Failed to add device certificate: %d", ret);
76 		goto exit;
77 	}
78 
79 	ret = tls_credential_add(TLS_TAG_DEVICE_PRIVATE_KEY, TLS_CREDENTIAL_PRIVATE_KEY,
80 				 private_key, private_key_len);
81 	if (ret < 0) {
82 		LOG_ERR("Failed to add device private key: %d", ret);
83 		goto exit;
84 	}
85 
86 	ret = tls_credential_add(TLS_TAG_AWS_CA_CERTIFICATE, TLS_CREDENTIAL_CA_CERTIFICATE, ca_cert,
87 				 ca_cert_len);
88 	if (ret < 0) {
89 		LOG_ERR("Failed to add device private key: %d", ret);
90 		goto exit;
91 	}
92 
93 exit:
94 	return ret;
95 }
96 
subscribe_topic(void)97 static int subscribe_topic(void)
98 {
99 	int ret;
100 	struct mqtt_topic topics[] = {{
101 		.topic = {.utf8 = CONFIG_AWS_SUBSCRIBE_TOPIC,
102 			  .size = strlen(CONFIG_AWS_SUBSCRIBE_TOPIC)},
103 		.qos = CONFIG_AWS_QOS,
104 	}};
105 	const struct mqtt_subscription_list sub_list = {
106 		.list = topics,
107 		.list_count = ARRAY_SIZE(topics),
108 		.message_id = 1u,
109 	};
110 
111 	LOG_INF("Subscribing to %hu topic(s)", sub_list.list_count);
112 
113 	ret = mqtt_subscribe(&client_ctx, &sub_list);
114 	if (ret != 0) {
115 		LOG_ERR("Failed to subscribe to topics: %d", ret);
116 	}
117 
118 	return ret;
119 }
120 
publish_message(const char * topic,size_t topic_len,uint8_t * payload,size_t payload_len)121 static int publish_message(const char *topic, size_t topic_len, uint8_t *payload,
122 			   size_t payload_len)
123 {
124 	static uint32_t message_id = 1u;
125 
126 	int ret;
127 	struct mqtt_publish_param msg;
128 
129 	msg.retain_flag = 0u;
130 	msg.dup_flag = 0u;
131 	msg.message.topic.topic.utf8 = topic;
132 	msg.message.topic.topic.size = topic_len;
133 	msg.message.topic.qos = CONFIG_AWS_QOS;
134 	msg.message.payload.data = payload;
135 	msg.message.payload.len = payload_len;
136 	msg.message_id = message_id++;
137 
138 	ret = mqtt_publish(&client_ctx, &msg);
139 	if (ret != 0) {
140 		LOG_ERR("Failed to publish message: %d", ret);
141 	}
142 
143 	LOG_INF("PUBLISHED on topic \"%s\" [ id: %u qos: %u ], payload: %u B", topic,
144 		msg.message_id, msg.message.topic.qos, payload_len);
145 	LOG_HEXDUMP_DBG(payload, payload_len, "Published payload:");
146 
147 	return ret;
148 }
149 
handle_published_message(const struct mqtt_publish_param * pub)150 static ssize_t handle_published_message(const struct mqtt_publish_param *pub)
151 {
152 	int ret;
153 	size_t received = 0u;
154 	const size_t message_size = pub->message.payload.len;
155 	const bool discarded = message_size > APP_BUFFER_SIZE;
156 
157 	LOG_INF("RECEIVED on topic \"%s\" [ id: %u qos: %u ] payload: %u / %u B",
158 		(const char *)pub->message.topic.topic.utf8, pub->message_id,
159 		pub->message.topic.qos, message_size, APP_BUFFER_SIZE);
160 
161 	while (received < message_size) {
162 		uint8_t *p = discarded ? buffer : &buffer[received];
163 
164 		ret = mqtt_read_publish_payload_blocking(&client_ctx, p, APP_BUFFER_SIZE);
165 		if (ret < 0) {
166 			return ret;
167 		}
168 
169 		received += ret;
170 	}
171 
172 	if (!discarded) {
173 		LOG_HEXDUMP_DBG(buffer, MIN(message_size, 256u), "Received payload:");
174 	}
175 
176 	/* Send ACK */
177 	switch (pub->message.topic.qos) {
178 	case MQTT_QOS_1_AT_LEAST_ONCE: {
179 		struct mqtt_puback_param puback;
180 
181 		puback.message_id = pub->message_id;
182 		mqtt_publish_qos1_ack(&client_ctx, &puback);
183 	} break;
184 	case MQTT_QOS_2_EXACTLY_ONCE: /* unhandled (not supported by AWS) */
185 	case MQTT_QOS_0_AT_MOST_ONCE: /* nothing to do */
186 	default:
187 		break;
188 	}
189 
190 	return discarded ? -ENOMEM : received;
191 }
192 
mqtt_evt_type_to_str(enum mqtt_evt_type type)193 const char *mqtt_evt_type_to_str(enum mqtt_evt_type type)
194 {
195 	static const char *const types[] = {
196 		"CONNACK", "DISCONNECT", "PUBLISH", "PUBACK",	"PUBREC",
197 		"PUBREL",  "PUBCOMP",	 "SUBACK",  "UNSUBACK", "PINGRESP",
198 	};
199 
200 	return (type < ARRAY_SIZE(types)) ? types[type] : "<unknown>";
201 }
202 
mqtt_event_cb(struct mqtt_client * client,const struct mqtt_evt * evt)203 static void mqtt_event_cb(struct mqtt_client *client, const struct mqtt_evt *evt)
204 {
205 	LOG_DBG("MQTT event: %s [%u] result: %d", mqtt_evt_type_to_str(evt->type), evt->type,
206 		evt->result);
207 
208 	switch (evt->type) {
209 	case MQTT_EVT_CONNACK: {
210 		do_subscribe = true;
211 	} break;
212 
213 	case MQTT_EVT_PUBLISH: {
214 		const struct mqtt_publish_param *pub = &evt->param.publish;
215 
216 		handle_published_message(pub);
217 		messages_received_counter++;
218 #if !defined(CONFIG_AWS_TEST_SUITE_RECV_QOS1)
219 		do_publish = true;
220 #endif
221 	} break;
222 
223 	case MQTT_EVT_SUBACK: {
224 #if !defined(CONFIG_AWS_TEST_SUITE_RECV_QOS1)
225 		do_publish = true;
226 #endif
227 	} break;
228 
229 	case MQTT_EVT_PUBACK:
230 	case MQTT_EVT_DISCONNECT:
231 	case MQTT_EVT_PUBREC:
232 	case MQTT_EVT_PUBREL:
233 	case MQTT_EVT_PUBCOMP:
234 	case MQTT_EVT_PINGRESP:
235 	case MQTT_EVT_UNSUBACK:
236 	default:
237 		break;
238 	}
239 }
240 
aws_client_setup(void)241 static void aws_client_setup(void)
242 {
243 	mqtt_client_init(&client_ctx);
244 
245 	client_ctx.broker = &aws_broker;
246 	client_ctx.evt_cb = mqtt_event_cb;
247 
248 	client_ctx.client_id.utf8 = (uint8_t *)mqtt_client_name;
249 	client_ctx.client_id.size = sizeof(mqtt_client_name) - 1;
250 	client_ctx.password = NULL;
251 	client_ctx.user_name = NULL;
252 
253 	client_ctx.keepalive = CONFIG_MQTT_KEEPALIVE;
254 
255 	client_ctx.protocol_version = MQTT_VERSION_3_1_1;
256 
257 	client_ctx.rx_buf = rx_buffer;
258 	client_ctx.rx_buf_size = MQTT_BUFFER_SIZE;
259 	client_ctx.tx_buf = tx_buffer;
260 	client_ctx.tx_buf_size = MQTT_BUFFER_SIZE;
261 
262 	/* setup TLS */
263 	client_ctx.transport.type = MQTT_TRANSPORT_SECURE;
264 	struct mqtt_sec_config *const tls_config = &client_ctx.transport.tls.config;
265 
266 	tls_config->peer_verify = TLS_PEER_VERIFY_REQUIRED;
267 	tls_config->cipher_list = NULL;
268 	tls_config->sec_tag_list = sec_tls_tags;
269 	tls_config->sec_tag_count = ARRAY_SIZE(sec_tls_tags);
270 	tls_config->hostname = CONFIG_AWS_ENDPOINT;
271 	tls_config->cert_nocopy = TLS_CERT_NOCOPY_NONE;
272 #if (CONFIG_AWS_MQTT_PORT == 443 && !defined(CONFIG_MQTT_LIB_WEBSOCKET))
273 	tls_config->alpn_protocol_name_list = alpn_list;
274 	tls_config->alpn_protocol_name_count = ARRAY_SIZE(alpn_list);
275 #endif
276 }
277 
278 struct backoff_context {
279 	uint16_t retries_count;
280 	uint16_t max_retries;
281 
282 #if defined(CONFIG_AWS_EXPONENTIAL_BACKOFF)
283 	uint32_t attempt_max_backoff; /* ms */
284 	uint32_t max_backoff;	      /* ms */
285 #endif
286 };
287 
backoff_context_init(struct backoff_context * bo)288 static void backoff_context_init(struct backoff_context *bo)
289 {
290 	__ASSERT_NO_MSG(bo != NULL);
291 
292 	bo->retries_count = 0u;
293 	bo->max_retries = MAX_RETRIES;
294 
295 #if defined(CONFIG_AWS_EXPONENTIAL_BACKOFF)
296 	bo->attempt_max_backoff = BACKOFF_EXP_BASE_MS;
297 	bo->max_backoff = BACKOFF_EXP_MAX_MS;
298 #endif
299 }
300 
301 /* https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ */
backoff_get_next(struct backoff_context * bo,uint32_t * next_backoff_ms)302 static void backoff_get_next(struct backoff_context *bo, uint32_t *next_backoff_ms)
303 {
304 	__ASSERT_NO_MSG(bo != NULL);
305 	__ASSERT_NO_MSG(next_backoff_ms != NULL);
306 
307 #if defined(CONFIG_AWS_EXPONENTIAL_BACKOFF)
308 	if (bo->retries_count <= bo->max_retries) {
309 		*next_backoff_ms = sys_rand32_get() % (bo->attempt_max_backoff + 1u);
310 
311 		/* Calculate max backoff for the next attempt (~ 2**attempt) */
312 		bo->attempt_max_backoff = MIN(bo->attempt_max_backoff * 2u, bo->max_backoff);
313 		bo->retries_count++;
314 	}
315 #else
316 	*next_backoff_ms = BACKOFF_CONST_MS;
317 #endif
318 }
319 
aws_client_try_connect(void)320 static int aws_client_try_connect(void)
321 {
322 	int ret;
323 	uint32_t backoff_ms;
324 	struct backoff_context bo;
325 
326 	backoff_context_init(&bo);
327 
328 	while (bo.retries_count <= bo.max_retries) {
329 		ret = mqtt_connect(&client_ctx);
330 		if (ret == 0) {
331 			goto exit;
332 		}
333 
334 		backoff_get_next(&bo, &backoff_ms);
335 
336 		LOG_ERR("Failed to connect: %d backoff delay: %u ms", ret, backoff_ms);
337 		k_msleep(backoff_ms);
338 	}
339 
340 exit:
341 	return ret;
342 }
343 
344 struct publish_payload {
345 	uint32_t counter;
346 };
347 
348 static const struct json_obj_descr json_descr[] = {
349 	JSON_OBJ_DESCR_PRIM(struct publish_payload, counter, JSON_TOK_NUMBER),
350 };
351 
publish(void)352 static int publish(void)
353 {
354 	struct publish_payload pl = {.counter = messages_received_counter};
355 
356 	json_obj_encode_buf(json_descr, ARRAY_SIZE(json_descr), &pl, buffer, sizeof(buffer));
357 
358 	return publish_message(CONFIG_AWS_PUBLISH_TOPIC, strlen(CONFIG_AWS_PUBLISH_TOPIC), buffer,
359 			       strlen(buffer));
360 }
361 
aws_client_loop(void)362 void aws_client_loop(void)
363 {
364 	int rc;
365 	int timeout;
366 	struct pollfd fds;
367 
368 	aws_client_setup();
369 
370 	rc = aws_client_try_connect();
371 	if (rc != 0) {
372 		goto cleanup;
373 	}
374 
375 	fds.fd = client_ctx.transport.tcp.sock;
376 	fds.events = POLLIN;
377 
378 	for (;;) {
379 		timeout = mqtt_keepalive_time_left(&client_ctx);
380 		rc = poll(&fds, 1u, timeout);
381 		if (rc >= 0) {
382 			if (fds.revents & POLLIN) {
383 				rc = mqtt_input(&client_ctx);
384 				if (rc != 0) {
385 					LOG_ERR("Failed to read MQTT input: %d", rc);
386 					break;
387 				}
388 			}
389 
390 			if (fds.revents & (POLLHUP | POLLERR)) {
391 				LOG_ERR("Socket closed/error");
392 				break;
393 			}
394 
395 			rc = mqtt_live(&client_ctx);
396 			if ((rc != 0) && (rc != -EAGAIN)) {
397 				LOG_ERR("Failed to live MQTT: %d", rc);
398 				break;
399 			}
400 		} else {
401 			LOG_ERR("poll failed: %d", rc);
402 			break;
403 		}
404 
405 		if (do_publish) {
406 			do_publish = false;
407 			publish();
408 		}
409 
410 		if (do_subscribe) {
411 			do_subscribe = false;
412 			subscribe_topic();
413 		}
414 	}
415 
416 cleanup:
417 	mqtt_disconnect(&client_ctx, NULL);
418 
419 	close(fds.fd);
420 	fds.fd = -1;
421 }
422 
resolve_broker_addr(struct sockaddr_in * broker)423 static int resolve_broker_addr(struct sockaddr_in *broker)
424 {
425 	int ret;
426 	struct addrinfo *ai = NULL;
427 
428 	const struct addrinfo hints = {
429 		.ai_family = AF_INET,
430 		.ai_socktype = SOCK_STREAM,
431 		.ai_protocol = 0,
432 	};
433 	char port_string[6] = {0};
434 
435 	sprintf(port_string, "%d", AWS_BROKER_PORT);
436 	ret = getaddrinfo(CONFIG_AWS_ENDPOINT, port_string, &hints, &ai);
437 	if (ret == 0) {
438 		char addr_str[INET_ADDRSTRLEN];
439 
440 		memcpy(broker, ai->ai_addr, MIN(ai->ai_addrlen, sizeof(struct sockaddr_storage)));
441 
442 		inet_ntop(AF_INET, &broker->sin_addr, addr_str, sizeof(addr_str));
443 		LOG_INF("Resolved: %s:%u", addr_str, htons(broker->sin_port));
444 	} else {
445 		LOG_ERR("failed to resolve hostname err = %d (errno = %d)", ret, errno);
446 	}
447 
448 	freeaddrinfo(ai);
449 
450 	return ret;
451 }
452 
main(void)453 int main(void)
454 {
455 	setup_credentials();
456 
457 	wait_for_network();
458 
459 	for (;;) {
460 		resolve_broker_addr(&aws_broker);
461 
462 		aws_client_loop();
463 
464 #if defined(CONFIG_MBEDTLS_MEMORY_DEBUG)
465 		size_t cur_used, cur_blocks, max_used, max_blocks;
466 
467 		mbedtls_memory_buffer_alloc_cur_get(&cur_used, &cur_blocks);
468 		mbedtls_memory_buffer_alloc_max_get(&max_used, &max_blocks);
469 		LOG_INF("mbedTLS heap usage: MAX %u/%u (%u) CUR %u (%u)", max_used,
470 			CONFIG_MBEDTLS_HEAP_SIZE, max_blocks, cur_used, cur_blocks);
471 #endif
472 
473 		k_sleep(K_SECONDS(1));
474 	}
475 
476 	return 0;
477 }
478