1 // Copyright 1995-2016 The OpenSSL Project Authors. All Rights Reserved.
2 // Copyright (c) 2002, Oracle and/or its affiliates. All rights reserved.
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     https://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 
16 #include <openssl/ssl.h>
17 
18 #include <assert.h>
19 #include <limits.h>
20 #include <string.h>
21 
22 #include <tuple>
23 
24 #include <openssl/buf.h>
25 #include <openssl/bytestring.h>
26 #include <openssl/err.h>
27 #include <openssl/evp.h>
28 #include <openssl/md5.h>
29 #include <openssl/mem.h>
30 #include <openssl/nid.h>
31 #include <openssl/rand.h>
32 #include <openssl/sha2.h>
33 
34 #include "../crypto/internal.h"
35 #include "internal.h"
36 
37 
38 BSSL_NAMESPACE_BEGIN
39 
add_record_to_flight(SSL * ssl,uint8_t type,Span<const uint8_t> in)40 static bool add_record_to_flight(SSL *ssl, uint8_t type,
41                                  Span<const uint8_t> in) {
42   // The caller should have flushed |pending_hs_data| first.
43   assert(!ssl->s3->pending_hs_data);
44   // We'll never add a flight while in the process of writing it out.
45   assert(ssl->s3->pending_flight_offset == 0);
46 
47   if (ssl->s3->pending_flight == nullptr) {
48     ssl->s3->pending_flight.reset(BUF_MEM_new());
49     if (ssl->s3->pending_flight == nullptr) {
50       return false;
51     }
52   }
53 
54   size_t max_out = in.size() + SSL_max_seal_overhead(ssl);
55   size_t new_cap = ssl->s3->pending_flight->length + max_out;
56   if (max_out < in.size() || new_cap < max_out) {
57     OPENSSL_PUT_ERROR(SSL, ERR_R_OVERFLOW);
58     return false;
59   }
60 
61   size_t len;
62   if (!BUF_MEM_reserve(ssl->s3->pending_flight.get(), new_cap) ||
63       !tls_seal_record(ssl,
64                        (uint8_t *)ssl->s3->pending_flight->data +
65                            ssl->s3->pending_flight->length,
66                        &len, max_out, type, in.data(), in.size())) {
67     return false;
68   }
69 
70   ssl->s3->pending_flight->length += len;
71   return true;
72 }
73 
tls_init_message(const SSL * ssl,CBB * cbb,CBB * body,uint8_t type)74 bool tls_init_message(const SSL *ssl, CBB *cbb, CBB *body, uint8_t type) {
75   // Pick a modest size hint to save most of the |realloc| calls.
76   if (!CBB_init(cbb, 64) ||      //
77       !CBB_add_u8(cbb, type) ||  //
78       !CBB_add_u24_length_prefixed(cbb, body)) {
79     OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
80     CBB_cleanup(cbb);
81     return false;
82   }
83 
84   return true;
85 }
86 
tls_finish_message(const SSL * ssl,CBB * cbb,Array<uint8_t> * out_msg)87 bool tls_finish_message(const SSL *ssl, CBB *cbb, Array<uint8_t> *out_msg) {
88   return CBBFinishArray(cbb, out_msg);
89 }
90 
tls_add_message(SSL * ssl,Array<uint8_t> msg)91 bool tls_add_message(SSL *ssl, Array<uint8_t> msg) {
92   // Pack handshake data into the minimal number of records. This avoids
93   // unnecessary encryption overhead, notably in TLS 1.3 where we send several
94   // encrypted messages in a row. For now, we do not do this for the null
95   // cipher. The benefit is smaller and there is a risk of breaking buggy
96   // implementations.
97   //
98   // TODO(crbug.com/374991962): See if we can do this uniformly.
99   Span<const uint8_t> rest = msg;
100   if (!SSL_is_quic(ssl) && ssl->s3->aead_write_ctx->is_null_cipher()) {
101     while (!rest.empty()) {
102       Span<const uint8_t> chunk = rest.subspan(0, ssl->max_send_fragment);
103       rest = rest.subspan(chunk.size());
104 
105       if (!add_record_to_flight(ssl, SSL3_RT_HANDSHAKE, chunk)) {
106         return false;
107       }
108     }
109   } else {
110     while (!rest.empty()) {
111       // Flush if |pending_hs_data| is full.
112       if (ssl->s3->pending_hs_data &&
113           ssl->s3->pending_hs_data->length >= ssl->max_send_fragment &&
114           !tls_flush_pending_hs_data(ssl)) {
115         return false;
116       }
117 
118       size_t pending_len =
119           ssl->s3->pending_hs_data ? ssl->s3->pending_hs_data->length : 0;
120       Span<const uint8_t> chunk =
121           rest.subspan(0, ssl->max_send_fragment - pending_len);
122       assert(!chunk.empty());
123       rest = rest.subspan(chunk.size());
124 
125       if (!ssl->s3->pending_hs_data) {
126         ssl->s3->pending_hs_data.reset(BUF_MEM_new());
127       }
128       if (!ssl->s3->pending_hs_data ||
129           !BUF_MEM_append(ssl->s3->pending_hs_data.get(), chunk.data(),
130                           chunk.size())) {
131         return false;
132       }
133     }
134   }
135 
136   ssl_do_msg_callback(ssl, 1 /* write */, SSL3_RT_HANDSHAKE, msg);
137   // TODO(svaldez): Move this up a layer to fix abstraction for SSLTranscript on
138   // hs.
139   if (ssl->s3->hs != NULL &&  //
140       !ssl->s3->hs->transcript.Update(msg)) {
141     return false;
142   }
143   return true;
144 }
145 
tls_flush_pending_hs_data(SSL * ssl)146 bool tls_flush_pending_hs_data(SSL *ssl) {
147   if (!ssl->s3->pending_hs_data || ssl->s3->pending_hs_data->length == 0) {
148     return true;
149   }
150 
151   UniquePtr<BUF_MEM> pending_hs_data = std::move(ssl->s3->pending_hs_data);
152   auto data = Span(reinterpret_cast<const uint8_t *>(pending_hs_data->data),
153                    pending_hs_data->length);
154   if (SSL_is_quic(ssl)) {
155     if ((ssl->s3->hs == nullptr || !ssl->s3->hs->hints_requested) &&
156         !ssl->quic_method->add_handshake_data(ssl, ssl->s3->quic_write_level,
157                                               data.data(), data.size())) {
158       OPENSSL_PUT_ERROR(SSL, SSL_R_QUIC_INTERNAL_ERROR);
159       return false;
160     }
161     return true;
162   }
163 
164   return add_record_to_flight(ssl, SSL3_RT_HANDSHAKE, data);
165 }
166 
tls_add_change_cipher_spec(SSL * ssl)167 bool tls_add_change_cipher_spec(SSL *ssl) {
168   if (SSL_is_quic(ssl)) {
169     return true;
170   }
171 
172   static const uint8_t kChangeCipherSpec[1] = {SSL3_MT_CCS};
173   if (!tls_flush_pending_hs_data(ssl) ||
174       !add_record_to_flight(ssl, SSL3_RT_CHANGE_CIPHER_SPEC,
175                             kChangeCipherSpec)) {
176     return false;
177   }
178 
179   ssl_do_msg_callback(ssl, 1 /* write */, SSL3_RT_CHANGE_CIPHER_SPEC,
180                       kChangeCipherSpec);
181   return true;
182 }
183 
tls_flush(SSL * ssl)184 int tls_flush(SSL *ssl) {
185   if (!tls_flush_pending_hs_data(ssl)) {
186     return -1;
187   }
188 
189   if (SSL_is_quic(ssl)) {
190     if (ssl->s3->write_shutdown != ssl_shutdown_none) {
191       OPENSSL_PUT_ERROR(SSL, SSL_R_PROTOCOL_IS_SHUTDOWN);
192       return -1;
193     }
194 
195     if (!ssl->quic_method->flush_flight(ssl)) {
196       OPENSSL_PUT_ERROR(SSL, SSL_R_QUIC_INTERNAL_ERROR);
197       return -1;
198     }
199   }
200 
201   if (ssl->s3->pending_flight == nullptr) {
202     return 1;
203   }
204 
205   if (ssl->s3->write_shutdown != ssl_shutdown_none) {
206     OPENSSL_PUT_ERROR(SSL, SSL_R_PROTOCOL_IS_SHUTDOWN);
207     return -1;
208   }
209 
210   static_assert(INT_MAX <= 0xffffffff, "int is larger than 32 bits");
211   if (ssl->s3->pending_flight->length > INT_MAX) {
212     OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
213     return -1;
214   }
215 
216   // If there is pending data in the write buffer, it must be flushed out before
217   // any new data in pending_flight.
218   if (!ssl->s3->write_buffer.empty()) {
219     int ret = ssl_write_buffer_flush(ssl);
220     if (ret <= 0) {
221       ssl->s3->rwstate = SSL_ERROR_WANT_WRITE;
222       return ret;
223     }
224   }
225 
226   if (ssl->wbio == nullptr) {
227     OPENSSL_PUT_ERROR(SSL, SSL_R_BIO_NOT_SET);
228     return -1;
229   }
230 
231   // Write the pending flight.
232   while (ssl->s3->pending_flight_offset < ssl->s3->pending_flight->length) {
233     int ret = BIO_write(
234         ssl->wbio.get(),
235         ssl->s3->pending_flight->data + ssl->s3->pending_flight_offset,
236         ssl->s3->pending_flight->length - ssl->s3->pending_flight_offset);
237     if (ret <= 0) {
238       ssl->s3->rwstate = SSL_ERROR_WANT_WRITE;
239       return ret;
240     }
241 
242     ssl->s3->pending_flight_offset += ret;
243   }
244 
245   if (BIO_flush(ssl->wbio.get()) <= 0) {
246     ssl->s3->rwstate = SSL_ERROR_WANT_WRITE;
247     return -1;
248   }
249 
250   ssl->s3->pending_flight.reset();
251   ssl->s3->pending_flight_offset = 0;
252   return 1;
253 }
254 
read_v2_client_hello(SSL * ssl,size_t * out_consumed,Span<const uint8_t> in)255 static ssl_open_record_t read_v2_client_hello(SSL *ssl, size_t *out_consumed,
256                                               Span<const uint8_t> in) {
257   *out_consumed = 0;
258   assert(in.size() >= SSL3_RT_HEADER_LENGTH);
259   // Determine the length of the V2ClientHello.
260   size_t msg_length = ((in[0] & 0x7f) << 8) | in[1];
261   if (msg_length > (1024 * 4)) {
262     OPENSSL_PUT_ERROR(SSL, SSL_R_RECORD_TOO_LARGE);
263     return ssl_open_record_error;
264   }
265   if (msg_length < SSL3_RT_HEADER_LENGTH - 2) {
266     // Reject lengths that are too short early. We have already read
267     // |SSL3_RT_HEADER_LENGTH| bytes, so we should not attempt to process an
268     // (invalid) V2ClientHello which would be shorter than that.
269     OPENSSL_PUT_ERROR(SSL, SSL_R_RECORD_LENGTH_MISMATCH);
270     return ssl_open_record_error;
271   }
272 
273   // Ask for the remainder of the V2ClientHello.
274   if (in.size() < 2 + msg_length) {
275     *out_consumed = 2 + msg_length;
276     return ssl_open_record_partial;
277   }
278 
279   CBS v2_client_hello = CBS(in.subspan(2, msg_length));
280   // The V2ClientHello without the length is incorporated into the handshake
281   // hash. This is only ever called at the start of the handshake, so hs is
282   // guaranteed to be non-NULL.
283   if (!ssl->s3->hs->transcript.Update(v2_client_hello)) {
284     return ssl_open_record_error;
285   }
286 
287   ssl_do_msg_callback(ssl, 0 /* read */, 0 /* V2ClientHello */,
288                       v2_client_hello);
289 
290   uint8_t msg_type;
291   uint16_t version, cipher_spec_length, session_id_length, challenge_length;
292   CBS cipher_specs, session_id, challenge;
293   if (!CBS_get_u8(&v2_client_hello, &msg_type) ||
294       !CBS_get_u16(&v2_client_hello, &version) ||
295       !CBS_get_u16(&v2_client_hello, &cipher_spec_length) ||
296       !CBS_get_u16(&v2_client_hello, &session_id_length) ||
297       !CBS_get_u16(&v2_client_hello, &challenge_length) ||
298       !CBS_get_bytes(&v2_client_hello, &cipher_specs, cipher_spec_length) ||
299       !CBS_get_bytes(&v2_client_hello, &session_id, session_id_length) ||
300       !CBS_get_bytes(&v2_client_hello, &challenge, challenge_length) ||
301       CBS_len(&v2_client_hello) != 0) {
302     OPENSSL_PUT_ERROR(SSL, SSL_R_DECODE_ERROR);
303     return ssl_open_record_error;
304   }
305 
306   // msg_type has already been checked.
307   assert(msg_type == SSL2_MT_CLIENT_HELLO);
308 
309   // The client_random is the V2ClientHello challenge. Truncate or left-pad with
310   // zeros as needed.
311   size_t rand_len = CBS_len(&challenge);
312   if (rand_len > SSL3_RANDOM_SIZE) {
313     rand_len = SSL3_RANDOM_SIZE;
314   }
315   uint8_t random[SSL3_RANDOM_SIZE];
316   OPENSSL_memset(random, 0, SSL3_RANDOM_SIZE);
317   OPENSSL_memcpy(random + (SSL3_RANDOM_SIZE - rand_len), CBS_data(&challenge),
318                  rand_len);
319 
320   // Write out an equivalent TLS ClientHello directly to the handshake buffer.
321   size_t max_v3_client_hello = SSL3_HM_HEADER_LENGTH + 2 /* version */ +
322                                SSL3_RANDOM_SIZE + 1 /* session ID length */ +
323                                2 /* cipher list length */ +
324                                CBS_len(&cipher_specs) / 3 * 2 +
325                                1 /* compression length */ + 1 /* compression */;
326   ScopedCBB client_hello;
327   CBB hello_body, cipher_suites;
328   if (!ssl->s3->hs_buf) {
329     ssl->s3->hs_buf.reset(BUF_MEM_new());
330   }
331   if (!ssl->s3->hs_buf ||
332       !BUF_MEM_reserve(ssl->s3->hs_buf.get(), max_v3_client_hello) ||
333       !CBB_init_fixed(client_hello.get(), (uint8_t *)ssl->s3->hs_buf->data,
334                       ssl->s3->hs_buf->max) ||
335       !CBB_add_u8(client_hello.get(), SSL3_MT_CLIENT_HELLO) ||
336       !CBB_add_u24_length_prefixed(client_hello.get(), &hello_body) ||
337       !CBB_add_u16(&hello_body, version) ||
338       !CBB_add_bytes(&hello_body, random, SSL3_RANDOM_SIZE) ||
339       // No session id.
340       !CBB_add_u8(&hello_body, 0) ||
341       !CBB_add_u16_length_prefixed(&hello_body, &cipher_suites)) {
342     return ssl_open_record_error;
343   }
344 
345   // Copy the cipher suites.
346   while (CBS_len(&cipher_specs) > 0) {
347     uint32_t cipher_spec;
348     if (!CBS_get_u24(&cipher_specs, &cipher_spec)) {
349       OPENSSL_PUT_ERROR(SSL, SSL_R_DECODE_ERROR);
350       return ssl_open_record_error;
351     }
352 
353     // Skip SSLv2 ciphers.
354     if ((cipher_spec & 0xff0000) != 0) {
355       continue;
356     }
357     if (!CBB_add_u16(&cipher_suites, cipher_spec)) {
358       OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
359       return ssl_open_record_error;
360     }
361   }
362 
363   // Add the null compression scheme and finish.
364   if (!CBB_add_u8(&hello_body, 1) ||  //
365       !CBB_add_u8(&hello_body, 0) ||  //
366       !CBB_finish(client_hello.get(), NULL, &ssl->s3->hs_buf->length)) {
367     OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
368     return ssl_open_record_error;
369   }
370 
371   *out_consumed = 2 + msg_length;
372   ssl->s3->is_v2_hello = true;
373   return ssl_open_record_success;
374 }
375 
parse_message(const SSL * ssl,SSLMessage * out,size_t * out_bytes_needed)376 static bool parse_message(const SSL *ssl, SSLMessage *out,
377                           size_t *out_bytes_needed) {
378   if (!ssl->s3->hs_buf) {
379     *out_bytes_needed = 4;
380     return false;
381   }
382 
383   CBS cbs;
384   uint32_t len;
385   CBS_init(&cbs, reinterpret_cast<const uint8_t *>(ssl->s3->hs_buf->data),
386            ssl->s3->hs_buf->length);
387   if (!CBS_get_u8(&cbs, &out->type) ||  //
388       !CBS_get_u24(&cbs, &len)) {
389     *out_bytes_needed = 4;
390     return false;
391   }
392 
393   if (!CBS_get_bytes(&cbs, &out->body, len)) {
394     *out_bytes_needed = 4 + len;
395     return false;
396   }
397 
398   CBS_init(&out->raw, reinterpret_cast<const uint8_t *>(ssl->s3->hs_buf->data),
399            4 + len);
400   out->is_v2_hello = ssl->s3->is_v2_hello;
401   return true;
402 }
403 
tls_get_message(const SSL * ssl,SSLMessage * out)404 bool tls_get_message(const SSL *ssl, SSLMessage *out) {
405   size_t unused;
406   if (!parse_message(ssl, out, &unused)) {
407     return false;
408   }
409   if (!ssl->s3->has_message) {
410     if (!out->is_v2_hello) {
411       ssl_do_msg_callback(ssl, 0 /* read */, SSL3_RT_HANDSHAKE, out->raw);
412     }
413     ssl->s3->has_message = true;
414   }
415   return true;
416 }
417 
tls_can_accept_handshake_data(const SSL * ssl,uint8_t * out_alert)418 bool tls_can_accept_handshake_data(const SSL *ssl, uint8_t *out_alert) {
419   // If there is a complete message, the caller must have consumed it first.
420   SSLMessage msg;
421   size_t bytes_needed;
422   if (parse_message(ssl, &msg, &bytes_needed)) {
423     OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
424     *out_alert = SSL_AD_INTERNAL_ERROR;
425     return false;
426   }
427 
428   // Enforce the limit so the peer cannot force us to buffer 16MB.
429   if (bytes_needed > 4 + ssl_max_handshake_message_len(ssl)) {
430     OPENSSL_PUT_ERROR(SSL, SSL_R_EXCESSIVE_MESSAGE_SIZE);
431     *out_alert = SSL_AD_ILLEGAL_PARAMETER;
432     return false;
433   }
434 
435   return true;
436 }
437 
tls_has_unprocessed_handshake_data(const SSL * ssl)438 bool tls_has_unprocessed_handshake_data(const SSL *ssl) {
439   size_t msg_len = 0;
440   if (ssl->s3->has_message) {
441     SSLMessage msg;
442     size_t unused;
443     if (parse_message(ssl, &msg, &unused)) {
444       msg_len = CBS_len(&msg.raw);
445     }
446   }
447 
448   return ssl->s3->hs_buf && ssl->s3->hs_buf->length > msg_len;
449 }
450 
tls_append_handshake_data(SSL * ssl,Span<const uint8_t> data)451 bool tls_append_handshake_data(SSL *ssl, Span<const uint8_t> data) {
452   // Re-create the handshake buffer if needed.
453   if (!ssl->s3->hs_buf) {
454     ssl->s3->hs_buf.reset(BUF_MEM_new());
455   }
456   return ssl->s3->hs_buf &&
457          BUF_MEM_append(ssl->s3->hs_buf.get(), data.data(), data.size());
458 }
459 
tls_open_handshake(SSL * ssl,size_t * out_consumed,uint8_t * out_alert,Span<uint8_t> in)460 ssl_open_record_t tls_open_handshake(SSL *ssl, size_t *out_consumed,
461                                      uint8_t *out_alert, Span<uint8_t> in) {
462   *out_consumed = 0;
463   // Bypass the record layer for the first message to handle V2ClientHello.
464   if (ssl->server && !ssl->s3->v2_hello_done) {
465     // Ask for the first 5 bytes, the size of the TLS record header. This is
466     // sufficient to detect a V2ClientHello and ensures that we never read
467     // beyond the first record.
468     if (in.size() < SSL3_RT_HEADER_LENGTH) {
469       *out_consumed = SSL3_RT_HEADER_LENGTH;
470       return ssl_open_record_partial;
471     }
472 
473     // Some dedicated error codes for protocol mixups should the application
474     // wish to interpret them differently. (These do not overlap with
475     // ClientHello or V2ClientHello.)
476     auto str = bssl::BytesAsStringView(in);
477     if (str.substr(0, 4) == "GET " ||   //
478         str.substr(0, 5) == "POST " ||  //
479         str.substr(0, 5) == "HEAD " ||  //
480         str.substr(0, 4) == "PUT ") {
481       OPENSSL_PUT_ERROR(SSL, SSL_R_HTTP_REQUEST);
482       *out_alert = 0;
483       return ssl_open_record_error;
484     }
485     if (str.substr(0, 5) == "CONNE") {
486       OPENSSL_PUT_ERROR(SSL, SSL_R_HTTPS_PROXY_REQUEST);
487       *out_alert = 0;
488       return ssl_open_record_error;
489     }
490 
491     // Check for a V2ClientHello.
492     if ((in[0] & 0x80) != 0 && in[2] == SSL2_MT_CLIENT_HELLO &&
493         in[3] == SSL3_VERSION_MAJOR) {
494       auto ret = read_v2_client_hello(ssl, out_consumed, in);
495       if (ret == ssl_open_record_error) {
496         *out_alert = 0;
497       } else if (ret == ssl_open_record_success) {
498         ssl->s3->v2_hello_done = true;
499       }
500       return ret;
501     }
502 
503     ssl->s3->v2_hello_done = true;
504   }
505 
506   uint8_t type;
507   Span<uint8_t> body;
508   auto ret = tls_open_record(ssl, &type, &body, out_consumed, out_alert, in);
509   if (ret != ssl_open_record_success) {
510     return ret;
511   }
512 
513   if (type != SSL3_RT_HANDSHAKE) {
514     OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_RECORD);
515     *out_alert = SSL_AD_UNEXPECTED_MESSAGE;
516     return ssl_open_record_error;
517   }
518 
519   // Append the entire handshake record to the buffer.
520   if (!tls_append_handshake_data(ssl, body)) {
521     *out_alert = SSL_AD_INTERNAL_ERROR;
522     return ssl_open_record_error;
523   }
524 
525   return ssl_open_record_success;
526 }
527 
tls_next_message(SSL * ssl)528 void tls_next_message(SSL *ssl) {
529   SSLMessage msg;
530   if (!tls_get_message(ssl, &msg) ||  //
531       !ssl->s3->hs_buf ||             //
532       ssl->s3->hs_buf->length < CBS_len(&msg.raw)) {
533     assert(0);
534     return;
535   }
536 
537   OPENSSL_memmove(ssl->s3->hs_buf->data,
538                   ssl->s3->hs_buf->data + CBS_len(&msg.raw),
539                   ssl->s3->hs_buf->length - CBS_len(&msg.raw));
540   ssl->s3->hs_buf->length -= CBS_len(&msg.raw);
541   ssl->s3->is_v2_hello = false;
542   ssl->s3->has_message = false;
543 
544   // Post-handshake messages are rare, so release the buffer after every
545   // message. During the handshake, |on_handshake_complete| will release it.
546   if (!SSL_in_init(ssl) && ssl->s3->hs_buf->length == 0) {
547     ssl->s3->hs_buf.reset();
548   }
549 }
550 
551 namespace {
552 
553 class CipherScorer {
554  public:
555   using Score = int;
556   static constexpr Score kMinScore = 0;
557 
558   virtual ~CipherScorer() = default;
559 
560   virtual Score Evaluate(const SSL_CIPHER *cipher) const = 0;
561 };
562 
563 // AesHwCipherScorer scores cipher suites based on whether AES is supported in
564 // hardware.
565 class AesHwCipherScorer : public CipherScorer {
566  public:
AesHwCipherScorer(bool has_aes_hw)567   explicit AesHwCipherScorer(bool has_aes_hw) : aes_is_fine_(has_aes_hw) {}
568 
569   virtual ~AesHwCipherScorer() override = default;
570 
Evaluate(const SSL_CIPHER * a) const571   Score Evaluate(const SSL_CIPHER *a) const override {
572     return
573         // Something is always preferable to nothing.
574         1 +
575         // Either AES is fine, or else ChaCha20 is preferred.
576         ((aes_is_fine_ || a->algorithm_enc == SSL_CHACHA20POLY1305) ? 1 : 0);
577   }
578 
579  private:
580   const bool aes_is_fine_;
581 };
582 
583 // CNsaCipherScorer prefers AES-256-GCM over AES-128-GCM over anything else.
584 class CNsaCipherScorer : public CipherScorer {
585  public:
586   virtual ~CNsaCipherScorer() override = default;
587 
Evaluate(const SSL_CIPHER * a) const588   Score Evaluate(const SSL_CIPHER *a) const override {
589     if (a->id == TLS1_3_CK_AES_256_GCM_SHA384) {
590       return 3;
591     } else if (a->id == TLS1_3_CK_AES_128_GCM_SHA256) {
592       return 2;
593     }
594     return 1;
595   }
596 };
597 
598 }  // namespace
599 
ssl_tls13_cipher_meets_policy(uint16_t cipher_id,enum ssl_compliance_policy_t policy)600 bool ssl_tls13_cipher_meets_policy(uint16_t cipher_id,
601                                    enum ssl_compliance_policy_t policy) {
602   switch (policy) {
603     case ssl_compliance_policy_none:
604     case ssl_compliance_policy_cnsa_202407:
605       return true;
606 
607     case ssl_compliance_policy_fips_202205:
608       switch (cipher_id) {
609         case TLS1_3_CK_AES_128_GCM_SHA256 & 0xffff:
610         case TLS1_3_CK_AES_256_GCM_SHA384 & 0xffff:
611           return true;
612         case TLS1_3_CK_CHACHA20_POLY1305_SHA256 & 0xffff:
613           return false;
614         default:
615           assert(false);
616           return false;
617       }
618 
619     case ssl_compliance_policy_wpa3_192_202304:
620       switch (cipher_id) {
621         case TLS1_3_CK_AES_256_GCM_SHA384 & 0xffff:
622           return true;
623         case TLS1_3_CK_AES_128_GCM_SHA256 & 0xffff:
624         case TLS1_3_CK_CHACHA20_POLY1305_SHA256 & 0xffff:
625           return false;
626         default:
627           assert(false);
628           return false;
629       }
630   }
631 
632   assert(false);
633   return false;
634 }
635 
ssl_choose_tls13_cipher(CBS cipher_suites,bool has_aes_hw,uint16_t version,enum ssl_compliance_policy_t policy)636 const SSL_CIPHER *ssl_choose_tls13_cipher(CBS cipher_suites, bool has_aes_hw,
637                                           uint16_t version,
638                                           enum ssl_compliance_policy_t policy) {
639   if (CBS_len(&cipher_suites) % 2 != 0) {
640     return nullptr;
641   }
642 
643   const SSL_CIPHER *best = nullptr;
644   AesHwCipherScorer aes_hw_scorer(has_aes_hw);
645   CNsaCipherScorer cnsa_scorer;
646   CipherScorer *const scorer =
647       (policy == ssl_compliance_policy_cnsa_202407)
648           ? static_cast<CipherScorer *>(&cnsa_scorer)
649           : static_cast<CipherScorer *>(&aes_hw_scorer);
650   CipherScorer::Score best_score = CipherScorer::kMinScore;
651 
652   while (CBS_len(&cipher_suites) > 0) {
653     uint16_t cipher_suite;
654     if (!CBS_get_u16(&cipher_suites, &cipher_suite)) {
655       return nullptr;
656     }
657 
658     // Limit to TLS 1.3 ciphers we know about.
659     const SSL_CIPHER *candidate = SSL_get_cipher_by_value(cipher_suite);
660     if (candidate == nullptr ||
661         SSL_CIPHER_get_min_version(candidate) > version ||
662         SSL_CIPHER_get_max_version(candidate) < version) {
663       continue;
664     }
665 
666     if (!ssl_tls13_cipher_meets_policy(SSL_CIPHER_get_protocol_id(candidate),
667                                        policy)) {
668       continue;
669     }
670 
671     const CipherScorer::Score candidate_score = scorer->Evaluate(candidate);
672     // |candidate_score| must be larger to displace the current choice. That way
673     // the client's order controls between ciphers with an equal score.
674     if (candidate_score > best_score) {
675       best = candidate;
676       best_score = candidate_score;
677     }
678   }
679 
680   return best;
681 }
682 
683 BSSL_NAMESPACE_END
684