1 /** \file ssl_helpers.c
2  *
3  * \brief Helper functions to set up a TLS connection.
4  */
5 
6 /*
7  *  Copyright The Mbed TLS Contributors
8  *  SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
9  */
10 
11 #include <test/ssl_helpers.h>
12 #include "mbedtls/psa_util.h"
13 
14 #include <limits.h>
15 
16 #if defined(MBEDTLS_SSL_TLS_C)
mbedtls_test_random(void * p_rng,unsigned char * output,size_t output_len)17 int mbedtls_test_random(void *p_rng, unsigned char *output, size_t output_len)
18 {
19     (void) p_rng;
20     for (size_t i = 0; i < output_len; i++) {
21         output[i] = rand();
22     }
23 
24     return 0;
25 }
26 
mbedtls_test_ssl_log_analyzer(void * ctx,int level,const char * file,int line,const char * str)27 void mbedtls_test_ssl_log_analyzer(void *ctx, int level,
28                                    const char *file, int line,
29                                    const char *str)
30 {
31     mbedtls_test_ssl_log_pattern *p = (mbedtls_test_ssl_log_pattern *) ctx;
32 
33 /* Change 0 to 1 for debugging of test cases that use this function. */
34 #if 0
35     const char *q, *basename;
36     /* Extract basename from file */
37     for (q = basename = file; *q != '\0'; q++) {
38         if (*q == '/' || *q == '\\') {
39             basename = q + 1;
40         }
41     }
42     printf("%s:%04d: |%d| %s",
43            basename, line, level, str);
44 #else
45     (void) level;
46     (void) line;
47     (void) file;
48 #endif
49 
50     if (NULL != p &&
51         NULL != p->pattern &&
52         NULL != strstr(str, p->pattern)) {
53         p->counter++;
54     }
55 }
56 
mbedtls_test_init_handshake_options(mbedtls_test_handshake_test_options * opts)57 void mbedtls_test_init_handshake_options(
58     mbedtls_test_handshake_test_options *opts)
59 {
60 #if defined(MBEDTLS_SSL_HANDSHAKE_WITH_CERT_ENABLED)
61     static int rng_seed = 0xBEEF;
62 
63     srand(rng_seed);
64     rng_seed += 0xD0;
65 #endif
66 
67     memset(opts, 0, sizeof(*opts));
68 
69     opts->cipher = "";
70     opts->client_min_version = MBEDTLS_SSL_VERSION_UNKNOWN;
71     opts->client_max_version = MBEDTLS_SSL_VERSION_UNKNOWN;
72     opts->server_min_version = MBEDTLS_SSL_VERSION_UNKNOWN;
73     opts->server_max_version = MBEDTLS_SSL_VERSION_UNKNOWN;
74     opts->expected_negotiated_version = MBEDTLS_SSL_VERSION_TLS1_3;
75     opts->pk_alg = MBEDTLS_PK_RSA;
76     opts->srv_auth_mode = MBEDTLS_SSL_VERIFY_REQUIRED;
77     opts->mfl = MBEDTLS_SSL_MAX_FRAG_LEN_NONE;
78     opts->cli_msg_len = 100;
79     opts->srv_msg_len = 100;
80     opts->expected_cli_fragments = 1;
81     opts->expected_srv_fragments = 1;
82     opts->legacy_renegotiation = MBEDTLS_SSL_LEGACY_NO_RENEGOTIATION;
83     opts->resize_buffers = 1;
84     opts->early_data = MBEDTLS_SSL_EARLY_DATA_DISABLED;
85     opts->max_early_data_size = -1;
86 #if defined(MBEDTLS_SSL_CACHE_C)
87     TEST_CALLOC(opts->cache, 1);
88     mbedtls_ssl_cache_init(opts->cache);
89 #if defined(MBEDTLS_HAVE_TIME)
90     TEST_EQUAL(mbedtls_ssl_cache_get_timeout(opts->cache),
91                MBEDTLS_SSL_CACHE_DEFAULT_TIMEOUT);
92 #endif
93 exit:
94     return;
95 #endif
96 }
97 
mbedtls_test_free_handshake_options(mbedtls_test_handshake_test_options * opts)98 void mbedtls_test_free_handshake_options(
99     mbedtls_test_handshake_test_options *opts)
100 {
101 #if defined(MBEDTLS_SSL_CACHE_C)
102     mbedtls_ssl_cache_free(opts->cache);
103     mbedtls_free(opts->cache);
104 #else
105     (void) opts;
106 #endif
107 }
108 
109 #if defined(MBEDTLS_TEST_HOOKS)
set_chk_buf_ptr_args(mbedtls_ssl_chk_buf_ptr_args * args,unsigned char * cur,unsigned char * end,size_t need)110 static void set_chk_buf_ptr_args(
111     mbedtls_ssl_chk_buf_ptr_args *args,
112     unsigned char *cur, unsigned char *end, size_t need)
113 {
114     args->cur = cur;
115     args->end = end;
116     args->need = need;
117 }
118 
reset_chk_buf_ptr_args(mbedtls_ssl_chk_buf_ptr_args * args)119 static void reset_chk_buf_ptr_args(mbedtls_ssl_chk_buf_ptr_args *args)
120 {
121     memset(args, 0, sizeof(*args));
122 }
123 #endif /* MBEDTLS_TEST_HOOKS */
124 
mbedtls_test_ssl_buffer_init(mbedtls_test_ssl_buffer * buf)125 void mbedtls_test_ssl_buffer_init(mbedtls_test_ssl_buffer *buf)
126 {
127     memset(buf, 0, sizeof(*buf));
128 }
129 
mbedtls_test_ssl_buffer_setup(mbedtls_test_ssl_buffer * buf,size_t capacity)130 int mbedtls_test_ssl_buffer_setup(mbedtls_test_ssl_buffer *buf,
131                                   size_t capacity)
132 {
133     buf->buffer = (unsigned char *) mbedtls_calloc(capacity,
134                                                    sizeof(unsigned char));
135     if (NULL == buf->buffer) {
136         return MBEDTLS_ERR_SSL_ALLOC_FAILED;
137     }
138     buf->capacity = capacity;
139 
140     return 0;
141 }
142 
mbedtls_test_ssl_buffer_free(mbedtls_test_ssl_buffer * buf)143 void mbedtls_test_ssl_buffer_free(mbedtls_test_ssl_buffer *buf)
144 {
145     if (buf->buffer != NULL) {
146         mbedtls_free(buf->buffer);
147     }
148 
149     memset(buf, 0, sizeof(*buf));
150 }
151 
mbedtls_test_ssl_buffer_put(mbedtls_test_ssl_buffer * buf,const unsigned char * input,size_t input_len)152 int mbedtls_test_ssl_buffer_put(mbedtls_test_ssl_buffer *buf,
153                                 const unsigned char *input, size_t input_len)
154 {
155     size_t overflow = 0;
156 
157     if ((buf == NULL) || (buf->buffer == NULL)) {
158         return -1;
159     }
160 
161     /* Reduce input_len to a number that fits in the buffer. */
162     if ((buf->content_length + input_len) > buf->capacity) {
163         input_len = buf->capacity - buf->content_length;
164     }
165 
166     if (input == NULL) {
167         return (input_len == 0) ? 0 : -1;
168     }
169 
170     /* Check if the buffer has not come full circle and free space is not in
171      * the middle */
172     if (buf->start + buf->content_length < buf->capacity) {
173 
174         /* Calculate the number of bytes that need to be placed at lower memory
175          * address */
176         if (buf->start + buf->content_length + input_len
177             > buf->capacity) {
178             overflow = (buf->start + buf->content_length + input_len)
179                        % buf->capacity;
180         }
181 
182         memcpy(buf->buffer + buf->start + buf->content_length, input,
183                input_len - overflow);
184         memcpy(buf->buffer, input + input_len - overflow, overflow);
185 
186     } else {
187         /* The buffer has come full circle and free space is in the middle */
188         memcpy(buf->buffer + buf->start + buf->content_length - buf->capacity,
189                input, input_len);
190     }
191 
192     buf->content_length += input_len;
193     return (input_len > INT_MAX) ? INT_MAX : (int) input_len;
194 }
195 
mbedtls_test_ssl_buffer_get(mbedtls_test_ssl_buffer * buf,unsigned char * output,size_t output_len)196 int mbedtls_test_ssl_buffer_get(mbedtls_test_ssl_buffer *buf,
197                                 unsigned char *output, size_t output_len)
198 {
199     size_t overflow = 0;
200 
201     if ((buf == NULL) || (buf->buffer == NULL)) {
202         return -1;
203     }
204 
205     if (output == NULL && output_len == 0) {
206         return 0;
207     }
208 
209     if (buf->content_length < output_len) {
210         output_len = buf->content_length;
211     }
212 
213     /* Calculate the number of bytes that need to be drawn from lower memory
214      * address */
215     if (buf->start + output_len > buf->capacity) {
216         overflow = (buf->start + output_len) % buf->capacity;
217     }
218 
219     if (output != NULL) {
220         memcpy(output, buf->buffer + buf->start, output_len - overflow);
221         memcpy(output + output_len - overflow, buf->buffer, overflow);
222     }
223 
224     buf->content_length -= output_len;
225     buf->start = (buf->start + output_len) % buf->capacity;
226 
227     return (output_len > INT_MAX) ? INT_MAX : (int) output_len;
228 }
229 
mbedtls_test_ssl_message_queue_setup(mbedtls_test_ssl_message_queue * queue,size_t capacity)230 int mbedtls_test_ssl_message_queue_setup(
231     mbedtls_test_ssl_message_queue *queue, size_t capacity)
232 {
233     queue->messages = (size_t *) mbedtls_calloc(capacity, sizeof(size_t));
234     if (NULL == queue->messages) {
235         return MBEDTLS_ERR_SSL_ALLOC_FAILED;
236     }
237 
238     queue->capacity = (capacity > INT_MAX) ? INT_MAX : (int) capacity;
239     queue->pos = 0;
240     queue->num = 0;
241 
242     return 0;
243 }
244 
mbedtls_test_ssl_message_queue_free(mbedtls_test_ssl_message_queue * queue)245 void mbedtls_test_ssl_message_queue_free(
246     mbedtls_test_ssl_message_queue *queue)
247 {
248     if (queue == NULL) {
249         return;
250     }
251 
252     if (queue->messages != NULL) {
253         mbedtls_free(queue->messages);
254     }
255 
256     memset(queue, 0, sizeof(*queue));
257 }
258 
mbedtls_test_ssl_message_queue_push_info(mbedtls_test_ssl_message_queue * queue,size_t len)259 int mbedtls_test_ssl_message_queue_push_info(
260     mbedtls_test_ssl_message_queue *queue, size_t len)
261 {
262     int place;
263     if (queue == NULL) {
264         return MBEDTLS_TEST_ERROR_ARG_NULL;
265     }
266 
267     if (queue->num >= queue->capacity) {
268         return MBEDTLS_ERR_SSL_WANT_WRITE;
269     }
270 
271     place = (queue->pos + queue->num) % queue->capacity;
272     queue->messages[place] = len;
273     queue->num++;
274     return (len > INT_MAX) ? INT_MAX : (int) len;
275 }
276 
mbedtls_test_ssl_message_queue_pop_info(mbedtls_test_ssl_message_queue * queue,size_t buf_len)277 int mbedtls_test_ssl_message_queue_pop_info(
278     mbedtls_test_ssl_message_queue *queue, size_t buf_len)
279 {
280     size_t message_length;
281     if (queue == NULL) {
282         return MBEDTLS_TEST_ERROR_ARG_NULL;
283     }
284     if (queue->num == 0) {
285         return MBEDTLS_ERR_SSL_WANT_READ;
286     }
287 
288     message_length = queue->messages[queue->pos];
289     queue->messages[queue->pos] = 0;
290     queue->num--;
291     queue->pos++;
292     queue->pos %= queue->capacity;
293     if (queue->pos < 0) {
294         queue->pos += queue->capacity;
295     }
296 
297     return (message_length > INT_MAX && buf_len > INT_MAX) ? INT_MAX :
298            (message_length > buf_len) ? (int) buf_len : (int) message_length;
299 }
300 
301 /*
302  * Take a peek on the info about the next message length from the queue.
303  * This will be the oldest inserted message length(fifo).
304  *
305  * \retval  MBEDTLS_TEST_ERROR_ARG_NULL, if the queue is null.
306  * \retval  MBEDTLS_ERR_SSL_WANT_READ, if the queue is empty.
307  * \retval  0, if the peek was successful.
308  * \retval  MBEDTLS_TEST_ERROR_MESSAGE_TRUNCATED, if the given buffer length is
309  *          too small to fit the message. In this case the \p msg_len will be
310  *          set to the full message length so that the
311  *          caller knows what portion of the message can be dropped.
312  */
test_ssl_message_queue_peek_info(mbedtls_test_ssl_message_queue * queue,size_t buf_len,size_t * msg_len)313 static int test_ssl_message_queue_peek_info(
314     mbedtls_test_ssl_message_queue *queue,
315     size_t buf_len, size_t *msg_len)
316 {
317     if (queue == NULL || msg_len == NULL) {
318         return MBEDTLS_TEST_ERROR_ARG_NULL;
319     }
320     if (queue->num == 0) {
321         return MBEDTLS_ERR_SSL_WANT_READ;
322     }
323 
324     *msg_len = queue->messages[queue->pos];
325     return (*msg_len > buf_len) ? MBEDTLS_TEST_ERROR_MESSAGE_TRUNCATED : 0;
326 }
327 
mbedtls_test_mock_socket_init(mbedtls_test_mock_socket * socket)328 void mbedtls_test_mock_socket_init(mbedtls_test_mock_socket *socket)
329 {
330     memset(socket, 0, sizeof(*socket));
331 }
332 
mbedtls_test_mock_socket_close(mbedtls_test_mock_socket * socket)333 void mbedtls_test_mock_socket_close(mbedtls_test_mock_socket *socket)
334 {
335     if (socket == NULL) {
336         return;
337     }
338 
339     if (socket->input != NULL) {
340         mbedtls_test_ssl_buffer_free(socket->input);
341         mbedtls_free(socket->input);
342     }
343 
344     if (socket->output != NULL) {
345         mbedtls_test_ssl_buffer_free(socket->output);
346         mbedtls_free(socket->output);
347     }
348 
349     if (socket->peer != NULL) {
350         memset(socket->peer, 0, sizeof(*socket->peer));
351     }
352 
353     memset(socket, 0, sizeof(*socket));
354 }
355 
mbedtls_test_mock_socket_connect(mbedtls_test_mock_socket * peer1,mbedtls_test_mock_socket * peer2,size_t bufsize)356 int mbedtls_test_mock_socket_connect(mbedtls_test_mock_socket *peer1,
357                                      mbedtls_test_mock_socket *peer2,
358                                      size_t bufsize)
359 {
360     int ret = -1;
361 
362     peer1->output =
363         (mbedtls_test_ssl_buffer *) mbedtls_calloc(
364             1, sizeof(mbedtls_test_ssl_buffer));
365     if (peer1->output == NULL) {
366         ret = MBEDTLS_ERR_SSL_ALLOC_FAILED;
367         goto exit;
368     }
369     mbedtls_test_ssl_buffer_init(peer1->output);
370     if (0 != (ret = mbedtls_test_ssl_buffer_setup(peer1->output, bufsize))) {
371         goto exit;
372     }
373 
374     peer2->output =
375         (mbedtls_test_ssl_buffer *) mbedtls_calloc(
376             1, sizeof(mbedtls_test_ssl_buffer));
377     if (peer2->output == NULL) {
378         ret = MBEDTLS_ERR_SSL_ALLOC_FAILED;
379         goto exit;
380     }
381     mbedtls_test_ssl_buffer_init(peer2->output);
382     if (0 != (ret = mbedtls_test_ssl_buffer_setup(peer2->output, bufsize))) {
383         goto exit;
384     }
385 
386     peer1->peer = peer2;
387     peer2->peer = peer1;
388     peer1->input = peer2->output;
389     peer2->input = peer1->output;
390 
391     peer1->status = peer2->status = MBEDTLS_MOCK_SOCKET_CONNECTED;
392     ret = 0;
393 
394 exit:
395 
396     if (ret != 0) {
397         mbedtls_test_mock_socket_close(peer1);
398         mbedtls_test_mock_socket_close(peer2);
399     }
400 
401     return ret;
402 }
403 
mbedtls_test_mock_tcp_send_b(void * ctx,const unsigned char * buf,size_t len)404 int mbedtls_test_mock_tcp_send_b(void *ctx,
405                                  const unsigned char *buf, size_t len)
406 {
407     mbedtls_test_mock_socket *socket = (mbedtls_test_mock_socket *) ctx;
408 
409     if (socket == NULL || socket->status != MBEDTLS_MOCK_SOCKET_CONNECTED) {
410         return -1;
411     }
412 
413     return mbedtls_test_ssl_buffer_put(socket->output, buf, len);
414 }
415 
mbedtls_test_mock_tcp_recv_b(void * ctx,unsigned char * buf,size_t len)416 int mbedtls_test_mock_tcp_recv_b(void *ctx, unsigned char *buf, size_t len)
417 {
418     mbedtls_test_mock_socket *socket = (mbedtls_test_mock_socket *) ctx;
419 
420     if (socket == NULL || socket->status != MBEDTLS_MOCK_SOCKET_CONNECTED) {
421         return -1;
422     }
423 
424     return mbedtls_test_ssl_buffer_get(socket->input, buf, len);
425 }
426 
mbedtls_test_mock_tcp_send_nb(void * ctx,const unsigned char * buf,size_t len)427 int mbedtls_test_mock_tcp_send_nb(void *ctx,
428                                   const unsigned char *buf, size_t len)
429 {
430     mbedtls_test_mock_socket *socket = (mbedtls_test_mock_socket *) ctx;
431 
432     if (socket == NULL || socket->status != MBEDTLS_MOCK_SOCKET_CONNECTED) {
433         return -1;
434     }
435 
436     if (socket->output->capacity == socket->output->content_length) {
437         return MBEDTLS_ERR_SSL_WANT_WRITE;
438     }
439 
440     return mbedtls_test_ssl_buffer_put(socket->output, buf, len);
441 }
442 
mbedtls_test_mock_tcp_recv_nb(void * ctx,unsigned char * buf,size_t len)443 int mbedtls_test_mock_tcp_recv_nb(void *ctx, unsigned char *buf, size_t len)
444 {
445     mbedtls_test_mock_socket *socket = (mbedtls_test_mock_socket *) ctx;
446 
447     if (socket == NULL || socket->status != MBEDTLS_MOCK_SOCKET_CONNECTED) {
448         return -1;
449     }
450 
451     if (socket->input->content_length == 0) {
452         return MBEDTLS_ERR_SSL_WANT_READ;
453     }
454 
455     return mbedtls_test_ssl_buffer_get(socket->input, buf, len);
456 }
457 
mbedtls_test_message_socket_init(mbedtls_test_message_socket_context * ctx)458 void mbedtls_test_message_socket_init(
459     mbedtls_test_message_socket_context *ctx)
460 {
461     ctx->queue_input = NULL;
462     ctx->queue_output = NULL;
463     ctx->socket = NULL;
464 }
465 
mbedtls_test_message_socket_setup(mbedtls_test_ssl_message_queue * queue_input,mbedtls_test_ssl_message_queue * queue_output,size_t queue_capacity,mbedtls_test_mock_socket * socket,mbedtls_test_message_socket_context * ctx)466 int mbedtls_test_message_socket_setup(
467     mbedtls_test_ssl_message_queue *queue_input,
468     mbedtls_test_ssl_message_queue *queue_output,
469     size_t queue_capacity,
470     mbedtls_test_mock_socket *socket,
471     mbedtls_test_message_socket_context *ctx)
472 {
473     int ret = mbedtls_test_ssl_message_queue_setup(queue_input, queue_capacity);
474     if (ret != 0) {
475         return ret;
476     }
477     ctx->queue_input = queue_input;
478     ctx->queue_output = queue_output;
479     ctx->socket = socket;
480     mbedtls_test_mock_socket_init(socket);
481 
482     return 0;
483 }
484 
mbedtls_test_message_socket_close(mbedtls_test_message_socket_context * ctx)485 void mbedtls_test_message_socket_close(
486     mbedtls_test_message_socket_context *ctx)
487 {
488     if (ctx == NULL) {
489         return;
490     }
491 
492     mbedtls_test_ssl_message_queue_free(ctx->queue_input);
493     mbedtls_test_mock_socket_close(ctx->socket);
494     memset(ctx, 0, sizeof(*ctx));
495 }
496 
mbedtls_test_mock_tcp_send_msg(void * ctx,const unsigned char * buf,size_t len)497 int mbedtls_test_mock_tcp_send_msg(void *ctx,
498                                    const unsigned char *buf, size_t len)
499 {
500     mbedtls_test_ssl_message_queue *queue;
501     mbedtls_test_mock_socket *socket;
502     mbedtls_test_message_socket_context *context =
503         (mbedtls_test_message_socket_context *) ctx;
504 
505     if (context == NULL || context->socket == NULL
506         || context->queue_output == NULL) {
507         return MBEDTLS_TEST_ERROR_CONTEXT_ERROR;
508     }
509 
510     queue = context->queue_output;
511     socket = context->socket;
512 
513     if (queue->num >= queue->capacity) {
514         return MBEDTLS_ERR_SSL_WANT_WRITE;
515     }
516 
517     if (mbedtls_test_mock_tcp_send_b(socket, buf, len) != (int) len) {
518         return MBEDTLS_TEST_ERROR_SEND_FAILED;
519     }
520 
521     return mbedtls_test_ssl_message_queue_push_info(queue, len);
522 }
523 
mbedtls_test_mock_tcp_recv_msg(void * ctx,unsigned char * buf,size_t buf_len)524 int mbedtls_test_mock_tcp_recv_msg(void *ctx,
525                                    unsigned char *buf, size_t buf_len)
526 {
527     mbedtls_test_ssl_message_queue *queue;
528     mbedtls_test_mock_socket *socket;
529     mbedtls_test_message_socket_context *context =
530         (mbedtls_test_message_socket_context *) ctx;
531     size_t drop_len = 0;
532     size_t msg_len;
533     int ret;
534 
535     if (context == NULL || context->socket == NULL
536         || context->queue_input == NULL) {
537         return MBEDTLS_TEST_ERROR_CONTEXT_ERROR;
538     }
539 
540     queue = context->queue_input;
541     socket = context->socket;
542 
543     /* Peek first, so that in case of a socket error the data remains in
544      * the queue. */
545     ret = test_ssl_message_queue_peek_info(queue, buf_len, &msg_len);
546     if (ret == MBEDTLS_TEST_ERROR_MESSAGE_TRUNCATED) {
547         /* Calculate how much to drop */
548         drop_len = msg_len - buf_len;
549 
550         /* Set the requested message len to be buffer length */
551         msg_len = buf_len;
552     } else if (ret != 0) {
553         return ret;
554     }
555 
556     if (mbedtls_test_mock_tcp_recv_b(socket, buf, msg_len) != (int) msg_len) {
557         return MBEDTLS_TEST_ERROR_RECV_FAILED;
558     }
559 
560     if (ret == MBEDTLS_TEST_ERROR_MESSAGE_TRUNCATED) {
561         /* Drop the remaining part of the message */
562         if (mbedtls_test_mock_tcp_recv_b(socket, NULL, drop_len) !=
563             (int) drop_len) {
564             /* Inconsistent state - part of the message was read,
565              * and a part couldn't. Not much we can do here, but it should not
566              * happen in test environment, unless forced manually. */
567         }
568     }
569     ret = mbedtls_test_ssl_message_queue_pop_info(queue, buf_len);
570     if (ret < 0) {
571         return ret;
572     }
573 
574     return (msg_len > INT_MAX) ? INT_MAX : (int) msg_len;
575 }
576 
577 
578 #if defined(MBEDTLS_SSL_HANDSHAKE_WITH_CERT_ENABLED) && \
579     defined(MBEDTLS_SSL_HANDSHAKE_WITH_PSK_ENABLED)  && \
580     defined(MBEDTLS_SSL_SRV_C)
psk_dummy_callback(void * p_info,mbedtls_ssl_context * ssl,const unsigned char * name,size_t name_len)581 static int psk_dummy_callback(void *p_info, mbedtls_ssl_context *ssl,
582                               const unsigned char *name, size_t name_len)
583 {
584     (void) p_info;
585     (void) ssl;
586     (void) name;
587     (void) name_len;
588 
589     return 0;
590 }
591 #endif /* MBEDTLS_SSL_HANDSHAKE_WITH_CERT_ENABLED &&
592           MBEDTLS_SSL_HANDSHAKE_WITH_PSK_ENABLED  &&
593           MBEDTLS_SSL_SRV_C */
594 
595 #if defined(MBEDTLS_SSL_HANDSHAKE_WITH_CERT_ENABLED)
596 
set_ciphersuite(mbedtls_test_ssl_endpoint * ep,const char * cipher)597 static int set_ciphersuite(mbedtls_test_ssl_endpoint *ep,
598                            const char *cipher)
599 {
600     if (cipher == NULL || cipher[0] == 0) {
601         return 1;
602     }
603 
604     int ok = 0;
605 
606     TEST_CALLOC(ep->ciphersuites, 2);
607     ep->ciphersuites[0] = mbedtls_ssl_get_ciphersuite_id(cipher);
608     ep->ciphersuites[1] = 0;
609 
610     const mbedtls_ssl_ciphersuite_t *ciphersuite_info =
611         mbedtls_ssl_ciphersuite_from_id(ep->ciphersuites[0]);
612 
613     TEST_ASSERT(ciphersuite_info != NULL);
614     TEST_ASSERT(ciphersuite_info->min_tls_version <= ep->conf.max_tls_version);
615     TEST_ASSERT(ciphersuite_info->max_tls_version >= ep->conf.min_tls_version);
616 
617     if (ep->conf.max_tls_version > ciphersuite_info->max_tls_version) {
618         ep->conf.max_tls_version = (mbedtls_ssl_protocol_version) ciphersuite_info->max_tls_version;
619     }
620     if (ep->conf.min_tls_version < ciphersuite_info->min_tls_version) {
621         ep->conf.min_tls_version = (mbedtls_ssl_protocol_version) ciphersuite_info->min_tls_version;
622     }
623 
624     mbedtls_ssl_conf_ciphersuites(&ep->conf, ep->ciphersuites);
625     ok = 1;
626 
627 exit:
628     return ok;
629 }
630 
631 /*
632  * Deinitializes certificates from endpoint represented by \p ep.
633  */
test_ssl_endpoint_certificate_free(mbedtls_test_ssl_endpoint * ep)634 static void test_ssl_endpoint_certificate_free(mbedtls_test_ssl_endpoint *ep)
635 {
636     if (ep->ca_chain != NULL) {
637         mbedtls_x509_crt_free(ep->ca_chain);
638         mbedtls_free(ep->ca_chain);
639         ep->ca_chain = NULL;
640     }
641     if (ep->cert != NULL) {
642         mbedtls_x509_crt_free(ep->cert);
643         mbedtls_free(ep->cert);
644         ep->cert = NULL;
645     }
646     if (ep->pkey != NULL) {
647         if (mbedtls_pk_get_type(ep->pkey) == MBEDTLS_PK_OPAQUE) {
648             psa_destroy_key(ep->pkey->priv_id);
649         }
650         mbedtls_pk_free(ep->pkey);
651         mbedtls_free(ep->pkey);
652         ep->pkey = NULL;
653     }
654 }
655 
load_endpoint_rsa(mbedtls_test_ssl_endpoint * ep)656 static int load_endpoint_rsa(mbedtls_test_ssl_endpoint *ep)
657 {
658     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
659     if (ep->conf.endpoint == MBEDTLS_SSL_IS_SERVER) {
660         ret = mbedtls_x509_crt_parse(
661             ep->cert,
662             (const unsigned char *) mbedtls_test_srv_crt_rsa_sha256_der,
663             mbedtls_test_srv_crt_rsa_sha256_der_len);
664         TEST_EQUAL(ret, 0);
665         ret = mbedtls_pk_parse_key(
666             ep->pkey,
667             (const unsigned char *) mbedtls_test_srv_key_rsa_der,
668             mbedtls_test_srv_key_rsa_der_len, NULL, 0);
669         TEST_EQUAL(ret, 0);
670     } else {
671         ret = mbedtls_x509_crt_parse(
672             ep->cert,
673             (const unsigned char *) mbedtls_test_cli_crt_rsa_der,
674             mbedtls_test_cli_crt_rsa_der_len);
675         TEST_EQUAL(ret, 0);
676         ret = mbedtls_pk_parse_key(
677             ep->pkey,
678             (const unsigned char *) mbedtls_test_cli_key_rsa_der,
679             mbedtls_test_cli_key_rsa_der_len, NULL, 0);
680         TEST_EQUAL(ret, 0);
681     }
682 
683 exit:
684     return ret;
685 }
686 
load_endpoint_ecc(mbedtls_test_ssl_endpoint * ep)687 static int load_endpoint_ecc(mbedtls_test_ssl_endpoint *ep)
688 {
689     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
690     if (ep->conf.endpoint == MBEDTLS_SSL_IS_SERVER) {
691         ret = mbedtls_x509_crt_parse(
692             ep->cert,
693             (const unsigned char *) mbedtls_test_srv_crt_ec_der,
694             mbedtls_test_srv_crt_ec_der_len);
695         TEST_EQUAL(ret, 0);
696         ret = mbedtls_pk_parse_key(
697             ep->pkey,
698             (const unsigned char *) mbedtls_test_srv_key_ec_der,
699             mbedtls_test_srv_key_ec_der_len, NULL, 0);
700         TEST_EQUAL(ret, 0);
701     } else {
702         ret = mbedtls_x509_crt_parse(
703             ep->cert,
704             (const unsigned char *) mbedtls_test_cli_crt_ec_der,
705             mbedtls_test_cli_crt_ec_len);
706         TEST_EQUAL(ret, 0);
707         ret = mbedtls_pk_parse_key(
708             ep->pkey,
709             (const unsigned char *) mbedtls_test_cli_key_ec_der,
710             mbedtls_test_cli_key_ec_der_len, NULL, 0);
711         TEST_EQUAL(ret, 0);
712     }
713 
714 exit:
715     return ret;
716 }
717 
mbedtls_test_ssl_endpoint_certificate_init(mbedtls_test_ssl_endpoint * ep,int pk_alg,int opaque_alg,int opaque_alg2,int opaque_usage)718 int mbedtls_test_ssl_endpoint_certificate_init(mbedtls_test_ssl_endpoint *ep,
719                                                int pk_alg,
720                                                int opaque_alg, int opaque_alg2,
721                                                int opaque_usage)
722 {
723     int i = 0;
724     int ret = -1;
725     int ok = 0;
726     mbedtls_svc_key_id_t key_slot = MBEDTLS_SVC_KEY_ID_INIT;
727 
728     if (ep == NULL) {
729         return MBEDTLS_ERR_SSL_BAD_INPUT_DATA;
730     }
731 
732     TEST_CALLOC(ep->ca_chain, 1);
733     TEST_CALLOC(ep->cert, 1);
734     TEST_CALLOC(ep->pkey, 1);
735 
736     mbedtls_x509_crt_init(ep->ca_chain);
737     mbedtls_x509_crt_init(ep->cert);
738     mbedtls_pk_init(ep->pkey);
739 
740     /* Load the trusted CA */
741 
742     for (i = 0; mbedtls_test_cas_der[i] != NULL; i++) {
743         ret = mbedtls_x509_crt_parse_der(
744             ep->ca_chain,
745             (const unsigned char *) mbedtls_test_cas_der[i],
746             mbedtls_test_cas_der_len[i]);
747         TEST_EQUAL(ret, 0);
748     }
749 
750     /* Load own certificate and private key */
751 
752     if (pk_alg == MBEDTLS_PK_RSA) {
753         TEST_EQUAL(load_endpoint_rsa(ep), 0);
754     } else {
755         TEST_EQUAL(load_endpoint_ecc(ep), 0);
756     }
757 
758     if (opaque_alg != 0) {
759         psa_key_attributes_t key_attr = PSA_KEY_ATTRIBUTES_INIT;
760         /* Use a fake key usage to get a successful initial guess for the PSA attributes. */
761         TEST_EQUAL(mbedtls_pk_get_psa_attributes(ep->pkey, PSA_KEY_USAGE_SIGN_HASH,
762                                                  &key_attr), 0);
763         /* Then manually usage, alg and alg2 as requested by the test. */
764         psa_set_key_usage_flags(&key_attr, opaque_usage);
765         psa_set_key_algorithm(&key_attr, opaque_alg);
766         if (opaque_alg2 != PSA_ALG_NONE) {
767             psa_set_key_enrollment_algorithm(&key_attr, opaque_alg2);
768         }
769         TEST_EQUAL(mbedtls_pk_import_into_psa(ep->pkey, &key_attr, &key_slot), 0);
770         mbedtls_pk_free(ep->pkey);
771         mbedtls_pk_init(ep->pkey);
772         TEST_EQUAL(mbedtls_pk_wrap_psa(ep->pkey, key_slot), 0);
773     }
774 
775     mbedtls_ssl_conf_ca_chain(&(ep->conf), ep->ca_chain, NULL);
776 
777     ret = mbedtls_ssl_conf_own_cert(&(ep->conf), ep->cert,
778                                     ep->pkey);
779     TEST_EQUAL(ret, 0);
780 
781     ok = 1;
782 
783 exit:
784     if (ret == 0 && !ok) {
785         /* Exiting due to a test assertion that isn't ret == 0 */
786         ret = -1;
787     }
788     if (ret != 0) {
789         test_ssl_endpoint_certificate_free(ep);
790     }
791 
792     return ret;
793 }
794 
mbedtls_test_ssl_endpoint_init_conf(mbedtls_test_ssl_endpoint * ep,int endpoint_type,const mbedtls_test_handshake_test_options * options)795 int mbedtls_test_ssl_endpoint_init_conf(
796     mbedtls_test_ssl_endpoint *ep, int endpoint_type,
797     const mbedtls_test_handshake_test_options *options)
798 {
799     int ret = -1;
800 #if defined(MBEDTLS_SSL_HANDSHAKE_WITH_PSK_ENABLED)
801     const char *psk_identity = "foo";
802 #endif
803 
804     if (ep == NULL) {
805         return MBEDTLS_ERR_SSL_BAD_INPUT_DATA;
806     }
807 
808     memset(ep, 0, sizeof(*ep));
809 
810     ep->name = (endpoint_type == MBEDTLS_SSL_IS_SERVER) ? "Server" : "Client";
811 
812     mbedtls_ssl_init(&(ep->ssl));
813     mbedtls_ssl_config_init(&(ep->conf));
814     mbedtls_test_message_socket_init(&ep->dtls_context);
815 
816     TEST_ASSERT(mbedtls_ssl_conf_get_user_data_p(&ep->conf) == NULL);
817     TEST_EQUAL(mbedtls_ssl_conf_get_user_data_n(&ep->conf), 0);
818     TEST_ASSERT(mbedtls_ssl_get_user_data_p(&ep->ssl) == NULL);
819     TEST_EQUAL(mbedtls_ssl_get_user_data_n(&ep->ssl), 0);
820 
821     (void) mbedtls_test_rnd_std_rand(NULL,
822                                      (void *) &ep->user_data_cookie,
823                                      sizeof(ep->user_data_cookie));
824     mbedtls_ssl_conf_set_user_data_n(&ep->conf, ep->user_data_cookie);
825     mbedtls_ssl_set_user_data_n(&ep->ssl, ep->user_data_cookie);
826 
827     mbedtls_test_mock_socket_init(&(ep->socket));
828 
829     ret = mbedtls_ssl_config_defaults(&(ep->conf), endpoint_type,
830                                       options->dtls ?
831                                       MBEDTLS_SSL_TRANSPORT_DATAGRAM :
832                                       MBEDTLS_SSL_TRANSPORT_STREAM,
833                                       MBEDTLS_SSL_PRESET_DEFAULT);
834     TEST_EQUAL(ret, 0);
835 
836     if (MBEDTLS_SSL_IS_CLIENT == endpoint_type) {
837         if (options->client_min_version != MBEDTLS_SSL_VERSION_UNKNOWN) {
838             mbedtls_ssl_conf_min_tls_version(&(ep->conf),
839                                              options->client_min_version);
840         }
841 
842         if (options->client_max_version != MBEDTLS_SSL_VERSION_UNKNOWN) {
843             mbedtls_ssl_conf_max_tls_version(&(ep->conf),
844                                              options->client_max_version);
845         }
846     } else {
847         if (options->server_min_version != MBEDTLS_SSL_VERSION_UNKNOWN) {
848             mbedtls_ssl_conf_min_tls_version(&(ep->conf),
849                                              options->server_min_version);
850         }
851 
852         if (options->server_max_version != MBEDTLS_SSL_VERSION_UNKNOWN) {
853             mbedtls_ssl_conf_max_tls_version(&(ep->conf),
854                                              options->server_max_version);
855         }
856     }
857 
858     if (MBEDTLS_SSL_IS_CLIENT == endpoint_type) {
859         TEST_ASSERT(set_ciphersuite(ep, options->cipher));
860     }
861 
862     if (options->group_list != NULL) {
863         mbedtls_ssl_conf_groups(&(ep->conf), options->group_list);
864     }
865 
866     if (MBEDTLS_SSL_IS_SERVER == endpoint_type) {
867         mbedtls_ssl_conf_authmode(&(ep->conf), options->srv_auth_mode);
868     } else {
869         mbedtls_ssl_conf_authmode(&(ep->conf), MBEDTLS_SSL_VERIFY_REQUIRED);
870     }
871 
872 #if defined(MBEDTLS_SSL_EARLY_DATA)
873     mbedtls_ssl_conf_early_data(&(ep->conf), options->early_data);
874 #if defined(MBEDTLS_SSL_SRV_C)
875     if (endpoint_type == MBEDTLS_SSL_IS_SERVER &&
876         (options->max_early_data_size >= 0)) {
877         mbedtls_ssl_conf_max_early_data_size(&(ep->conf),
878                                              options->max_early_data_size);
879     }
880 #endif
881 
882 #if defined(MBEDTLS_SSL_ALPN)
883     /* check that alpn_list contains at least one valid entry */
884     if (options->alpn_list[0] != NULL) {
885         mbedtls_ssl_conf_alpn_protocols(&(ep->conf), options->alpn_list);
886     }
887 #endif
888 #endif
889 
890 #if defined(MBEDTLS_SSL_RENEGOTIATION)
891     if (options->renegotiate) {
892         mbedtls_ssl_conf_renegotiation(&ep->conf,
893                                        MBEDTLS_SSL_RENEGOTIATION_ENABLED);
894         mbedtls_ssl_conf_legacy_renegotiation(&ep->conf,
895                                               options->legacy_renegotiation);
896     }
897 #endif /* MBEDTLS_SSL_RENEGOTIATION */
898 
899 #if defined(MBEDTLS_SSL_CACHE_C) && defined(MBEDTLS_SSL_SRV_C)
900     if (endpoint_type == MBEDTLS_SSL_IS_SERVER && options->cache != NULL) {
901         mbedtls_ssl_conf_session_cache(&(ep->conf), options->cache,
902                                        mbedtls_ssl_cache_get,
903                                        mbedtls_ssl_cache_set);
904     }
905 #endif
906 
907 #if defined(MBEDTLS_SSL_MAX_FRAGMENT_LENGTH)
908     TEST_EQUAL(mbedtls_ssl_conf_max_frag_len(&ep->conf,
909                                              (unsigned char) options->mfl),
910                0);
911 #else
912     TEST_EQUAL(MBEDTLS_SSL_MAX_FRAG_LEN_NONE, options->mfl);
913 #endif /* MBEDTLS_SSL_MAX_FRAGMENT_LENGTH */
914 
915 #if defined(MBEDTLS_SSL_PROTO_DTLS) && defined(MBEDTLS_SSL_SRV_C)
916     if (endpoint_type == MBEDTLS_SSL_IS_SERVER && options->dtls) {
917         mbedtls_ssl_conf_dtls_cookies(&(ep->conf), NULL, NULL, NULL);
918     }
919 #endif
920 
921 #if defined(MBEDTLS_DEBUG_C)
922 #if defined(MBEDTLS_SSL_SRV_C)
923     if (endpoint_type == MBEDTLS_SSL_IS_SERVER &&
924         options->srv_log_fun != NULL) {
925         mbedtls_ssl_conf_dbg(&(ep->conf), options->srv_log_fun,
926                              options->srv_log_obj);
927     }
928 #endif
929 #if defined(MBEDTLS_SSL_CLI_C)
930     if (endpoint_type == MBEDTLS_SSL_IS_CLIENT &&
931         options->cli_log_fun != NULL) {
932         mbedtls_ssl_conf_dbg(&(ep->conf), options->cli_log_fun,
933                              options->cli_log_obj);
934     }
935 #endif
936 #endif /* MBEDTLS_DEBUG_C */
937 
938     ret = mbedtls_test_ssl_endpoint_certificate_init(ep, options->pk_alg,
939                                                      options->opaque_alg,
940                                                      options->opaque_alg2,
941                                                      options->opaque_usage);
942     TEST_EQUAL(ret, 0);
943 
944 #if defined(MBEDTLS_SSL_HANDSHAKE_WITH_PSK_ENABLED)
945     if (options->psk_str != NULL && options->psk_str->len > 0) {
946         TEST_EQUAL(mbedtls_ssl_conf_psk(
947                        &ep->conf, options->psk_str->x,
948                        options->psk_str->len,
949                        (const unsigned char *) psk_identity,
950                        strlen(psk_identity)), 0);
951 #if defined(MBEDTLS_SSL_SRV_C)
952         if (MBEDTLS_SSL_IS_SERVER == endpoint_type) {
953             mbedtls_ssl_conf_psk_cb(&ep->conf, psk_dummy_callback, NULL);
954         }
955 #endif
956     }
957 #endif
958 
959     TEST_EQUAL(mbedtls_ssl_conf_get_user_data_n(&ep->conf),
960                ep->user_data_cookie);
961     mbedtls_ssl_conf_set_user_data_p(&ep->conf, ep);
962 
963     return 0;
964 
965 exit:
966     if (ret == 0) {
967         /* Exiting due to a test assertion that isn't ret == 0 */
968         ret = -1;
969     }
970     return ret;
971 }
972 
mbedtls_test_ssl_endpoint_init_ssl(mbedtls_test_ssl_endpoint * ep,const mbedtls_test_handshake_test_options * options)973 int mbedtls_test_ssl_endpoint_init_ssl(
974     mbedtls_test_ssl_endpoint *ep,
975     const mbedtls_test_handshake_test_options *options)
976 {
977     int endpoint_type = mbedtls_ssl_conf_get_endpoint(&ep->conf);
978     int ret = -1;
979 
980     ret = mbedtls_ssl_setup(&(ep->ssl), &(ep->conf));
981     TEST_EQUAL(ret, 0);
982 
983     if (MBEDTLS_SSL_IS_CLIENT == endpoint_type) {
984         ret = mbedtls_ssl_set_hostname(&(ep->ssl), "localhost");
985         TEST_EQUAL(ret, 0);
986     }
987 
988     /* Non-blocking callbacks without timeout */
989     if (options->dtls) {
990         mbedtls_ssl_set_bio(&(ep->ssl), &ep->dtls_context,
991                             mbedtls_test_mock_tcp_send_msg,
992                             mbedtls_test_mock_tcp_recv_msg,
993                             NULL);
994 #if defined(MBEDTLS_TIMING_C)
995         mbedtls_ssl_set_timer_cb(&ep->ssl, &ep->timer,
996                                  mbedtls_timing_set_delay,
997                                  mbedtls_timing_get_delay);
998 #endif
999     } else {
1000         mbedtls_ssl_set_bio(&(ep->ssl), &(ep->socket),
1001                             mbedtls_test_mock_tcp_send_nb,
1002                             mbedtls_test_mock_tcp_recv_nb,
1003                             NULL);
1004     }
1005 
1006     TEST_EQUAL(mbedtls_ssl_get_user_data_n(&ep->ssl), ep->user_data_cookie);
1007     mbedtls_ssl_set_user_data_p(&ep->ssl, ep);
1008 
1009     return 0;
1010 
1011 exit:
1012     if (ret == 0) {
1013         /* Exiting due to a test assertion that isn't ret == 0 */
1014         ret = -1;
1015     }
1016     return ret;
1017 }
1018 
mbedtls_test_ssl_endpoint_init(mbedtls_test_ssl_endpoint * ep,int endpoint_type,const mbedtls_test_handshake_test_options * options)1019 int mbedtls_test_ssl_endpoint_init(
1020     mbedtls_test_ssl_endpoint *ep, int endpoint_type,
1021     const mbedtls_test_handshake_test_options *options)
1022 {
1023     int ret = mbedtls_test_ssl_endpoint_init_conf(ep, endpoint_type, options);
1024     if (ret != 0) {
1025         return ret;
1026     }
1027     ret = mbedtls_test_ssl_endpoint_init_ssl(ep, options);
1028     return ret;
1029 }
1030 
mbedtls_test_ssl_endpoint_free(mbedtls_test_ssl_endpoint * ep)1031 void mbedtls_test_ssl_endpoint_free(
1032     mbedtls_test_ssl_endpoint *ep)
1033 {
1034     mbedtls_ssl_free(&(ep->ssl));
1035     mbedtls_ssl_config_free(&(ep->conf));
1036 
1037     mbedtls_free(ep->ciphersuites);
1038     ep->ciphersuites = NULL;
1039     test_ssl_endpoint_certificate_free(ep);
1040 
1041     if (ep->dtls_context.socket != NULL) {
1042         mbedtls_test_message_socket_close(&ep->dtls_context);
1043     } else {
1044         mbedtls_test_mock_socket_close(&(ep->socket));
1045     }
1046 }
1047 
mbedtls_test_ssl_dtls_join_endpoints(mbedtls_test_ssl_endpoint * client,mbedtls_test_ssl_endpoint * server)1048 int mbedtls_test_ssl_dtls_join_endpoints(mbedtls_test_ssl_endpoint *client,
1049                                          mbedtls_test_ssl_endpoint *server)
1050 {
1051     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
1052 
1053     ret = mbedtls_test_message_socket_setup(&client->queue_input,
1054                                             &server->queue_input,
1055                                             100, &(client->socket),
1056                                             &client->dtls_context);
1057     TEST_EQUAL(ret, 0);
1058 
1059     ret = mbedtls_test_message_socket_setup(&server->queue_input,
1060                                             &client->queue_input,
1061                                             100, &(server->socket),
1062                                             &server->dtls_context);
1063     TEST_EQUAL(ret, 0);
1064 
1065 exit:
1066     return ret;
1067 }
1068 
mbedtls_test_move_handshake_to_state(mbedtls_ssl_context * ssl,mbedtls_ssl_context * second_ssl,int state)1069 int mbedtls_test_move_handshake_to_state(mbedtls_ssl_context *ssl,
1070                                          mbedtls_ssl_context *second_ssl,
1071                                          int state)
1072 {
1073     enum { BUFFSIZE = 1024 };
1074     int max_steps = 1000;
1075     int ret = 0;
1076 
1077     if (ssl == NULL || second_ssl == NULL) {
1078         return MBEDTLS_ERR_SSL_BAD_INPUT_DATA;
1079     }
1080 
1081     /* Perform communication via connected sockets */
1082     while ((ssl->state != state) && (--max_steps >= 0)) {
1083         /* If /p second_ssl ends the handshake procedure before /p ssl then
1084          * there is no need to call the next step */
1085         if (!mbedtls_ssl_is_handshake_over(second_ssl)) {
1086             ret = mbedtls_ssl_handshake_step(second_ssl);
1087             if (ret != 0 && ret != MBEDTLS_ERR_SSL_WANT_READ &&
1088                 ret != MBEDTLS_ERR_SSL_WANT_WRITE) {
1089                 return ret;
1090             }
1091         }
1092 
1093         /* We only care about the \p ssl state and returns, so we call it last,
1094          * to leave the iteration as soon as the state is as expected. */
1095         ret = mbedtls_ssl_handshake_step(ssl);
1096         if (ret != 0 && ret != MBEDTLS_ERR_SSL_WANT_READ &&
1097             ret != MBEDTLS_ERR_SSL_WANT_WRITE) {
1098             return ret;
1099         }
1100     }
1101 
1102     return (max_steps >= 0) ? ret : -1;
1103 }
1104 
1105 #endif /* MBEDTLS_SSL_HANDSHAKE_WITH_CERT_ENABLED */
1106 
1107 /*
1108  * Write application data. Increase write counter if necessary.
1109  */
mbedtls_ssl_write_fragment(mbedtls_ssl_context * ssl,unsigned char * buf,int buf_len,int * written,const int expected_fragments)1110 static int mbedtls_ssl_write_fragment(mbedtls_ssl_context *ssl,
1111                                       unsigned char *buf, int buf_len,
1112                                       int *written,
1113                                       const int expected_fragments)
1114 {
1115     int ret;
1116     /* Verify that calling mbedtls_ssl_write with a NULL buffer and zero length is
1117      * a valid no-op for TLS connections. */
1118     if (ssl->conf->transport != MBEDTLS_SSL_TRANSPORT_DATAGRAM) {
1119         TEST_EQUAL(mbedtls_ssl_write(ssl, NULL, 0), 0);
1120     }
1121 
1122     ret = mbedtls_ssl_write(ssl, buf + *written, buf_len - *written);
1123     if (ret > 0) {
1124         *written += ret;
1125     }
1126 
1127     if (expected_fragments == 0) {
1128         /* Used for DTLS and the message size larger than MFL. In that case
1129          * the message can not be fragmented and the library should return
1130          * MBEDTLS_ERR_SSL_BAD_INPUT_DATA error. This error must be returned
1131          * to prevent a dead loop inside mbedtls_test_ssl_exchange_data(). */
1132         return ret;
1133     } else if (expected_fragments == 1) {
1134         /* Used for TLS/DTLS and the message size lower than MFL */
1135         TEST_ASSERT(ret == buf_len ||
1136                     ret == MBEDTLS_ERR_SSL_WANT_READ ||
1137                     ret == MBEDTLS_ERR_SSL_WANT_WRITE);
1138     } else {
1139         /* Used for TLS and the message size larger than MFL */
1140         TEST_ASSERT(expected_fragments > 1);
1141         TEST_ASSERT((ret >= 0 && ret <= buf_len) ||
1142                     ret == MBEDTLS_ERR_SSL_WANT_READ ||
1143                     ret == MBEDTLS_ERR_SSL_WANT_WRITE);
1144     }
1145 
1146     return 0;
1147 
1148 exit:
1149     /* Some of the tests failed */
1150     return -1;
1151 }
1152 
1153 /*
1154  * Read application data and increase read counter and fragments counter
1155  * if necessary.
1156  */
mbedtls_ssl_read_fragment(mbedtls_ssl_context * ssl,unsigned char * buf,int buf_len,int * read,int * fragments,const int expected_fragments)1157 static int mbedtls_ssl_read_fragment(mbedtls_ssl_context *ssl,
1158                                      unsigned char *buf, int buf_len,
1159                                      int *read, int *fragments,
1160                                      const int expected_fragments)
1161 {
1162     int ret;
1163     /* Verify that calling mbedtls_ssl_write with a NULL buffer and zero length is
1164      * a valid no-op for TLS connections. */
1165     if (ssl->conf->transport != MBEDTLS_SSL_TRANSPORT_DATAGRAM) {
1166         TEST_EQUAL(mbedtls_ssl_read(ssl, NULL, 0), 0);
1167     }
1168 
1169     ret = mbedtls_ssl_read(ssl, buf + *read, buf_len - *read);
1170     if (ret > 0) {
1171         (*fragments)++;
1172         *read += ret;
1173     }
1174 
1175     if (expected_fragments == 0) {
1176         TEST_EQUAL(ret, 0);
1177     } else if (expected_fragments == 1) {
1178         TEST_ASSERT(ret == buf_len ||
1179                     ret == MBEDTLS_ERR_SSL_WANT_READ ||
1180                     ret == MBEDTLS_ERR_SSL_WANT_WRITE);
1181     } else {
1182         TEST_ASSERT(expected_fragments > 1);
1183         TEST_ASSERT((ret >= 0 && ret <= buf_len) ||
1184                     ret == MBEDTLS_ERR_SSL_WANT_READ ||
1185                     ret == MBEDTLS_ERR_SSL_WANT_WRITE);
1186     }
1187 
1188     return 0;
1189 
1190 exit:
1191     /* Some of the tests failed */
1192     return -1;
1193 }
1194 
1195 #if defined(MBEDTLS_SSL_PROTO_TLS1_2) && \
1196     defined(PSA_WANT_ALG_CBC_NO_PADDING) && defined(PSA_WANT_KEY_TYPE_AES)
mbedtls_test_psa_cipher_encrypt_helper(mbedtls_ssl_transform * transform,const unsigned char * iv,size_t iv_len,const unsigned char * input,size_t ilen,unsigned char * output,size_t * olen)1197 int mbedtls_test_psa_cipher_encrypt_helper(mbedtls_ssl_transform *transform,
1198                                            const unsigned char *iv,
1199                                            size_t iv_len,
1200                                            const unsigned char *input,
1201                                            size_t ilen,
1202                                            unsigned char *output,
1203                                            size_t *olen)
1204 {
1205     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
1206     psa_cipher_operation_t cipher_op = PSA_CIPHER_OPERATION_INIT;
1207     size_t part_len;
1208 
1209     status = psa_cipher_encrypt_setup(&cipher_op,
1210                                       transform->psa_key_enc,
1211                                       transform->psa_alg);
1212 
1213     if (status != PSA_SUCCESS) {
1214         return PSA_TO_MBEDTLS_ERR(status);
1215     }
1216 
1217     status = psa_cipher_set_iv(&cipher_op, iv, iv_len);
1218 
1219     if (status != PSA_SUCCESS) {
1220         return PSA_TO_MBEDTLS_ERR(status);
1221     }
1222 
1223     status = psa_cipher_update(&cipher_op, input, ilen, output, ilen, olen);
1224 
1225     if (status != PSA_SUCCESS) {
1226         return PSA_TO_MBEDTLS_ERR(status);
1227     }
1228 
1229     status = psa_cipher_finish(&cipher_op, output + *olen, ilen - *olen,
1230                                &part_len);
1231 
1232     if (status != PSA_SUCCESS) {
1233         return PSA_TO_MBEDTLS_ERR(status);
1234     }
1235 
1236     *olen += part_len;
1237     return 0;
1238 }
1239 #endif /* MBEDTLS_SSL_PROTO_TLS1_2 && PSA_WANT_ALG_CBC_NO_PADDING &&
1240           PSA_WANT_KEY_TYPE_AES */
1241 
mbedtls_test_ssl_cipher_info_from_type(mbedtls_cipher_type_t cipher_type,mbedtls_cipher_mode_t * cipher_mode,size_t * key_bits,size_t * iv_len)1242 static void mbedtls_test_ssl_cipher_info_from_type(mbedtls_cipher_type_t cipher_type,
1243                                                    mbedtls_cipher_mode_t *cipher_mode,
1244                                                    size_t *key_bits, size_t *iv_len)
1245 {
1246     switch (cipher_type) {
1247         case MBEDTLS_CIPHER_AES_128_CBC:
1248             *cipher_mode = MBEDTLS_MODE_CBC;
1249             *key_bits = 128;
1250             *iv_len = 16;
1251             break;
1252         case MBEDTLS_CIPHER_AES_256_CBC:
1253             *cipher_mode = MBEDTLS_MODE_CBC;
1254             *key_bits = 256;
1255             *iv_len = 16;
1256             break;
1257         case MBEDTLS_CIPHER_ARIA_128_CBC:
1258             *cipher_mode = MBEDTLS_MODE_CBC;
1259             *key_bits = 128;
1260             *iv_len = 16;
1261             break;
1262         case MBEDTLS_CIPHER_ARIA_256_CBC:
1263             *cipher_mode = MBEDTLS_MODE_CBC;
1264             *key_bits = 256;
1265             *iv_len = 16;
1266             break;
1267         case MBEDTLS_CIPHER_CAMELLIA_128_CBC:
1268             *cipher_mode = MBEDTLS_MODE_CBC;
1269             *key_bits = 128;
1270             *iv_len = 16;
1271             break;
1272         case MBEDTLS_CIPHER_CAMELLIA_256_CBC:
1273             *cipher_mode = MBEDTLS_MODE_CBC;
1274             *key_bits = 256;
1275             *iv_len = 16;
1276             break;
1277 
1278         case MBEDTLS_CIPHER_AES_128_CCM:
1279             *cipher_mode = MBEDTLS_MODE_CCM;
1280             *key_bits = 128;
1281             *iv_len = 12;
1282             break;
1283         case MBEDTLS_CIPHER_AES_192_CCM:
1284             *cipher_mode = MBEDTLS_MODE_CCM;
1285             *key_bits = 192;
1286             *iv_len = 12;
1287             break;
1288         case MBEDTLS_CIPHER_AES_256_CCM:
1289             *cipher_mode = MBEDTLS_MODE_CCM;
1290             *key_bits = 256;
1291             *iv_len = 12;
1292             break;
1293         case MBEDTLS_CIPHER_CAMELLIA_128_CCM:
1294             *cipher_mode = MBEDTLS_MODE_CCM;
1295             *key_bits = 128;
1296             *iv_len = 12;
1297             break;
1298         case MBEDTLS_CIPHER_CAMELLIA_192_CCM:
1299             *cipher_mode = MBEDTLS_MODE_CCM;
1300             *key_bits = 192;
1301             *iv_len = 12;
1302             break;
1303         case MBEDTLS_CIPHER_CAMELLIA_256_CCM:
1304             *cipher_mode = MBEDTLS_MODE_CCM;
1305             *key_bits = 256;
1306             *iv_len = 12;
1307             break;
1308 
1309         case MBEDTLS_CIPHER_AES_128_GCM:
1310             *cipher_mode = MBEDTLS_MODE_GCM;
1311             *key_bits = 128;
1312             *iv_len = 12;
1313             break;
1314         case MBEDTLS_CIPHER_AES_192_GCM:
1315             *cipher_mode = MBEDTLS_MODE_GCM;
1316             *key_bits = 192;
1317             *iv_len = 12;
1318             break;
1319         case MBEDTLS_CIPHER_AES_256_GCM:
1320             *cipher_mode = MBEDTLS_MODE_GCM;
1321             *key_bits = 256;
1322             *iv_len = 12;
1323             break;
1324         case MBEDTLS_CIPHER_CAMELLIA_128_GCM:
1325             *cipher_mode = MBEDTLS_MODE_GCM;
1326             *key_bits = 128;
1327             *iv_len = 12;
1328             break;
1329         case MBEDTLS_CIPHER_CAMELLIA_192_GCM:
1330             *cipher_mode = MBEDTLS_MODE_GCM;
1331             *key_bits = 192;
1332             *iv_len = 12;
1333             break;
1334         case MBEDTLS_CIPHER_CAMELLIA_256_GCM:
1335             *cipher_mode = MBEDTLS_MODE_GCM;
1336             *key_bits = 256;
1337             *iv_len = 12;
1338             break;
1339 
1340         case MBEDTLS_CIPHER_CHACHA20_POLY1305:
1341             *cipher_mode = MBEDTLS_MODE_CHACHAPOLY;
1342             *key_bits = 256;
1343             *iv_len = 12;
1344             break;
1345 
1346         case MBEDTLS_CIPHER_NULL:
1347             *cipher_mode = MBEDTLS_MODE_STREAM;
1348             *key_bits = 0;
1349             *iv_len = 0;
1350             break;
1351 
1352         default:
1353             *cipher_mode = MBEDTLS_MODE_NONE;
1354             *key_bits = 0;
1355             *iv_len = 0;
1356     }
1357 }
1358 
mbedtls_test_ssl_build_transforms(mbedtls_ssl_transform * t_in,mbedtls_ssl_transform * t_out,int cipher_type,int hash_id,int etm,int tag_mode,mbedtls_ssl_protocol_version tls_version,size_t cid0_len,size_t cid1_len)1359 int mbedtls_test_ssl_build_transforms(mbedtls_ssl_transform *t_in,
1360                                       mbedtls_ssl_transform *t_out,
1361                                       int cipher_type, int hash_id,
1362                                       int etm, int tag_mode,
1363                                       mbedtls_ssl_protocol_version tls_version,
1364                                       size_t cid0_len,
1365                                       size_t cid1_len)
1366 {
1367     mbedtls_cipher_mode_t cipher_mode = MBEDTLS_MODE_NONE;
1368     size_t key_bits = 0;
1369     int ret = 0;
1370 
1371     psa_key_type_t key_type;
1372     psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
1373     psa_algorithm_t alg;
1374     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
1375 
1376     size_t keylen, maclen, ivlen = 0;
1377     unsigned char *key0 = NULL, *key1 = NULL;
1378     unsigned char *md0 = NULL, *md1 = NULL;
1379     unsigned char iv_enc[16], iv_dec[16];
1380 
1381 #if defined(MBEDTLS_SSL_DTLS_CONNECTION_ID)
1382     unsigned char cid0[SSL_CID_LEN_MIN];
1383     unsigned char cid1[SSL_CID_LEN_MIN];
1384 
1385     mbedtls_test_rnd_std_rand(NULL, cid0, sizeof(cid0));
1386     mbedtls_test_rnd_std_rand(NULL, cid1, sizeof(cid1));
1387 #else
1388     ((void) cid0_len);
1389     ((void) cid1_len);
1390 #endif /* MBEDTLS_SSL_DTLS_CONNECTION_ID */
1391 
1392     maclen = 0;
1393     mbedtls_test_ssl_cipher_info_from_type((mbedtls_cipher_type_t) cipher_type,
1394                                            &cipher_mode, &key_bits, &ivlen);
1395 
1396     /* Pick keys */
1397     keylen = key_bits / 8;
1398     /* Allocate `keylen + 1` bytes to ensure that we get
1399      * a non-NULL pointers from `mbedtls_calloc` even if
1400      * `keylen == 0` in the case of the NULL cipher. */
1401     CHK((key0 = mbedtls_calloc(1, keylen + 1)) != NULL);
1402     CHK((key1 = mbedtls_calloc(1, keylen + 1)) != NULL);
1403     memset(key0, 0x1, keylen);
1404     memset(key1, 0x2, keylen);
1405 
1406     /* Setup MAC contexts */
1407 #if defined(MBEDTLS_SSL_SOME_SUITES_USE_MAC)
1408     if (cipher_mode == MBEDTLS_MODE_CBC ||
1409         cipher_mode == MBEDTLS_MODE_STREAM) {
1410         maclen = mbedtls_md_get_size_from_type((mbedtls_md_type_t) hash_id);
1411         CHK(maclen != 0);
1412         /* Pick hash keys */
1413         CHK((md0 = mbedtls_calloc(1, maclen)) != NULL);
1414         CHK((md1 = mbedtls_calloc(1, maclen)) != NULL);
1415         memset(md0, 0x5, maclen);
1416         memset(md1, 0x6, maclen);
1417 
1418         alg = mbedtls_md_psa_alg_from_type(hash_id);
1419 
1420         CHK(alg != 0);
1421 
1422         t_out->psa_mac_alg = PSA_ALG_HMAC(alg);
1423         t_in->psa_mac_alg = PSA_ALG_HMAC(alg);
1424         t_in->psa_mac_enc = MBEDTLS_SVC_KEY_ID_INIT;
1425         t_out->psa_mac_enc = MBEDTLS_SVC_KEY_ID_INIT;
1426         t_in->psa_mac_dec = MBEDTLS_SVC_KEY_ID_INIT;
1427         t_out->psa_mac_dec = MBEDTLS_SVC_KEY_ID_INIT;
1428 
1429         psa_reset_key_attributes(&attributes);
1430         psa_set_key_usage_flags(&attributes, PSA_KEY_USAGE_SIGN_MESSAGE);
1431         psa_set_key_algorithm(&attributes, PSA_ALG_HMAC(alg));
1432         psa_set_key_type(&attributes, PSA_KEY_TYPE_HMAC);
1433 
1434         CHK(psa_import_key(&attributes,
1435                            md0, maclen,
1436                            &t_in->psa_mac_enc) == PSA_SUCCESS);
1437 
1438         CHK(psa_import_key(&attributes,
1439                            md1, maclen,
1440                            &t_out->psa_mac_enc) == PSA_SUCCESS);
1441 
1442         if (cipher_mode == MBEDTLS_MODE_STREAM ||
1443             etm == MBEDTLS_SSL_ETM_DISABLED) {
1444             /* mbedtls_ct_hmac() requires the key to be exportable */
1445             psa_set_key_usage_flags(&attributes, PSA_KEY_USAGE_EXPORT |
1446                                     PSA_KEY_USAGE_VERIFY_HASH);
1447         } else {
1448             psa_set_key_usage_flags(&attributes, PSA_KEY_USAGE_VERIFY_HASH);
1449         }
1450 
1451         CHK(psa_import_key(&attributes,
1452                            md1, maclen,
1453                            &t_in->psa_mac_dec) == PSA_SUCCESS);
1454 
1455         CHK(psa_import_key(&attributes,
1456                            md0, maclen,
1457                            &t_out->psa_mac_dec) == PSA_SUCCESS);
1458     }
1459 #else
1460     ((void) hash_id);
1461 #endif /* MBEDTLS_SSL_SOME_SUITES_USE_MAC */
1462 
1463 
1464     /* Pick IV's (regardless of whether they
1465      * are being used by the transform). */
1466     memset(iv_enc, 0x3, sizeof(iv_enc));
1467     memset(iv_dec, 0x4, sizeof(iv_dec));
1468 
1469     /*
1470      * Setup transforms
1471      */
1472 
1473 #if defined(MBEDTLS_SSL_ENCRYPT_THEN_MAC) && \
1474     defined(MBEDTLS_SSL_SOME_SUITES_USE_MAC)
1475     t_out->encrypt_then_mac = etm;
1476     t_in->encrypt_then_mac = etm;
1477 #else
1478     ((void) etm);
1479 #endif
1480 
1481     t_out->tls_version = tls_version;
1482     t_in->tls_version = tls_version;
1483     t_out->ivlen = ivlen;
1484     t_in->ivlen = ivlen;
1485 
1486     switch (cipher_mode) {
1487         case MBEDTLS_MODE_GCM:
1488         case MBEDTLS_MODE_CCM:
1489 #if defined(MBEDTLS_SSL_PROTO_TLS1_3)
1490             if (tls_version == MBEDTLS_SSL_VERSION_TLS1_3) {
1491                 t_out->fixed_ivlen = 12;
1492                 t_in->fixed_ivlen  = 12;
1493             } else
1494 #endif /* MBEDTLS_SSL_PROTO_TLS1_3 */
1495             {
1496                 t_out->fixed_ivlen = 4;
1497                 t_in->fixed_ivlen = 4;
1498             }
1499             t_out->maclen = 0;
1500             t_in->maclen = 0;
1501             switch (tag_mode) {
1502                 case 0: /* Full tag */
1503                     t_out->taglen = 16;
1504                     t_in->taglen = 16;
1505                     break;
1506                 case 1: /* Partial tag */
1507                     t_out->taglen = 8;
1508                     t_in->taglen = 8;
1509                     break;
1510                 default:
1511                     ret = 1;
1512                     goto cleanup;
1513             }
1514             break;
1515 
1516         case MBEDTLS_MODE_CHACHAPOLY:
1517             t_out->fixed_ivlen = 12;
1518             t_in->fixed_ivlen = 12;
1519             t_out->maclen = 0;
1520             t_in->maclen = 0;
1521             switch (tag_mode) {
1522                 case 0: /* Full tag */
1523                     t_out->taglen = 16;
1524                     t_in->taglen = 16;
1525                     break;
1526                 case 1: /* Partial tag */
1527                     t_out->taglen = 8;
1528                     t_in->taglen = 8;
1529                     break;
1530                 default:
1531                     ret = 1;
1532                     goto cleanup;
1533             }
1534             break;
1535 
1536         case MBEDTLS_MODE_STREAM:
1537         case MBEDTLS_MODE_CBC:
1538             t_out->fixed_ivlen = 0; /* redundant, must be 0 */
1539             t_in->fixed_ivlen = 0;  /* redundant, must be 0 */
1540             t_out->taglen = 0;
1541             t_in->taglen = 0;
1542             switch (tag_mode) {
1543                 case 0: /* Full tag */
1544                     t_out->maclen = maclen;
1545                     t_in->maclen = maclen;
1546                     break;
1547                 default:
1548                     ret = 1;
1549                     goto cleanup;
1550             }
1551             break;
1552         default:
1553             ret = 1;
1554             goto cleanup;
1555             break;
1556     }
1557 
1558     /* Setup IV's */
1559 
1560     memcpy(&t_in->iv_dec, iv_dec, sizeof(iv_dec));
1561     memcpy(&t_in->iv_enc, iv_enc, sizeof(iv_enc));
1562     memcpy(&t_out->iv_dec, iv_enc, sizeof(iv_enc));
1563     memcpy(&t_out->iv_enc, iv_dec, sizeof(iv_dec));
1564 
1565 #if defined(MBEDTLS_SSL_DTLS_CONNECTION_ID)
1566     /* Add CID */
1567     memcpy(&t_in->in_cid,  cid0, cid0_len);
1568     memcpy(&t_in->out_cid, cid1, cid1_len);
1569     t_in->in_cid_len = (uint8_t) cid0_len;
1570     t_in->out_cid_len = (uint8_t) cid1_len;
1571     memcpy(&t_out->in_cid,  cid1, cid1_len);
1572     memcpy(&t_out->out_cid, cid0, cid0_len);
1573     t_out->in_cid_len = (uint8_t) cid1_len;
1574     t_out->out_cid_len = (uint8_t) cid0_len;
1575 #endif /* MBEDTLS_SSL_DTLS_CONNECTION_ID */
1576 
1577     status = mbedtls_ssl_cipher_to_psa(cipher_type,
1578                                        t_in->taglen,
1579                                        &alg,
1580                                        &key_type,
1581                                        &key_bits);
1582 
1583     if (status != PSA_SUCCESS) {
1584         ret = PSA_TO_MBEDTLS_ERR(status);
1585         goto cleanup;
1586     }
1587 
1588     t_in->psa_alg = alg;
1589     t_out->psa_alg = alg;
1590 
1591     if (alg != MBEDTLS_SSL_NULL_CIPHER) {
1592         psa_reset_key_attributes(&attributes);
1593         psa_set_key_usage_flags(&attributes, PSA_KEY_USAGE_ENCRYPT);
1594         psa_set_key_algorithm(&attributes, alg);
1595         psa_set_key_type(&attributes, key_type);
1596 
1597         status = psa_import_key(&attributes,
1598                                 key0,
1599                                 PSA_BITS_TO_BYTES(key_bits),
1600                                 &t_in->psa_key_enc);
1601 
1602         if (status != PSA_SUCCESS) {
1603             ret = PSA_TO_MBEDTLS_ERR(status);
1604             goto cleanup;
1605         }
1606 
1607         status = psa_import_key(&attributes,
1608                                 key1,
1609                                 PSA_BITS_TO_BYTES(key_bits),
1610                                 &t_out->psa_key_enc);
1611 
1612         if (status != PSA_SUCCESS) {
1613             ret = PSA_TO_MBEDTLS_ERR(status);
1614             goto cleanup;
1615         }
1616 
1617         psa_set_key_usage_flags(&attributes, PSA_KEY_USAGE_DECRYPT);
1618 
1619         status = psa_import_key(&attributes,
1620                                 key1,
1621                                 PSA_BITS_TO_BYTES(key_bits),
1622                                 &t_in->psa_key_dec);
1623 
1624         if (status != PSA_SUCCESS) {
1625             ret = PSA_TO_MBEDTLS_ERR(status);
1626             goto cleanup;
1627         }
1628 
1629         status = psa_import_key(&attributes,
1630                                 key0,
1631                                 PSA_BITS_TO_BYTES(key_bits),
1632                                 &t_out->psa_key_dec);
1633 
1634         if (status != PSA_SUCCESS) {
1635             ret = PSA_TO_MBEDTLS_ERR(status);
1636             goto cleanup;
1637         }
1638     }
1639 
1640 cleanup:
1641 
1642     mbedtls_free(key0);
1643     mbedtls_free(key1);
1644 
1645     mbedtls_free(md0);
1646     mbedtls_free(md1);
1647 
1648     return ret;
1649 }
1650 
1651 #if defined(MBEDTLS_SSL_SOME_SUITES_USE_MAC)
mbedtls_test_ssl_prepare_record_mac(mbedtls_record * record,mbedtls_ssl_transform * transform_out)1652 int mbedtls_test_ssl_prepare_record_mac(mbedtls_record *record,
1653                                         mbedtls_ssl_transform *transform_out)
1654 {
1655     psa_mac_operation_t operation = PSA_MAC_OPERATION_INIT;
1656 
1657     /* Serialized version of record header for MAC purposes */
1658     unsigned char add_data[13];
1659     memcpy(add_data, record->ctr, 8);
1660     add_data[8] = record->type;
1661     add_data[9] = record->ver[0];
1662     add_data[10] = record->ver[1];
1663     add_data[11] = (record->data_len >> 8) & 0xff;
1664     add_data[12] = (record->data_len >> 0) & 0xff;
1665 
1666     /* MAC with additional data */
1667     size_t sign_mac_length = 0;
1668     TEST_EQUAL(PSA_SUCCESS, psa_mac_sign_setup(&operation,
1669                                                transform_out->psa_mac_enc,
1670                                                transform_out->psa_mac_alg));
1671     TEST_EQUAL(PSA_SUCCESS, psa_mac_update(&operation, add_data, 13));
1672     TEST_EQUAL(PSA_SUCCESS, psa_mac_update(&operation,
1673                                            record->buf + record->data_offset,
1674                                            record->data_len));
1675     /* Use a temporary buffer for the MAC, because with the truncated HMAC
1676      * extension, there might not be enough room in the record for the
1677      * full-length MAC. */
1678     unsigned char mac[PSA_HASH_MAX_SIZE];
1679     TEST_EQUAL(PSA_SUCCESS, psa_mac_sign_finish(&operation,
1680                                                 mac, sizeof(mac),
1681                                                 &sign_mac_length));
1682     memcpy(record->buf + record->data_offset + record->data_len, mac, transform_out->maclen);
1683     record->data_len += transform_out->maclen;
1684 
1685     return 0;
1686 
1687 exit:
1688     psa_mac_abort(&operation);
1689     return -1;
1690 }
1691 #endif /* MBEDTLS_SSL_SOME_SUITES_USE_MAC */
1692 
1693 #if defined(MBEDTLS_SSL_PROTO_TLS1_2)
mbedtls_test_ssl_tls12_populate_session(mbedtls_ssl_session * session,int ticket_len,int endpoint_type,const char * crt_file)1694 int mbedtls_test_ssl_tls12_populate_session(mbedtls_ssl_session *session,
1695                                             int ticket_len,
1696                                             int endpoint_type,
1697                                             const char *crt_file)
1698 {
1699     (void) ticket_len;
1700 
1701 #if defined(MBEDTLS_HAVE_TIME)
1702     session->start = mbedtls_time(NULL) - 42;
1703 #endif
1704     session->tls_version = MBEDTLS_SSL_VERSION_TLS1_2;
1705 
1706     TEST_ASSERT(endpoint_type == MBEDTLS_SSL_IS_CLIENT ||
1707                 endpoint_type == MBEDTLS_SSL_IS_SERVER);
1708 
1709     session->endpoint = endpoint_type;
1710     session->ciphersuite = 0xabcd;
1711     session->id_len = sizeof(session->id);
1712     memset(session->id, 66, session->id_len);
1713     memset(session->master, 17, sizeof(session->master));
1714 
1715 #if defined(MBEDTLS_SSL_HANDSHAKE_WITH_CERT_ENABLED) && defined(MBEDTLS_FS_IO)
1716     if (crt_file != NULL && strlen(crt_file) != 0) {
1717         mbedtls_x509_crt tmp_crt;
1718         int ret;
1719 
1720         mbedtls_x509_crt_init(&tmp_crt);
1721         ret = mbedtls_x509_crt_parse_file(&tmp_crt, crt_file);
1722         if (ret != 0) {
1723             return ret;
1724         }
1725 
1726 #if defined(MBEDTLS_SSL_KEEP_PEER_CERTIFICATE)
1727         /* Move temporary CRT. */
1728         session->peer_cert = mbedtls_calloc(1, sizeof(*session->peer_cert));
1729         if (session->peer_cert == NULL) {
1730             return -1;
1731         }
1732         *session->peer_cert = tmp_crt;
1733         memset(&tmp_crt, 0, sizeof(tmp_crt));
1734 #else /* MBEDTLS_SSL_KEEP_PEER_CERTIFICATE */
1735         /* Calculate digest of temporary CRT. */
1736         session->peer_cert_digest =
1737             mbedtls_calloc(1, MBEDTLS_SSL_PEER_CERT_DIGEST_DFL_LEN);
1738         if (session->peer_cert_digest == NULL) {
1739             return -1;
1740         }
1741 
1742         psa_algorithm_t psa_alg = mbedtls_md_psa_alg_from_type(
1743             MBEDTLS_SSL_PEER_CERT_DIGEST_DFL_TYPE);
1744         size_t hash_size = 0;
1745         psa_status_t status = psa_hash_compute(
1746             psa_alg, tmp_crt.raw.p,
1747             tmp_crt.raw.len,
1748             session->peer_cert_digest,
1749             MBEDTLS_SSL_PEER_CERT_DIGEST_DFL_LEN,
1750             &hash_size);
1751         ret = PSA_TO_MBEDTLS_ERR(status);
1752         if (ret != 0) {
1753             return ret;
1754         }
1755         session->peer_cert_digest_type =
1756             MBEDTLS_SSL_PEER_CERT_DIGEST_DFL_TYPE;
1757         session->peer_cert_digest_len =
1758             MBEDTLS_SSL_PEER_CERT_DIGEST_DFL_LEN;
1759 #endif /* MBEDTLS_SSL_KEEP_PEER_CERTIFICATE */
1760 
1761         mbedtls_x509_crt_free(&tmp_crt);
1762     }
1763 #else /* MBEDTLS_SSL_HANDSHAKE_WITH_CERT_ENABLED && MBEDTLS_FS_IO */
1764     (void) crt_file;
1765 #endif /* MBEDTLS_SSL_HANDSHAKE_WITH_CERT_ENABLED && MBEDTLS_FS_IO */
1766     session->verify_result = 0xdeadbeef;
1767 
1768 #if defined(MBEDTLS_SSL_SESSION_TICKETS)
1769 #if defined(MBEDTLS_SSL_CLI_C)
1770     if (ticket_len != 0) {
1771         session->ticket = mbedtls_calloc(1, ticket_len);
1772         if (session->ticket == NULL) {
1773             return -1;
1774         }
1775         memset(session->ticket, 33, ticket_len);
1776     }
1777     session->ticket_len = ticket_len;
1778     session->ticket_lifetime = 86401;
1779 #endif /* MBEDTLS_SSL_CLI_C */
1780 
1781 #if defined(MBEDTLS_SSL_SRV_C) && defined(MBEDTLS_HAVE_TIME)
1782     if (session->endpoint == MBEDTLS_SSL_IS_SERVER) {
1783         session->ticket_creation_time = mbedtls_ms_time() - 42;
1784     }
1785 #endif
1786 #endif /* MBEDTLS_SSL_SESSION_TICKETS */
1787 
1788 #if defined(MBEDTLS_SSL_MAX_FRAGMENT_LENGTH)
1789     session->mfl_code = 1;
1790 #endif
1791 #if defined(MBEDTLS_SSL_ENCRYPT_THEN_MAC)
1792     session->encrypt_then_mac = 1;
1793 #endif
1794 
1795 exit:
1796     return 0;
1797 }
1798 #endif /* MBEDTLS_SSL_PROTO_TLS1_2 */
1799 
1800 #if defined(MBEDTLS_SSL_PROTO_TLS1_3)
mbedtls_test_ssl_tls13_populate_session(mbedtls_ssl_session * session,int ticket_len,int endpoint_type)1801 int mbedtls_test_ssl_tls13_populate_session(mbedtls_ssl_session *session,
1802                                             int ticket_len,
1803                                             int endpoint_type)
1804 {
1805     ((void) ticket_len);
1806     session->tls_version = MBEDTLS_SSL_VERSION_TLS1_3;
1807     session->endpoint = endpoint_type == MBEDTLS_SSL_IS_CLIENT ?
1808                         MBEDTLS_SSL_IS_CLIENT : MBEDTLS_SSL_IS_SERVER;
1809     session->ciphersuite = 0xabcd;
1810 
1811 #if defined(MBEDTLS_SSL_SESSION_TICKETS)
1812     session->ticket_age_add = 0x87654321;
1813     session->ticket_flags = 0x7;
1814     session->resumption_key_len = 32;
1815     memset(session->resumption_key, 0x99, sizeof(session->resumption_key));
1816 #endif
1817 
1818 #if defined(MBEDTLS_SSL_SRV_C)
1819     if (session->endpoint == MBEDTLS_SSL_IS_SERVER) {
1820 #if defined(MBEDTLS_SSL_SESSION_TICKETS)
1821 #if defined(MBEDTLS_SSL_EARLY_DATA) && defined(MBEDTLS_SSL_ALPN)
1822         int ret = mbedtls_ssl_session_set_ticket_alpn(session, "ALPNExample");
1823         if (ret != 0) {
1824             return -1;
1825         }
1826 #endif
1827 #if defined(MBEDTLS_HAVE_TIME)
1828         session->ticket_creation_time = mbedtls_ms_time() - 42;
1829 #endif
1830 #endif /* MBEDTLS_SSL_SESSION_TICKETS */
1831     }
1832 #endif /* MBEDTLS_SSL_SRV_C */
1833 
1834 #if defined(MBEDTLS_SSL_CLI_C)
1835     if (session->endpoint == MBEDTLS_SSL_IS_CLIENT) {
1836 #if defined(MBEDTLS_SSL_SESSION_TICKETS)
1837 #if defined(MBEDTLS_HAVE_TIME)
1838         session->ticket_reception_time = mbedtls_ms_time() - 40;
1839 #endif
1840         session->ticket_lifetime = 0xfedcba98;
1841 
1842         session->ticket_len = ticket_len;
1843         if (ticket_len != 0) {
1844             session->ticket = mbedtls_calloc(1, ticket_len);
1845             if (session->ticket == NULL) {
1846                 return -1;
1847             }
1848             memset(session->ticket, 33, ticket_len);
1849         }
1850 #if defined(MBEDTLS_SSL_SERVER_NAME_INDICATION)
1851         char hostname[] = "hostname example";
1852         session->hostname = mbedtls_calloc(1, sizeof(hostname));
1853         if (session->hostname == NULL) {
1854             return -1;
1855         }
1856         memcpy(session->hostname, hostname, sizeof(hostname));
1857 #endif
1858 #endif /* MBEDTLS_SSL_SESSION_TICKETS */
1859     }
1860 #endif /* MBEDTLS_SSL_CLI_C */
1861 
1862 #if defined(MBEDTLS_SSL_EARLY_DATA)
1863     session->max_early_data_size = 0x87654321;
1864 #endif /* MBEDTLS_SSL_EARLY_DATA */
1865 
1866 #if defined(MBEDTLS_SSL_RECORD_SIZE_LIMIT)
1867     session->record_size_limit = 2048;
1868 #endif
1869 
1870     return 0;
1871 }
1872 #endif /* MBEDTLS_SSL_PROTO_TLS1_3 */
1873 
mbedtls_test_ssl_exchange_data(mbedtls_ssl_context * ssl_1,int msg_len_1,const int expected_fragments_1,mbedtls_ssl_context * ssl_2,int msg_len_2,const int expected_fragments_2)1874 int mbedtls_test_ssl_exchange_data(
1875     mbedtls_ssl_context *ssl_1,
1876     int msg_len_1, const int expected_fragments_1,
1877     mbedtls_ssl_context *ssl_2,
1878     int msg_len_2, const int expected_fragments_2)
1879 {
1880     unsigned char *msg_buf_1 = malloc(msg_len_1);
1881     unsigned char *msg_buf_2 = malloc(msg_len_2);
1882     unsigned char *in_buf_1  = malloc(msg_len_2);
1883     unsigned char *in_buf_2  = malloc(msg_len_1);
1884     int msg_type, ret = -1;
1885 
1886     /* Perform this test with two message types. At first use a message
1887      * consisting of only 0x00 for the client and only 0xFF for the server.
1888      * At the second time use message with generated data */
1889     for (msg_type = 0; msg_type < 2; msg_type++) {
1890         int written_1 = 0;
1891         int written_2 = 0;
1892         int read_1 = 0;
1893         int read_2 = 0;
1894         int fragments_1 = 0;
1895         int fragments_2 = 0;
1896 
1897         if (msg_type == 0) {
1898             memset(msg_buf_1, 0x00, msg_len_1);
1899             memset(msg_buf_2, 0xff, msg_len_2);
1900         } else {
1901             int i, j = 0;
1902             for (i = 0; i < msg_len_1; i++) {
1903                 msg_buf_1[i] = j++ & 0xFF;
1904             }
1905             for (i = 0; i < msg_len_2; i++) {
1906                 msg_buf_2[i] = (j -= 5) & 0xFF;
1907             }
1908         }
1909 
1910         while (read_1 < msg_len_2 || read_2 < msg_len_1) {
1911             /* ssl_1 sending */
1912             if (msg_len_1 > written_1) {
1913                 ret = mbedtls_ssl_write_fragment(ssl_1, msg_buf_1,
1914                                                  msg_len_1, &written_1,
1915                                                  expected_fragments_1);
1916                 if (expected_fragments_1 == 0) {
1917                     /* This error is expected when the message is too large and
1918                      * cannot be fragmented */
1919                     TEST_EQUAL(ret, MBEDTLS_ERR_SSL_BAD_INPUT_DATA);
1920                     msg_len_1 = 0;
1921                 } else {
1922                     TEST_EQUAL(ret, 0);
1923                 }
1924             }
1925 
1926             /* ssl_2 sending */
1927             if (msg_len_2 > written_2) {
1928                 ret = mbedtls_ssl_write_fragment(ssl_2, msg_buf_2,
1929                                                  msg_len_2, &written_2,
1930                                                  expected_fragments_2);
1931                 if (expected_fragments_2 == 0) {
1932                     /* This error is expected when the message is too large and
1933                      * cannot be fragmented */
1934                     TEST_EQUAL(ret, MBEDTLS_ERR_SSL_BAD_INPUT_DATA);
1935                     msg_len_2 = 0;
1936                 } else {
1937                     TEST_EQUAL(ret, 0);
1938                 }
1939             }
1940 
1941             /* ssl_1 reading */
1942             if (read_1 < msg_len_2) {
1943                 ret = mbedtls_ssl_read_fragment(ssl_1, in_buf_1,
1944                                                 msg_len_2, &read_1,
1945                                                 &fragments_2,
1946                                                 expected_fragments_2);
1947                 TEST_EQUAL(ret, 0);
1948             }
1949 
1950             /* ssl_2 reading */
1951             if (read_2 < msg_len_1) {
1952                 ret = mbedtls_ssl_read_fragment(ssl_2, in_buf_2,
1953                                                 msg_len_1, &read_2,
1954                                                 &fragments_1,
1955                                                 expected_fragments_1);
1956                 TEST_EQUAL(ret, 0);
1957             }
1958         }
1959 
1960         ret = -1;
1961         TEST_EQUAL(0, memcmp(msg_buf_1, in_buf_2, msg_len_1));
1962         TEST_EQUAL(0, memcmp(msg_buf_2, in_buf_1, msg_len_2));
1963         TEST_EQUAL(fragments_1, expected_fragments_1);
1964         TEST_EQUAL(fragments_2, expected_fragments_2);
1965     }
1966 
1967     ret = 0;
1968 
1969 exit:
1970     free(msg_buf_1);
1971     free(in_buf_1);
1972     free(msg_buf_2);
1973     free(in_buf_2);
1974 
1975     return ret;
1976 }
1977 
1978 /*
1979  * Perform data exchanging between \p ssl_1 and \p ssl_2. Both of endpoints
1980  * must be initialized and connected beforehand.
1981  *
1982  * \retval  0 on success, otherwise error code.
1983  */
1984 #if defined(MBEDTLS_SSL_HANDSHAKE_WITH_CERT_ENABLED) && \
1985     (defined(MBEDTLS_SSL_RENEGOTIATION)              || \
1986     defined(MBEDTLS_SSL_VARIABLE_BUFFER_LENGTH))
exchange_data(mbedtls_ssl_context * ssl_1,mbedtls_ssl_context * ssl_2)1987 static int exchange_data(mbedtls_ssl_context *ssl_1,
1988                          mbedtls_ssl_context *ssl_2)
1989 {
1990     return mbedtls_test_ssl_exchange_data(ssl_1, 256, 1,
1991                                           ssl_2, 256, 1);
1992 }
1993 #endif /* MBEDTLS_SSL_HANDSHAKE_WITH_CERT_ENABLED &&
1994           (MBEDTLS_SSL_RENEGOTIATION              ||
1995           MBEDTLS_SSL_VARIABLE_BUFFER_LENGTH) */
1996 
1997 #if defined(MBEDTLS_SSL_HANDSHAKE_WITH_CERT_ENABLED)
check_ssl_version(mbedtls_ssl_protocol_version expected_negotiated_version,const mbedtls_ssl_context * client,const mbedtls_ssl_context * server)1998 static int check_ssl_version(
1999     mbedtls_ssl_protocol_version expected_negotiated_version,
2000     const mbedtls_ssl_context *client,
2001     const mbedtls_ssl_context *server)
2002 {
2003     /* First check that both sides have chosen the same version.
2004      * If so, we can make more sanity checks just on one side.
2005      * If not, something is deeply wrong. */
2006     TEST_EQUAL(client->tls_version, server->tls_version);
2007 
2008     /* Make further checks on the client to validate that the
2009      * reported data about the version is correct. */
2010     const char *version_string = mbedtls_ssl_get_version(client);
2011     mbedtls_ssl_protocol_version version_number =
2012         mbedtls_ssl_get_version_number(client);
2013 
2014     TEST_EQUAL(client->tls_version, expected_negotiated_version);
2015 
2016     if (client->conf->transport == MBEDTLS_SSL_TRANSPORT_DATAGRAM) {
2017         TEST_EQUAL(version_string[0], 'D');
2018         ++version_string;
2019     }
2020 
2021     switch (expected_negotiated_version) {
2022         case MBEDTLS_SSL_VERSION_TLS1_2:
2023             TEST_EQUAL(version_number, MBEDTLS_SSL_VERSION_TLS1_2);
2024             TEST_EQUAL(strcmp(version_string, "TLSv1.2"), 0);
2025             break;
2026 
2027         case MBEDTLS_SSL_VERSION_TLS1_3:
2028             TEST_EQUAL(version_number, MBEDTLS_SSL_VERSION_TLS1_3);
2029             TEST_EQUAL(strcmp(version_string, "TLSv1.3"), 0);
2030             break;
2031 
2032         default:
2033             TEST_FAIL(
2034                 "Version check not implemented for this protocol version");
2035     }
2036 
2037     return 1;
2038 
2039 exit:
2040     return 0;
2041 }
2042 #endif /* MBEDTLS_SSL_HANDSHAKE_WITH_CERT_ENABLED */
2043 
2044 #if defined(MBEDTLS_SSL_HANDSHAKE_WITH_CERT_ENABLED)
mbedtls_test_ssl_do_handshake_with_endpoints(mbedtls_test_ssl_endpoint * server_ep,mbedtls_test_ssl_endpoint * client_ep,mbedtls_test_handshake_test_options * options,mbedtls_ssl_protocol_version proto)2045 int mbedtls_test_ssl_do_handshake_with_endpoints(
2046     mbedtls_test_ssl_endpoint *server_ep,
2047     mbedtls_test_ssl_endpoint *client_ep,
2048     mbedtls_test_handshake_test_options *options,
2049     mbedtls_ssl_protocol_version proto)
2050 {
2051     enum { BUFFSIZE = 1024 };
2052 
2053     int ret = -1;
2054 
2055     mbedtls_platform_zeroize(server_ep, sizeof(mbedtls_test_ssl_endpoint));
2056     mbedtls_platform_zeroize(client_ep, sizeof(mbedtls_test_ssl_endpoint));
2057 
2058     mbedtls_test_init_handshake_options(options);
2059     options->server_min_version = proto;
2060     options->client_min_version = proto;
2061     options->server_max_version = proto;
2062     options->client_max_version = proto;
2063 
2064     ret = mbedtls_test_ssl_endpoint_init(client_ep, MBEDTLS_SSL_IS_CLIENT, options);
2065     if (ret != 0) {
2066         return ret;
2067     }
2068     ret = mbedtls_test_ssl_endpoint_init(server_ep, MBEDTLS_SSL_IS_SERVER, options);
2069     if (ret != 0) {
2070         return ret;
2071     }
2072 
2073     ret = mbedtls_test_mock_socket_connect(&client_ep->socket, &server_ep->socket, BUFFSIZE);
2074     if (ret != 0) {
2075         return ret;
2076     }
2077 
2078     ret = mbedtls_test_move_handshake_to_state(&server_ep->ssl,
2079                                                &client_ep->ssl,
2080                                                MBEDTLS_SSL_HANDSHAKE_OVER);
2081     if (ret != 0 && ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE) {
2082         return ret;
2083     }
2084     ret = mbedtls_test_move_handshake_to_state(&client_ep->ssl,
2085                                                &server_ep->ssl,
2086                                                MBEDTLS_SSL_HANDSHAKE_OVER);
2087     if (ret != 0 && ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE) {
2088         return ret;
2089     }
2090     if (!mbedtls_ssl_is_handshake_over(&client_ep->ssl) ||
2091         !mbedtls_ssl_is_handshake_over(&server_ep->ssl)) {
2092         return -1;
2093     }
2094 
2095     return 0;
2096 }
2097 #endif /* defined(MBEDTLS_SSL_HANDSHAKE_WITH_CERT_ENABLED) */
2098 
2099 #if defined(MBEDTLS_SSL_HANDSHAKE_WITH_CERT_ENABLED)
2100 
2101 #if defined(MBEDTLS_SSL_RENEGOTIATION)
test_renegotiation(const mbedtls_test_handshake_test_options * options,mbedtls_test_ssl_endpoint * client,mbedtls_test_ssl_endpoint * server)2102 static int test_renegotiation(const mbedtls_test_handshake_test_options *options,
2103                               mbedtls_test_ssl_endpoint *client,
2104                               mbedtls_test_ssl_endpoint *server)
2105 {
2106     int ok = 0;
2107     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
2108 
2109     (void) options; // only used in some configurations
2110 
2111     /* Start test with renegotiation */
2112     TEST_EQUAL(server->ssl.renego_status,
2113                MBEDTLS_SSL_INITIAL_HANDSHAKE);
2114     TEST_EQUAL(client->ssl.renego_status,
2115                MBEDTLS_SSL_INITIAL_HANDSHAKE);
2116 
2117     /* After calling this function for the server, it only sends a handshake
2118      * request. All renegotiation should happen during data exchanging */
2119     TEST_EQUAL(mbedtls_ssl_renegotiate(&(server->ssl)), 0);
2120     TEST_EQUAL(server->ssl.renego_status,
2121                MBEDTLS_SSL_RENEGOTIATION_PENDING);
2122     TEST_EQUAL(client->ssl.renego_status,
2123                MBEDTLS_SSL_INITIAL_HANDSHAKE);
2124 
2125     TEST_EQUAL(exchange_data(&(client->ssl), &(server->ssl)), 0);
2126     TEST_EQUAL(server->ssl.renego_status,
2127                MBEDTLS_SSL_RENEGOTIATION_DONE);
2128     TEST_EQUAL(client->ssl.renego_status,
2129                MBEDTLS_SSL_RENEGOTIATION_DONE);
2130 
2131     /* After calling mbedtls_ssl_renegotiate for the client,
2132      * all renegotiation should happen inside this function.
2133      * However in this test, we cannot perform simultaneous communication
2134      * between client and server so this function will return waiting error
2135      * on the socket. All rest of renegotiation should happen
2136      * during data exchanging */
2137     ret = mbedtls_ssl_renegotiate(&(client->ssl));
2138 #if defined(MBEDTLS_SSL_VARIABLE_BUFFER_LENGTH)
2139     if (options->resize_buffers != 0) {
2140         /* Ensure that the buffer sizes are appropriate before resizes */
2141         TEST_EQUAL(client->ssl.out_buf_len, MBEDTLS_SSL_OUT_BUFFER_LEN);
2142         TEST_EQUAL(client->ssl.in_buf_len, MBEDTLS_SSL_IN_BUFFER_LEN);
2143     }
2144 #endif
2145     TEST_ASSERT(ret == 0 ||
2146                 ret == MBEDTLS_ERR_SSL_WANT_READ ||
2147                 ret == MBEDTLS_ERR_SSL_WANT_WRITE);
2148     TEST_EQUAL(server->ssl.renego_status,
2149                MBEDTLS_SSL_RENEGOTIATION_DONE);
2150     TEST_EQUAL(client->ssl.renego_status,
2151                MBEDTLS_SSL_RENEGOTIATION_IN_PROGRESS);
2152 
2153     TEST_EQUAL(exchange_data(&(client->ssl), &(server->ssl)), 0);
2154     TEST_EQUAL(server->ssl.renego_status,
2155                MBEDTLS_SSL_RENEGOTIATION_DONE);
2156     TEST_EQUAL(client->ssl.renego_status,
2157                MBEDTLS_SSL_RENEGOTIATION_DONE);
2158 #if defined(MBEDTLS_SSL_VARIABLE_BUFFER_LENGTH)
2159     /* Validate buffer sizes after renegotiation */
2160     if (options->resize_buffers != 0) {
2161         TEST_EQUAL(client->ssl.out_buf_len,
2162                    mbedtls_ssl_get_output_buflen(&client->ssl));
2163         TEST_EQUAL(client->ssl.in_buf_len,
2164                    mbedtls_ssl_get_input_buflen(&client->ssl));
2165         TEST_EQUAL(server->ssl.out_buf_len,
2166                    mbedtls_ssl_get_output_buflen(&server->ssl));
2167         TEST_EQUAL(server->ssl.in_buf_len,
2168                    mbedtls_ssl_get_input_buflen(&server->ssl));
2169     }
2170 #endif /* MBEDTLS_SSL_VARIABLE_BUFFER_LENGTH */
2171 
2172     ok = 1;
2173 
2174 exit:
2175     return ok;
2176 }
2177 #endif /* MBEDTLS_SSL_RENEGOTIATION */
2178 
2179 #if defined(MBEDTLS_SSL_CONTEXT_SERIALIZATION)
test_serialization(const mbedtls_test_handshake_test_options * options,mbedtls_test_ssl_endpoint * client,mbedtls_test_ssl_endpoint * server)2180 static int test_serialization(const mbedtls_test_handshake_test_options *options,
2181                               mbedtls_test_ssl_endpoint *client,
2182                               mbedtls_test_ssl_endpoint *server)
2183 {
2184     int ok = 0;
2185     unsigned char *context_buf = NULL;
2186     size_t context_buf_len;
2187 
2188     TEST_EQUAL(options->dtls, 1);
2189 
2190     TEST_EQUAL(mbedtls_ssl_context_save(&(server->ssl), NULL,
2191                                         0, &context_buf_len),
2192                MBEDTLS_ERR_SSL_BUFFER_TOO_SMALL);
2193 
2194     context_buf = mbedtls_calloc(1, context_buf_len);
2195     TEST_ASSERT(context_buf != NULL);
2196 
2197     TEST_EQUAL(mbedtls_ssl_context_save(&(server->ssl), context_buf,
2198                                         context_buf_len,
2199                                         &context_buf_len),
2200                0);
2201 
2202     mbedtls_ssl_free(&(server->ssl));
2203     mbedtls_ssl_init(&(server->ssl));
2204 
2205     TEST_EQUAL(mbedtls_ssl_setup(&(server->ssl), &(server->conf)), 0);
2206 
2207     mbedtls_ssl_set_bio(&(server->ssl), &server->dtls_context,
2208                         mbedtls_test_mock_tcp_send_msg,
2209                         mbedtls_test_mock_tcp_recv_msg,
2210                         NULL);
2211 
2212     mbedtls_ssl_set_user_data_p(&server->ssl, server);
2213 
2214 #if defined(MBEDTLS_TIMING_C)
2215     mbedtls_ssl_set_timer_cb(&server->ssl, &server->timer,
2216                              mbedtls_timing_set_delay,
2217                              mbedtls_timing_get_delay);
2218 #endif
2219 #if defined(MBEDTLS_SSL_VARIABLE_BUFFER_LENGTH)
2220     if (options->resize_buffers != 0) {
2221         /* Ensure that the buffer sizes are appropriate before resizes */
2222         TEST_EQUAL(server->ssl.out_buf_len, MBEDTLS_SSL_OUT_BUFFER_LEN);
2223         TEST_EQUAL(server->ssl.in_buf_len, MBEDTLS_SSL_IN_BUFFER_LEN);
2224     }
2225 #endif
2226     TEST_EQUAL(mbedtls_ssl_context_load(&(server->ssl), context_buf,
2227                                         context_buf_len), 0);
2228 
2229 #if defined(MBEDTLS_SSL_VARIABLE_BUFFER_LENGTH)
2230     /* Validate buffer sizes after context deserialization */
2231     if (options->resize_buffers != 0) {
2232         TEST_EQUAL(server->ssl.out_buf_len,
2233                    mbedtls_ssl_get_output_buflen(&server->ssl));
2234         TEST_EQUAL(server->ssl.in_buf_len,
2235                    mbedtls_ssl_get_input_buflen(&server->ssl));
2236     }
2237 #endif
2238     /* Retest writing/reading */
2239     if (options->cli_msg_len != 0 || options->srv_msg_len != 0) {
2240         TEST_EQUAL(mbedtls_test_ssl_exchange_data(
2241                        &(client->ssl), options->cli_msg_len,
2242                        options->expected_cli_fragments,
2243                        &(server->ssl), options->srv_msg_len,
2244                        options->expected_srv_fragments),
2245                    0);
2246     }
2247 
2248     ok = 1;
2249 
2250 exit:
2251     mbedtls_free(context_buf);
2252     return ok;
2253 }
2254 #endif /* MBEDTLS_SSL_CONTEXT_SERIALIZATION */
2255 
mbedtls_test_ssl_perform_connection(const mbedtls_test_handshake_test_options * options,mbedtls_test_ssl_endpoint * client,mbedtls_test_ssl_endpoint * server)2256 int mbedtls_test_ssl_perform_connection(
2257     const mbedtls_test_handshake_test_options *options,
2258     mbedtls_test_ssl_endpoint *client,
2259     mbedtls_test_ssl_endpoint *server)
2260 {
2261     enum { BUFFSIZE = 17000 };
2262     int expected_handshake_result = options->expected_handshake_result;
2263     int ok = 0;
2264 
2265     TEST_EQUAL(mbedtls_test_mock_socket_connect(&(client->socket),
2266                                                 &(server->socket),
2267                                                 BUFFSIZE), 0);
2268 
2269 #if defined(MBEDTLS_SSL_VARIABLE_BUFFER_LENGTH)
2270     if (options->resize_buffers != 0) {
2271         /* Ensure that the buffer sizes are appropriate before resizes */
2272         TEST_EQUAL(client->ssl.out_buf_len, MBEDTLS_SSL_OUT_BUFFER_LEN);
2273         TEST_EQUAL(client->ssl.in_buf_len, MBEDTLS_SSL_IN_BUFFER_LEN);
2274         TEST_EQUAL(server->ssl.out_buf_len, MBEDTLS_SSL_OUT_BUFFER_LEN);
2275         TEST_EQUAL(server->ssl.in_buf_len, MBEDTLS_SSL_IN_BUFFER_LEN);
2276     }
2277 #endif
2278 
2279     if (options->expected_negotiated_version == MBEDTLS_SSL_VERSION_UNKNOWN) {
2280         expected_handshake_result = MBEDTLS_ERR_SSL_BAD_PROTOCOL_VERSION;
2281     }
2282 
2283     TEST_EQUAL(mbedtls_test_move_handshake_to_state(&(client->ssl),
2284                                                     &(server->ssl),
2285                                                     MBEDTLS_SSL_HANDSHAKE_OVER),
2286                expected_handshake_result);
2287 
2288     if (expected_handshake_result != 0) {
2289         /* Connection will have failed by this point, skip to cleanup */
2290         ok = 1;
2291         goto exit;
2292     }
2293 
2294     TEST_EQUAL(mbedtls_ssl_is_handshake_over(&client->ssl), 1);
2295 
2296     /* Make sure server state is moved to HANDSHAKE_OVER also. */
2297     TEST_EQUAL(mbedtls_test_move_handshake_to_state(&(server->ssl),
2298                                                     &(client->ssl),
2299                                                     MBEDTLS_SSL_HANDSHAKE_OVER),
2300                0);
2301 
2302     TEST_EQUAL(mbedtls_ssl_is_handshake_over(&server->ssl), 1);
2303 
2304     /* Check that both sides have negotiated the expected version. */
2305     TEST_ASSERT(check_ssl_version(options->expected_negotiated_version,
2306                                   &client->ssl,
2307                                   &server->ssl));
2308 
2309     if (options->expected_ciphersuite != 0) {
2310         TEST_EQUAL(server->ssl.session->ciphersuite,
2311                    options->expected_ciphersuite);
2312     }
2313 
2314 #if defined(MBEDTLS_SSL_VARIABLE_BUFFER_LENGTH)
2315     if (options->resize_buffers != 0) {
2316         /* A server, when using DTLS, might delay a buffer resize to happen
2317          * after it receives a message, so we force it. */
2318         TEST_EQUAL(exchange_data(&(client->ssl), &(server->ssl)), 0);
2319 
2320         TEST_EQUAL(client->ssl.out_buf_len,
2321                    mbedtls_ssl_get_output_buflen(&client->ssl));
2322         TEST_EQUAL(client->ssl.in_buf_len,
2323                    mbedtls_ssl_get_input_buflen(&client->ssl));
2324         TEST_EQUAL(server->ssl.out_buf_len,
2325                    mbedtls_ssl_get_output_buflen(&server->ssl));
2326         TEST_EQUAL(server->ssl.in_buf_len,
2327                    mbedtls_ssl_get_input_buflen(&server->ssl));
2328     }
2329 #endif
2330 
2331     if (options->cli_msg_len != 0 || options->srv_msg_len != 0) {
2332         /* Start data exchanging test */
2333         TEST_EQUAL(mbedtls_test_ssl_exchange_data(
2334                        &(client->ssl), options->cli_msg_len,
2335                        options->expected_cli_fragments,
2336                        &(server->ssl), options->srv_msg_len,
2337                        options->expected_srv_fragments),
2338                    0);
2339     }
2340 #if defined(MBEDTLS_SSL_CONTEXT_SERIALIZATION)
2341     if (options->serialize == 1) {
2342         TEST_ASSERT(test_serialization(options, client, server));
2343     }
2344 #endif /* MBEDTLS_SSL_CONTEXT_SERIALIZATION */
2345 
2346 #if defined(MBEDTLS_SSL_RENEGOTIATION)
2347     if (options->renegotiate) {
2348         TEST_ASSERT(test_renegotiation(options, client, server));
2349     }
2350 #endif /* MBEDTLS_SSL_RENEGOTIATION */
2351 
2352     ok = 1;
2353 
2354 exit:
2355     return ok;
2356 }
2357 
mbedtls_test_ssl_perform_handshake(const mbedtls_test_handshake_test_options * options)2358 void mbedtls_test_ssl_perform_handshake(
2359     const mbedtls_test_handshake_test_options *options)
2360 {
2361     mbedtls_test_ssl_endpoint client_struct;
2362     memset(&client_struct, 0, sizeof(client_struct));
2363     mbedtls_test_ssl_endpoint *const client = &client_struct;
2364     mbedtls_test_ssl_endpoint server_struct;
2365     memset(&server_struct, 0, sizeof(server_struct));
2366     mbedtls_test_ssl_endpoint *const server = &server_struct;
2367 
2368     MD_OR_USE_PSA_INIT();
2369 
2370 #if defined(MBEDTLS_DEBUG_C)
2371     if (options->cli_log_fun || options->srv_log_fun) {
2372         mbedtls_debug_set_threshold(4);
2373     }
2374 #endif
2375 
2376     /* Client side */
2377     TEST_EQUAL(mbedtls_test_ssl_endpoint_init(client,
2378                                               MBEDTLS_SSL_IS_CLIENT,
2379                                               options), 0);
2380 
2381     /* Server side */
2382     TEST_EQUAL(mbedtls_test_ssl_endpoint_init(server,
2383                                               MBEDTLS_SSL_IS_SERVER,
2384                                               options), 0);
2385 
2386     if (options->dtls) {
2387         TEST_EQUAL(mbedtls_test_ssl_dtls_join_endpoints(client, server), 0);
2388     }
2389 
2390     TEST_ASSERT(mbedtls_test_ssl_perform_connection(options, client, server));
2391 
2392     TEST_ASSERT(mbedtls_ssl_conf_get_user_data_p(&client->conf) == client);
2393     TEST_ASSERT(mbedtls_ssl_get_user_data_p(&client->ssl) == client);
2394     TEST_ASSERT(mbedtls_ssl_conf_get_user_data_p(&server->conf) == server);
2395     TEST_ASSERT(mbedtls_ssl_get_user_data_p(&server->ssl) == server);
2396 
2397 exit:
2398     mbedtls_test_ssl_endpoint_free(client);
2399     mbedtls_test_ssl_endpoint_free(server);
2400 #if defined(MBEDTLS_DEBUG_C)
2401     if (options->cli_log_fun || options->srv_log_fun) {
2402         mbedtls_debug_set_threshold(0);
2403     }
2404 #endif
2405     MD_OR_USE_PSA_DONE();
2406 }
2407 #endif /* MBEDTLS_SSL_HANDSHAKE_WITH_CERT_ENABLED */
2408 
2409 #if defined(MBEDTLS_TEST_HOOKS)
mbedtls_test_tweak_tls13_certificate_msg_vector_len(unsigned char * buf,unsigned char ** end,int tweak,int * expected_result,mbedtls_ssl_chk_buf_ptr_args * args)2410 int mbedtls_test_tweak_tls13_certificate_msg_vector_len(
2411     unsigned char *buf, unsigned char **end, int tweak,
2412     int *expected_result, mbedtls_ssl_chk_buf_ptr_args *args)
2413 {
2414 /*
2415  * The definition of the tweaks assume that the certificate list contains only
2416  * one certificate.
2417  */
2418 
2419 /*
2420  * struct {
2421  *     opaque cert_data<1..2^24-1>;
2422  *     Extension extensions<0..2^16-1>;
2423  * } CertificateEntry;
2424  *
2425  * struct {
2426  *     opaque certificate_request_context<0..2^8-1>;
2427  *     CertificateEntry certificate_list<0..2^24-1>;
2428  * } Certificate;
2429  */
2430     unsigned char *p_certificate_request_context_len = buf;
2431     size_t certificate_request_context_len = buf[0];
2432 
2433     unsigned char *p_certificate_list_len =
2434         buf + 1 + certificate_request_context_len;
2435     unsigned char *certificate_list = p_certificate_list_len + 3;
2436     size_t certificate_list_len =
2437         MBEDTLS_GET_UINT24_BE(p_certificate_list_len, 0);
2438 
2439     unsigned char *p_cert_data_len = certificate_list;
2440     unsigned char *cert_data = p_cert_data_len + 3;
2441     size_t cert_data_len = MBEDTLS_GET_UINT24_BE(p_cert_data_len, 0);
2442 
2443     unsigned char *p_extensions_len = cert_data + cert_data_len;
2444     unsigned char *extensions = p_extensions_len + 2;
2445     size_t extensions_len = MBEDTLS_GET_UINT16_BE(p_extensions_len, 0);
2446 
2447     *expected_result = MBEDTLS_ERR_SSL_DECODE_ERROR;
2448 
2449     switch (tweak) {
2450         case 1:
2451             /* Failure when checking if the certificate request context length
2452              * and certificate list length can be read
2453              */
2454             *end = buf + 3;
2455             set_chk_buf_ptr_args(args, buf, *end, 4);
2456             break;
2457 
2458         case 2:
2459             /* Invalid certificate request context length.
2460              */
2461             *p_certificate_request_context_len =
2462                 (unsigned char) certificate_request_context_len + 1;
2463             reset_chk_buf_ptr_args(args);
2464             break;
2465 
2466         case 3:
2467             /* Failure when checking if certificate_list data can be read. */
2468             MBEDTLS_PUT_UINT24_BE(certificate_list_len + 1,
2469                                   p_certificate_list_len, 0);
2470             set_chk_buf_ptr_args(args, certificate_list, *end,
2471                                  certificate_list_len + 1);
2472             break;
2473 
2474         case 4:
2475             /* Failure when checking if the cert_data length can be read. */
2476             MBEDTLS_PUT_UINT24_BE(2, p_certificate_list_len, 0);
2477             set_chk_buf_ptr_args(args, p_cert_data_len, certificate_list + 2, 3);
2478             break;
2479 
2480         case 5:
2481             /* Failure when checking if cert_data data can be read. */
2482             MBEDTLS_PUT_UINT24_BE(certificate_list_len - 3 + 1,
2483                                   p_cert_data_len, 0);
2484             set_chk_buf_ptr_args(args, cert_data,
2485                                  certificate_list + certificate_list_len,
2486                                  certificate_list_len - 3 + 1);
2487             break;
2488 
2489         case 6:
2490             /* Failure when checking if the extensions length can be read. */
2491             MBEDTLS_PUT_UINT24_BE(certificate_list_len - extensions_len - 1,
2492                                   p_certificate_list_len, 0);
2493             set_chk_buf_ptr_args(
2494                 args, p_extensions_len,
2495                 certificate_list + certificate_list_len - extensions_len - 1, 2);
2496             break;
2497 
2498         case 7:
2499             /* Failure when checking if extensions data can be read. */
2500             MBEDTLS_PUT_UINT16_BE(extensions_len + 1, p_extensions_len, 0);
2501 
2502             set_chk_buf_ptr_args(
2503                 args, extensions,
2504                 certificate_list + certificate_list_len, extensions_len + 1);
2505             break;
2506 
2507         default:
2508             return -1;
2509     }
2510 
2511     return 0;
2512 }
2513 #endif /* MBEDTLS_TEST_HOOKS */
2514 
2515 /*
2516  * Functions for tests based on tickets. Implementations of the
2517  * write/parse ticket interfaces as defined by mbedtls_ssl_ticket_write/parse_t.
2518  * Basically same implementations as in ticket.c without the encryption. That
2519  * way we can tweak easily tickets characteristics to simulate misbehaving
2520  * peers.
2521  */
2522 #if defined(MBEDTLS_SSL_SESSION_TICKETS)
mbedtls_test_ticket_write(void * p_ticket,const mbedtls_ssl_session * session,unsigned char * start,const unsigned char * end,size_t * tlen,uint32_t * lifetime)2523 int mbedtls_test_ticket_write(
2524     void *p_ticket, const mbedtls_ssl_session *session,
2525     unsigned char *start, const unsigned char *end,
2526     size_t *tlen, uint32_t *lifetime)
2527 {
2528     int ret;
2529     ((void) p_ticket);
2530 
2531     if ((ret = mbedtls_ssl_session_save(session, start, end - start,
2532                                         tlen)) != 0) {
2533         return ret;
2534     }
2535 
2536     /* Maximum ticket lifetime as defined in RFC 8446 */
2537     *lifetime = 7 * 24 * 3600;
2538 
2539     return 0;
2540 }
2541 
mbedtls_test_ticket_parse(void * p_ticket,mbedtls_ssl_session * session,unsigned char * buf,size_t len)2542 int mbedtls_test_ticket_parse(void *p_ticket, mbedtls_ssl_session *session,
2543                               unsigned char *buf, size_t len)
2544 {
2545     ((void) p_ticket);
2546 
2547     return mbedtls_ssl_session_load(session, buf, len);
2548 }
2549 #endif /* MBEDTLS_SSL_SESSION_TICKETS */
2550 
2551 #if defined(MBEDTLS_SSL_CLI_C) && defined(MBEDTLS_SSL_SRV_C) && \
2552     defined(MBEDTLS_SSL_PROTO_TLS1_3) && defined(MBEDTLS_SSL_SESSION_TICKETS) && \
2553     defined(MBEDTLS_SSL_HANDSHAKE_WITH_CERT_ENABLED)
mbedtls_test_get_tls13_ticket(mbedtls_test_handshake_test_options * client_options,mbedtls_test_handshake_test_options * server_options,mbedtls_ssl_session * session)2554 int mbedtls_test_get_tls13_ticket(
2555     mbedtls_test_handshake_test_options *client_options,
2556     mbedtls_test_handshake_test_options *server_options,
2557     mbedtls_ssl_session *session)
2558 {
2559     int ret = -1;
2560     int ok = 0;
2561     unsigned char buf[64];
2562     mbedtls_test_ssl_endpoint client_ep, server_ep;
2563 
2564     mbedtls_platform_zeroize(&client_ep, sizeof(client_ep));
2565     mbedtls_platform_zeroize(&server_ep, sizeof(server_ep));
2566 
2567     ret = mbedtls_test_ssl_endpoint_init(&client_ep, MBEDTLS_SSL_IS_CLIENT,
2568                                          client_options);
2569     TEST_EQUAL(ret, 0);
2570 
2571     ret = mbedtls_test_ssl_endpoint_init(&server_ep, MBEDTLS_SSL_IS_SERVER,
2572                                          server_options);
2573     TEST_EQUAL(ret, 0);
2574 
2575     mbedtls_ssl_conf_session_tickets_cb(&server_ep.conf,
2576                                         mbedtls_test_ticket_write,
2577                                         mbedtls_test_ticket_parse,
2578                                         NULL);
2579 
2580     ret = mbedtls_test_mock_socket_connect(&(client_ep.socket),
2581                                            &(server_ep.socket), 1024);
2582     TEST_EQUAL(ret, 0);
2583 
2584     TEST_EQUAL(mbedtls_test_move_handshake_to_state(
2585                    &(server_ep.ssl), &(client_ep.ssl),
2586                    MBEDTLS_SSL_HANDSHAKE_OVER), 0);
2587 
2588     TEST_EQUAL(server_ep.ssl.handshake->new_session_tickets_count, 0);
2589 
2590     do {
2591         ret = mbedtls_ssl_read(&(client_ep.ssl), buf, sizeof(buf));
2592     } while (ret != MBEDTLS_ERR_SSL_RECEIVED_NEW_SESSION_TICKET);
2593 
2594     ret = mbedtls_ssl_get_session(&(client_ep.ssl), session);
2595     TEST_EQUAL(ret, 0);
2596 
2597     ok = 1;
2598 
2599 exit:
2600     mbedtls_test_ssl_endpoint_free(&client_ep);
2601     mbedtls_test_ssl_endpoint_free(&server_ep);
2602 
2603     if (ret == 0 && !ok) {
2604         /* Exiting due to a test assertion that isn't ret == 0 */
2605         ret = -1;
2606     }
2607     return ret;
2608 }
2609 #endif /* MBEDTLS_SSL_CLI_C && MBEDTLS_SSL_SRV_C &&
2610           MBEDTLS_SSL_PROTO_TLS1_3 && MBEDTLS_SSL_SESSION_TICKETS &&
2611           MBEDTLS_SSL_HANDSHAKE_WITH_CERT_ENABLED */
2612 
2613 #endif /* MBEDTLS_SSL_TLS_C */
2614