1 // Copyright 2017 The Fuchsia Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include <arpa/inet.h>
6 #include <netinet/in.h>
7 #include <sys/socket.h>
8 #include <sys/stat.h>
9 #include <sys/time.h>
10 
11 #include <fcntl.h>
12 #include <inttypes.h>
13 #include <stdbool.h>
14 #include <stdio.h>
15 #include <stdlib.h>
16 #include <string.h>
17 #include <time.h>
18 #include <unistd.h>
19 
20 #include <errno.h>
21 #include <stdint.h>
22 
23 #include <zircon/boot/netboot.h>
24 
25 #include "bootserver.h"
26 
27 #define DEFAULT_US_BETWEEN_PACKETS 20
28 
29 static uint32_t cookie = 1;
30 static const int MAX_READ_RETRIES = 10;
31 static const int MAX_SEND_RETRIES = 10000;
32 
io_rcv(int s,nbmsg * msg,nbmsg * ack)33 static int io_rcv(int s, nbmsg* msg, nbmsg* ack) {
34     for (int i = 0; i < MAX_READ_RETRIES; i++) {
35         bool retry_allowed = i + 1 < MAX_READ_RETRIES;
36 
37         ssize_t r = read(s, ack, 2048);
38         if (r < 0) {
39             if (retry_allowed && errno == EAGAIN) {
40                 continue;
41             }
42             fprintf(stderr, "\n%s: error: Socket read error %d\n", appname, errno);
43             return -1;
44         }
45         if ((size_t)r < sizeof(nbmsg)) {
46             fprintf(stderr, "\n%s: error: Read too short\n", appname);
47             return -1;
48         }
49 #ifdef DEBUG
50         fprintf(stdout, " < magic = %08x, cookie = %08x, cmd = %08x, arg = %08x\n",
51                 ack->magic, ack->cookie, ack->cmd, ack->arg);
52 #endif
53 
54         if (ack->magic != NB_MAGIC) {
55             fprintf(stderr, "\n%s: error: Bad magic\n", appname);
56             return 0;
57         }
58         if (msg) {
59             if (ack->cookie > msg->cookie) {
60                 fprintf(stderr, "\n%s: error: Bad cookie\n", appname);
61                 return 0;
62             }
63         }
64 
65         if (ack->cmd == NB_ACK || ack->cmd == NB_FILE_RECEIVED) {
66             return 0;
67         }
68 
69         switch (ack->cmd) {
70         case NB_ERROR:
71             fprintf(stderr, "\n%s: error: Generic error\n", appname);
72             break;
73         case NB_ERROR_BAD_CMD:
74             fprintf(stderr, "\n%s: error: Bad command\n", appname);
75             break;
76         case NB_ERROR_BAD_PARAM:
77             fprintf(stderr, "\n%s: error: Bad parameter\n", appname);
78             break;
79         case NB_ERROR_TOO_LARGE:
80             fprintf(stderr, "\n%s: error: File too large\n", appname);
81             break;
82         case NB_ERROR_BAD_FILE:
83             fprintf(stderr, "\n%s: error: Bad file\n", appname);
84             break;
85         default:
86             fprintf(stderr, "\n%s: error: Unknown command 0x%08X\n", appname, ack->cmd);
87         }
88         return -1;
89     }
90     fprintf(stderr, "\n%s: error: Unexpected code path\n", appname);
91     return -1;
92 }
93 
io_send(int s,nbmsg * msg,size_t len)94 static int io_send(int s, nbmsg* msg, size_t len) {
95     for (int i = 0; i < MAX_SEND_RETRIES; i++) {
96 #if defined(__APPLE__)
97         bool retry_allowed = i + 1 < MAX_SEND_RETRIES;
98 #endif
99 
100         int r = write(s, msg, len);
101         if (r < 0) {
102 #if defined(__APPLE__)
103             if (retry_allowed && errno == ENOBUFS) {
104                 // On Darwin we manage to overflow the ethernet driver, so retry
105                 struct timespec reqtime;
106                 reqtime.tv_sec = 0;
107                 reqtime.tv_nsec = 50 * 1000;
108                 nanosleep(&reqtime, NULL);
109                 continue;
110             }
111 #endif
112             fprintf(stderr, "\n%s: error: Socket write error %d\n", appname, errno);
113             return -1;
114         }
115         return 0;
116     }
117     fprintf(stderr, "\n%s: error: Unexpected code path\n", appname);
118     return -1;
119 }
120 
io(int s,nbmsg * msg,size_t len,nbmsg * ack,bool wait_reply)121 static int io(int s, nbmsg* msg, size_t len, nbmsg* ack, bool wait_reply) {
122     int r, n;
123     struct timeval tv;
124     fd_set reads, writes;
125     fd_set* ws = NULL;
126     fd_set* rs = NULL;
127 
128     ack->cookie = 0;
129     ack->cmd = 0;
130     ack->arg = 0;
131 
132     FD_ZERO(&reads);
133     if (!wait_reply) {
134         FD_SET(s, &reads);
135         rs = &reads;
136     }
137 
138     FD_ZERO(&writes);
139     if (msg && len > 0) {
140         msg->magic = NB_MAGIC;
141         msg->cookie = cookie++;
142 
143         FD_SET(s, &writes);
144         ws = &writes;
145     }
146 
147     if (rs || ws) {
148         n = s + 1;
149         tv.tv_sec = 10;
150         tv.tv_usec = 500000;
151         int rv = select(n, rs, ws, NULL, &tv);
152         if (rv == -1) {
153             fprintf(stderr, "\n%s: error: Select failed %d\n", appname, errno);
154             return -1;
155         } else if (rv == 0) {
156             // Timed-out
157             fprintf(stderr, "\n%s: error: Select timed out\n", appname);
158             return -1;
159         } else {
160             r = 0;
161             if (FD_ISSET(s, &reads)) {
162                 r = io_rcv(s, msg, ack);
163             }
164 
165             // If we got an ack, don't bother sending anything - go handle the ack first
166             if (!r && FD_ISSET(s, &writes) && (ack->cookie == 0 || ack->cmd != NB_ACK)) {
167                 r = io_send(s, msg, len);
168             }
169 
170             if (r || !wait_reply) {
171                 return r;
172             }
173         }
174     } else if (!wait_reply) { // no-op
175         return 0;
176     }
177 
178     if (wait_reply) {
179         return io_rcv(s, msg, ack);
180     }
181     fprintf(stderr, "\n%s: error: Select triggered without events\n", appname);
182     return -1;
183 }
184 
185 typedef struct {
186     FILE* fp;
187     const char* data;
188     size_t datalen;
189     const char* ptr;
190     size_t avail;
191 } xferdata;
192 
xread(xferdata * xd,void * data,size_t len)193 static ssize_t xread(xferdata* xd, void* data, size_t len) {
194     if (xd->fp == NULL) {
195         if (len > xd->avail) {
196             len = xd->avail;
197         }
198         memcpy(data, xd->ptr, len);
199         xd->avail -= len;
200         xd->ptr += len;
201         return len;
202     } else {
203         ssize_t r = fread(data, 1, len, xd->fp);
204         if (r == 0) {
205             return ferror(xd->fp) ? -1 : 0;
206         }
207         return r;
208     }
209 }
210 
xseek(xferdata * xd,size_t off)211 static int xseek(xferdata* xd, size_t off) {
212     if (xd->fp == NULL) {
213         if (off > xd->datalen) {
214             return -1;
215         }
216         xd->ptr = xd->data + off;
217         xd->avail = xd->datalen - off;
218         return 0;
219     } else {
220         return fseek(xd->fp, off, SEEK_SET);
221     }
222 }
223 
224 // UDP6_MAX_PAYLOAD (ETH_MTU - ETH_HDR_LEN - IP6_HDR_LEN - UDP_HDR_LEN)
225 //      1452           1514   -     14      -     40      -    8
226 // nbfile is PAYLOAD_SIZE + 2 * sizeof(size_t)
227 
228 // Some EFI network stacks have problems with larger packets
229 // 1280 is friendlier
230 #define PAYLOAD_SIZE 1280
231 
netboot_xfer(struct sockaddr_in6 * addr,const char * fn,const char * name)232 int netboot_xfer(struct sockaddr_in6* addr, const char* fn, const char* name) {
233     xferdata xd;
234     char msgbuf[2048];
235     char ackbuf[2048];
236     char tmp[INET6_ADDRSTRLEN];
237     struct timeval tv;
238     nbmsg* msg = (void*)msgbuf;
239     nbmsg* ack = (void*)ackbuf;
240     int s;
241     int status = -1;
242     size_t current_pos = 0;
243     size_t sz = 0;
244 
245     if (!strcmp(fn, "(cmdline)")) {
246         xd.fp = NULL;
247         xd.data = name;
248         xd.datalen = strlen(name) + 1;
249         xd.ptr = xd.data;
250         xd.avail = xd.datalen;
251         name = NB_CMDLINE_FILENAME;
252         sz = xd.datalen;
253     } else {
254         if ((xd.fp = fopen(fn, "rb")) == NULL) {
255             fprintf(stderr, "%s: error: Could not open file %s\n", appname, fn);
256             return -1;
257         }
258         if (fseek(xd.fp, 0L, SEEK_END)) {
259             fprintf(stderr, "%s: error: Could not determine size of %s\n", appname, fn);
260         } else if ((sz = ftell(xd.fp)) < 0) {
261             fprintf(stderr, "%s: error: Could not determine size of %s\n", appname, fn);
262             sz = 0;
263         } else if (fseek(xd.fp, 0L, SEEK_SET)) {
264             fprintf(stderr, "%s: error: Failed to rewind %s\n", appname, fn);
265             return -1;
266         }
267     }
268 
269     if (sz > 0) {
270         initialize_status(xd.data, sz);
271     }
272 
273     if ((s = socket(AF_INET6, SOCK_DGRAM, IPPROTO_UDP)) < 0) {
274         fprintf(stderr, "%s: error: Cannot create socket %d\n", appname, errno);
275         goto done;
276     }
277     tv.tv_sec = 0;
278     tv.tv_usec = 250 * 1000;
279     setsockopt(s, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv));
280     if (connect(s, (void*)addr, sizeof(*addr)) < 0) {
281         fprintf(stderr, "%s: error: Cannot connect to [%s]%d\n", appname,
282                 inet_ntop(AF_INET6, &addr->sin6_addr, tmp, sizeof(tmp)),
283                 ntohs(addr->sin6_port));
284         goto done;
285     }
286 
287     msg->cmd = NB_SEND_FILE;
288     msg->arg = sz;
289     strcpy((void*)msg->data, name);
290     if (io(s, msg, sizeof(nbmsg) + strlen(name) + 1, ack, true)) {
291         fprintf(stderr, "%s: error: Failed to start transfer\n", appname);
292         goto done;
293     }
294 
295     msg->cmd = NB_DATA;
296     msg->arg = 0;
297 
298     bool completed = false;
299     do {
300         struct timeval packet_start_time;
301         gettimeofday(&packet_start_time, NULL);
302 
303         ssize_t r = xread(&xd, msg->data, PAYLOAD_SIZE);
304         if (r < 0) {
305             fprintf(stderr, "\n%s: error: Reading '%s'\n", appname, fn);
306             goto done;
307         }
308 
309         update_status(msg->arg);
310 
311         if (r == 0) {
312             fprintf(stderr, "\n%s: Reached end of file, waiting for confirmation.\n", appname);
313             // Do not send anything, but keep waiting on incoming messages
314             if (io(s, NULL, 0, ack, true)) {
315                 goto done;
316             }
317         } else {
318             if (current_pos + (size_t)r >= sz) {
319                 msg->cmd = NB_LAST_DATA;
320             } else {
321                 msg->cmd = NB_DATA;
322             }
323 
324             if (io(s, msg, sizeof(nbmsg) + r, ack, false)) {
325                 goto done;
326             }
327 
328             // Some UEFI netstacks can lose back-to-back packets at max speed
329             // so throttle output.
330             // At 1280 bytes per packet, we should at least have 10 microseconds
331             // between packets, to be safe using 20 microseconds here.
332             // 1280 bytes * (1,000,000/10) seconds = 128,000,000 bytes/seconds = 122MB/s = 976Mb/s
333             // We wait as a busy wait as the context switching a sleep can cause
334             // will often degrade performance significantly.
335             int64_t us_since_last_packet;
336             do {
337                 struct timeval now;
338                 gettimeofday(&now, NULL);
339                 us_since_last_packet = (int64_t)(now.tv_sec - packet_start_time.tv_sec) * 1000000 +
340                                        ((int64_t)now.tv_usec - (int64_t)packet_start_time.tv_usec);
341             } while (us_since_last_packet < us_between_packets);
342         }
343 
344         // ACKs really are NACKs
345         if (ack->cookie > 0 && ack->cmd == NB_ACK) {
346             // ACKs tend to be generated in groups, since a dropped packet will cause ACKs for all
347             // outstanding packets. Therefore briefly sleep when we receive an ACK with a different
348             // position, to let things settle and prevent ourselves from fighting subsequent acks.
349             if (ack->arg != current_pos) {
350                 fprintf(stderr, "\n%s: need to reset to %d from %zu\n",
351                         appname, ack->arg, current_pos);
352                 current_pos = ack->arg;
353 
354                 tv.tv_usec = 100000;
355                 select(0, NULL, NULL, NULL, &tv);
356             }
357             if (xseek(&xd, current_pos)) {
358                 fprintf(stderr, "\n%s: error: Failed to rewind '%s' to %zu\n",
359                         appname, fn, current_pos);
360                 goto done;
361             }
362         } else if (ack->cmd == NB_FILE_RECEIVED) {
363             current_pos += r;
364             completed = true;
365         } else {
366             current_pos += r;
367         }
368 
369         msg->arg = current_pos;
370     } while (!completed);
371 
372     status = 0;
373     update_status(msg->arg);
374 done:
375     if (s >= 0) {
376         close(s);
377     }
378     if (xd.fp != NULL) {
379         fclose(xd.fp);
380     }
381     return status;
382 }
383