1 // Copyright 2014 The BoringSSL Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "packeted_bio.h"
16 
17 #include <assert.h>
18 #include <inttypes.h>
19 #include <limits.h>
20 #include <stdio.h>
21 #include <string.h>
22 
23 #include <functional>
24 #include <utility>
25 #include <vector>
26 
27 #include <openssl/bio.h>
28 #include <openssl/mem.h>
29 
30 #include "../../crypto/internal.h"
31 
32 
33 namespace {
34 
35 constexpr uint8_t kOpcodePacket = 'P';
36 constexpr uint8_t kOpcodeTimeout = 'T';
37 constexpr uint8_t kOpcodeTimeoutAck = 't';
38 constexpr uint8_t kOpcodeMTU = 'M';
39 constexpr uint8_t kOpcodeExpectNextTimeout = 'E';
40 
41 struct PacketedBio {
PacketedBio__anon4fd3ae670111::PacketedBio42   PacketedBio(timeval *clock_arg,
43               std::function<bool(timeval *)> get_timeout_arg,
44               std::function<bool(uint32_t)> set_mtu_arg)
45       : clock(clock_arg),
46         get_timeout(std::move(get_timeout_arg)),
47         set_mtu(std::move(set_mtu_arg)) {
48     OPENSSL_memset(&timeout, 0, sizeof(timeout));
49   }
50 
HasTimeout__anon4fd3ae670111::PacketedBio51   bool HasTimeout() const {
52     return timeout.tv_sec != 0 || timeout.tv_usec != 0;
53   }
54 
55   timeval timeout;
56   timeval *clock;
57   std::function<bool(timeval *)> get_timeout;
58   std::function<bool(uint32_t)> set_mtu;
59 };
60 
PacketedBioMethodType()61 static int PacketedBioMethodType() {
62   static int type = [] {
63     int idx = BIO_get_new_index();
64     BSSL_CHECK(idx > 0);
65     return idx | BIO_TYPE_FILTER;
66   }();
67   return type;
68 }
69 
GetData(BIO * bio)70 PacketedBio *GetData(BIO *bio) {
71   if (BIO_method_type(bio) != PacketedBioMethodType()) {
72     return NULL;
73   }
74   return static_cast<PacketedBio *>(BIO_get_data(bio));
75 }
76 
77 // ReadAll reads |len| bytes from |bio| into |out|. It returns 1 on success and
78 // 0 or -1 on error.
ReadAll(BIO * bio,uint8_t * out,size_t len)79 static int ReadAll(BIO *bio, uint8_t *out, size_t len) {
80   while (len > 0) {
81     int chunk_len = INT_MAX;
82     if (len <= INT_MAX) {
83       chunk_len = (int)len;
84     }
85     int ret = BIO_read(bio, out, chunk_len);
86     if (ret <= 0) {
87       return ret;
88     }
89     out += ret;
90     len -= ret;
91   }
92   return 1;
93 }
94 
PacketedWrite(BIO * bio,const char * in,int inl)95 static int PacketedWrite(BIO *bio, const char *in, int inl) {
96   BIO *next = BIO_next(bio);
97   if (next == nullptr) {
98     return -1;
99   }
100 
101   BIO_clear_retry_flags(bio);
102 
103   // Write the header.
104   uint8_t header[5];
105   header[0] = kOpcodePacket;
106   header[1] = (inl >> 24) & 0xff;
107   header[2] = (inl >> 16) & 0xff;
108   header[3] = (inl >> 8) & 0xff;
109   header[4] = inl & 0xff;
110   int ret = BIO_write(next, header, sizeof(header));
111   if (ret <= 0) {
112     BIO_copy_next_retry(bio);
113     return ret;
114   }
115 
116   // Write the buffer.
117   ret = BIO_write(next, in, inl);
118   if (ret < 0 || (inl > 0 && ret == 0)) {
119     BIO_copy_next_retry(bio);
120     return ret;
121   }
122   assert(ret == inl);
123   return ret;
124 }
125 
PacketedRead(BIO * bio,char * out,int outl)126 static int PacketedRead(BIO *bio, char *out, int outl) {
127   PacketedBio *data = GetData(bio);
128   BIO *next = BIO_next(bio);
129   if (next == nullptr) {
130     return -1;
131   }
132 
133   BIO_clear_retry_flags(bio);
134 
135   for (;;) {
136     // Read the opcode.
137     uint8_t opcode;
138     int ret = ReadAll(next, &opcode, sizeof(opcode));
139     if (ret <= 0) {
140       BIO_copy_next_retry(bio);
141       return ret;
142     }
143 
144     if (opcode == kOpcodeTimeout) {
145       // The caller is required to advance any pending timeouts before
146       // continuing.
147       if (data->HasTimeout()) {
148         fprintf(stderr, "Unprocessed timeout!\n");
149         return -1;
150       }
151 
152       // Process the timeout.
153       uint8_t buf[8];
154       ret = ReadAll(next, buf, sizeof(buf));
155       if (ret <= 0) {
156         BIO_copy_next_retry(bio);
157         return ret;
158       }
159       uint64_t timeout = CRYPTO_load_u64_be(buf);
160       timeout /= 1000;  // Convert nanoseconds to microseconds.
161 
162       data->timeout.tv_usec = timeout % 1000000;
163       data->timeout.tv_sec = timeout / 1000000;
164 
165       // Send an ACK to the peer.
166       ret = BIO_write(next, &kOpcodeTimeoutAck, 1);
167       if (ret <= 0) {
168         return ret;
169       }
170       assert(ret == 1);
171 
172       // Signal to the caller to retry the read, after advancing the clock.
173       BIO_set_retry_read(bio);
174       return -1;
175     }
176 
177     if (opcode == kOpcodeMTU) {
178       uint8_t buf[4];
179       ret = ReadAll(next, buf, sizeof(buf));
180       if (ret <= 0) {
181         BIO_copy_next_retry(bio);
182         return ret;
183       }
184       uint32_t mtu = CRYPTO_load_u32_be(buf);
185       if (!data->set_mtu(mtu)) {
186         fprintf(stderr, "Error setting MTU\n");
187         return -1;
188       }
189       // Continue reading.
190       continue;
191     }
192 
193     if (opcode == kOpcodeExpectNextTimeout) {
194       uint8_t buf[8];
195       ret = ReadAll(next, buf, sizeof(buf));
196       if (ret <= 0) {
197         BIO_copy_next_retry(bio);
198         return ret;
199       }
200       uint64_t expected = CRYPTO_load_u64_be(buf);
201       timeval timeout;
202       bool has_timeout = data->get_timeout(&timeout);
203       if (expected == UINT64_MAX) {
204         if (has_timeout) {
205           fprintf(stderr,
206                   "Expected no timeout, but got %" PRIu64 ".%06" PRIu64 "s.\n",
207                   static_cast<uint64_t>(timeout.tv_sec),
208                   static_cast<uint64_t>(timeout.tv_usec));
209           return -1;
210         }
211       } else {
212         expected /= 1000;  // Convert nanoseconds to microseconds.
213         uint64_t expected_sec = expected / 1000000;
214         uint64_t expected_usec = expected % 1000000;
215         if (!has_timeout) {
216           fprintf(stderr,
217                   "Expected timeout of %" PRIu64 ".%06" PRIu64
218                   "s, but got none.\n",
219                   expected_sec, expected_usec);
220           return -1;
221         }
222         if (static_cast<uint64_t>(timeout.tv_sec) != expected_sec ||
223             static_cast<uint64_t>(timeout.tv_usec) != expected_usec) {
224           fprintf(stderr,
225                   "Expected timeout of %" PRIu64 ".%06" PRIu64
226                   "s, but got %" PRIu64 ".%06" PRIu64 "s.\n",
227                   expected_sec, expected_usec,
228                   static_cast<uint64_t>(timeout.tv_sec),
229                   static_cast<uint64_t>(timeout.tv_usec));
230           return -1;
231         }
232       }
233       // Continue reading.
234       continue;
235     }
236 
237     if (opcode != kOpcodePacket) {
238       fprintf(stderr, "Unknown opcode, %u\n", opcode);
239       return -1;
240     }
241 
242     // Read the length prefix.
243     uint8_t len_bytes[4];
244     ret = ReadAll(next, len_bytes, sizeof(len_bytes));
245     if (ret <= 0) {
246       BIO_copy_next_retry(bio);
247       return ret;
248     }
249 
250     std::vector<uint8_t> buf(CRYPTO_load_u32_be(len_bytes), 0);
251     ret = ReadAll(next, buf.data(), buf.size());
252     if (ret <= 0) {
253       fprintf(stderr, "Packeted BIO was truncated\n");
254       return -1;
255     }
256 
257     if (static_cast<size_t>(outl) > buf.size()) {
258       outl = static_cast<int>(buf.size());
259     }
260     OPENSSL_memcpy(out, buf.data(), outl);
261     return outl;
262   }
263 }
264 
PacketedCtrl(BIO * bio,int cmd,long num,void * ptr)265 static long PacketedCtrl(BIO *bio, int cmd, long num, void *ptr) {
266   BIO *next = BIO_next(bio);
267   if (next == nullptr) {
268     return 0;
269   }
270 
271   BIO_clear_retry_flags(bio);
272   long ret = BIO_ctrl(next, cmd, num, ptr);
273   BIO_copy_next_retry(bio);
274   return ret;
275 }
276 
PacketedNew(BIO * bio)277 static int PacketedNew(BIO *bio) {
278   BIO_set_init(bio, 1);
279   return 1;
280 }
281 
PacketedFree(BIO * bio)282 static int PacketedFree(BIO *bio) {
283   if (bio == nullptr) {
284     return 0;
285   }
286 
287   delete GetData(bio);
288   return 1;
289 }
290 
PacketedCallbackCtrl(BIO * bio,int cmd,BIO_info_cb * fp)291 static long PacketedCallbackCtrl(BIO *bio, int cmd, BIO_info_cb *fp) {
292   BIO *next = BIO_next(bio);
293   if (next == nullptr) {
294     return 0;
295   }
296   return BIO_callback_ctrl(next, cmd, fp);
297 }
298 
PacketedBioMethod()299 static const BIO_METHOD *PacketedBioMethod() {
300   static const BIO_METHOD *method = [] {
301     BIO_METHOD *ret = BIO_meth_new(PacketedBioMethodType(), "packeted bio");
302     BSSL_CHECK(ret);
303     BSSL_CHECK(BIO_meth_set_write(ret, PacketedWrite));
304     BSSL_CHECK(BIO_meth_set_read(ret, PacketedRead));
305     BSSL_CHECK(BIO_meth_set_ctrl(ret, PacketedCtrl));
306     BSSL_CHECK(BIO_meth_set_create(ret, PacketedNew));
307     BSSL_CHECK(BIO_meth_set_destroy(ret, PacketedFree));
308     BSSL_CHECK(BIO_meth_set_callback_ctrl(ret, PacketedCallbackCtrl));
309     return ret;
310   }();
311   return method;
312 }
313 
314 }  // namespace
315 
PacketedBioCreate(timeval * clock,std::function<bool (timeval *)> get_timeout,std::function<bool (uint32_t)> set_mtu)316 bssl::UniquePtr<BIO> PacketedBioCreate(
317     timeval *clock, std::function<bool(timeval *)> get_timeout,
318     std::function<bool(uint32_t)> set_mtu) {
319   bssl::UniquePtr<BIO> bio(BIO_new(PacketedBioMethod()));
320   if (!bio) {
321     return nullptr;
322   }
323   BIO_set_data(bio.get(), new PacketedBio(clock, std::move(get_timeout),
324                                           std::move(set_mtu)));
325   return bio;
326 }
327 
PacketedBioAdvanceClock(BIO * bio)328 bool PacketedBioAdvanceClock(BIO *bio) {
329   PacketedBio *data = GetData(bio);
330   if (data == nullptr) {
331     return false;
332   }
333 
334   if (!data->HasTimeout()) {
335     return false;
336   }
337 
338   data->clock->tv_usec += data->timeout.tv_usec;
339   data->clock->tv_sec += data->clock->tv_usec / 1000000;
340   data->clock->tv_usec %= 1000000;
341   data->clock->tv_sec += data->timeout.tv_sec;
342   OPENSSL_memset(&data->timeout, 0, sizeof(data->timeout));
343   return true;
344 }
345 
PacketedBioGetClock(BIO * bio)346 timeval *PacketedBioGetClock(BIO *bio) {
347   PacketedBio *data = GetData(bio);
348   if (data == nullptr) {
349     return nullptr;
350   }
351   return data->clock;
352 }
353