1 // Copyright 2005-2016 The OpenSSL Project Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <openssl/ssl.h>
16 
17 #include <assert.h>
18 #include <limits.h>
19 #include <string.h>
20 
21 #include <algorithm>
22 
23 #include <openssl/err.h>
24 #include <openssl/evp.h>
25 #include <openssl/mem.h>
26 #include <openssl/rand.h>
27 
28 #include "../crypto/internal.h"
29 #include "internal.h"
30 
31 
32 BSSL_NAMESPACE_BEGIN
33 
34 // TODO(davidben): 28 comes from the size of IP + UDP header. Is this reasonable
35 // for these values? Notably, why is kMinMTU a function of the transport
36 // protocol's overhead rather than, say, what's needed to hold a minimally-sized
37 // handshake fragment plus protocol overhead.
38 
39 // kMinMTU is the minimum acceptable MTU value.
40 static const unsigned int kMinMTU = 256 - 28;
41 
42 // kDefaultMTU is the default MTU value to use if neither the user nor
43 // the underlying BIO supplies one.
44 static const unsigned int kDefaultMTU = 1500 - 28;
45 
46 // BitRange returns a |uint8_t| with bits |start|, inclusive, to |end|,
47 // exclusive, set.
BitRange(size_t start,size_t end)48 static uint8_t BitRange(size_t start, size_t end) {
49   assert(start <= end && end <= 8);
50   return static_cast<uint8_t>(~((1u << start) - 1) & ((1u << end) - 1));
51 }
52 
53 // FirstUnmarkedRangeInByte returns the first unmarked range in bits |b|.
FirstUnmarkedRangeInByte(uint8_t b)54 static DTLSMessageBitmap::Range FirstUnmarkedRangeInByte(uint8_t b) {
55   size_t start, end;
56   for (start = 0; start < 8; start++) {
57     if ((b & (1u << start)) == 0) {
58       break;
59     }
60   }
61   for (end = start; end < 8; end++) {
62     if ((b & (1u << end)) != 0) {
63       break;
64     }
65   }
66   return DTLSMessageBitmap::Range{start, end};
67 }
68 
Init(size_t num_bits)69 bool DTLSMessageBitmap::Init(size_t num_bits) {
70   if (num_bits + 7 < num_bits) {
71     OPENSSL_PUT_ERROR(SSL, ERR_R_OVERFLOW);
72     return false;
73   }
74   size_t num_bytes = (num_bits + 7) / 8;
75   size_t bits_rounded = num_bytes * 8;
76   if (!bytes_.Init(num_bytes)) {
77     return false;
78   }
79   MarkRange(num_bits, bits_rounded);
80   first_unmarked_byte_ = 0;
81   return true;
82 }
83 
MarkRange(size_t start,size_t end)84 void DTLSMessageBitmap::MarkRange(size_t start, size_t end) {
85   assert(start <= end);
86   // Don't bother touching bytes that have already been marked.
87   start = std::max(start, first_unmarked_byte_ << 3);
88   // Clamp everything within range.
89   start = std::min(start, bytes_.size() << 3);
90   end = std::min(end, bytes_.size() << 3);
91   if (start >= end) {
92     return;
93   }
94 
95   if ((start >> 3) == (end >> 3)) {
96     bytes_[start >> 3] |= BitRange(start & 7, end & 7);
97   } else {
98     bytes_[start >> 3] |= BitRange(start & 7, 8);
99     for (size_t i = (start >> 3) + 1; i < (end >> 3); i++) {
100       bytes_[i] = 0xff;
101     }
102     if ((end & 7) != 0) {
103       bytes_[end >> 3] |= BitRange(0, end & 7);
104     }
105   }
106 
107   // Maintain the |first_unmarked_byte_| invariant. This work is amortized
108   // across all |MarkRange| calls.
109   while (first_unmarked_byte_ < bytes_.size() &&
110          bytes_[first_unmarked_byte_] == 0xff) {
111     first_unmarked_byte_++;
112   }
113   // If the whole message is marked, we no longer need to spend memory on the
114   // bitmap.
115   if (first_unmarked_byte_ >= bytes_.size()) {
116     bytes_.Reset();
117     first_unmarked_byte_ = 0;
118   }
119 }
120 
NextUnmarkedRange(size_t start) const121 DTLSMessageBitmap::Range DTLSMessageBitmap::NextUnmarkedRange(
122     size_t start) const {
123   // Don't bother looking at bytes that are known to be fully marked.
124   start = std::max(start, first_unmarked_byte_ << 3);
125 
126   size_t idx = start >> 3;
127   if (idx >= bytes_.size()) {
128     return Range{0, 0};
129   }
130 
131   // Look at the bits from |start| up to a byte boundary.
132   uint8_t byte = bytes_[idx] | BitRange(0, start & 7);
133   if (byte == 0xff) {
134     // Nothing unmarked at this byte. Keep searching for an unmarked bit.
135     for (idx = idx + 1; idx < bytes_.size(); idx++) {
136       if (bytes_[idx] != 0xff) {
137         byte = bytes_[idx];
138         break;
139       }
140     }
141     if (idx >= bytes_.size()) {
142       return Range{0, 0};
143     }
144   }
145 
146   Range range = FirstUnmarkedRangeInByte(byte);
147   assert(!range.empty());
148   bool should_extend = range.end == 8;
149   range.start += idx << 3;
150   range.end += idx << 3;
151   if (!should_extend) {
152     // The range did not end at a byte boundary. We're done.
153     return range;
154   }
155 
156   // Collect all fully unmarked bytes.
157   for (idx = idx + 1; idx < bytes_.size(); idx++) {
158     if (bytes_[idx] != 0) {
159       break;
160     }
161   }
162   range.end = idx << 3;
163 
164   // Add any bits from the remaining byte, if any.
165   if (idx < bytes_.size()) {
166     Range extra = FirstUnmarkedRangeInByte(bytes_[idx]);
167     if (extra.start == 0) {
168       range.end += extra.end;
169     }
170   }
171 
172   return range;
173 }
174 
175 // Receiving handshake messages.
176 
dtls_new_incoming_message(const struct hm_header_st * msg_hdr)177 static UniquePtr<DTLSIncomingMessage> dtls_new_incoming_message(
178     const struct hm_header_st *msg_hdr) {
179   ScopedCBB cbb;
180   UniquePtr<DTLSIncomingMessage> frag = MakeUnique<DTLSIncomingMessage>();
181   if (!frag) {
182     return nullptr;
183   }
184   frag->type = msg_hdr->type;
185   frag->seq = msg_hdr->seq;
186 
187   // Allocate space for the reassembled message and fill in the header.
188   if (!frag->data.InitForOverwrite(DTLS1_HM_HEADER_LENGTH + msg_hdr->msg_len)) {
189     return nullptr;
190   }
191 
192   if (!CBB_init_fixed(cbb.get(), frag->data.data(), DTLS1_HM_HEADER_LENGTH) ||
193       !CBB_add_u8(cbb.get(), msg_hdr->type) ||
194       !CBB_add_u24(cbb.get(), msg_hdr->msg_len) ||
195       !CBB_add_u16(cbb.get(), msg_hdr->seq) ||
196       !CBB_add_u24(cbb.get(), 0 /* frag_off */) ||
197       !CBB_add_u24(cbb.get(), msg_hdr->msg_len) ||
198       !CBB_finish(cbb.get(), NULL, NULL)) {
199     return nullptr;
200   }
201 
202   if (!frag->reassembly.Init(msg_hdr->msg_len)) {
203     return nullptr;
204   }
205 
206   return frag;
207 }
208 
209 // dtls1_is_current_message_complete returns whether the current handshake
210 // message is complete.
dtls1_is_current_message_complete(const SSL * ssl)211 static bool dtls1_is_current_message_complete(const SSL *ssl) {
212   size_t idx = ssl->d1->handshake_read_seq % SSL_MAX_HANDSHAKE_FLIGHT;
213   DTLSIncomingMessage *frag = ssl->d1->incoming_messages[idx].get();
214   return frag != nullptr && frag->reassembly.IsComplete();
215 }
216 
217 // dtls1_get_incoming_message returns the incoming message corresponding to
218 // |msg_hdr|. If none exists, it creates a new one and inserts it in the
219 // queue. Otherwise, it checks |msg_hdr| is consistent with the existing one. It
220 // returns NULL on failure. The caller does not take ownership of the result.
dtls1_get_incoming_message(SSL * ssl,uint8_t * out_alert,const struct hm_header_st * msg_hdr)221 static DTLSIncomingMessage *dtls1_get_incoming_message(
222     SSL *ssl, uint8_t *out_alert, const struct hm_header_st *msg_hdr) {
223   if (msg_hdr->seq < ssl->d1->handshake_read_seq ||
224       msg_hdr->seq - ssl->d1->handshake_read_seq >= SSL_MAX_HANDSHAKE_FLIGHT) {
225     *out_alert = SSL_AD_INTERNAL_ERROR;
226     return NULL;
227   }
228 
229   size_t idx = msg_hdr->seq % SSL_MAX_HANDSHAKE_FLIGHT;
230   DTLSIncomingMessage *frag = ssl->d1->incoming_messages[idx].get();
231   if (frag != NULL) {
232     assert(frag->seq == msg_hdr->seq);
233     // The new fragment must be compatible with the previous fragments from this
234     // message.
235     if (frag->type != msg_hdr->type ||  //
236         frag->msg_len() != msg_hdr->msg_len) {
237       OPENSSL_PUT_ERROR(SSL, SSL_R_FRAGMENT_MISMATCH);
238       *out_alert = SSL_AD_ILLEGAL_PARAMETER;
239       return NULL;
240     }
241     return frag;
242   }
243 
244   // This is the first fragment from this message.
245   ssl->d1->incoming_messages[idx] = dtls_new_incoming_message(msg_hdr);
246   if (!ssl->d1->incoming_messages[idx]) {
247     *out_alert = SSL_AD_INTERNAL_ERROR;
248     return NULL;
249   }
250   return ssl->d1->incoming_messages[idx].get();
251 }
252 
dtls1_process_handshake_fragments(SSL * ssl,uint8_t * out_alert,DTLSRecordNumber record_number,Span<const uint8_t> record)253 bool dtls1_process_handshake_fragments(SSL *ssl, uint8_t *out_alert,
254                                        DTLSRecordNumber record_number,
255                                        Span<const uint8_t> record) {
256   bool implicit_ack = false;
257   bool skipped_fragments = false;
258   CBS cbs = record;
259   while (CBS_len(&cbs) > 0) {
260     // Read a handshake fragment.
261     struct hm_header_st msg_hdr;
262     CBS body;
263     if (!dtls1_parse_fragment(&cbs, &msg_hdr, &body)) {
264       OPENSSL_PUT_ERROR(SSL, SSL_R_BAD_HANDSHAKE_RECORD);
265       *out_alert = SSL_AD_DECODE_ERROR;
266       return false;
267     }
268 
269     const size_t frag_off = msg_hdr.frag_off;
270     const size_t frag_len = msg_hdr.frag_len;
271     const size_t msg_len = msg_hdr.msg_len;
272     if (frag_off > msg_len || frag_len > msg_len - frag_off) {
273       OPENSSL_PUT_ERROR(SSL, SSL_R_BAD_HANDSHAKE_RECORD);
274       *out_alert = SSL_AD_ILLEGAL_PARAMETER;
275       return false;
276     }
277 
278     if (msg_hdr.seq < ssl->d1->handshake_read_seq ||
279         ssl->d1->handshake_read_overflow) {
280       // Ignore fragments from the past. This is a retransmit of data we already
281       // received.
282       //
283       // TODO(crbug.com/42290594): Use this to drive retransmits.
284       continue;
285     }
286 
287     if (record_number.epoch() != ssl->d1->read_epoch.epoch ||
288         ssl->d1->next_read_epoch != nullptr) {
289       // New messages can only arrive in the latest epoch. This can fail if the
290       // record came from |prev_read_epoch|, or if it came from |read_epoch| but
291       // |next_read_epoch| exists. (It cannot come from |next_read_epoch|
292       // because |next_read_epoch| becomes |read_epoch| once it receives a
293       // record.)
294       OPENSSL_PUT_ERROR(SSL, SSL_R_EXCESS_HANDSHAKE_DATA);
295       *out_alert = SSL_AD_UNEXPECTED_MESSAGE;
296       return false;
297     }
298 
299     if (msg_len > ssl_max_handshake_message_len(ssl)) {
300       OPENSSL_PUT_ERROR(SSL, SSL_R_EXCESSIVE_MESSAGE_SIZE);
301       *out_alert = SSL_AD_ILLEGAL_PARAMETER;
302       return false;
303     }
304 
305     if (SSL_in_init(ssl) && ssl_has_final_version(ssl) &&
306         ssl_protocol_version(ssl) >= TLS1_3_VERSION) {
307       // During the handshake, if we receive any portion of the next flight, the
308       // peer must have received our most recent flight. In DTLS 1.3, this is an
309       // implicit ACK. See RFC 9147, Section 7.1.
310       //
311       // This only applies during the handshake. After the handshake, the next
312       // message may be part of a post-handshake transaction. It also does not
313       // apply immediately after the handshake. As a client, receiving a
314       // KeyUpdate or NewSessionTicket does not imply the server has received
315       // our Finished. The server may have sent those messages in half-RTT.
316       implicit_ack = true;
317     }
318 
319     if (msg_hdr.seq - ssl->d1->handshake_read_seq > SSL_MAX_HANDSHAKE_FLIGHT) {
320       // Ignore fragments too far in the future.
321       skipped_fragments = true;
322       continue;
323     }
324 
325     DTLSIncomingMessage *frag =
326         dtls1_get_incoming_message(ssl, out_alert, &msg_hdr);
327     if (frag == nullptr) {
328       return false;
329     }
330     assert(frag->msg_len() == msg_len);
331 
332     if (frag->reassembly.IsComplete()) {
333       // The message is already assembled.
334       continue;
335     }
336     assert(msg_len > 0);
337 
338     // Copy the body into the fragment.
339     Span<uint8_t> dest = frag->msg().subspan(frag_off, CBS_len(&body));
340     OPENSSL_memcpy(dest.data(), CBS_data(&body), CBS_len(&body));
341     frag->reassembly.MarkRange(frag_off, frag_off + frag_len);
342   }
343 
344   if (implicit_ack) {
345     dtls1_stop_timer(ssl);
346     dtls_clear_outgoing_messages(ssl);
347   }
348 
349   if (!skipped_fragments) {
350     ssl->d1->records_to_ack.PushBack(record_number);
351 
352     if (ssl_has_final_version(ssl) &&
353         ssl_protocol_version(ssl) >= TLS1_3_VERSION &&
354         !ssl->d1->ack_timer.IsSet() && !ssl->d1->sending_ack) {
355       // Schedule sending an ACK. The delay serves several purposes:
356       // - If there are more records to come, we send only one ACK.
357       // - If there are more records to come and the flight is now complete, we
358       //   will send the reply (which implicitly ACKs the previous flight) and
359       //   cancel the timer.
360       // - If there are more records to come, the flight is now complete, but
361       //   generating the response is delayed (e.g. a slow, async private key),
362       //   the timer will fire and we send an ACK anyway.
363       OPENSSL_timeval now = ssl_ctx_get_current_time(ssl->ctx.get());
364       ssl->d1->ack_timer.StartMicroseconds(
365           now, uint64_t{ssl->d1->timeout_duration_ms} * 1000 / 4);
366     }
367   }
368 
369   return true;
370 }
371 
dtls1_open_handshake(SSL * ssl,size_t * out_consumed,uint8_t * out_alert,Span<uint8_t> in)372 ssl_open_record_t dtls1_open_handshake(SSL *ssl, size_t *out_consumed,
373                                        uint8_t *out_alert, Span<uint8_t> in) {
374   uint8_t type;
375   DTLSRecordNumber record_number;
376   Span<uint8_t> record;
377   auto ret = dtls_open_record(ssl, &type, &record_number, &record, out_consumed,
378                               out_alert, in);
379   if (ret != ssl_open_record_success) {
380     return ret;
381   }
382 
383   switch (type) {
384     case SSL3_RT_APPLICATION_DATA:
385       // In DTLS 1.2, out-of-order application data may be received between
386       // ChangeCipherSpec and Finished. Discard it.
387       return ssl_open_record_discard;
388 
389     case SSL3_RT_CHANGE_CIPHER_SPEC:
390       if (record.size() != 1u || record[0] != SSL3_MT_CCS) {
391         OPENSSL_PUT_ERROR(SSL, SSL_R_BAD_CHANGE_CIPHER_SPEC);
392         *out_alert = SSL_AD_ILLEGAL_PARAMETER;
393         return ssl_open_record_error;
394       }
395 
396       // We do not support renegotiation, so encrypted ChangeCipherSpec records
397       // are illegal.
398       if (record_number.epoch() != 0) {
399         OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_RECORD);
400         *out_alert = SSL_AD_UNEXPECTED_MESSAGE;
401         return ssl_open_record_error;
402       }
403 
404       // Ignore ChangeCipherSpec from a previous epoch.
405       if (record_number.epoch() != ssl->d1->read_epoch.epoch) {
406         return ssl_open_record_discard;
407       }
408 
409       // Flag the ChangeCipherSpec for later.
410       // TODO(crbug.com/42290594): Should we reject this in DTLS 1.3?
411       ssl->d1->has_change_cipher_spec = true;
412       ssl_do_msg_callback(ssl, 0 /* read */, SSL3_RT_CHANGE_CIPHER_SPEC,
413                           record);
414       return ssl_open_record_success;
415 
416     case SSL3_RT_ACK:
417       return dtls1_process_ack(ssl, out_alert, record_number, record);
418 
419     case SSL3_RT_HANDSHAKE:
420       if (!dtls1_process_handshake_fragments(ssl, out_alert, record_number,
421                                              record)) {
422         return ssl_open_record_error;
423       }
424       return ssl_open_record_success;
425 
426     default:
427       OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_RECORD);
428       *out_alert = SSL_AD_UNEXPECTED_MESSAGE;
429       return ssl_open_record_error;
430   }
431 }
432 
dtls1_get_message(const SSL * ssl,SSLMessage * out)433 bool dtls1_get_message(const SSL *ssl, SSLMessage *out) {
434   if (!dtls1_is_current_message_complete(ssl)) {
435     return false;
436   }
437 
438   size_t idx = ssl->d1->handshake_read_seq % SSL_MAX_HANDSHAKE_FLIGHT;
439   const DTLSIncomingMessage *frag = ssl->d1->incoming_messages[idx].get();
440   out->type = frag->type;
441   out->raw = CBS(frag->data);
442   out->body = CBS(frag->msg());
443   out->is_v2_hello = false;
444   if (!ssl->s3->has_message) {
445     ssl_do_msg_callback(ssl, 0 /* read */, SSL3_RT_HANDSHAKE, out->raw);
446     ssl->s3->has_message = true;
447   }
448   return true;
449 }
450 
dtls1_next_message(SSL * ssl)451 void dtls1_next_message(SSL *ssl) {
452   assert(ssl->s3->has_message);
453   assert(dtls1_is_current_message_complete(ssl));
454   size_t index = ssl->d1->handshake_read_seq % SSL_MAX_HANDSHAKE_FLIGHT;
455   ssl->d1->incoming_messages[index].reset();
456   ssl->d1->handshake_read_seq++;
457   if (ssl->d1->handshake_read_seq == 0) {
458     ssl->d1->handshake_read_overflow = true;
459   }
460   ssl->s3->has_message = false;
461   // If we previously sent a flight, mark it as having a reply, so
462   // |on_handshake_complete| can manage post-handshake retransmission.
463   if (ssl->d1->outgoing_messages_complete) {
464     ssl->d1->flight_has_reply = true;
465   }
466 }
467 
dtls_has_unprocessed_handshake_data(const SSL * ssl)468 bool dtls_has_unprocessed_handshake_data(const SSL *ssl) {
469   size_t current = ssl->d1->handshake_read_seq % SSL_MAX_HANDSHAKE_FLIGHT;
470   for (size_t i = 0; i < SSL_MAX_HANDSHAKE_FLIGHT; i++) {
471     // Skip the current message.
472     if (ssl->s3->has_message && i == current) {
473       assert(dtls1_is_current_message_complete(ssl));
474       continue;
475     }
476     if (ssl->d1->incoming_messages[i] != nullptr) {
477       return true;
478     }
479   }
480   return false;
481 }
482 
dtls1_parse_fragment(CBS * cbs,struct hm_header_st * out_hdr,CBS * out_body)483 bool dtls1_parse_fragment(CBS *cbs, struct hm_header_st *out_hdr,
484                           CBS *out_body) {
485   OPENSSL_memset(out_hdr, 0x00, sizeof(struct hm_header_st));
486 
487   if (!CBS_get_u8(cbs, &out_hdr->type) ||
488       !CBS_get_u24(cbs, &out_hdr->msg_len) ||
489       !CBS_get_u16(cbs, &out_hdr->seq) ||
490       !CBS_get_u24(cbs, &out_hdr->frag_off) ||
491       !CBS_get_u24(cbs, &out_hdr->frag_len) ||
492       !CBS_get_bytes(cbs, out_body, out_hdr->frag_len)) {
493     return false;
494   }
495 
496   return true;
497 }
498 
dtls1_open_change_cipher_spec(SSL * ssl,size_t * out_consumed,uint8_t * out_alert,Span<uint8_t> in)499 ssl_open_record_t dtls1_open_change_cipher_spec(SSL *ssl, size_t *out_consumed,
500                                                 uint8_t *out_alert,
501                                                 Span<uint8_t> in) {
502   if (!ssl->d1->has_change_cipher_spec) {
503     // dtls1_open_handshake processes both handshake and ChangeCipherSpec.
504     auto ret = dtls1_open_handshake(ssl, out_consumed, out_alert, in);
505     if (ret != ssl_open_record_success) {
506       return ret;
507     }
508   }
509   if (ssl->d1->has_change_cipher_spec) {
510     ssl->d1->has_change_cipher_spec = false;
511     return ssl_open_record_success;
512   }
513   return ssl_open_record_discard;
514 }
515 
516 
517 // Sending handshake messages.
518 
dtls_clear_outgoing_messages(SSL * ssl)519 void dtls_clear_outgoing_messages(SSL *ssl) {
520   ssl->d1->outgoing_messages.clear();
521   ssl->d1->sent_records = nullptr;
522   ssl->d1->outgoing_written = 0;
523   ssl->d1->outgoing_offset = 0;
524   ssl->d1->outgoing_messages_complete = false;
525   ssl->d1->flight_has_reply = false;
526   ssl->d1->sending_flight = false;
527   dtls_clear_unused_write_epochs(ssl);
528 }
529 
dtls_clear_unused_write_epochs(SSL * ssl)530 void dtls_clear_unused_write_epochs(SSL *ssl) {
531   ssl->d1->extra_write_epochs.EraseIf(
532       [ssl](const UniquePtr<DTLSWriteEpoch> &write_epoch) -> bool {
533         // Non-current epochs may be discarded once there are no incomplete
534         // outgoing messages that reference them.
535         //
536         // TODO(crbug.com/42290594): Epoch 1 (0-RTT) should be retained until
537         // epoch 3 (app data) is available.
538         for (const auto &msg : ssl->d1->outgoing_messages) {
539           if (msg.epoch == write_epoch->epoch() && !msg.IsFullyAcked()) {
540             return false;
541           }
542         }
543         return true;
544       });
545 }
546 
dtls1_init_message(const SSL * ssl,CBB * cbb,CBB * body,uint8_t type)547 bool dtls1_init_message(const SSL *ssl, CBB *cbb, CBB *body, uint8_t type) {
548   // Pick a modest size hint to save most of the |realloc| calls.
549   if (!CBB_init(cbb, 64) ||                                   //
550       !CBB_add_u8(cbb, type) ||                               //
551       !CBB_add_u24(cbb, 0 /* length (filled in later) */) ||  //
552       !CBB_add_u16(cbb, ssl->d1->handshake_write_seq) ||      //
553       !CBB_add_u24(cbb, 0 /* offset */) ||                    //
554       !CBB_add_u24_length_prefixed(cbb, body)) {
555     return false;
556   }
557 
558   return true;
559 }
560 
dtls1_finish_message(const SSL * ssl,CBB * cbb,Array<uint8_t> * out_msg)561 bool dtls1_finish_message(const SSL *ssl, CBB *cbb, Array<uint8_t> *out_msg) {
562   if (!CBBFinishArray(cbb, out_msg) ||
563       out_msg->size() < DTLS1_HM_HEADER_LENGTH) {
564     OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
565     return false;
566   }
567 
568   // Fix up the header. Copy the fragment length into the total message
569   // length.
570   OPENSSL_memcpy(out_msg->data() + 1,
571                  out_msg->data() + DTLS1_HM_HEADER_LENGTH - 3, 3);
572   return true;
573 }
574 
575 // add_outgoing adds a new handshake message or ChangeCipherSpec to the current
576 // outgoing flight. It returns true on success and false on error.
add_outgoing(SSL * ssl,bool is_ccs,Array<uint8_t> data)577 static bool add_outgoing(SSL *ssl, bool is_ccs, Array<uint8_t> data) {
578   if (ssl->d1->outgoing_messages_complete) {
579     // If we've begun writing a new flight, we received the peer flight. Discard
580     // the timer and the our flight.
581     dtls1_stop_timer(ssl);
582     dtls_clear_outgoing_messages(ssl);
583   }
584 
585   if (!is_ccs) {
586     if (ssl->d1->handshake_write_overflow) {
587       OPENSSL_PUT_ERROR(SSL, ERR_R_OVERFLOW);
588       return false;
589     }
590     // TODO(svaldez): Move this up a layer to fix abstraction for SSLTranscript
591     // on hs.
592     if (ssl->s3->hs != NULL && !ssl->s3->hs->transcript.Update(data)) {
593       OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
594       return false;
595     }
596     ssl->d1->handshake_write_seq++;
597     if (ssl->d1->handshake_write_seq == 0) {
598       ssl->d1->handshake_write_overflow = true;
599     }
600   }
601 
602   DTLSOutgoingMessage msg;
603   msg.data = std::move(data);
604   msg.epoch = ssl->d1->write_epoch.epoch();
605   msg.is_ccs = is_ccs;
606   // Zero-length messages need 1 bit to track whether the peer has received the
607   // message header. (Normally the message header is implicitly received when
608   // any fragment of the message is received at all.)
609   if (!is_ccs && !msg.acked.Init(std::max(msg.msg_len(), size_t{1}))) {
610     return false;
611   }
612 
613   // This should not fail if |SSL_MAX_HANDSHAKE_FLIGHT| was sized correctly.
614   //
615   // TODO(crbug.com/42290594): This can currently fail in DTLS 1.3. The caller
616   // can configure how many tickets to send, up to kMaxTickets. Additionally, if
617   // we send 0.5-RTT tickets in 0-RTT, we may even have tickets queued up with
618   // the server flight.
619   if (!ssl->d1->outgoing_messages.TryPushBack(std::move(msg))) {
620     assert(false);
621     OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
622     return false;
623   }
624 
625   return true;
626 }
627 
dtls1_add_message(SSL * ssl,Array<uint8_t> data)628 bool dtls1_add_message(SSL *ssl, Array<uint8_t> data) {
629   return add_outgoing(ssl, false /* handshake */, std::move(data));
630 }
631 
dtls1_add_change_cipher_spec(SSL * ssl)632 bool dtls1_add_change_cipher_spec(SSL *ssl) {
633   // DTLS 1.3 disables compatibility mode, which means that DTLS 1.3 never sends
634   // a ChangeCipherSpec message.
635   if (ssl_protocol_version(ssl) > TLS1_2_VERSION) {
636     return true;
637   }
638   return add_outgoing(ssl, true /* ChangeCipherSpec */, Array<uint8_t>());
639 }
640 
641 // dtls1_update_mtu updates the current MTU from the BIO, ensuring it is above
642 // the minimum.
dtls1_update_mtu(SSL * ssl)643 static void dtls1_update_mtu(SSL *ssl) {
644   // TODO(davidben): No consumer implements |BIO_CTRL_DGRAM_SET_MTU| and the
645   // only |BIO_CTRL_DGRAM_QUERY_MTU| implementation could use
646   // |SSL_set_mtu|. Does this need to be so complex?
647   if (ssl->d1->mtu < dtls1_min_mtu() &&
648       !(SSL_get_options(ssl) & SSL_OP_NO_QUERY_MTU)) {
649     long mtu = BIO_ctrl(ssl->wbio.get(), BIO_CTRL_DGRAM_QUERY_MTU, 0, NULL);
650     if (mtu >= 0 && mtu <= (1 << 30) && (unsigned)mtu >= dtls1_min_mtu()) {
651       ssl->d1->mtu = (unsigned)mtu;
652     } else {
653       ssl->d1->mtu = kDefaultMTU;
654       BIO_ctrl(ssl->wbio.get(), BIO_CTRL_DGRAM_SET_MTU, ssl->d1->mtu, NULL);
655     }
656   }
657 
658   // The MTU should be above the minimum now.
659   assert(ssl->d1->mtu >= dtls1_min_mtu());
660 }
661 
662 enum seal_result_t {
663   seal_error,
664   seal_continue,
665   seal_flush,
666 };
667 
668 // seal_next_record seals one record's worth of messages to |out| and advances
669 // |ssl|'s internal state past the data that was sealed. If progress was made,
670 // it returns |seal_flush| or |seal_continue| and sets
671 // |*out_len| to the number of bytes written.
672 //
673 // If the function stopped because the next message could not be combined into
674 // this record, it returns |seal_continue| and the caller should loop again.
675 // Otherwise, it returns |seal_flush| and the packet is complete (either because
676 // there are no more messages or the packet is full).
seal_next_record(SSL * ssl,Span<uint8_t> out,size_t * out_len)677 static seal_result_t seal_next_record(SSL *ssl, Span<uint8_t> out,
678                                       size_t *out_len) {
679   *out_len = 0;
680 
681   // Skip any fully acked messages.
682   while (ssl->d1->outgoing_written < ssl->d1->outgoing_messages.size() &&
683          ssl->d1->outgoing_messages[ssl->d1->outgoing_written].IsFullyAcked()) {
684     ssl->d1->outgoing_offset = 0;
685     ssl->d1->outgoing_written++;
686   }
687 
688   // There was nothing left to write.
689   if (ssl->d1->outgoing_written >= ssl->d1->outgoing_messages.size()) {
690     return seal_flush;
691   }
692 
693   const auto &first_msg = ssl->d1->outgoing_messages[ssl->d1->outgoing_written];
694   size_t prefix_len = dtls_seal_prefix_len(ssl, first_msg.epoch);
695   size_t max_in_len = dtls_seal_max_input_len(ssl, first_msg.epoch, out.size());
696   if (max_in_len == 0) {
697     // There is no room for a single record.
698     return seal_flush;
699   }
700 
701   if (first_msg.is_ccs) {
702     static const uint8_t kChangeCipherSpec[1] = {SSL3_MT_CCS};
703     DTLSRecordNumber record_number;
704     if (!dtls_seal_record(ssl, &record_number, out.data(), out_len, out.size(),
705                           SSL3_RT_CHANGE_CIPHER_SPEC, kChangeCipherSpec,
706                           sizeof(kChangeCipherSpec), first_msg.epoch)) {
707       return seal_error;
708     }
709 
710     ssl_do_msg_callback(ssl, /*is_write=*/1, SSL3_RT_CHANGE_CIPHER_SPEC,
711                         kChangeCipherSpec);
712     ssl->d1->outgoing_offset = 0;
713     ssl->d1->outgoing_written++;
714     return seal_continue;
715   }
716 
717   // TODO(crbug.com/374991962): For now, only send one message per record in
718   // epoch 0. Sending multiple is allowed and more efficient, but breaks
719   // b/378742138.
720   const bool allow_multiple_messages = first_msg.epoch != 0;
721 
722   // Pack as many handshake fragments into one record as we can. We stage the
723   // fragments in the output buffer, to be sealed in-place.
724   bool should_continue = false;
725   Span<uint8_t> fragments = out.subspan(prefix_len, max_in_len);
726   CBB cbb;
727   CBB_init_fixed(&cbb, fragments.data(), fragments.size());
728   DTLSSentRecord sent_record;
729   sent_record.first_msg = ssl->d1->outgoing_written;
730   sent_record.first_msg_start = ssl->d1->outgoing_offset;
731   while (ssl->d1->outgoing_written < ssl->d1->outgoing_messages.size()) {
732     const auto &msg = ssl->d1->outgoing_messages[ssl->d1->outgoing_written];
733     if (msg.epoch != first_msg.epoch || msg.is_ccs) {
734       // We can only pack messages if the epoch matches. There may be more room
735       // in the packet, so tell the caller to keep going.
736       should_continue = true;
737       break;
738     }
739 
740     // Decode |msg|'s header.
741     CBS cbs(msg.data), body_cbs;
742     struct hm_header_st hdr;
743     if (!dtls1_parse_fragment(&cbs, &hdr, &body_cbs) ||  //
744         hdr.frag_off != 0 ||                             //
745         hdr.frag_len != CBS_len(&body_cbs) ||            //
746         hdr.msg_len != CBS_len(&body_cbs) ||             //
747         CBS_len(&cbs) != 0) {
748       OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
749       return seal_error;
750     }
751 
752     // Iterate over every un-acked range in the message, if any.
753     Span<const uint8_t> body = body_cbs;
754     for (;;) {
755       auto range = msg.acked.NextUnmarkedRange(ssl->d1->outgoing_offset);
756       if (range.empty()) {
757         // Advance to the next message.
758         ssl->d1->outgoing_offset = 0;
759         ssl->d1->outgoing_written++;
760         break;
761       }
762 
763       // Determine how much progress can be made (minimum one byte of progress).
764       size_t capacity = fragments.size() - CBB_len(&cbb);
765       if (capacity < DTLS1_HM_HEADER_LENGTH + 1) {
766         goto packet_full;
767       }
768       size_t todo = std::min(range.size(), capacity - DTLS1_HM_HEADER_LENGTH);
769 
770       // Empty messages are special-cased in ACK tracking. We act as if they
771       // have one byte, but in reality that byte is tracking the header.
772       Span<const uint8_t> frag;
773       if (!body.empty()) {
774         frag = body.subspan(range.start, todo);
775       }
776 
777       // Assemble the fragment.
778       size_t frag_start = CBB_len(&cbb);
779       CBB child;
780       if (!CBB_add_u8(&cbb, hdr.type) ||                       //
781           !CBB_add_u24(&cbb, hdr.msg_len) ||                   //
782           !CBB_add_u16(&cbb, hdr.seq) ||                       //
783           !CBB_add_u24(&cbb, range.start) ||                   //
784           !CBB_add_u24_length_prefixed(&cbb, &child) ||        //
785           !CBB_add_bytes(&child, frag.data(), frag.size()) ||  //
786           !CBB_flush(&cbb)) {
787         OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
788         return seal_error;
789       }
790       size_t frag_end = CBB_len(&cbb);
791 
792       // TODO(davidben): It is odd that, on output, we inform the caller of
793       // retransmits and individual fragments, but on input we only inform the
794       // caller of complete messages.
795       ssl_do_msg_callback(ssl, /*is_write=*/1, SSL3_RT_HANDSHAKE,
796                           fragments.subspan(frag_start, frag_end - frag_start));
797 
798       ssl->d1->outgoing_offset = range.start + todo;
799       if (todo < range.size()) {
800         // The packet was the limiting factor.
801         goto packet_full;
802       }
803     }
804 
805     if (!allow_multiple_messages) {
806       should_continue = true;
807       break;
808     }
809   }
810 
811 packet_full:
812   sent_record.last_msg = ssl->d1->outgoing_written;
813   sent_record.last_msg_end = ssl->d1->outgoing_offset;
814 
815   // We could not fit anything. Don't try to make a record.
816   if (CBB_len(&cbb) == 0) {
817     assert(!should_continue);
818     return seal_flush;
819   }
820 
821   if (!dtls_seal_record(ssl, &sent_record.number, out.data(), out_len,
822                         out.size(), SSL3_RT_HANDSHAKE, CBB_data(&cbb),
823                         CBB_len(&cbb), first_msg.epoch)) {
824     return seal_error;
825   }
826 
827   // If DTLS 1.3 (or if the version is not yet known and it may be DTLS 1.3),
828   // save the record number to match against ACKs later.
829   if (ssl->s3->version == 0 || ssl_protocol_version(ssl) >= TLS1_3_VERSION) {
830     if (ssl->d1->sent_records == nullptr) {
831       ssl->d1->sent_records =
832           MakeUnique<MRUQueue<DTLSSentRecord, DTLS_MAX_ACK_BUFFER>>();
833       if (ssl->d1->sent_records == nullptr) {
834         return seal_error;
835       }
836     }
837     ssl->d1->sent_records->PushBack(sent_record);
838   }
839 
840   return should_continue ? seal_continue : seal_flush;
841 }
842 
843 // seal_next_packet writes as much of the next flight as possible to |out| and
844 // advances |ssl->d1->outgoing_written| and |ssl->d1->outgoing_offset| as
845 // appropriate.
seal_next_packet(SSL * ssl,Span<uint8_t> out,size_t * out_len)846 static bool seal_next_packet(SSL *ssl, Span<uint8_t> out, size_t *out_len) {
847   size_t total = 0;
848   for (;;) {
849     size_t len;
850     seal_result_t ret = seal_next_record(ssl, out, &len);
851     switch (ret) {
852       case seal_error:
853         return false;
854 
855       case seal_flush:
856       case seal_continue:
857         out = out.subspan(len);
858         total += len;
859         break;
860     }
861 
862     if (ret == seal_flush) {
863       break;
864     }
865   }
866 
867   *out_len = total;
868   return true;
869 }
870 
send_flight(SSL * ssl)871 static int send_flight(SSL *ssl) {
872   if (ssl->s3->write_shutdown != ssl_shutdown_none) {
873     OPENSSL_PUT_ERROR(SSL, SSL_R_PROTOCOL_IS_SHUTDOWN);
874     return -1;
875   }
876 
877   if (ssl->wbio == nullptr) {
878     OPENSSL_PUT_ERROR(SSL, SSL_R_BIO_NOT_SET);
879     return -1;
880   }
881 
882   if (ssl->d1->num_timeouts > DTLS1_MAX_TIMEOUTS) {
883     OPENSSL_PUT_ERROR(SSL, SSL_R_READ_TIMEOUT_EXPIRED);
884     return -1;
885   }
886 
887   dtls1_update_mtu(ssl);
888 
889   Array<uint8_t> packet;
890   if (!packet.InitForOverwrite(ssl->d1->mtu)) {
891     return -1;
892   }
893 
894   while (ssl->d1->outgoing_written < ssl->d1->outgoing_messages.size()) {
895     uint8_t old_written = ssl->d1->outgoing_written;
896     uint32_t old_offset = ssl->d1->outgoing_offset;
897 
898     size_t packet_len;
899     if (!seal_next_packet(ssl, Span(packet), &packet_len)) {
900       return -1;
901     }
902 
903     if (packet_len == 0 &&
904         ssl->d1->outgoing_written < ssl->d1->outgoing_messages.size()) {
905       // We made no progress with the packet size available, but did not reach
906       // the end.
907       OPENSSL_PUT_ERROR(SSL, SSL_R_MTU_TOO_SMALL);
908       return false;
909     }
910 
911     if (packet_len != 0) {
912       int bio_ret = BIO_write(ssl->wbio.get(), packet.data(), packet_len);
913       if (bio_ret <= 0) {
914         // Retry this packet the next time around.
915         ssl->d1->outgoing_written = old_written;
916         ssl->d1->outgoing_offset = old_offset;
917         ssl->s3->rwstate = SSL_ERROR_WANT_WRITE;
918         return bio_ret;
919       }
920     }
921   }
922 
923   ssl->d1->pending_flush = true;
924   return 1;
925 }
926 
dtls1_finish_flight(SSL * ssl)927 void dtls1_finish_flight(SSL *ssl) {
928   if (ssl->d1->outgoing_messages.empty() ||
929       ssl->d1->outgoing_messages_complete) {
930     return;  // Nothing to do.
931   }
932 
933   if (ssl->d1->outgoing_messages[0].epoch <= 2) {
934     // DTLS 1.3 handshake messages (epoch 2 and below) implicitly ACK the
935     // previous flight, so there is no need to ACK previous records. This
936     // clears the ACK buffer slightly earlier than the specification suggests.
937     // See the discussion in
938     // https://mailarchive.ietf.org/arch/msg/tls/kjJnquJOVaWxu5hUCmNzB35eqY0/
939     ssl->d1->records_to_ack.Clear();
940     ssl->d1->ack_timer.Stop();
941     ssl->d1->sending_ack = false;
942   }
943 
944   ssl->d1->outgoing_messages_complete = true;
945   ssl->d1->sending_flight = true;
946   // Stop retransmitting the previous flight. In DTLS 1.3, we'll have stopped
947   // the timer already, but DTLS 1.2 keeps it running until the next flight is
948   // ready.
949   dtls1_stop_timer(ssl);
950 }
951 
dtls1_schedule_ack(SSL * ssl)952 void dtls1_schedule_ack(SSL *ssl) {
953   ssl->d1->ack_timer.Stop();
954   ssl->d1->sending_ack = !ssl->d1->records_to_ack.empty();
955 }
956 
send_ack(SSL * ssl)957 static int send_ack(SSL *ssl) {
958   assert(ssl_protocol_version(ssl) >= TLS1_3_VERSION);
959 
960   // Ensure we don't send so many ACKs that we overflow the MTU. There is a
961   // 2-byte length prefix and each ACK is 16 bytes.
962   dtls1_update_mtu(ssl);
963   size_t max_plaintext =
964       dtls_seal_max_input_len(ssl, ssl->d1->write_epoch.epoch(), ssl->d1->mtu);
965   if (max_plaintext < 2 + 16) {
966     OPENSSL_PUT_ERROR(SSL, SSL_R_MTU_TOO_SMALL);  // No room for even one ACK.
967     return -1;
968   }
969   size_t num_acks =
970       std::min((max_plaintext - 2) / 16, ssl->d1->records_to_ack.size());
971 
972   // Assemble the ACK. RFC 9147 says to sort ACKs numerically. It is unclear if
973   // other implementations do this, but go ahead and sort for now. See
974   // https://mailarchive.ietf.org/arch/msg/tls/kjJnquJOVaWxu5hUCmNzB35eqY0/.
975   // Remove this if rfc9147bis removes this requirement.
976   InplaceVector<DTLSRecordNumber, DTLS_MAX_ACK_BUFFER> sorted;
977   for (size_t i = ssl->d1->records_to_ack.size() - num_acks;
978        i < ssl->d1->records_to_ack.size(); i++) {
979     sorted.PushBack(ssl->d1->records_to_ack[i]);
980   }
981   std::sort(sorted.begin(), sorted.end());
982 
983   uint8_t buf[2 + 16 * DTLS_MAX_ACK_BUFFER];
984   CBB cbb, child;
985   CBB_init_fixed(&cbb, buf, sizeof(buf));
986   BSSL_CHECK(CBB_add_u16_length_prefixed(&cbb, &child));
987   for (const auto &number : sorted) {
988     BSSL_CHECK(CBB_add_u64(&child, number.epoch()));
989     BSSL_CHECK(CBB_add_u64(&child, number.sequence()));
990   }
991   BSSL_CHECK(CBB_flush(&cbb));
992 
993   // Encrypt it.
994   uint8_t record[DTLS1_3_RECORD_HEADER_WRITE_LENGTH + sizeof(buf) +
995                  1 /* record type */ + EVP_AEAD_MAX_OVERHEAD];
996   size_t record_len;
997   DTLSRecordNumber record_number;
998   if (!dtls_seal_record(ssl, &record_number, record, &record_len,
999                         sizeof(record), SSL3_RT_ACK, CBB_data(&cbb),
1000                         CBB_len(&cbb), ssl->d1->write_epoch.epoch())) {
1001     return -1;
1002   }
1003 
1004   ssl_do_msg_callback(ssl, /*is_write=*/1, SSL3_RT_ACK,
1005                       Span(CBB_data(&cbb), CBB_len(&cbb)));
1006 
1007   int bio_ret =
1008       BIO_write(ssl->wbio.get(), record, static_cast<int>(record_len));
1009   if (bio_ret <= 0) {
1010     ssl->s3->rwstate = SSL_ERROR_WANT_WRITE;
1011     return bio_ret;
1012   }
1013 
1014   ssl->d1->pending_flush = true;
1015   return 1;
1016 }
1017 
dtls1_flush(SSL * ssl)1018 int dtls1_flush(SSL *ssl) {
1019   // Send the pending ACK, if any.
1020   if (ssl->d1->sending_ack) {
1021     int ret = send_ack(ssl);
1022     if (ret <= 0) {
1023       return ret;
1024     }
1025     ssl->d1->sending_ack = false;
1026   }
1027 
1028   // Send the pending flight, if any.
1029   if (ssl->d1->sending_flight) {
1030     int ret = send_flight(ssl);
1031     if (ret <= 0) {
1032       return ret;
1033     }
1034 
1035     // Reset state for the next send.
1036     ssl->d1->outgoing_written = 0;
1037     ssl->d1->outgoing_offset = 0;
1038     ssl->d1->sending_flight = false;
1039 
1040     // Schedule the next retransmit timer. In DTLS 1.3, we retransmit all
1041     // flights until ACKed. In DTLS 1.2, the final Finished flight is never
1042     // ACKed, so we do not keep the timer running after the handshake.
1043     if (SSL_in_init(ssl) || ssl_protocol_version(ssl) >= TLS1_3_VERSION) {
1044       if (ssl->d1->num_timeouts == 0) {
1045         ssl->d1->timeout_duration_ms = ssl->initial_timeout_duration_ms;
1046       } else {
1047         ssl->d1->timeout_duration_ms =
1048             std::min(ssl->d1->timeout_duration_ms * 2, uint32_t{60000});
1049       }
1050 
1051       OPENSSL_timeval now = ssl_ctx_get_current_time(ssl->ctx.get());
1052       ssl->d1->retransmit_timer.StartMicroseconds(
1053           now, uint64_t{ssl->d1->timeout_duration_ms} * 1000);
1054     }
1055   }
1056 
1057   if (ssl->d1->pending_flush) {
1058     if (BIO_flush(ssl->wbio.get()) <= 0) {
1059       ssl->s3->rwstate = SSL_ERROR_WANT_WRITE;
1060       return -1;
1061     }
1062     ssl->d1->pending_flush = false;
1063   }
1064 
1065   return 1;
1066 }
1067 
dtls1_min_mtu(void)1068 unsigned int dtls1_min_mtu(void) { return kMinMTU; }
1069 
1070 BSSL_NAMESPACE_END
1071