1 /*
2 * Copyright (c) 2020 Friedt Professional Engineering Services, Inc
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 */
6
7 #include <zephyr/logging/log.h>
8 #include <zephyr/net/net_core.h>
9 #include <zephyr/net/net_ip.h>
10 #include <zephyr/net/socket.h>
11 #include <zephyr/net/tls_credentials.h>
12 #include <zephyr/posix/unistd.h>
13 #include <zephyr/sys/util.h>
14 #include <zephyr/ztest.h>
15
16 #include <mbedtls/x509.h>
17 #include <mbedtls/x509_crt.h>
18
19 LOG_MODULE_REGISTER(tls_test, CONFIG_NET_SOCKETS_LOG_LEVEL);
20
21 /**
22 * @brief An encrypted message to pass between server and client.
23 *
24 * The answer to life, the universe, and everything.
25 *
26 * See also <a href="https://en.wikipedia.org/wiki/42_(number)#The_Hitchhiker's_Guide_to_the_Galaxy">42</a>.
27 */
28 #define SECRET "forty-two"
29
30 /**
31 * @brief Size of the encrypted message passed between server and client.
32 */
33 #define SECRET_SIZE (sizeof(SECRET) - 1)
34
35 /** @brief Stack size for the server thread */
36 #define STACK_SIZE 8192
37
38 #define MY_IPV4_ADDR "127.0.0.1"
39
40 /** @brief TCP port for the server thread */
41 #define PORT 4242
42
43 /** @brief arbitrary timeout value in ms */
44 #define TIMEOUT 1000
45
46 /**
47 * @brief Application-dependent TLS credential identifiers
48 *
49 * Since both the server and client exist in the same test
50 * application in this case, both the server and client credentials
51 * are loaded together.
52 *
53 * The server would normally need
54 * - SERVER_CERTIFICATE_TAG (for both public and private keys)
55 * - CA_CERTIFICATE_TAG (only when client authentication is required)
56 *
57 * The client would normally load
58 * - CA_CERTIFICATE_TAG (always required, to verify the server)
59 * - CLIENT_CERTIFICATE_TAG (for both public and private keys, only when
60 * client authentication is required)
61 */
62 enum tls_tag {
63 /** The Certificate Authority public key */
64 CA_CERTIFICATE_TAG,
65 /** Used for both the public and private server keys */
66 SERVER_CERTIFICATE_TAG,
67 /** Used for both the public and private client keys */
68 CLIENT_CERTIFICATE_TAG,
69 };
70
71 /** @brief synchronization object for server & client threads */
72 static struct k_sem server_sem;
73
74 /** @brief The server thread stack */
75 static K_THREAD_STACK_DEFINE(server_stack, STACK_SIZE);
76 /** @brief the server thread object */
77 static struct k_thread server_thread;
78
79 #ifdef CONFIG_TLS_CREDENTIALS
80 /**
81 * @brief The Certificate Authority (CA) Certificate
82 *
83 * The client needs the CA cert to verify the server public key. TLS client
84 * sockets are always required to verify the server public key.
85 *
86 * Additionally, when the peer verification mode is
87 * @ref TLS_PEER_VERIFY_OPTIONAL or @ref TLS_PEER_VERIFY_REQUIRED, then
88 * the server also needs the CA cert in order to verify the client. This
89 * type of configuration is often referred to as *mutual authentication*.
90 */
91 static const unsigned char ca[] = {
92 #include "ca.inc"
93 };
94
95 /**
96 * @brief The Server Certificate
97 *
98 * This is the public key of the server.
99 */
100 static const unsigned char server[] = {
101 #include "server.inc"
102 };
103
104 /**
105 * @brief The Server Private Key
106 *
107 * This is the private key of the server.
108 */
109 static const unsigned char server_privkey[] = {
110 #include "server_privkey.inc"
111 };
112
113 /**
114 * @brief The Client Certificate
115 *
116 * This is the public key of the client.
117 */
118 static const unsigned char client[] = {
119 #include "client.inc"
120 };
121
122 /**
123 * @brief The Client Private Key
124 *
125 * This is the private key of the client.
126 */
127 static const unsigned char client_privkey[] = {
128 #include "client_privkey.inc"
129 };
130 #else /* CONFIG_TLS_CREDENTIALS */
131 #define ca NULL
132 #define server NULL
133 #define server_privkey NULL
134 #define client NULL
135 #define client_privkey NULL
136 #endif /* CONFIG_TLS_CREDENTIALS */
137
138 /**
139 * @brief The server thread function
140 *
141 * This function simply accepts a client connection and
142 * echoes the first @ref SECRET_SIZE bytes of the first
143 * packet. After that, the server is closed and connections
144 * are no longer accepted.
145 *
146 * @param arg0 a pointer to the int representing the server file descriptor
147 * @param arg1 ignored
148 * @param arg2 ignored
149 */
server_thread_fn(void * arg0,void * arg1,void * arg2)150 static void server_thread_fn(void *arg0, void *arg1, void *arg2)
151 {
152 const int server_fd = POINTER_TO_INT(arg0);
153 const int echo = POINTER_TO_INT(arg1);
154 const int expect_failure = POINTER_TO_INT(arg2);
155
156 int r;
157 int client_fd;
158 socklen_t addrlen;
159 char addrstr[INET_ADDRSTRLEN];
160 struct sockaddr_in sa;
161 char *addrstrp;
162
163 k_thread_name_set(k_current_get(), "server");
164
165 NET_DBG("Server thread running");
166
167 memset(&sa, 0, sizeof(sa));
168 addrlen = sizeof(sa);
169
170 NET_DBG("Accepting client connection..");
171 k_sem_give(&server_sem);
172 r = accept(server_fd, (struct sockaddr *)&sa, &addrlen);
173 if (expect_failure) {
174 zassert_equal(r, -1, "accept() should've failed");
175 return;
176 }
177 zassert_not_equal(r, -1, "accept() failed (%d)", r);
178 client_fd = r;
179
180 memset(addrstr, '\0', sizeof(addrstr));
181 addrstrp = (char *)inet_ntop(AF_INET, &sa.sin_addr,
182 addrstr, sizeof(addrstr));
183 zassert_not_equal(addrstrp, NULL, "inet_ntop() failed (%d)", errno);
184
185 NET_DBG("accepted connection from [%s]:%d as fd %d",
186 addrstr, ntohs(sa.sin_port), client_fd);
187
188 if (echo) {
189 NET_DBG("calling recv()");
190 r = recv(client_fd, addrstr, sizeof(addrstr), 0);
191 zassert_not_equal(r, -1, "recv() failed (%d)", errno);
192 zassert_equal(r, SECRET_SIZE, "expected: %zu actual: %d",
193 SECRET_SIZE, r);
194
195 NET_DBG("calling send()");
196 r = send(client_fd, SECRET, SECRET_SIZE, 0);
197 zassert_not_equal(r, -1, "send() failed (%d)", errno);
198 zassert_equal(r, SECRET_SIZE, "expected: %zu actual: %d",
199 SECRET_SIZE, r);
200 }
201
202 NET_DBG("closing client fd");
203 r = close(client_fd);
204 zassert_not_equal(r, -1, "close() failed on the server fd (%d)", errno);
205 }
206
test_configure_server(k_tid_t * server_thread_id,int peer_verify,int echo,int expect_failure)207 static int test_configure_server(k_tid_t *server_thread_id, int peer_verify,
208 int echo, int expect_failure)
209 {
210 static const sec_tag_t server_tag_list_verify_none[] = {
211 SERVER_CERTIFICATE_TAG,
212 };
213
214 static const sec_tag_t server_tag_list_verify[] = {
215 CA_CERTIFICATE_TAG,
216 SERVER_CERTIFICATE_TAG,
217 };
218
219 char addrstr[INET_ADDRSTRLEN];
220 const sec_tag_t *sec_tag_list;
221 size_t sec_tag_list_size;
222 struct sockaddr_in sa;
223 const int yes = true;
224 char *addrstrp;
225 int server_fd;
226 int r;
227
228 k_sem_init(&server_sem, 0, 1);
229
230 NET_DBG("Creating server socket");
231 r = socket(AF_INET, SOCK_STREAM, IPPROTO_TLS_1_2);
232 zassert_not_equal(r, -1, "failed to create server socket (%d)", errno);
233 server_fd = r;
234
235 r = setsockopt(server_fd, SOL_SOCKET, SO_REUSEADDR, &yes, sizeof(yes));
236 zassert_not_equal(r, -1, "failed to set SO_REUSEADDR (%d)", errno);
237
238 switch (peer_verify) {
239 case TLS_PEER_VERIFY_NONE:
240 sec_tag_list = server_tag_list_verify_none;
241 sec_tag_list_size = sizeof(server_tag_list_verify_none);
242 break;
243 case TLS_PEER_VERIFY_OPTIONAL:
244 case TLS_PEER_VERIFY_REQUIRED:
245 sec_tag_list = server_tag_list_verify;
246 sec_tag_list_size = sizeof(server_tag_list_verify);
247
248 r = setsockopt(server_fd, SOL_TLS, TLS_PEER_VERIFY,
249 &peer_verify, sizeof(peer_verify));
250 zassert_not_equal(r, -1, "failed to set TLS_PEER_VERIFY (%d)",
251 errno);
252 break;
253 default:
254 zassert_true(false, "unrecognized TLS peer verify type %d",
255 peer_verify);
256 return -1;
257 }
258
259 r = setsockopt(server_fd, SOL_TLS, TLS_SEC_TAG_LIST,
260 sec_tag_list, sec_tag_list_size);
261 zassert_not_equal(r, -1, "failed to set TLS_SEC_TAG_LIST (%d)", errno);
262
263 r = setsockopt(server_fd, SOL_TLS, TLS_HOSTNAME, "localhost",
264 sizeof("localhost"));
265 zassert_not_equal(r, -1, "failed to set TLS_HOSTNAME (%d)", errno);
266
267 memset(&sa, 0, sizeof(sa));
268 /* The server listens on all network interfaces */
269 sa.sin_addr.s_addr = INADDR_ANY;
270 sa.sin_family = AF_INET;
271 sa.sin_port = htons(PORT);
272
273 r = bind(server_fd, (struct sockaddr *)&sa, sizeof(sa));
274 zassert_not_equal(r, -1, "failed to bind (%d)", errno);
275
276 r = listen(server_fd, 1);
277 zassert_not_equal(r, -1, "failed to listen (%d)", errno);
278
279 memset(addrstr, '\0', sizeof(addrstr));
280 addrstrp = (char *)inet_ntop(AF_INET, &sa.sin_addr,
281 addrstr, sizeof(addrstr));
282 zassert_not_equal(addrstrp, NULL, "inet_ntop() failed (%d)", errno);
283
284 NET_DBG("listening on [%s]:%d as fd %d",
285 addrstr, ntohs(sa.sin_port), server_fd);
286
287 NET_DBG("Creating server thread");
288 *server_thread_id = k_thread_create(&server_thread, server_stack,
289 STACK_SIZE, server_thread_fn,
290 INT_TO_POINTER(server_fd),
291 INT_TO_POINTER(echo),
292 INT_TO_POINTER(expect_failure),
293 K_PRIO_PREEMPT(8), 0, K_NO_WAIT);
294
295 r = k_sem_take(&server_sem, K_MSEC(TIMEOUT));
296 zassert_equal(0, r, "failed to synchronize with server thread (%d)", r);
297
298 return server_fd;
299 }
300
test_configure_client(struct sockaddr_in * sa,bool own_cert,const char * hostname)301 static int test_configure_client(struct sockaddr_in *sa, bool own_cert,
302 const char *hostname)
303 {
304 static const sec_tag_t client_tag_list_verify_none[] = {
305 CA_CERTIFICATE_TAG,
306 };
307
308 static const sec_tag_t client_tag_list_verify[] = {
309 CA_CERTIFICATE_TAG,
310 CLIENT_CERTIFICATE_TAG,
311 };
312
313 char addrstr[INET_ADDRSTRLEN];
314 const sec_tag_t *sec_tag_list;
315 size_t sec_tag_list_size;
316 char *addrstrp;
317 int client_fd;
318 int r;
319
320 k_thread_name_set(k_current_get(), "client");
321
322 NET_DBG("Creating client socket");
323 r = socket(AF_INET, SOCK_STREAM, IPPROTO_TLS_1_2);
324 zassert_not_equal(r, -1, "failed to create client socket (%d)", errno);
325 client_fd = r;
326
327 if (own_cert) {
328 sec_tag_list = client_tag_list_verify;
329 sec_tag_list_size = sizeof(client_tag_list_verify);
330 } else {
331 sec_tag_list = client_tag_list_verify_none;
332 sec_tag_list_size = sizeof(client_tag_list_verify_none);
333 }
334
335 r = setsockopt(client_fd, SOL_TLS, TLS_SEC_TAG_LIST,
336 sec_tag_list, sec_tag_list_size);
337 zassert_not_equal(r, -1, "failed to set TLS_SEC_TAG_LIST (%d)", errno);
338
339 r = setsockopt(client_fd, SOL_TLS, TLS_HOSTNAME, hostname,
340 strlen(hostname) + 1);
341 zassert_not_equal(r, -1, "failed to set TLS_HOSTNAME (%d)", errno);
342
343 sa->sin_family = AF_INET;
344 sa->sin_port = htons(PORT);
345 r = inet_pton(AF_INET, MY_IPV4_ADDR, &sa->sin_addr.s_addr);
346 zassert_not_equal(-1, r, "inet_pton() failed (%d)", errno);
347 zassert_not_equal(0, r, "%s is not a valid IPv4 address", MY_IPV4_ADDR);
348 zassert_equal(1, r, "inet_pton() failed to convert %s", MY_IPV4_ADDR);
349
350 memset(addrstr, '\0', sizeof(addrstr));
351 addrstrp = (char *)inet_ntop(AF_INET, &sa->sin_addr,
352 addrstr, sizeof(addrstr));
353 zassert_not_equal(addrstrp, NULL, "inet_ntop() failed (%d)", errno);
354
355 NET_DBG("connecting to [%s]:%d with fd %d",
356 addrstr, ntohs(sa->sin_port), client_fd);
357
358 return client_fd;
359 }
test_shutdown(int client_fd,int server_fd,k_tid_t server_thread_id)360 static void test_shutdown(int client_fd, int server_fd, k_tid_t server_thread_id)
361 {
362 int r;
363
364 NET_DBG("closing client fd");
365 r = close(client_fd);
366 zassert_not_equal(-1, r, "close() failed on the client fd (%d)", errno);
367
368 NET_DBG("closing server fd");
369 r = close(server_fd);
370 zassert_not_equal(-1, r, "close() failed on the server fd (%d)", errno);
371
372 r = k_thread_join(&server_thread, K_FOREVER);
373 zassert_equal(0, r, "k_thread_join() failed (%d)", r);
374
375 k_yield();
376 }
377
test_common(int peer_verify)378 static void test_common(int peer_verify)
379 {
380 k_tid_t server_thread_id;
381 struct sockaddr_in sa;
382 uint8_t rx_buf[16];
383 int server_fd;
384 int client_fd;
385 int r;
386
387 /*
388 * Server socket setup
389 */
390 server_fd = test_configure_server(&server_thread_id, peer_verify, true,
391 false);
392
393 /*
394 * Client socket setup
395 */
396 client_fd = test_configure_client(&sa, peer_verify != TLS_PEER_VERIFY_NONE,
397 "localhost");
398
399 /*
400 * The main part of the test
401 */
402
403 r = connect(client_fd, (struct sockaddr *)&sa, sizeof(sa));
404 zassert_not_equal(r, -1, "failed to connect (%d)", errno);
405
406 NET_DBG("Calling send()");
407 r = send(client_fd, SECRET, SECRET_SIZE, 0);
408 zassert_not_equal(r, -1, "send() failed (%d)", errno);
409 zassert_equal(SECRET_SIZE, r, "expected: %zu actual: %d", SECRET_SIZE, r);
410
411 NET_DBG("Calling recv()");
412 memset(rx_buf, 0, sizeof(rx_buf));
413 r = recv(client_fd, rx_buf, sizeof(rx_buf), 0);
414 zassert_not_equal(r, -1, "recv() failed (%d)", errno);
415 zassert_equal(SECRET_SIZE, r, "expected: %zu actual: %d", SECRET_SIZE, r);
416 zassert_mem_equal(SECRET, rx_buf, SECRET_SIZE,
417 "expected: %s actual: %s", SECRET, rx_buf);
418
419 /*
420 * Cleanup resources
421 */
422 test_shutdown(client_fd, server_fd, server_thread_id);
423 }
424
ZTEST(net_socket_tls_api_extension,test_tls_peer_verify_none)425 ZTEST(net_socket_tls_api_extension, test_tls_peer_verify_none)
426 {
427 test_common(TLS_PEER_VERIFY_NONE);
428 }
429
ZTEST(net_socket_tls_api_extension,test_tls_peer_verify_optional)430 ZTEST(net_socket_tls_api_extension, test_tls_peer_verify_optional)
431 {
432 test_common(TLS_PEER_VERIFY_OPTIONAL);
433 }
434
ZTEST(net_socket_tls_api_extension,test_tls_peer_verify_required)435 ZTEST(net_socket_tls_api_extension, test_tls_peer_verify_required)
436 {
437 test_common(TLS_PEER_VERIFY_REQUIRED);
438 }
439
test_tls_cert_verify_result_opt_common(uint32_t expect)440 static void test_tls_cert_verify_result_opt_common(uint32_t expect)
441 {
442 int server_fd, client_fd, ret;
443 k_tid_t server_thread_id;
444 struct sockaddr_in sa;
445 uint32_t optval;
446 socklen_t optlen = sizeof(optval);
447 const char *hostname = "localhost";
448 int peer_verify = TLS_PEER_VERIFY_OPTIONAL;
449
450 if (expect == MBEDTLS_X509_BADCERT_CN_MISMATCH) {
451 hostname = "dummy";
452 }
453
454 server_fd = test_configure_server(&server_thread_id, TLS_PEER_VERIFY_NONE,
455 false, false);
456 client_fd = test_configure_client(&sa, false, hostname);
457
458 ret = zsock_setsockopt(client_fd, SOL_TLS, TLS_PEER_VERIFY,
459 &peer_verify, sizeof(peer_verify));
460 zassert_ok(ret, "failed to set TLS_PEER_VERIFY (%d)", errno);
461
462 ret = zsock_connect(client_fd, (struct sockaddr *)&sa, sizeof(sa));
463 zassert_not_equal(ret, -1, "failed to connect (%d)", errno);
464
465 ret = zsock_getsockopt(client_fd, SOL_TLS, TLS_CERT_VERIFY_RESULT,
466 &optval, &optlen);
467 zassert_equal(ret, 0, "getsockopt failed (%d)", errno);
468 zassert_equal(optval, expect, "getsockopt got invalid verify result %d",
469 optval);
470
471 test_shutdown(client_fd, server_fd, server_thread_id);
472 }
473
ZTEST(net_socket_tls_api_extension,test_tls_cert_verify_result_opt_ok)474 ZTEST(net_socket_tls_api_extension, test_tls_cert_verify_result_opt_ok)
475 {
476 test_tls_cert_verify_result_opt_common(0);
477 }
478
ZTEST(net_socket_tls_api_extension,test_tls_cert_verify_result_opt_bad_cn)479 ZTEST(net_socket_tls_api_extension, test_tls_cert_verify_result_opt_bad_cn)
480 {
481 test_tls_cert_verify_result_opt_common(MBEDTLS_X509_BADCERT_CN_MISMATCH);
482 }
483
484 struct test_cert_verify_ctx {
485 bool cb_called;
486 int result;
487 };
488
cert_verify_cb(void * ctx,mbedtls_x509_crt * crt,int depth,uint32_t * flags)489 static int cert_verify_cb(void *ctx, mbedtls_x509_crt *crt, int depth,
490 uint32_t *flags)
491 {
492 struct test_cert_verify_ctx *test_ctx = (struct test_cert_verify_ctx *)ctx;
493
494 test_ctx->cb_called = true;
495
496 if (test_ctx->result == 0) {
497 *flags = 0;
498 } else {
499 *flags |= MBEDTLS_X509_BADCERT_NOT_TRUSTED;
500 }
501
502 return test_ctx->result;
503 }
504
test_tls_cert_verify_cb_opt_common(int result)505 static void test_tls_cert_verify_cb_opt_common(int result)
506 {
507 int server_fd, client_fd, ret;
508 k_tid_t server_thread_id;
509 struct sockaddr_in sa;
510 struct test_cert_verify_ctx ctx = {
511 .cb_called = false,
512 .result = result,
513 };
514 struct tls_cert_verify_cb cb = {
515 .cb = cert_verify_cb,
516 .ctx = &ctx,
517 };
518
519 server_fd = test_configure_server(&server_thread_id, TLS_PEER_VERIFY_NONE,
520 false, result == 0 ? false : true);
521 client_fd = test_configure_client(&sa, false, "localhost");
522
523 ret = zsock_setsockopt(client_fd, SOL_TLS, TLS_CERT_VERIFY_CALLBACK,
524 &cb, sizeof(cb));
525 zassert_ok(ret, "failed to set TLS_CERT_VERIFY_CALLBACK (%d)", errno);
526
527 ret = zsock_connect(client_fd, (struct sockaddr *)&sa, sizeof(sa));
528 zassert_true(ctx.cb_called, "callback not called");
529 if (result == 0) {
530 zassert_equal(ret, 0, "failed to connect (%d)", errno);
531 } else {
532 zassert_equal(ret, -1, "connect() should fail");
533 zassert_equal(errno, ECONNABORTED, "invalid errno");
534 }
535
536 test_shutdown(client_fd, server_fd, server_thread_id);
537 }
538
ZTEST(net_socket_tls_api_extension,test_tls_cert_verify_cb_opt_ok)539 ZTEST(net_socket_tls_api_extension, test_tls_cert_verify_cb_opt_ok)
540 {
541 test_tls_cert_verify_cb_opt_common(0);
542 }
543
ZTEST(net_socket_tls_api_extension,test_tls_cert_verify_cb_opt_bad_cert)544 ZTEST(net_socket_tls_api_extension, test_tls_cert_verify_cb_opt_bad_cert)
545 {
546 test_tls_cert_verify_cb_opt_common(MBEDTLS_ERR_X509_CERT_VERIFY_FAILED);
547 }
548
setup(void)549 static void *setup(void)
550 {
551 int r;
552
553 /*
554 * Load both client & server credentials
555 *
556 * Normally, this would be split into separate applications but
557 * for testing purposes, we just use separate threads.
558 *
559 * Also, it has to be done before tests are run, otherwise
560 * there are errors due to attempts to load too many certificates.
561 *
562 * The server would normally load
563 * - server public key
564 * - server private key
565 * - ca cert (only when client authentication is required)
566 *
567 * The client would normally load
568 * - ca cert (to verify the server)
569 * - client public key (only when client authentication is required)
570 * - client private key (only when client authentication is required)
571 */
572 if (IS_ENABLED(CONFIG_TLS_CREDENTIALS)) {
573 NET_DBG("Loading credentials");
574 r = tls_credential_add(CA_CERTIFICATE_TAG,
575 TLS_CREDENTIAL_CA_CERTIFICATE,
576 ca, sizeof(ca));
577 zassert_equal(r, 0, "failed to add CA Certificate (%d)", r);
578
579 r = tls_credential_add(SERVER_CERTIFICATE_TAG,
580 TLS_CREDENTIAL_PUBLIC_CERTIFICATE,
581 server, sizeof(server));
582 zassert_equal(r, 0, "failed to add Server Certificate (%d)", r);
583
584 r = tls_credential_add(SERVER_CERTIFICATE_TAG,
585 TLS_CREDENTIAL_PRIVATE_KEY,
586 server_privkey, sizeof(server_privkey));
587 zassert_equal(r, 0, "failed to add Server Private Key (%d)", r);
588
589 r = tls_credential_add(CLIENT_CERTIFICATE_TAG,
590 TLS_CREDENTIAL_PUBLIC_CERTIFICATE,
591 client, sizeof(client));
592 zassert_equal(r, 0, "failed to add Client Certificate (%d)", r);
593
594 r = tls_credential_add(CLIENT_CERTIFICATE_TAG,
595 TLS_CREDENTIAL_PRIVATE_KEY,
596 client_privkey, sizeof(client_privkey));
597 zassert_equal(r, 0, "failed to add Client Private Key (%d)", r);
598 }
599 return NULL;
600 }
601
602 ZTEST_SUITE(net_socket_tls_api_extension, NULL, setup, NULL, NULL, NULL);
603