1 // Copyright 2005-2016 The OpenSSL Project Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include <openssl/ssl.h>
16
17 #include <assert.h>
18 #include <limits.h>
19 #include <string.h>
20
21 #include <algorithm>
22
23 #include <openssl/err.h>
24 #include <openssl/evp.h>
25 #include <openssl/mem.h>
26 #include <openssl/rand.h>
27
28 #include "../crypto/internal.h"
29 #include "internal.h"
30
31
32 BSSL_NAMESPACE_BEGIN
33
34 // TODO(davidben): 28 comes from the size of IP + UDP header. Is this reasonable
35 // for these values? Notably, why is kMinMTU a function of the transport
36 // protocol's overhead rather than, say, what's needed to hold a minimally-sized
37 // handshake fragment plus protocol overhead.
38
39 // kMinMTU is the minimum acceptable MTU value.
40 static const unsigned int kMinMTU = 256 - 28;
41
42 // kDefaultMTU is the default MTU value to use if neither the user nor
43 // the underlying BIO supplies one.
44 static const unsigned int kDefaultMTU = 1500 - 28;
45
46 // BitRange returns a |uint8_t| with bits |start|, inclusive, to |end|,
47 // exclusive, set.
BitRange(size_t start,size_t end)48 static uint8_t BitRange(size_t start, size_t end) {
49 assert(start <= end && end <= 8);
50 return static_cast<uint8_t>(~((1u << start) - 1) & ((1u << end) - 1));
51 }
52
53 // FirstUnmarkedRangeInByte returns the first unmarked range in bits |b|.
FirstUnmarkedRangeInByte(uint8_t b)54 static DTLSMessageBitmap::Range FirstUnmarkedRangeInByte(uint8_t b) {
55 size_t start, end;
56 for (start = 0; start < 8; start++) {
57 if ((b & (1u << start)) == 0) {
58 break;
59 }
60 }
61 for (end = start; end < 8; end++) {
62 if ((b & (1u << end)) != 0) {
63 break;
64 }
65 }
66 return DTLSMessageBitmap::Range{start, end};
67 }
68
Init(size_t num_bits)69 bool DTLSMessageBitmap::Init(size_t num_bits) {
70 if (num_bits + 7 < num_bits) {
71 OPENSSL_PUT_ERROR(SSL, ERR_R_OVERFLOW);
72 return false;
73 }
74 size_t num_bytes = (num_bits + 7) / 8;
75 size_t bits_rounded = num_bytes * 8;
76 if (!bytes_.Init(num_bytes)) {
77 return false;
78 }
79 MarkRange(num_bits, bits_rounded);
80 first_unmarked_byte_ = 0;
81 return true;
82 }
83
MarkRange(size_t start,size_t end)84 void DTLSMessageBitmap::MarkRange(size_t start, size_t end) {
85 assert(start <= end);
86 // Don't bother touching bytes that have already been marked.
87 start = std::max(start, first_unmarked_byte_ << 3);
88 // Clamp everything within range.
89 start = std::min(start, bytes_.size() << 3);
90 end = std::min(end, bytes_.size() << 3);
91 if (start >= end) {
92 return;
93 }
94
95 if ((start >> 3) == (end >> 3)) {
96 bytes_[start >> 3] |= BitRange(start & 7, end & 7);
97 } else {
98 bytes_[start >> 3] |= BitRange(start & 7, 8);
99 for (size_t i = (start >> 3) + 1; i < (end >> 3); i++) {
100 bytes_[i] = 0xff;
101 }
102 if ((end & 7) != 0) {
103 bytes_[end >> 3] |= BitRange(0, end & 7);
104 }
105 }
106
107 // Maintain the |first_unmarked_byte_| invariant. This work is amortized
108 // across all |MarkRange| calls.
109 while (first_unmarked_byte_ < bytes_.size() &&
110 bytes_[first_unmarked_byte_] == 0xff) {
111 first_unmarked_byte_++;
112 }
113 // If the whole message is marked, we no longer need to spend memory on the
114 // bitmap.
115 if (first_unmarked_byte_ >= bytes_.size()) {
116 bytes_.Reset();
117 first_unmarked_byte_ = 0;
118 }
119 }
120
NextUnmarkedRange(size_t start) const121 DTLSMessageBitmap::Range DTLSMessageBitmap::NextUnmarkedRange(
122 size_t start) const {
123 // Don't bother looking at bytes that are known to be fully marked.
124 start = std::max(start, first_unmarked_byte_ << 3);
125
126 size_t idx = start >> 3;
127 if (idx >= bytes_.size()) {
128 return Range{0, 0};
129 }
130
131 // Look at the bits from |start| up to a byte boundary.
132 uint8_t byte = bytes_[idx] | BitRange(0, start & 7);
133 if (byte == 0xff) {
134 // Nothing unmarked at this byte. Keep searching for an unmarked bit.
135 for (idx = idx + 1; idx < bytes_.size(); idx++) {
136 if (bytes_[idx] != 0xff) {
137 byte = bytes_[idx];
138 break;
139 }
140 }
141 if (idx >= bytes_.size()) {
142 return Range{0, 0};
143 }
144 }
145
146 Range range = FirstUnmarkedRangeInByte(byte);
147 assert(!range.empty());
148 bool should_extend = range.end == 8;
149 range.start += idx << 3;
150 range.end += idx << 3;
151 if (!should_extend) {
152 // The range did not end at a byte boundary. We're done.
153 return range;
154 }
155
156 // Collect all fully unmarked bytes.
157 for (idx = idx + 1; idx < bytes_.size(); idx++) {
158 if (bytes_[idx] != 0) {
159 break;
160 }
161 }
162 range.end = idx << 3;
163
164 // Add any bits from the remaining byte, if any.
165 if (idx < bytes_.size()) {
166 Range extra = FirstUnmarkedRangeInByte(bytes_[idx]);
167 if (extra.start == 0) {
168 range.end += extra.end;
169 }
170 }
171
172 return range;
173 }
174
175 // Receiving handshake messages.
176
dtls_new_incoming_message(const struct hm_header_st * msg_hdr)177 static UniquePtr<DTLSIncomingMessage> dtls_new_incoming_message(
178 const struct hm_header_st *msg_hdr) {
179 ScopedCBB cbb;
180 UniquePtr<DTLSIncomingMessage> frag = MakeUnique<DTLSIncomingMessage>();
181 if (!frag) {
182 return nullptr;
183 }
184 frag->type = msg_hdr->type;
185 frag->seq = msg_hdr->seq;
186
187 // Allocate space for the reassembled message and fill in the header.
188 if (!frag->data.InitForOverwrite(DTLS1_HM_HEADER_LENGTH + msg_hdr->msg_len)) {
189 return nullptr;
190 }
191
192 if (!CBB_init_fixed(cbb.get(), frag->data.data(), DTLS1_HM_HEADER_LENGTH) ||
193 !CBB_add_u8(cbb.get(), msg_hdr->type) ||
194 !CBB_add_u24(cbb.get(), msg_hdr->msg_len) ||
195 !CBB_add_u16(cbb.get(), msg_hdr->seq) ||
196 !CBB_add_u24(cbb.get(), 0 /* frag_off */) ||
197 !CBB_add_u24(cbb.get(), msg_hdr->msg_len) ||
198 !CBB_finish(cbb.get(), NULL, NULL)) {
199 return nullptr;
200 }
201
202 if (!frag->reassembly.Init(msg_hdr->msg_len)) {
203 return nullptr;
204 }
205
206 return frag;
207 }
208
209 // dtls1_is_current_message_complete returns whether the current handshake
210 // message is complete.
dtls1_is_current_message_complete(const SSL * ssl)211 static bool dtls1_is_current_message_complete(const SSL *ssl) {
212 size_t idx = ssl->d1->handshake_read_seq % SSL_MAX_HANDSHAKE_FLIGHT;
213 DTLSIncomingMessage *frag = ssl->d1->incoming_messages[idx].get();
214 return frag != nullptr && frag->reassembly.IsComplete();
215 }
216
217 // dtls1_get_incoming_message returns the incoming message corresponding to
218 // |msg_hdr|. If none exists, it creates a new one and inserts it in the
219 // queue. Otherwise, it checks |msg_hdr| is consistent with the existing one. It
220 // returns NULL on failure. The caller does not take ownership of the result.
dtls1_get_incoming_message(SSL * ssl,uint8_t * out_alert,const struct hm_header_st * msg_hdr)221 static DTLSIncomingMessage *dtls1_get_incoming_message(
222 SSL *ssl, uint8_t *out_alert, const struct hm_header_st *msg_hdr) {
223 if (msg_hdr->seq < ssl->d1->handshake_read_seq ||
224 msg_hdr->seq - ssl->d1->handshake_read_seq >= SSL_MAX_HANDSHAKE_FLIGHT) {
225 *out_alert = SSL_AD_INTERNAL_ERROR;
226 return NULL;
227 }
228
229 size_t idx = msg_hdr->seq % SSL_MAX_HANDSHAKE_FLIGHT;
230 DTLSIncomingMessage *frag = ssl->d1->incoming_messages[idx].get();
231 if (frag != NULL) {
232 assert(frag->seq == msg_hdr->seq);
233 // The new fragment must be compatible with the previous fragments from this
234 // message.
235 if (frag->type != msg_hdr->type || //
236 frag->msg_len() != msg_hdr->msg_len) {
237 OPENSSL_PUT_ERROR(SSL, SSL_R_FRAGMENT_MISMATCH);
238 *out_alert = SSL_AD_ILLEGAL_PARAMETER;
239 return NULL;
240 }
241 return frag;
242 }
243
244 // This is the first fragment from this message.
245 ssl->d1->incoming_messages[idx] = dtls_new_incoming_message(msg_hdr);
246 if (!ssl->d1->incoming_messages[idx]) {
247 *out_alert = SSL_AD_INTERNAL_ERROR;
248 return NULL;
249 }
250 return ssl->d1->incoming_messages[idx].get();
251 }
252
dtls1_process_handshake_fragments(SSL * ssl,uint8_t * out_alert,DTLSRecordNumber record_number,Span<const uint8_t> record)253 bool dtls1_process_handshake_fragments(SSL *ssl, uint8_t *out_alert,
254 DTLSRecordNumber record_number,
255 Span<const uint8_t> record) {
256 bool implicit_ack = false;
257 bool skipped_fragments = false;
258 CBS cbs = record;
259 while (CBS_len(&cbs) > 0) {
260 // Read a handshake fragment.
261 struct hm_header_st msg_hdr;
262 CBS body;
263 if (!dtls1_parse_fragment(&cbs, &msg_hdr, &body)) {
264 OPENSSL_PUT_ERROR(SSL, SSL_R_BAD_HANDSHAKE_RECORD);
265 *out_alert = SSL_AD_DECODE_ERROR;
266 return false;
267 }
268
269 const size_t frag_off = msg_hdr.frag_off;
270 const size_t frag_len = msg_hdr.frag_len;
271 const size_t msg_len = msg_hdr.msg_len;
272 if (frag_off > msg_len || frag_len > msg_len - frag_off) {
273 OPENSSL_PUT_ERROR(SSL, SSL_R_BAD_HANDSHAKE_RECORD);
274 *out_alert = SSL_AD_ILLEGAL_PARAMETER;
275 return false;
276 }
277
278 if (msg_hdr.seq < ssl->d1->handshake_read_seq ||
279 ssl->d1->handshake_read_overflow) {
280 // Ignore fragments from the past. This is a retransmit of data we already
281 // received.
282 //
283 // TODO(crbug.com/42290594): Use this to drive retransmits.
284 continue;
285 }
286
287 if (record_number.epoch() != ssl->d1->read_epoch.epoch ||
288 ssl->d1->next_read_epoch != nullptr) {
289 // New messages can only arrive in the latest epoch. This can fail if the
290 // record came from |prev_read_epoch|, or if it came from |read_epoch| but
291 // |next_read_epoch| exists. (It cannot come from |next_read_epoch|
292 // because |next_read_epoch| becomes |read_epoch| once it receives a
293 // record.)
294 OPENSSL_PUT_ERROR(SSL, SSL_R_EXCESS_HANDSHAKE_DATA);
295 *out_alert = SSL_AD_UNEXPECTED_MESSAGE;
296 return false;
297 }
298
299 if (msg_len > ssl_max_handshake_message_len(ssl)) {
300 OPENSSL_PUT_ERROR(SSL, SSL_R_EXCESSIVE_MESSAGE_SIZE);
301 *out_alert = SSL_AD_ILLEGAL_PARAMETER;
302 return false;
303 }
304
305 if (SSL_in_init(ssl) && ssl_has_final_version(ssl) &&
306 ssl_protocol_version(ssl) >= TLS1_3_VERSION) {
307 // During the handshake, if we receive any portion of the next flight, the
308 // peer must have received our most recent flight. In DTLS 1.3, this is an
309 // implicit ACK. See RFC 9147, Section 7.1.
310 //
311 // This only applies during the handshake. After the handshake, the next
312 // message may be part of a post-handshake transaction. It also does not
313 // apply immediately after the handshake. As a client, receiving a
314 // KeyUpdate or NewSessionTicket does not imply the server has received
315 // our Finished. The server may have sent those messages in half-RTT.
316 implicit_ack = true;
317 }
318
319 if (msg_hdr.seq - ssl->d1->handshake_read_seq > SSL_MAX_HANDSHAKE_FLIGHT) {
320 // Ignore fragments too far in the future.
321 skipped_fragments = true;
322 continue;
323 }
324
325 DTLSIncomingMessage *frag =
326 dtls1_get_incoming_message(ssl, out_alert, &msg_hdr);
327 if (frag == nullptr) {
328 return false;
329 }
330 assert(frag->msg_len() == msg_len);
331
332 if (frag->reassembly.IsComplete()) {
333 // The message is already assembled.
334 continue;
335 }
336 assert(msg_len > 0);
337
338 // Copy the body into the fragment.
339 Span<uint8_t> dest = frag->msg().subspan(frag_off, CBS_len(&body));
340 OPENSSL_memcpy(dest.data(), CBS_data(&body), CBS_len(&body));
341 frag->reassembly.MarkRange(frag_off, frag_off + frag_len);
342 }
343
344 if (implicit_ack) {
345 dtls1_stop_timer(ssl);
346 dtls_clear_outgoing_messages(ssl);
347 }
348
349 if (!skipped_fragments) {
350 ssl->d1->records_to_ack.PushBack(record_number);
351
352 if (ssl_has_final_version(ssl) &&
353 ssl_protocol_version(ssl) >= TLS1_3_VERSION &&
354 !ssl->d1->ack_timer.IsSet() && !ssl->d1->sending_ack) {
355 // Schedule sending an ACK. The delay serves several purposes:
356 // - If there are more records to come, we send only one ACK.
357 // - If there are more records to come and the flight is now complete, we
358 // will send the reply (which implicitly ACKs the previous flight) and
359 // cancel the timer.
360 // - If there are more records to come, the flight is now complete, but
361 // generating the response is delayed (e.g. a slow, async private key),
362 // the timer will fire and we send an ACK anyway.
363 OPENSSL_timeval now = ssl_ctx_get_current_time(ssl->ctx.get());
364 ssl->d1->ack_timer.StartMicroseconds(
365 now, uint64_t{ssl->d1->timeout_duration_ms} * 1000 / 4);
366 }
367 }
368
369 return true;
370 }
371
dtls1_open_handshake(SSL * ssl,size_t * out_consumed,uint8_t * out_alert,Span<uint8_t> in)372 ssl_open_record_t dtls1_open_handshake(SSL *ssl, size_t *out_consumed,
373 uint8_t *out_alert, Span<uint8_t> in) {
374 uint8_t type;
375 DTLSRecordNumber record_number;
376 Span<uint8_t> record;
377 auto ret = dtls_open_record(ssl, &type, &record_number, &record, out_consumed,
378 out_alert, in);
379 if (ret != ssl_open_record_success) {
380 return ret;
381 }
382
383 switch (type) {
384 case SSL3_RT_APPLICATION_DATA:
385 // In DTLS 1.2, out-of-order application data may be received between
386 // ChangeCipherSpec and Finished. Discard it.
387 return ssl_open_record_discard;
388
389 case SSL3_RT_CHANGE_CIPHER_SPEC:
390 if (record.size() != 1u || record[0] != SSL3_MT_CCS) {
391 OPENSSL_PUT_ERROR(SSL, SSL_R_BAD_CHANGE_CIPHER_SPEC);
392 *out_alert = SSL_AD_ILLEGAL_PARAMETER;
393 return ssl_open_record_error;
394 }
395
396 // We do not support renegotiation, so encrypted ChangeCipherSpec records
397 // are illegal.
398 if (record_number.epoch() != 0) {
399 OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_RECORD);
400 *out_alert = SSL_AD_UNEXPECTED_MESSAGE;
401 return ssl_open_record_error;
402 }
403
404 // Ignore ChangeCipherSpec from a previous epoch.
405 if (record_number.epoch() != ssl->d1->read_epoch.epoch) {
406 return ssl_open_record_discard;
407 }
408
409 // Flag the ChangeCipherSpec for later.
410 // TODO(crbug.com/42290594): Should we reject this in DTLS 1.3?
411 ssl->d1->has_change_cipher_spec = true;
412 ssl_do_msg_callback(ssl, 0 /* read */, SSL3_RT_CHANGE_CIPHER_SPEC,
413 record);
414 return ssl_open_record_success;
415
416 case SSL3_RT_ACK:
417 return dtls1_process_ack(ssl, out_alert, record_number, record);
418
419 case SSL3_RT_HANDSHAKE:
420 if (!dtls1_process_handshake_fragments(ssl, out_alert, record_number,
421 record)) {
422 return ssl_open_record_error;
423 }
424 return ssl_open_record_success;
425
426 default:
427 OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_RECORD);
428 *out_alert = SSL_AD_UNEXPECTED_MESSAGE;
429 return ssl_open_record_error;
430 }
431 }
432
dtls1_get_message(const SSL * ssl,SSLMessage * out)433 bool dtls1_get_message(const SSL *ssl, SSLMessage *out) {
434 if (!dtls1_is_current_message_complete(ssl)) {
435 return false;
436 }
437
438 size_t idx = ssl->d1->handshake_read_seq % SSL_MAX_HANDSHAKE_FLIGHT;
439 const DTLSIncomingMessage *frag = ssl->d1->incoming_messages[idx].get();
440 out->type = frag->type;
441 out->raw = CBS(frag->data);
442 out->body = CBS(frag->msg());
443 out->is_v2_hello = false;
444 if (!ssl->s3->has_message) {
445 ssl_do_msg_callback(ssl, 0 /* read */, SSL3_RT_HANDSHAKE, out->raw);
446 ssl->s3->has_message = true;
447 }
448 return true;
449 }
450
dtls1_next_message(SSL * ssl)451 void dtls1_next_message(SSL *ssl) {
452 assert(ssl->s3->has_message);
453 assert(dtls1_is_current_message_complete(ssl));
454 size_t index = ssl->d1->handshake_read_seq % SSL_MAX_HANDSHAKE_FLIGHT;
455 ssl->d1->incoming_messages[index].reset();
456 ssl->d1->handshake_read_seq++;
457 if (ssl->d1->handshake_read_seq == 0) {
458 ssl->d1->handshake_read_overflow = true;
459 }
460 ssl->s3->has_message = false;
461 // If we previously sent a flight, mark it as having a reply, so
462 // |on_handshake_complete| can manage post-handshake retransmission.
463 if (ssl->d1->outgoing_messages_complete) {
464 ssl->d1->flight_has_reply = true;
465 }
466 }
467
dtls_has_unprocessed_handshake_data(const SSL * ssl)468 bool dtls_has_unprocessed_handshake_data(const SSL *ssl) {
469 size_t current = ssl->d1->handshake_read_seq % SSL_MAX_HANDSHAKE_FLIGHT;
470 for (size_t i = 0; i < SSL_MAX_HANDSHAKE_FLIGHT; i++) {
471 // Skip the current message.
472 if (ssl->s3->has_message && i == current) {
473 assert(dtls1_is_current_message_complete(ssl));
474 continue;
475 }
476 if (ssl->d1->incoming_messages[i] != nullptr) {
477 return true;
478 }
479 }
480 return false;
481 }
482
dtls1_parse_fragment(CBS * cbs,struct hm_header_st * out_hdr,CBS * out_body)483 bool dtls1_parse_fragment(CBS *cbs, struct hm_header_st *out_hdr,
484 CBS *out_body) {
485 OPENSSL_memset(out_hdr, 0x00, sizeof(struct hm_header_st));
486
487 if (!CBS_get_u8(cbs, &out_hdr->type) ||
488 !CBS_get_u24(cbs, &out_hdr->msg_len) ||
489 !CBS_get_u16(cbs, &out_hdr->seq) ||
490 !CBS_get_u24(cbs, &out_hdr->frag_off) ||
491 !CBS_get_u24(cbs, &out_hdr->frag_len) ||
492 !CBS_get_bytes(cbs, out_body, out_hdr->frag_len)) {
493 return false;
494 }
495
496 return true;
497 }
498
dtls1_open_change_cipher_spec(SSL * ssl,size_t * out_consumed,uint8_t * out_alert,Span<uint8_t> in)499 ssl_open_record_t dtls1_open_change_cipher_spec(SSL *ssl, size_t *out_consumed,
500 uint8_t *out_alert,
501 Span<uint8_t> in) {
502 if (!ssl->d1->has_change_cipher_spec) {
503 // dtls1_open_handshake processes both handshake and ChangeCipherSpec.
504 auto ret = dtls1_open_handshake(ssl, out_consumed, out_alert, in);
505 if (ret != ssl_open_record_success) {
506 return ret;
507 }
508 }
509 if (ssl->d1->has_change_cipher_spec) {
510 ssl->d1->has_change_cipher_spec = false;
511 return ssl_open_record_success;
512 }
513 return ssl_open_record_discard;
514 }
515
516
517 // Sending handshake messages.
518
dtls_clear_outgoing_messages(SSL * ssl)519 void dtls_clear_outgoing_messages(SSL *ssl) {
520 ssl->d1->outgoing_messages.clear();
521 ssl->d1->sent_records = nullptr;
522 ssl->d1->outgoing_written = 0;
523 ssl->d1->outgoing_offset = 0;
524 ssl->d1->outgoing_messages_complete = false;
525 ssl->d1->flight_has_reply = false;
526 ssl->d1->sending_flight = false;
527 dtls_clear_unused_write_epochs(ssl);
528 }
529
dtls_clear_unused_write_epochs(SSL * ssl)530 void dtls_clear_unused_write_epochs(SSL *ssl) {
531 ssl->d1->extra_write_epochs.EraseIf(
532 [ssl](const UniquePtr<DTLSWriteEpoch> &write_epoch) -> bool {
533 // Non-current epochs may be discarded once there are no incomplete
534 // outgoing messages that reference them.
535 //
536 // TODO(crbug.com/42290594): Epoch 1 (0-RTT) should be retained until
537 // epoch 3 (app data) is available.
538 for (const auto &msg : ssl->d1->outgoing_messages) {
539 if (msg.epoch == write_epoch->epoch() && !msg.IsFullyAcked()) {
540 return false;
541 }
542 }
543 return true;
544 });
545 }
546
dtls1_init_message(const SSL * ssl,CBB * cbb,CBB * body,uint8_t type)547 bool dtls1_init_message(const SSL *ssl, CBB *cbb, CBB *body, uint8_t type) {
548 // Pick a modest size hint to save most of the |realloc| calls.
549 if (!CBB_init(cbb, 64) || //
550 !CBB_add_u8(cbb, type) || //
551 !CBB_add_u24(cbb, 0 /* length (filled in later) */) || //
552 !CBB_add_u16(cbb, ssl->d1->handshake_write_seq) || //
553 !CBB_add_u24(cbb, 0 /* offset */) || //
554 !CBB_add_u24_length_prefixed(cbb, body)) {
555 return false;
556 }
557
558 return true;
559 }
560
dtls1_finish_message(const SSL * ssl,CBB * cbb,Array<uint8_t> * out_msg)561 bool dtls1_finish_message(const SSL *ssl, CBB *cbb, Array<uint8_t> *out_msg) {
562 if (!CBBFinishArray(cbb, out_msg) ||
563 out_msg->size() < DTLS1_HM_HEADER_LENGTH) {
564 OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
565 return false;
566 }
567
568 // Fix up the header. Copy the fragment length into the total message
569 // length.
570 OPENSSL_memcpy(out_msg->data() + 1,
571 out_msg->data() + DTLS1_HM_HEADER_LENGTH - 3, 3);
572 return true;
573 }
574
575 // add_outgoing adds a new handshake message or ChangeCipherSpec to the current
576 // outgoing flight. It returns true on success and false on error.
add_outgoing(SSL * ssl,bool is_ccs,Array<uint8_t> data)577 static bool add_outgoing(SSL *ssl, bool is_ccs, Array<uint8_t> data) {
578 if (ssl->d1->outgoing_messages_complete) {
579 // If we've begun writing a new flight, we received the peer flight. Discard
580 // the timer and the our flight.
581 dtls1_stop_timer(ssl);
582 dtls_clear_outgoing_messages(ssl);
583 }
584
585 if (!is_ccs) {
586 if (ssl->d1->handshake_write_overflow) {
587 OPENSSL_PUT_ERROR(SSL, ERR_R_OVERFLOW);
588 return false;
589 }
590 // TODO(svaldez): Move this up a layer to fix abstraction for SSLTranscript
591 // on hs.
592 if (ssl->s3->hs != NULL && !ssl->s3->hs->transcript.Update(data)) {
593 OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
594 return false;
595 }
596 ssl->d1->handshake_write_seq++;
597 if (ssl->d1->handshake_write_seq == 0) {
598 ssl->d1->handshake_write_overflow = true;
599 }
600 }
601
602 DTLSOutgoingMessage msg;
603 msg.data = std::move(data);
604 msg.epoch = ssl->d1->write_epoch.epoch();
605 msg.is_ccs = is_ccs;
606 // Zero-length messages need 1 bit to track whether the peer has received the
607 // message header. (Normally the message header is implicitly received when
608 // any fragment of the message is received at all.)
609 if (!is_ccs && !msg.acked.Init(std::max(msg.msg_len(), size_t{1}))) {
610 return false;
611 }
612
613 // This should not fail if |SSL_MAX_HANDSHAKE_FLIGHT| was sized correctly.
614 //
615 // TODO(crbug.com/42290594): This can currently fail in DTLS 1.3. The caller
616 // can configure how many tickets to send, up to kMaxTickets. Additionally, if
617 // we send 0.5-RTT tickets in 0-RTT, we may even have tickets queued up with
618 // the server flight.
619 if (!ssl->d1->outgoing_messages.TryPushBack(std::move(msg))) {
620 assert(false);
621 OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
622 return false;
623 }
624
625 return true;
626 }
627
dtls1_add_message(SSL * ssl,Array<uint8_t> data)628 bool dtls1_add_message(SSL *ssl, Array<uint8_t> data) {
629 return add_outgoing(ssl, false /* handshake */, std::move(data));
630 }
631
dtls1_add_change_cipher_spec(SSL * ssl)632 bool dtls1_add_change_cipher_spec(SSL *ssl) {
633 // DTLS 1.3 disables compatibility mode, which means that DTLS 1.3 never sends
634 // a ChangeCipherSpec message.
635 if (ssl_protocol_version(ssl) > TLS1_2_VERSION) {
636 return true;
637 }
638 return add_outgoing(ssl, true /* ChangeCipherSpec */, Array<uint8_t>());
639 }
640
641 // dtls1_update_mtu updates the current MTU from the BIO, ensuring it is above
642 // the minimum.
dtls1_update_mtu(SSL * ssl)643 static void dtls1_update_mtu(SSL *ssl) {
644 // TODO(davidben): No consumer implements |BIO_CTRL_DGRAM_SET_MTU| and the
645 // only |BIO_CTRL_DGRAM_QUERY_MTU| implementation could use
646 // |SSL_set_mtu|. Does this need to be so complex?
647 if (ssl->d1->mtu < dtls1_min_mtu() &&
648 !(SSL_get_options(ssl) & SSL_OP_NO_QUERY_MTU)) {
649 long mtu = BIO_ctrl(ssl->wbio.get(), BIO_CTRL_DGRAM_QUERY_MTU, 0, NULL);
650 if (mtu >= 0 && mtu <= (1 << 30) && (unsigned)mtu >= dtls1_min_mtu()) {
651 ssl->d1->mtu = (unsigned)mtu;
652 } else {
653 ssl->d1->mtu = kDefaultMTU;
654 BIO_ctrl(ssl->wbio.get(), BIO_CTRL_DGRAM_SET_MTU, ssl->d1->mtu, NULL);
655 }
656 }
657
658 // The MTU should be above the minimum now.
659 assert(ssl->d1->mtu >= dtls1_min_mtu());
660 }
661
662 enum seal_result_t {
663 seal_error,
664 seal_continue,
665 seal_flush,
666 };
667
668 // seal_next_record seals one record's worth of messages to |out| and advances
669 // |ssl|'s internal state past the data that was sealed. If progress was made,
670 // it returns |seal_flush| or |seal_continue| and sets
671 // |*out_len| to the number of bytes written.
672 //
673 // If the function stopped because the next message could not be combined into
674 // this record, it returns |seal_continue| and the caller should loop again.
675 // Otherwise, it returns |seal_flush| and the packet is complete (either because
676 // there are no more messages or the packet is full).
seal_next_record(SSL * ssl,Span<uint8_t> out,size_t * out_len)677 static seal_result_t seal_next_record(SSL *ssl, Span<uint8_t> out,
678 size_t *out_len) {
679 *out_len = 0;
680
681 // Skip any fully acked messages.
682 while (ssl->d1->outgoing_written < ssl->d1->outgoing_messages.size() &&
683 ssl->d1->outgoing_messages[ssl->d1->outgoing_written].IsFullyAcked()) {
684 ssl->d1->outgoing_offset = 0;
685 ssl->d1->outgoing_written++;
686 }
687
688 // There was nothing left to write.
689 if (ssl->d1->outgoing_written >= ssl->d1->outgoing_messages.size()) {
690 return seal_flush;
691 }
692
693 const auto &first_msg = ssl->d1->outgoing_messages[ssl->d1->outgoing_written];
694 size_t prefix_len = dtls_seal_prefix_len(ssl, first_msg.epoch);
695 size_t max_in_len = dtls_seal_max_input_len(ssl, first_msg.epoch, out.size());
696 if (max_in_len == 0) {
697 // There is no room for a single record.
698 return seal_flush;
699 }
700
701 if (first_msg.is_ccs) {
702 static const uint8_t kChangeCipherSpec[1] = {SSL3_MT_CCS};
703 DTLSRecordNumber record_number;
704 if (!dtls_seal_record(ssl, &record_number, out.data(), out_len, out.size(),
705 SSL3_RT_CHANGE_CIPHER_SPEC, kChangeCipherSpec,
706 sizeof(kChangeCipherSpec), first_msg.epoch)) {
707 return seal_error;
708 }
709
710 ssl_do_msg_callback(ssl, /*is_write=*/1, SSL3_RT_CHANGE_CIPHER_SPEC,
711 kChangeCipherSpec);
712 ssl->d1->outgoing_offset = 0;
713 ssl->d1->outgoing_written++;
714 return seal_continue;
715 }
716
717 // TODO(crbug.com/374991962): For now, only send one message per record in
718 // epoch 0. Sending multiple is allowed and more efficient, but breaks
719 // b/378742138.
720 const bool allow_multiple_messages = first_msg.epoch != 0;
721
722 // Pack as many handshake fragments into one record as we can. We stage the
723 // fragments in the output buffer, to be sealed in-place.
724 bool should_continue = false;
725 Span<uint8_t> fragments = out.subspan(prefix_len, max_in_len);
726 CBB cbb;
727 CBB_init_fixed(&cbb, fragments.data(), fragments.size());
728 DTLSSentRecord sent_record;
729 sent_record.first_msg = ssl->d1->outgoing_written;
730 sent_record.first_msg_start = ssl->d1->outgoing_offset;
731 while (ssl->d1->outgoing_written < ssl->d1->outgoing_messages.size()) {
732 const auto &msg = ssl->d1->outgoing_messages[ssl->d1->outgoing_written];
733 if (msg.epoch != first_msg.epoch || msg.is_ccs) {
734 // We can only pack messages if the epoch matches. There may be more room
735 // in the packet, so tell the caller to keep going.
736 should_continue = true;
737 break;
738 }
739
740 // Decode |msg|'s header.
741 CBS cbs(msg.data), body_cbs;
742 struct hm_header_st hdr;
743 if (!dtls1_parse_fragment(&cbs, &hdr, &body_cbs) || //
744 hdr.frag_off != 0 || //
745 hdr.frag_len != CBS_len(&body_cbs) || //
746 hdr.msg_len != CBS_len(&body_cbs) || //
747 CBS_len(&cbs) != 0) {
748 OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
749 return seal_error;
750 }
751
752 // Iterate over every un-acked range in the message, if any.
753 Span<const uint8_t> body = body_cbs;
754 for (;;) {
755 auto range = msg.acked.NextUnmarkedRange(ssl->d1->outgoing_offset);
756 if (range.empty()) {
757 // Advance to the next message.
758 ssl->d1->outgoing_offset = 0;
759 ssl->d1->outgoing_written++;
760 break;
761 }
762
763 // Determine how much progress can be made (minimum one byte of progress).
764 size_t capacity = fragments.size() - CBB_len(&cbb);
765 if (capacity < DTLS1_HM_HEADER_LENGTH + 1) {
766 goto packet_full;
767 }
768 size_t todo = std::min(range.size(), capacity - DTLS1_HM_HEADER_LENGTH);
769
770 // Empty messages are special-cased in ACK tracking. We act as if they
771 // have one byte, but in reality that byte is tracking the header.
772 Span<const uint8_t> frag;
773 if (!body.empty()) {
774 frag = body.subspan(range.start, todo);
775 }
776
777 // Assemble the fragment.
778 size_t frag_start = CBB_len(&cbb);
779 CBB child;
780 if (!CBB_add_u8(&cbb, hdr.type) || //
781 !CBB_add_u24(&cbb, hdr.msg_len) || //
782 !CBB_add_u16(&cbb, hdr.seq) || //
783 !CBB_add_u24(&cbb, range.start) || //
784 !CBB_add_u24_length_prefixed(&cbb, &child) || //
785 !CBB_add_bytes(&child, frag.data(), frag.size()) || //
786 !CBB_flush(&cbb)) {
787 OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
788 return seal_error;
789 }
790 size_t frag_end = CBB_len(&cbb);
791
792 // TODO(davidben): It is odd that, on output, we inform the caller of
793 // retransmits and individual fragments, but on input we only inform the
794 // caller of complete messages.
795 ssl_do_msg_callback(ssl, /*is_write=*/1, SSL3_RT_HANDSHAKE,
796 fragments.subspan(frag_start, frag_end - frag_start));
797
798 ssl->d1->outgoing_offset = range.start + todo;
799 if (todo < range.size()) {
800 // The packet was the limiting factor.
801 goto packet_full;
802 }
803 }
804
805 if (!allow_multiple_messages) {
806 should_continue = true;
807 break;
808 }
809 }
810
811 packet_full:
812 sent_record.last_msg = ssl->d1->outgoing_written;
813 sent_record.last_msg_end = ssl->d1->outgoing_offset;
814
815 // We could not fit anything. Don't try to make a record.
816 if (CBB_len(&cbb) == 0) {
817 assert(!should_continue);
818 return seal_flush;
819 }
820
821 if (!dtls_seal_record(ssl, &sent_record.number, out.data(), out_len,
822 out.size(), SSL3_RT_HANDSHAKE, CBB_data(&cbb),
823 CBB_len(&cbb), first_msg.epoch)) {
824 return seal_error;
825 }
826
827 // If DTLS 1.3 (or if the version is not yet known and it may be DTLS 1.3),
828 // save the record number to match against ACKs later.
829 if (ssl->s3->version == 0 || ssl_protocol_version(ssl) >= TLS1_3_VERSION) {
830 if (ssl->d1->sent_records == nullptr) {
831 ssl->d1->sent_records =
832 MakeUnique<MRUQueue<DTLSSentRecord, DTLS_MAX_ACK_BUFFER>>();
833 if (ssl->d1->sent_records == nullptr) {
834 return seal_error;
835 }
836 }
837 ssl->d1->sent_records->PushBack(sent_record);
838 }
839
840 return should_continue ? seal_continue : seal_flush;
841 }
842
843 // seal_next_packet writes as much of the next flight as possible to |out| and
844 // advances |ssl->d1->outgoing_written| and |ssl->d1->outgoing_offset| as
845 // appropriate.
seal_next_packet(SSL * ssl,Span<uint8_t> out,size_t * out_len)846 static bool seal_next_packet(SSL *ssl, Span<uint8_t> out, size_t *out_len) {
847 size_t total = 0;
848 for (;;) {
849 size_t len;
850 seal_result_t ret = seal_next_record(ssl, out, &len);
851 switch (ret) {
852 case seal_error:
853 return false;
854
855 case seal_flush:
856 case seal_continue:
857 out = out.subspan(len);
858 total += len;
859 break;
860 }
861
862 if (ret == seal_flush) {
863 break;
864 }
865 }
866
867 *out_len = total;
868 return true;
869 }
870
send_flight(SSL * ssl)871 static int send_flight(SSL *ssl) {
872 if (ssl->s3->write_shutdown != ssl_shutdown_none) {
873 OPENSSL_PUT_ERROR(SSL, SSL_R_PROTOCOL_IS_SHUTDOWN);
874 return -1;
875 }
876
877 if (ssl->wbio == nullptr) {
878 OPENSSL_PUT_ERROR(SSL, SSL_R_BIO_NOT_SET);
879 return -1;
880 }
881
882 if (ssl->d1->num_timeouts > DTLS1_MAX_TIMEOUTS) {
883 OPENSSL_PUT_ERROR(SSL, SSL_R_READ_TIMEOUT_EXPIRED);
884 return -1;
885 }
886
887 dtls1_update_mtu(ssl);
888
889 Array<uint8_t> packet;
890 if (!packet.InitForOverwrite(ssl->d1->mtu)) {
891 return -1;
892 }
893
894 while (ssl->d1->outgoing_written < ssl->d1->outgoing_messages.size()) {
895 uint8_t old_written = ssl->d1->outgoing_written;
896 uint32_t old_offset = ssl->d1->outgoing_offset;
897
898 size_t packet_len;
899 if (!seal_next_packet(ssl, Span(packet), &packet_len)) {
900 return -1;
901 }
902
903 if (packet_len == 0 &&
904 ssl->d1->outgoing_written < ssl->d1->outgoing_messages.size()) {
905 // We made no progress with the packet size available, but did not reach
906 // the end.
907 OPENSSL_PUT_ERROR(SSL, SSL_R_MTU_TOO_SMALL);
908 return false;
909 }
910
911 if (packet_len != 0) {
912 int bio_ret = BIO_write(ssl->wbio.get(), packet.data(), packet_len);
913 if (bio_ret <= 0) {
914 // Retry this packet the next time around.
915 ssl->d1->outgoing_written = old_written;
916 ssl->d1->outgoing_offset = old_offset;
917 ssl->s3->rwstate = SSL_ERROR_WANT_WRITE;
918 return bio_ret;
919 }
920 }
921 }
922
923 ssl->d1->pending_flush = true;
924 return 1;
925 }
926
dtls1_finish_flight(SSL * ssl)927 void dtls1_finish_flight(SSL *ssl) {
928 if (ssl->d1->outgoing_messages.empty() ||
929 ssl->d1->outgoing_messages_complete) {
930 return; // Nothing to do.
931 }
932
933 if (ssl->d1->outgoing_messages[0].epoch <= 2) {
934 // DTLS 1.3 handshake messages (epoch 2 and below) implicitly ACK the
935 // previous flight, so there is no need to ACK previous records. This
936 // clears the ACK buffer slightly earlier than the specification suggests.
937 // See the discussion in
938 // https://mailarchive.ietf.org/arch/msg/tls/kjJnquJOVaWxu5hUCmNzB35eqY0/
939 ssl->d1->records_to_ack.Clear();
940 ssl->d1->ack_timer.Stop();
941 ssl->d1->sending_ack = false;
942 }
943
944 ssl->d1->outgoing_messages_complete = true;
945 ssl->d1->sending_flight = true;
946 // Stop retransmitting the previous flight. In DTLS 1.3, we'll have stopped
947 // the timer already, but DTLS 1.2 keeps it running until the next flight is
948 // ready.
949 dtls1_stop_timer(ssl);
950 }
951
dtls1_schedule_ack(SSL * ssl)952 void dtls1_schedule_ack(SSL *ssl) {
953 ssl->d1->ack_timer.Stop();
954 ssl->d1->sending_ack = !ssl->d1->records_to_ack.empty();
955 }
956
send_ack(SSL * ssl)957 static int send_ack(SSL *ssl) {
958 assert(ssl_protocol_version(ssl) >= TLS1_3_VERSION);
959
960 // Ensure we don't send so many ACKs that we overflow the MTU. There is a
961 // 2-byte length prefix and each ACK is 16 bytes.
962 dtls1_update_mtu(ssl);
963 size_t max_plaintext =
964 dtls_seal_max_input_len(ssl, ssl->d1->write_epoch.epoch(), ssl->d1->mtu);
965 if (max_plaintext < 2 + 16) {
966 OPENSSL_PUT_ERROR(SSL, SSL_R_MTU_TOO_SMALL); // No room for even one ACK.
967 return -1;
968 }
969 size_t num_acks =
970 std::min((max_plaintext - 2) / 16, ssl->d1->records_to_ack.size());
971
972 // Assemble the ACK. RFC 9147 says to sort ACKs numerically. It is unclear if
973 // other implementations do this, but go ahead and sort for now. See
974 // https://mailarchive.ietf.org/arch/msg/tls/kjJnquJOVaWxu5hUCmNzB35eqY0/.
975 // Remove this if rfc9147bis removes this requirement.
976 InplaceVector<DTLSRecordNumber, DTLS_MAX_ACK_BUFFER> sorted;
977 for (size_t i = ssl->d1->records_to_ack.size() - num_acks;
978 i < ssl->d1->records_to_ack.size(); i++) {
979 sorted.PushBack(ssl->d1->records_to_ack[i]);
980 }
981 std::sort(sorted.begin(), sorted.end());
982
983 uint8_t buf[2 + 16 * DTLS_MAX_ACK_BUFFER];
984 CBB cbb, child;
985 CBB_init_fixed(&cbb, buf, sizeof(buf));
986 BSSL_CHECK(CBB_add_u16_length_prefixed(&cbb, &child));
987 for (const auto &number : sorted) {
988 BSSL_CHECK(CBB_add_u64(&child, number.epoch()));
989 BSSL_CHECK(CBB_add_u64(&child, number.sequence()));
990 }
991 BSSL_CHECK(CBB_flush(&cbb));
992
993 // Encrypt it.
994 uint8_t record[DTLS1_3_RECORD_HEADER_WRITE_LENGTH + sizeof(buf) +
995 1 /* record type */ + EVP_AEAD_MAX_OVERHEAD];
996 size_t record_len;
997 DTLSRecordNumber record_number;
998 if (!dtls_seal_record(ssl, &record_number, record, &record_len,
999 sizeof(record), SSL3_RT_ACK, CBB_data(&cbb),
1000 CBB_len(&cbb), ssl->d1->write_epoch.epoch())) {
1001 return -1;
1002 }
1003
1004 ssl_do_msg_callback(ssl, /*is_write=*/1, SSL3_RT_ACK,
1005 Span(CBB_data(&cbb), CBB_len(&cbb)));
1006
1007 int bio_ret =
1008 BIO_write(ssl->wbio.get(), record, static_cast<int>(record_len));
1009 if (bio_ret <= 0) {
1010 ssl->s3->rwstate = SSL_ERROR_WANT_WRITE;
1011 return bio_ret;
1012 }
1013
1014 ssl->d1->pending_flush = true;
1015 return 1;
1016 }
1017
dtls1_flush(SSL * ssl)1018 int dtls1_flush(SSL *ssl) {
1019 // Send the pending ACK, if any.
1020 if (ssl->d1->sending_ack) {
1021 int ret = send_ack(ssl);
1022 if (ret <= 0) {
1023 return ret;
1024 }
1025 ssl->d1->sending_ack = false;
1026 }
1027
1028 // Send the pending flight, if any.
1029 if (ssl->d1->sending_flight) {
1030 int ret = send_flight(ssl);
1031 if (ret <= 0) {
1032 return ret;
1033 }
1034
1035 // Reset state for the next send.
1036 ssl->d1->outgoing_written = 0;
1037 ssl->d1->outgoing_offset = 0;
1038 ssl->d1->sending_flight = false;
1039
1040 // Schedule the next retransmit timer. In DTLS 1.3, we retransmit all
1041 // flights until ACKed. In DTLS 1.2, the final Finished flight is never
1042 // ACKed, so we do not keep the timer running after the handshake.
1043 if (SSL_in_init(ssl) || ssl_protocol_version(ssl) >= TLS1_3_VERSION) {
1044 if (ssl->d1->num_timeouts == 0) {
1045 ssl->d1->timeout_duration_ms = ssl->initial_timeout_duration_ms;
1046 } else {
1047 ssl->d1->timeout_duration_ms =
1048 std::min(ssl->d1->timeout_duration_ms * 2, uint32_t{60000});
1049 }
1050
1051 OPENSSL_timeval now = ssl_ctx_get_current_time(ssl->ctx.get());
1052 ssl->d1->retransmit_timer.StartMicroseconds(
1053 now, uint64_t{ssl->d1->timeout_duration_ms} * 1000);
1054 }
1055 }
1056
1057 if (ssl->d1->pending_flush) {
1058 if (BIO_flush(ssl->wbio.get()) <= 0) {
1059 ssl->s3->rwstate = SSL_ERROR_WANT_WRITE;
1060 return -1;
1061 }
1062 ssl->d1->pending_flush = false;
1063 }
1064
1065 return 1;
1066 }
1067
dtls1_min_mtu(void)1068 unsigned int dtls1_min_mtu(void) { return kMinMTU; }
1069
1070 BSSL_NAMESPACE_END
1071