1 // Copyright 2016 The Fuchsia Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #define _POSIX_C_SOURCE 200809L
6 #define _DARWIN_C_SOURCE
7 #define _GNU_SOURCE
8 
9 #include "netprotocol.h"
10 
11 #include <zircon/boot/netboot.h>
12 
13 #include <arpa/inet.h>
14 #include <ifaddrs.h>
15 #include <netinet/in.h>
16 #include <sys/socket.h>
17 #include <sys/time.h>
18 
19 #include <fcntl.h>
20 #include <getopt.h>
21 #include <poll.h>
22 #include <stdio.h>
23 #include <stdlib.h>
24 #include <string.h>
25 #include <unistd.h>
26 
27 #include <errno.h>
28 #include <stdint.h>
29 
30 uint16_t tftp_block_size = TFTP_DEFAULT_BLOCK_SZ;
31 uint16_t tftp_window_size = TFTP_DEFAULT_WINDOW_SZ;
32 
33 static uint32_t cookie = 0x12345678;
34 static int netboot_timeout = 250;
35 static bool netboot_wait = true;
36 
netboot_timeout_init(int msec)37 static struct timeval netboot_timeout_init(int msec) {
38     struct timeval timeout_tv;
39     timeout_tv.tv_sec = msec / 1000;
40     timeout_tv.tv_usec = (msec % 1000) * 1000;
41 
42     struct timeval end_tv;
43     gettimeofday(&end_tv, NULL);
44     timeradd(&end_tv, &timeout_tv, &end_tv);
45 
46     return end_tv;
47 }
48 
netboot_timeout_get_msec(const struct timeval * end_tv)49 static int netboot_timeout_get_msec(const struct timeval* end_tv) {
50     struct timeval wait_tv;
51     struct timeval now_tv;
52     gettimeofday(&now_tv, NULL);
53     timersub(end_tv, &now_tv, &wait_tv);
54     return wait_tv.tv_sec * 1000 + wait_tv.tv_usec / 1000;
55 }
56 
netboot_bind_to_cmd_port(int socket)57 static int netboot_bind_to_cmd_port(int socket) {
58     struct sockaddr_in6 addr;
59     memset(&addr, 0, sizeof(addr));
60     addr.sin6_family = AF_INET6;
61 
62     for (uint16_t port = NB_CMD_PORT_START; port <= NB_CMD_PORT_END; port++) {
63         addr.sin6_port = htons(port);
64         if (bind(socket, (void*)&addr, sizeof(addr)) == 0) {
65             return 0;
66         }
67     }
68     return -1;
69 }
70 
netboot_send_query(int socket,unsigned port,const char * ifname)71 static int netboot_send_query(int socket, unsigned port, const char* ifname) {
72     const char* hostname = "*";
73     size_t hostname_len = strlen(hostname) + 1;
74 
75     msg m;
76     m.hdr.magic = NB_MAGIC;
77     m.hdr.cookie = ++cookie;
78     m.hdr.cmd = NB_QUERY;
79     m.hdr.arg = 0;
80     memcpy(m.data, hostname, hostname_len);
81 
82     struct sockaddr_in6 addr;
83     memset(&addr, 0, sizeof(addr));
84     addr.sin6_family = AF_INET6;
85     addr.sin6_port = htons(port);
86     inet_pton(AF_INET6, "ff02::1", &addr.sin6_addr);
87 
88     struct ifaddrs* ifa;
89     if (getifaddrs(&ifa) < 0) {
90         fprintf(stderr, "error: cannot enumerate network interfaces\n");
91         return -1;
92     }
93 
94     for (; ifa != NULL; ifa = ifa->ifa_next) {
95         if (ifa->ifa_addr == NULL) {
96             continue;
97         }
98         if (ifa->ifa_addr->sa_family != AF_INET6) {
99             continue;
100         }
101         struct sockaddr_in6* in6 = (void*)ifa->ifa_addr;
102         if (in6->sin6_scope_id == 0) {
103             continue;
104         }
105         if (ifname && ifname[0] != 0 && strcmp(ifname, ifa->ifa_name))
106             continue;
107         // printf("tx %s (sid=%d)\n", ifa->ifa_name, in6->sin6_scope_id);
108         size_t sz = sizeof(nbmsg) + hostname_len;
109         addr.sin6_scope_id = in6->sin6_scope_id;
110 
111         ssize_t r = sendto(socket, &m, sz, 0,
112                            (struct sockaddr*)&addr, sizeof(addr));
113         if (r < 0) {
114             fprintf(stderr, "error: sendto: %s\n", strerror(errno));
115         } else if ((size_t)r != sz) {
116             fprintf(stderr, "error: sendto: short count %zu != %zu\n", r, sz);
117         }
118     }
119 
120     return 0;
121 }
122 
netboot_receive_query(int socket,on_device_cb callback,void * data)123 static bool netboot_receive_query(int socket, on_device_cb callback, void* data) {
124     struct sockaddr_in6 ra;
125     socklen_t rlen = sizeof(ra);
126     memset(&ra, 0, sizeof(ra));
127     msg m;
128     ssize_t r = recvfrom(socket, &m, sizeof(m), 0, (void*)&ra, &rlen);
129     if (r < 0) {
130         fprintf(stderr, "error: recvfrom: %s\n", strerror(errno));
131     } else if ((size_t)r > sizeof(nbmsg)) {
132         r -= sizeof(nbmsg);
133         m.data[r] = 0;
134         if ((m.hdr.magic == NB_MAGIC) &&
135             (m.hdr.cookie == cookie) &&
136             (m.hdr.cmd == NB_ACK)) {
137             char tmp[INET6_ADDRSTRLEN];
138             if (inet_ntop(AF_INET6, &ra.sin6_addr, tmp, sizeof(tmp)) == NULL) {
139                 strcpy(tmp, "???");
140             }
141             // printf("found %s at %s/%d\n", (char*)m.data, tmp, ra.sin6_scope_id);
142             if (strncmp("::", tmp, 2)) {
143                 device_info_t info;
144                 strncpy(info.nodename, (char*)m.data, sizeof(info.nodename));
145                 strncpy(info.inet6_addr_s, tmp, INET6_ADDRSTRLEN);
146                 memcpy(&info.inet6_addr, &ra, sizeof(ra));
147                 info.state = DEVICE;
148                 return callback(&info, data);
149             }
150         }
151     }
152     return false;
153 }
154 
155 static struct option default_opts[] = {
156     {"help", no_argument, NULL, 'h'},
157     {"timeout", required_argument, NULL, 't'},
158     {"nowait", no_argument, NULL, 'n'},
159     {"block-size", required_argument, NULL, 'b'},
160     {"window-size", required_argument, NULL, 'w'},
161     {NULL, 0, NULL, 0},
162 };
163 
164 static const struct option netboot_zero_opt = {NULL, 0, NULL, 0};
165 
netboot_count_opts(const struct option * opts)166 static size_t netboot_count_opts(const struct option* opts) {
167     if (!opts) {
168         return 0;
169     }
170     size_t count = 0;
171     while (memcmp(&opts[count], &netboot_zero_opt, sizeof(netboot_zero_opt))) {
172         count++;
173     }
174     return count;
175 }
176 
netboot_copy_opts(struct option * dst_opts,const struct option * src_opts)177 static void netboot_copy_opts(struct option* dst_opts, const struct option* src_opts) {
178     if (!src_opts) {
179         return;
180     }
181     size_t i;
182     for (i = 0; memcmp(&src_opts[i], &netboot_zero_opt, sizeof(netboot_zero_opt)); i++) {
183         dst_opts[i] = src_opts[i];
184     }
185 }
186 
netboot_handle_custom_getopt(int argc,char * const * argv,const struct option * custom_opts,size_t num_custom_opts0,bool (* opt_callback)(int ch,int argc,char * const * argv))187 int netboot_handle_custom_getopt(int argc, char* const* argv,
188                                  const struct option* custom_opts,
189                                  size_t num_custom_opts0,
190                                  bool (*opt_callback)(int ch, int argc, char* const* argv)) {
191     size_t num_default_opts = netboot_count_opts(default_opts);
192     size_t num_custom_opts = netboot_count_opts(custom_opts);
193 
194     struct option* combined_opts;
195     combined_opts = (struct option*)malloc(sizeof(struct option) *
196                                            (num_default_opts + num_custom_opts + 1));
197 
198     netboot_copy_opts(combined_opts, default_opts);
199     netboot_copy_opts(combined_opts + num_default_opts, custom_opts);
200     memset(&combined_opts[num_default_opts + num_custom_opts], 0x0,
201            sizeof(struct option));
202 
203     int retval = -1;
204     int ch;
205     while ((ch = getopt_long_only(argc, argv, "t:", combined_opts, NULL)) != -1) {
206         switch (ch) {
207         case 't':
208             netboot_timeout = atoi(optarg);
209             break;
210         case 'n':
211             netboot_wait = false;
212             break;
213         case 'b':
214             tftp_block_size = atoi(optarg);
215             break;
216         case 'w':
217             tftp_window_size = atoi(optarg);
218             break;
219         default:
220             if (opt_callback && opt_callback(ch, argc, argv)) {
221                 break;
222             } else {
223                 goto err;
224             }
225         }
226     }
227     retval = optind;
228 err:
229     free(combined_opts);
230     return retval;
231 }
232 
netboot_handle_getopt(int argc,char * const * argv)233 int netboot_handle_getopt(int argc, char* const* argv) {
234     return netboot_handle_custom_getopt(argc, argv, NULL, 0, NULL);
235 }
236 
netboot_usage(bool show_tftp_opts)237 void netboot_usage(bool show_tftp_opts) {
238     fprintf(stderr, "options:\n");
239     fprintf(stderr, "    --help              Print this message.\n");
240     fprintf(stderr, "    --timeout=<msec>    Set discovery timeout to <msec>.\n");
241     fprintf(stderr, "    --nowait            Do not wait for first packet before timing out.\n");
242     if (show_tftp_opts) {
243         fprintf(stderr, "    --block-size=<sz>   Set tftp block size (default=%d).\n",
244                 TFTP_DEFAULT_BLOCK_SZ);
245         fprintf(stderr, "    --window-size=<sz>  Set tftp window size (default=%d).\n",
246                 TFTP_DEFAULT_WINDOW_SZ);
247     }
248 }
249 
netboot_discover(unsigned port,const char * ifname,on_device_cb callback,void * data)250 int netboot_discover(unsigned port, const char* ifname, on_device_cb callback, void* data) {
251     if (!callback) {
252         errno = EINVAL;
253         return -1;
254     }
255 
256     int s;
257     if ((s = socket(AF_INET6, SOCK_DGRAM, IPPROTO_UDP)) < 0) {
258         fprintf(stderr, "error: cannot create socket: %s\n", strerror(errno));
259         return -1;
260     }
261 
262     if (netboot_bind_to_cmd_port(s) < 0) {
263         fprintf(stderr, "error: cannot bind to command port: %s\n", strerror(errno));
264         close(s);
265         return -1;
266     }
267 
268     if (netboot_send_query(s, port, ifname) < 0) {
269         fprintf(stderr, "error: failed to send netboot query\n");
270         close(s);
271         return -1;
272     }
273 
274     struct pollfd fds;
275     fds.fd = s;
276     fds.events = POLLIN;
277     bool received_packets = false;
278     bool first_wait = netboot_wait;
279 
280     struct timeval end_tv = netboot_timeout_init(first_wait ? 3600000 : netboot_timeout);
281     for (;;) {
282         int wait_ms = netboot_timeout_get_msec(&end_tv);
283         if (wait_ms < 0) {
284             // Expired.
285             break;
286         }
287 
288         int r = poll(&fds, 1, wait_ms);
289         if (r > 0 && (fds.revents & POLLIN)) {
290             received_packets = true;
291             if (!netboot_receive_query(s, callback, data)) {
292                 break;
293             }
294         } else if (r < 0 && errno != EAGAIN && errno != EINTR) {
295             fprintf(stderr, "poll returned error: %s\n", strerror(errno));
296             return -1;
297         }
298         if (first_wait) {
299             end_tv = netboot_timeout_init(netboot_timeout);
300             first_wait = 0;
301         }
302     }
303 
304     close(s);
305     if (received_packets) {
306         return 0;
307     } else {
308         errno = ETIMEDOUT;
309         return -1;
310     }
311 }
312 
313 typedef struct netboot_open_cookie {
314     struct sockaddr_in6 addr;
315     const char* hostname;
316     uint32_t index;
317 } netboot_open_cookie_t;
318 
netboot_open_callback(device_info_t * device,void * data)319 static bool netboot_open_callback(device_info_t* device, void* data) {
320     netboot_open_cookie_t* cookie = data;
321     cookie->index++;
322     if (strcmp(cookie->hostname, "*") && strcmp(cookie->hostname, device->nodename)) {
323         return true;
324     }
325     memcpy(&cookie->addr, &device->inet6_addr, sizeof(device->inet6_addr));
326     return false;
327 }
328 
netboot_open(const char * hostname,const char * ifname,struct sockaddr_in6 * addr,bool make_connection)329 int netboot_open(const char* hostname, const char* ifname,
330                  struct sockaddr_in6* addr, bool make_connection) {
331     if ((hostname == NULL) || (hostname[0] == 0)) {
332         char* envname = getenv("ZIRCON_NODENAME");
333         hostname = envname && envname[0] != 0 ? envname : "*";
334     }
335     size_t hostname_len = strlen(hostname) + 1;
336     if (hostname_len > MAXSIZE) {
337         errno = EINVAL;
338         return -1;
339     }
340 
341     netboot_open_cookie_t cookie;
342     socklen_t rlen = sizeof(cookie.addr);
343     memset(&(cookie.addr), 0, sizeof(cookie.addr));
344     cookie.index = 0;
345     cookie.hostname = hostname;
346     if (netboot_discover(NB_SERVER_PORT, ifname, netboot_open_callback, &cookie) < 0) {
347         return -1;
348     }
349     // Device not found
350     if (cookie.index == 0) {
351         errno = EINVAL;
352         return -1;
353     }
354 
355     int s;
356     if ((s = socket(AF_INET6, SOCK_DGRAM, IPPROTO_UDP)) < 0) {
357         fprintf(stderr, "error: cannot create socket: %s\n", strerror(errno));
358         return -1;
359     }
360 
361     if (netboot_bind_to_cmd_port(s) < 0) {
362         fprintf(stderr, "cannot bind to command port: %s\n", strerror(errno));
363         return -1;
364     }
365 
366     struct timeval tv;
367     tv.tv_sec = 0;
368     tv.tv_usec = 250 * 1000;
369     setsockopt(s, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv));
370 
371     if (addr) {
372         memcpy(addr, &cookie.addr, sizeof(cookie.addr));
373     }
374 
375     if (make_connection && connect(s, (void*)&cookie.addr, rlen) < 0) {
376         fprintf(stderr, "error: cannot connect UDP port\n");
377         close(s);
378         return -1;
379     }
380     return s;
381 }
382 
383 // The netboot protocol ignores response packets that are invalid,
384 // retransmits requests if responses don't arrive in a timely
385 // fashion, and only returns an error upon eventual timeout or
386 // a specific (correctly formed) remote error packet.
netboot_txn(int s,msg * in,msg * out,int outlen)387 int netboot_txn(int s, msg* in, msg* out, int outlen) {
388     ssize_t r;
389 
390     out->hdr.magic = NB_MAGIC;
391     out->hdr.cookie = ++cookie;
392 
393     int retry = 5;
394 resend:
395     write(s, out, outlen);
396     for (;;) {
397         if ((r = recv(s, in, sizeof(*in), 0)) < 0) {
398             if ((errno == EAGAIN) || (errno == EWOULDBLOCK)) {
399                 if (retry-- > 0) {
400                     goto resend;
401                 }
402                 errno = ETIMEDOUT;
403             }
404             return -1;
405         }
406         if (r < (ssize_t)sizeof(in->hdr)) {
407             fprintf(stderr, "netboot: response too short\n");
408             continue;
409         }
410         if ((in->hdr.magic != NB_MAGIC) ||
411             (in->hdr.cookie != out->hdr.cookie) ||
412             (in->hdr.cmd != NB_ACK)) {
413             fprintf(stderr, "netboot: bad ack header"
414                             " (magic=0x%x, cookie=%x/%x, cmd=%d)\n",
415                     in->hdr.magic, in->hdr.cookie, cookie, in->hdr.cmd);
416             continue;
417         }
418         int arg = in->hdr.arg;
419         if (arg < 0) {
420             errno = -arg;
421             return -1;
422         }
423         return r;
424     }
425 }
426