1 // Copyright 2016 The Fuchsia Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #define _POSIX_C_SOURCE 200809L
6 
7 #include "netprotocol.h"
8 
9 #include <arpa/inet.h>
10 #include <ifaddrs.h>
11 #include <netinet/in.h>
12 #include <poll.h>
13 #include <sched.h>
14 #include <sys/param.h>
15 #include <sys/socket.h>
16 #include <sys/stat.h>
17 #include <sys/time.h>
18 
19 #include <fcntl.h>
20 #include <libgen.h>
21 #include <stdbool.h>
22 #include <stdio.h>
23 #include <stdlib.h>
24 #include <string.h>
25 #include <unistd.h>
26 
27 #include <errno.h>
28 #include <stdint.h>
29 
30 #include <tftp/tftp.h>
31 #include <zircon/boot/netboot.h>
32 
33 #define TFTP_BUF_SZ 2048
34 
35 typedef struct {
36     int fd;
37     size_t size;
38 } file_info_t;
39 
40 typedef struct {
41     int socket;
42     bool connected;
43     uint32_t previous_timeout_ms;
44     struct sockaddr_in6 target_addr;
45 } transport_info_t;
46 
47 static const char* appname;
48 
file_open_read(const char * filename,void * file_cookie)49 static ssize_t file_open_read(const char* filename, void* file_cookie) {
50     int fd = open(filename, O_RDONLY);
51     if (fd < 0) {
52         return TFTP_ERR_IO;
53     }
54     file_info_t* file_info = file_cookie;
55     file_info->fd = fd;
56     struct stat st;
57     if (fstat(file_info->fd, &st) < 0) {
58         close(fd);
59         return TFTP_ERR_IO;
60     }
61     file_info->size = st.st_size;
62     return st.st_size;
63 }
64 
file_open_write(const char * filename,size_t size,void * file_cookie)65 static tftp_status file_open_write(const char* filename, size_t size, void* file_cookie) {
66     int fd = open(filename, O_WRONLY | O_CREAT | O_TRUNC, S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH);
67     if (fd < 0) {
68         return TFTP_ERR_IO;
69     }
70     file_info_t* file_info = file_cookie;
71     file_info->fd = fd;
72     file_info->size = size;
73     return TFTP_NO_ERROR;
74 }
75 
file_read(void * data,size_t * length,off_t offset,void * file_cookie)76 static tftp_status file_read(void* data, size_t* length, off_t offset, void* file_cookie) {
77     int fd = ((file_info_t*)file_cookie)->fd;
78     ssize_t n = pread(fd, data, *length, offset);
79     if (n < 0) {
80         return TFTP_ERR_IO;
81     }
82     *length = n;
83     return TFTP_NO_ERROR;
84 }
85 
file_write(const void * data,size_t * length,off_t offset,void * file_cookie)86 static tftp_status file_write(const void* data, size_t* length, off_t offset, void* file_cookie) {
87     int fd = ((file_info_t*)file_cookie)->fd;
88     ssize_t n = pwrite(fd, data, *length, offset);
89     if (n < 0) {
90         return TFTP_ERR_IO;
91     }
92     *length = n;
93     return TFTP_NO_ERROR;
94 }
95 
file_close(void * file_cookie)96 static void file_close(void* file_cookie) {
97     close(((file_info_t*)file_cookie)->fd);
98 }
99 
100 // Longest time we will wait for a send operation to succeed
101 #define MAX_SEND_TIME_MS 1000
102 
transport_send(void * data,size_t len,void * transport_cookie)103 static tftp_status transport_send(void* data, size_t len, void* transport_cookie) {
104     transport_info_t* transport_info = transport_cookie;
105     ssize_t send_result;
106     struct pollfd poll_fds = {.fd = transport_info->socket,
107                               .events = POLLOUT};
108     do {
109         int poll_result = poll(&poll_fds, 1, MAX_SEND_TIME_MS);
110         if (poll_result <= 0) {
111             // We'll treat a timeout as an IO error and not a TFTP_ERR_TIMED_OUT,
112             // since the latter is a timeout waiting for a response from the server.
113             return TFTP_ERR_IO;
114         }
115         if (!transport_info->connected) {
116             transport_info->target_addr.sin6_port = htons(NB_TFTP_INCOMING_PORT);
117             send_result = sendto(transport_info->socket, data, len, 0,
118                                  (struct sockaddr*)&transport_info->target_addr,
119                                  sizeof(transport_info->target_addr));
120         } else {
121             send_result = send(transport_info->socket, data, len, 0);
122         }
123     } while ((send_result < 0) &&
124              ((errno == EAGAIN) || (errno == EWOULDBLOCK) ||
125               (errno == ENOBUFS && sched_yield() == 0)));
126 
127     if (send_result < 0) {
128         return TFTP_ERR_IO;
129     }
130     return TFTP_NO_ERROR;
131 }
132 
transport_recv(void * data,size_t len,bool block,void * transport_cookie)133 static int transport_recv(void* data, size_t len, bool block, void* transport_cookie) {
134     transport_info_t* transport_info = transport_cookie;
135     int flags = fcntl(transport_info->socket, F_GETFL, 0);
136     if (flags < 0) {
137         return TFTP_ERR_IO;
138     }
139     if (block) {
140         flags &= ~O_NONBLOCK;
141     } else {
142         flags |= O_NONBLOCK;
143     }
144     if (fcntl(transport_info->socket, F_SETFL, flags)) {
145         return TFTP_ERR_IO;
146     }
147     ssize_t recv_result;
148     struct sockaddr_in6 connection_addr;
149     socklen_t addr_len = sizeof(connection_addr);
150     if (!transport_info->connected) {
151         recv_result = recvfrom(transport_info->socket, data, len, 0,
152                                (struct sockaddr*)&connection_addr,
153                                &addr_len);
154     } else {
155         recv_result = recv(transport_info->socket, data, len, 0);
156     }
157     if (recv_result < 0) {
158         if ((errno == EAGAIN) || (errno == EWOULDBLOCK)) {
159             return TFTP_ERR_TIMED_OUT;
160         }
161         return TFTP_ERR_INTERNAL;
162     }
163     if (!transport_info->connected) {
164         if (connect(transport_info->socket, (struct sockaddr*)&connection_addr,
165                     sizeof(connection_addr)) < 0) {
166             return TFTP_ERR_IO;
167         }
168         memcpy(&transport_info->target_addr, &connection_addr,
169                sizeof(transport_info->target_addr));
170         transport_info->connected = true;
171     }
172     return recv_result;
173 }
174 
transport_timeout_set(uint32_t timeout_ms,void * transport_cookie)175 static int transport_timeout_set(uint32_t timeout_ms, void* transport_cookie) {
176     transport_info_t* transport_info = transport_cookie;
177     if (transport_info->previous_timeout_ms != timeout_ms && timeout_ms > 0) {
178         transport_info->previous_timeout_ms = timeout_ms;
179         struct timeval tv;
180         tv.tv_sec = timeout_ms / 1000;
181         tv.tv_usec = 1000 * (timeout_ms - 1000 * tv.tv_sec);
182         return setsockopt(transport_info->socket, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv));
183     }
184     return 0;
185 }
186 
transfer_file(bool push,int s,struct sockaddr_in6 * addr,const char * dst,const char * src)187 static int transfer_file(bool push, int s, struct sockaddr_in6* addr, const char* dst,
188                          const char* src) {
189     // Initialize session
190     tftp_session* session = NULL;
191     size_t session_data_sz = tftp_sizeof_session();
192     void* session_data = calloc(session_data_sz, 1);
193     if (session_data == NULL) {
194         fprintf(stderr, "%s: unable to allocate tftp session memory\n", appname);
195         return 1;
196     }
197     if (tftp_init(&session, session_data, session_data_sz) != TFTP_NO_ERROR) {
198         fprintf(stderr, "%s: unable to initiate tftp session\n", appname);
199         free(session_data);
200         return 1;
201     }
202 
203     // Initialize file interface
204     file_info_t file_info;
205     tftp_file_interface file_ifc = {file_open_read, file_open_write,
206                                     file_read, file_write, file_close};
207     tftp_session_set_file_interface(session, &file_ifc);
208 
209     // Initialize transport interface
210     transport_info_t transport_info;
211     transport_info.previous_timeout_ms = 0;
212     transport_info.socket = s;
213     transport_info.connected = false;
214     memcpy(&transport_info.target_addr, addr, sizeof(transport_info.target_addr));
215     tftp_transport_interface transport_ifc = {transport_send, transport_recv,
216                                               transport_timeout_set};
217     tftp_session_set_transport_interface(session, &transport_ifc);
218 
219     // Set our preferred transport options
220     tftp_set_options(session, &tftp_block_size, NULL, &tftp_window_size);
221 
222     // Prepare buffers
223     char err_msg[128];
224     tftp_request_opts opts = {0};
225     opts.inbuf = malloc(TFTP_BUF_SZ);
226     opts.inbuf_sz = TFTP_BUF_SZ;
227     opts.outbuf = malloc(TFTP_BUF_SZ);
228     opts.outbuf_sz = TFTP_BUF_SZ;
229     opts.err_msg = err_msg;
230     opts.err_msg_sz = sizeof(err_msg);
231 
232     tftp_status status;
233     if (push) {
234         status = tftp_push_file(session, &transport_info, &file_info, src, dst, &opts);
235     } else {
236         status = tftp_pull_file(session, &transport_info, &file_info, dst, src, &opts);
237     }
238 
239     free(session_data);
240     free(opts.inbuf);
241     free(opts.outbuf);
242 
243     if (status < 0) {
244         fprintf(stderr, "%s: %s (status = %d)\n", appname, opts.err_msg, (int)status);
245         return 1;
246     }
247 
248     fprintf(stderr, "wrote %zu bytes\n", file_info.size);
249 
250     return 0;
251 }
252 
usage(void)253 static void usage(void) {
254     fprintf(stderr, "usage: %s [options] [hostname:]src [hostname:]dst\n", appname);
255     netboot_usage(true);
256 }
257 
main(int argc,char ** argv)258 int main(int argc, char** argv) {
259     appname = argv[0];
260 
261     int index = netboot_handle_getopt(argc, argv);
262     if (index < 0) {
263         usage();
264         return -1;
265     }
266     argv += index;
267     argc -= index;
268 
269     if (argc != 2) {
270         usage();
271         return -1;
272     }
273 
274     const char* src = argv[0];
275     const char* dst = argv[1];
276 
277     int push = -1;
278     char* pos;
279     const char* hostname;
280     if ((pos = strpbrk(src, ":")) != 0) {
281         push = 0;
282         hostname = src;
283         pos[0] = 0;
284         src = pos + 1;
285     }
286     if ((pos = strpbrk(dst, ":")) != 0) {
287         if (push == 0) {
288             fprintf(stderr, "%s: only one of src or dst can have a hostname\n", appname);
289             return -1;
290         }
291         push = 1;
292         hostname = dst;
293         pos[0] = 0;
294         dst = pos + 1;
295     }
296     if (push == -1) {
297         fprintf(stderr, "%s: either src or dst needs a hostname\n", appname);
298         return -1;
299     }
300 
301     int s;
302     struct sockaddr_in6 server_addr;
303     if ((s = netboot_open(hostname, NULL, &server_addr, false)) < 0) {
304         if (errno == ETIMEDOUT) {
305             fprintf(stderr, "%s: lookup of %s timed out\n", appname, hostname);
306         } else {
307             fprintf(stderr, "%s: failed to connect to %s: %d\n", appname, hostname, errno);
308         }
309         return -1;
310     }
311 
312     int ret = transfer_file(push, s, &server_addr, dst, src);
313     close(s);
314     return ret;
315 }
316