1 /*
2  * Copyright (c) 2014 Brian Swetland
3  *
4  * Use of this source code is governed by a MIT-style
5  * license that can be found in the LICENSE file or at
6  * https://opensource.org/licenses/MIT
7  */
8 
9 #include <stdio.h>
10 #include <stdlib.h>
11 #include <string.h>
12 #include <unistd.h>
13 #include <errno.h>
14 #include <fcntl.h>
15 #include <sys/types.h>
16 
17 #include "network.h"
18 #include "../app/lkboot/lkboot_protocol.h"
19 
readx(int s,void * _data,int len)20 static int readx(int s, void *_data, int len) {
21     char *data = _data;
22     int r;
23     while (len > 0) {
24         r = read(s, data, len);
25         if (r == 0) {
26             fprintf(stderr, "error: eof during socket read\n");
27             return -1;
28         }
29         if (r < 0) {
30             if (errno == EINTR) continue;
31             fprintf(stderr, "error: %s during socket read\n", strerror(errno));
32             return -1;
33         }
34         data += r;
35         len -= r;
36     }
37     return 0;
38 }
39 
upload(int s,int txfd,size_t txlen,int do_endian_swap)40 static int upload(int s, int txfd, size_t txlen, int do_endian_swap) {
41     int err = 0;
42     msg_hdr_t hdr;
43 
44     char *buf = malloc(txlen);
45     if (!buf)
46         return -1;
47 
48     if (readx(txfd, buf, txlen)) {
49         fprintf(stderr, "error: reading from file\n");
50         err = -1;
51         goto done;
52     }
53 
54     /* 4 byte swap data if requested */
55     if (do_endian_swap) {
56         size_t i;
57         for (i = 0; i < txlen; i += 4) {
58             char temp = buf[i];
59             buf[i] = buf[i + 3];
60             buf[i + 3] = temp;
61 
62             temp = buf[i + 1];
63             buf[i + 1] = buf[i + 2];
64             buf[i + 2] = temp;
65         }
66     }
67 
68     size_t pos = 0;
69     while (pos < txlen) {
70         size_t xfer = (txlen - pos > 65536) ? 65536 : txlen - pos;
71 
72         hdr.opcode = MSG_SEND_DATA;
73         hdr.extra = 0;
74         hdr.length = xfer - 1;
75         if (write(s, &hdr, sizeof(hdr)) != sizeof(hdr)) {
76             fprintf(stderr, "error: writing socket\n");
77             err = -1;
78             goto done;
79         }
80         if (write(s, buf + pos, xfer) != xfer) {
81             fprintf(stderr, "error: writing socket\n");
82             err = -1;
83             goto done;
84         }
85         pos += xfer;
86     }
87 
88     hdr.opcode = MSG_END_DATA;
89     hdr.extra = 0;
90     hdr.length = 0;
91     if (write(s, &hdr, sizeof(hdr)) != sizeof(hdr)) {
92         fprintf(stderr, "error: writing socket\n");
93         err = -1;
94         goto done;
95     }
96 
97 done:
98     free(buf);
99 
100     return err;
101 }
102 
trim_fpga_image(int fd,off_t len)103 static off_t trim_fpga_image(int fd, off_t len) {
104     /* fd should be at start of bitfile, seek until the
105      * ff ff ff ff aa 99 55 66 pattern is found and subtract
106      * the number of bytes read until pattern found.
107      */
108     const unsigned char pat[] = { 0xff, 0xff, 0xff, 0xff, 0xaa, 0x99, 0x55, 0x66 };
109     unsigned char buf[sizeof(pat)];
110 
111     memset(buf, 0, sizeof(buf));
112 
113     off_t i;
114     for (i = 0; i < len; i++) {
115         memmove(buf, buf + 1, sizeof(buf) - 1);
116         if (read(fd, &buf[sizeof(buf) - 1], 1) < 1) {
117             return -1;
118         }
119 
120         /* look for pattern */
121         if (memcmp(pat, buf, sizeof(pat)) == 0) {
122             /* found it, rewind the fd and return the truncated length */
123             lseek(fd, -sizeof(pat), SEEK_CUR);
124             return len - (i + 1 - sizeof(pat));
125         }
126     }
127 
128     return -1;
129 }
130 
131 #define DCC_SUBPROCESS "zynq-dcc"
132 
start_dcc_subprocess(int * fd_in,int * fd_out)133 static int start_dcc_subprocess(int *fd_in, int *fd_out) {
134     int outpipe[2];
135     if (pipe(outpipe) != 0)
136         return -1;
137 
138     int inpipe[2];
139     if (pipe(inpipe) != 0)
140         return -1;
141 
142     *fd_in = inpipe[0];
143     *fd_out = outpipe[1];
144 
145     pid_t pid = fork();
146     if (pid == 0) {
147         /* we are the child */
148         close(STDIN_FILENO);
149         close(STDOUT_FILENO);
150 
151         dup2(outpipe[0], STDIN_FILENO);
152         close(outpipe[1]);
153 
154         dup2(inpipe[1], STDOUT_FILENO);
155         close(inpipe[0]);
156 
157         fprintf(stderr, "in the child\n");
158 
159         execlp(DCC_SUBPROCESS, DCC_SUBPROCESS, NULL);
160         fprintf(stderr, "after exec, didn't work!\n");
161     } else {
162         fprintf(stderr, "in parent, pid %u\n", pid);
163 
164         close(outpipe[0]);
165         close(inpipe[1]);
166     }
167 
168     return 0;
169 }
170 
171 #define REPLYMAX (9 * 1024 * 1024)
172 static unsigned char replybuf[REPLYMAX];
173 static unsigned replylen = 0;
174 
lkboot_get_reply(void ** ptr)175 unsigned lkboot_get_reply(void **ptr) {
176     *ptr = replybuf;
177     return replylen;
178 }
179 
lkboot_txn(const char * host,const char * _cmd,int txfd,const char * args)180 int lkboot_txn(const char *host, const char *_cmd, int txfd, const char *args) {
181     msg_hdr_t hdr;
182     char cmd[128];
183     char tmp[65536];
184     off_t txlen = 0;
185     int do_endian_swap = 0;
186     int once = 1;
187     int len;
188     int fd_in, fd_out;
189     int ret = 0;
190 
191     if (txfd != -1) {
192         txlen = lseek(txfd, 0, SEEK_END);
193         if (txlen > (512*1024*1024)) {
194             fprintf(stderr, "error: file too large\n");
195             return -1;
196         }
197         lseek(txfd, 0, SEEK_SET);
198     }
199 
200     if (!strcmp(_cmd, "fpga")) {
201         /* if we were asked to send an fpga image, try to find the sync words and
202          * trim all the data before it
203          */
204         txlen = trim_fpga_image(txfd, txlen);
205         if (txlen < 0) {
206             fprintf(stderr, "error: fpga image doesn't contain sync pattern\n");
207             return -1;
208         }
209 
210         /* it'll need a 4 byte endian swap as well */
211         do_endian_swap = 1;
212     }
213 
214     len = snprintf(cmd, 128, "%s:%d:%s", _cmd, (int) txlen, args);
215     if (len > 127) {
216         fprintf(stderr, "error: command too large\n");
217         return -1;
218     }
219 
220     /* if host is -, use stdin/stdout */
221     if (!strcmp(host, "-")) {
222         fprintf(stderr, "using stdin/stdout for io\n");
223         fd_in = STDIN_FILENO;
224         fd_out = STDOUT_FILENO;
225     } else if (!strcasecmp(host, "jtag")) {
226         fprintf(stderr, "using zynq-dcc utility for io\n");
227         if (start_dcc_subprocess(&fd_in, &fd_out) < 0) {
228             fprintf(stderr, "error starting jtag subprocess, is it in your path?\n");
229             return -1;
230         }
231     } else {
232         in_addr_t addr = lookup_hostname(host);
233         if (addr == 0) {
234             fprintf(stderr, "error: cannot find host '%s'\n", host);
235             return -1;
236         }
237         while ((fd_in = tcp_connect(addr, 1023)) < 0) {
238             if (once) {
239                 fprintf(stderr, "error: cannot connect to host '%s'. retrying...\n", host);
240                 once = 0;
241             }
242             usleep(100000);
243         }
244         fd_out = fd_in;
245     }
246 
247     hdr.opcode = MSG_CMD;
248     hdr.extra = 0;
249     hdr.length = len;
250     if (write(fd_out, &hdr, sizeof(hdr)) != sizeof(hdr)) goto iofail;
251     if (write(fd_out, cmd, len) != len) goto iofail;
252 
253     for (;;) {
254         if (readx(fd_in, &hdr, sizeof(hdr))) goto iofail;
255         switch (hdr.opcode) {
256             case MSG_GO_AHEAD:
257                 if (upload(fd_out, txfd, txlen, do_endian_swap)) {
258                     ret = -1;
259                     goto out;
260                 }
261                 break;
262             case MSG_OKAY:
263                 ret = 0;
264                 goto out;
265             case MSG_FAIL:
266                 len = (hdr.length > 127) ? 127 : hdr.length;
267                 if (readx(fd_in, cmd, len)) {
268                     cmd[0] = 0;
269                 } else {
270                     cmd[len] = 0;
271                 }
272                 fprintf(stderr,"error: remote failure: %s\n", cmd);
273                 ret = -1;
274                 goto out;
275             case MSG_SEND_DATA:
276                 len = hdr.length + 1;
277                 if (readx(fd_in, tmp, len)) goto iofail;
278                 if (len > (REPLYMAX - replylen)) {
279                     fprintf(stderr, "error: too much reply data\n");
280                     ret = -1;
281                     goto out;
282                 }
283                 memcpy(replybuf + replylen, tmp, len);
284                 replylen += len;
285                 break;
286             default:
287                 fprintf(stderr, "error: unknown opcode %d\n", hdr.opcode);
288                 ret = -1;
289                 goto out;
290         }
291     }
292 
293 iofail:
294     fprintf(stderr, "error: socket io\n");
295     ret = -1;
296 
297 out:
298     close(fd_in);
299     if (fd_out != fd_in)
300         close(fd_out);
301     return ret;
302 }
303