1 // Copyright 2016 The Fuchsia Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #define _POSIX_C_SOURCE 200809L
6
7 #include "netprotocol.h"
8
9 #include <arpa/inet.h>
10 #include <ifaddrs.h>
11 #include <netinet/in.h>
12 #include <poll.h>
13 #include <sched.h>
14 #include <sys/param.h>
15 #include <sys/socket.h>
16 #include <sys/stat.h>
17 #include <sys/time.h>
18
19 #include <fcntl.h>
20 #include <libgen.h>
21 #include <stdbool.h>
22 #include <stdio.h>
23 #include <stdlib.h>
24 #include <string.h>
25 #include <unistd.h>
26
27 #include <errno.h>
28 #include <stdint.h>
29
30 #include <tftp/tftp.h>
31 #include <zircon/boot/netboot.h>
32
33 #define TFTP_BUF_SZ 2048
34
35 typedef struct {
36 int fd;
37 size_t size;
38 } file_info_t;
39
40 typedef struct {
41 int socket;
42 bool connected;
43 uint32_t previous_timeout_ms;
44 struct sockaddr_in6 target_addr;
45 } transport_info_t;
46
47 static const char* appname;
48
file_open_read(const char * filename,void * file_cookie)49 static ssize_t file_open_read(const char* filename, void* file_cookie) {
50 int fd = open(filename, O_RDONLY);
51 if (fd < 0) {
52 return TFTP_ERR_IO;
53 }
54 file_info_t* file_info = file_cookie;
55 file_info->fd = fd;
56 struct stat st;
57 if (fstat(file_info->fd, &st) < 0) {
58 close(fd);
59 return TFTP_ERR_IO;
60 }
61 file_info->size = st.st_size;
62 return st.st_size;
63 }
64
file_open_write(const char * filename,size_t size,void * file_cookie)65 static tftp_status file_open_write(const char* filename, size_t size, void* file_cookie) {
66 int fd = open(filename, O_WRONLY | O_CREAT | O_TRUNC, S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH);
67 if (fd < 0) {
68 return TFTP_ERR_IO;
69 }
70 file_info_t* file_info = file_cookie;
71 file_info->fd = fd;
72 file_info->size = size;
73 return TFTP_NO_ERROR;
74 }
75
file_read(void * data,size_t * length,off_t offset,void * file_cookie)76 static tftp_status file_read(void* data, size_t* length, off_t offset, void* file_cookie) {
77 int fd = ((file_info_t*)file_cookie)->fd;
78 ssize_t n = pread(fd, data, *length, offset);
79 if (n < 0) {
80 return TFTP_ERR_IO;
81 }
82 *length = n;
83 return TFTP_NO_ERROR;
84 }
85
file_write(const void * data,size_t * length,off_t offset,void * file_cookie)86 static tftp_status file_write(const void* data, size_t* length, off_t offset, void* file_cookie) {
87 int fd = ((file_info_t*)file_cookie)->fd;
88 ssize_t n = pwrite(fd, data, *length, offset);
89 if (n < 0) {
90 return TFTP_ERR_IO;
91 }
92 *length = n;
93 return TFTP_NO_ERROR;
94 }
95
file_close(void * file_cookie)96 static void file_close(void* file_cookie) {
97 close(((file_info_t*)file_cookie)->fd);
98 }
99
100 // Longest time we will wait for a send operation to succeed
101 #define MAX_SEND_TIME_MS 1000
102
transport_send(void * data,size_t len,void * transport_cookie)103 static tftp_status transport_send(void* data, size_t len, void* transport_cookie) {
104 transport_info_t* transport_info = transport_cookie;
105 ssize_t send_result;
106 struct pollfd poll_fds = {.fd = transport_info->socket,
107 .events = POLLOUT};
108 do {
109 int poll_result = poll(&poll_fds, 1, MAX_SEND_TIME_MS);
110 if (poll_result <= 0) {
111 // We'll treat a timeout as an IO error and not a TFTP_ERR_TIMED_OUT,
112 // since the latter is a timeout waiting for a response from the server.
113 return TFTP_ERR_IO;
114 }
115 if (!transport_info->connected) {
116 transport_info->target_addr.sin6_port = htons(NB_TFTP_INCOMING_PORT);
117 send_result = sendto(transport_info->socket, data, len, 0,
118 (struct sockaddr*)&transport_info->target_addr,
119 sizeof(transport_info->target_addr));
120 } else {
121 send_result = send(transport_info->socket, data, len, 0);
122 }
123 } while ((send_result < 0) &&
124 ((errno == EAGAIN) || (errno == EWOULDBLOCK) ||
125 (errno == ENOBUFS && sched_yield() == 0)));
126
127 if (send_result < 0) {
128 return TFTP_ERR_IO;
129 }
130 return TFTP_NO_ERROR;
131 }
132
transport_recv(void * data,size_t len,bool block,void * transport_cookie)133 static int transport_recv(void* data, size_t len, bool block, void* transport_cookie) {
134 transport_info_t* transport_info = transport_cookie;
135 int flags = fcntl(transport_info->socket, F_GETFL, 0);
136 if (flags < 0) {
137 return TFTP_ERR_IO;
138 }
139 if (block) {
140 flags &= ~O_NONBLOCK;
141 } else {
142 flags |= O_NONBLOCK;
143 }
144 if (fcntl(transport_info->socket, F_SETFL, flags)) {
145 return TFTP_ERR_IO;
146 }
147 ssize_t recv_result;
148 struct sockaddr_in6 connection_addr;
149 socklen_t addr_len = sizeof(connection_addr);
150 if (!transport_info->connected) {
151 recv_result = recvfrom(transport_info->socket, data, len, 0,
152 (struct sockaddr*)&connection_addr,
153 &addr_len);
154 } else {
155 recv_result = recv(transport_info->socket, data, len, 0);
156 }
157 if (recv_result < 0) {
158 if ((errno == EAGAIN) || (errno == EWOULDBLOCK)) {
159 return TFTP_ERR_TIMED_OUT;
160 }
161 return TFTP_ERR_INTERNAL;
162 }
163 if (!transport_info->connected) {
164 if (connect(transport_info->socket, (struct sockaddr*)&connection_addr,
165 sizeof(connection_addr)) < 0) {
166 return TFTP_ERR_IO;
167 }
168 memcpy(&transport_info->target_addr, &connection_addr,
169 sizeof(transport_info->target_addr));
170 transport_info->connected = true;
171 }
172 return recv_result;
173 }
174
transport_timeout_set(uint32_t timeout_ms,void * transport_cookie)175 static int transport_timeout_set(uint32_t timeout_ms, void* transport_cookie) {
176 transport_info_t* transport_info = transport_cookie;
177 if (transport_info->previous_timeout_ms != timeout_ms && timeout_ms > 0) {
178 transport_info->previous_timeout_ms = timeout_ms;
179 struct timeval tv;
180 tv.tv_sec = timeout_ms / 1000;
181 tv.tv_usec = 1000 * (timeout_ms - 1000 * tv.tv_sec);
182 return setsockopt(transport_info->socket, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv));
183 }
184 return 0;
185 }
186
transfer_file(bool push,int s,struct sockaddr_in6 * addr,const char * dst,const char * src)187 static int transfer_file(bool push, int s, struct sockaddr_in6* addr, const char* dst,
188 const char* src) {
189 // Initialize session
190 tftp_session* session = NULL;
191 size_t session_data_sz = tftp_sizeof_session();
192 void* session_data = calloc(session_data_sz, 1);
193 if (session_data == NULL) {
194 fprintf(stderr, "%s: unable to allocate tftp session memory\n", appname);
195 return 1;
196 }
197 if (tftp_init(&session, session_data, session_data_sz) != TFTP_NO_ERROR) {
198 fprintf(stderr, "%s: unable to initiate tftp session\n", appname);
199 free(session_data);
200 return 1;
201 }
202
203 // Initialize file interface
204 file_info_t file_info;
205 tftp_file_interface file_ifc = {file_open_read, file_open_write,
206 file_read, file_write, file_close};
207 tftp_session_set_file_interface(session, &file_ifc);
208
209 // Initialize transport interface
210 transport_info_t transport_info;
211 transport_info.previous_timeout_ms = 0;
212 transport_info.socket = s;
213 transport_info.connected = false;
214 memcpy(&transport_info.target_addr, addr, sizeof(transport_info.target_addr));
215 tftp_transport_interface transport_ifc = {transport_send, transport_recv,
216 transport_timeout_set};
217 tftp_session_set_transport_interface(session, &transport_ifc);
218
219 // Set our preferred transport options
220 tftp_set_options(session, &tftp_block_size, NULL, &tftp_window_size);
221
222 // Prepare buffers
223 char err_msg[128];
224 tftp_request_opts opts = {0};
225 opts.inbuf = malloc(TFTP_BUF_SZ);
226 opts.inbuf_sz = TFTP_BUF_SZ;
227 opts.outbuf = malloc(TFTP_BUF_SZ);
228 opts.outbuf_sz = TFTP_BUF_SZ;
229 opts.err_msg = err_msg;
230 opts.err_msg_sz = sizeof(err_msg);
231
232 tftp_status status;
233 if (push) {
234 status = tftp_push_file(session, &transport_info, &file_info, src, dst, &opts);
235 } else {
236 status = tftp_pull_file(session, &transport_info, &file_info, dst, src, &opts);
237 }
238
239 free(session_data);
240 free(opts.inbuf);
241 free(opts.outbuf);
242
243 if (status < 0) {
244 fprintf(stderr, "%s: %s (status = %d)\n", appname, opts.err_msg, (int)status);
245 return 1;
246 }
247
248 fprintf(stderr, "wrote %zu bytes\n", file_info.size);
249
250 return 0;
251 }
252
usage(void)253 static void usage(void) {
254 fprintf(stderr, "usage: %s [options] [hostname:]src [hostname:]dst\n", appname);
255 netboot_usage(true);
256 }
257
main(int argc,char ** argv)258 int main(int argc, char** argv) {
259 appname = argv[0];
260
261 int index = netboot_handle_getopt(argc, argv);
262 if (index < 0) {
263 usage();
264 return -1;
265 }
266 argv += index;
267 argc -= index;
268
269 if (argc != 2) {
270 usage();
271 return -1;
272 }
273
274 const char* src = argv[0];
275 const char* dst = argv[1];
276
277 int push = -1;
278 char* pos;
279 const char* hostname;
280 if ((pos = strpbrk(src, ":")) != 0) {
281 push = 0;
282 hostname = src;
283 pos[0] = 0;
284 src = pos + 1;
285 }
286 if ((pos = strpbrk(dst, ":")) != 0) {
287 if (push == 0) {
288 fprintf(stderr, "%s: only one of src or dst can have a hostname\n", appname);
289 return -1;
290 }
291 push = 1;
292 hostname = dst;
293 pos[0] = 0;
294 dst = pos + 1;
295 }
296 if (push == -1) {
297 fprintf(stderr, "%s: either src or dst needs a hostname\n", appname);
298 return -1;
299 }
300
301 int s;
302 struct sockaddr_in6 server_addr;
303 if ((s = netboot_open(hostname, NULL, &server_addr, false)) < 0) {
304 if (errno == ETIMEDOUT) {
305 fprintf(stderr, "%s: lookup of %s timed out\n", appname, hostname);
306 } else {
307 fprintf(stderr, "%s: failed to connect to %s: %d\n", appname, hostname, errno);
308 }
309 return -1;
310 }
311
312 int ret = transfer_file(push, s, &server_addr, dst, src);
313 close(s);
314 return ret;
315 }
316