1 /*
2  * Copyright (C) 2015-2017 Alibaba Group Holding Limited
3  */
4 
5 #include "lwip/opt.h"
6 #include "lwip/udp.h"
7 #include "lwip/timeouts.h"
8 #include "lwip/debug.h"
9 #include "lwip/apps/tftp.h"
10 
11 #include <string.h>
12 #include <stdlib.h>
13 #include <stdio.h>
14 
15 #define PRINT_BASE_SIZE 102400
16 
17 typedef struct tftp_state_s {
18     const tftp_context_t *ctx;
19     struct udp_pcb *upcb;
20     void           *handle;
21     tftp_done_cb   cb;
22     ip_addr_t      addr;
23     uint16_t       port;
24     uint16_t       seq_expect;
25     uint16_t       seq_last;
26     uint16_t       seq;
27     int            flen;
28     uint16_t       tick;
29     uint16_t       last_tick;
30     uint16_t       retries;
31     uint32_t       time;
32 } tftp_state_t;
33 
34 static tftp_state_t tftp_state;
35 static uint16_t     tftp_port = TFTP_PORT;
36 
37 static uint8_t      tftp_binary_mode = 0;
38 
39 static void tftp_tmr(void* arg);
40 void tftp_send_error(struct udp_pcb *pcb, const ip_addr_t *addr, u16_t port,
41                      tftp_error_t code, const char *str);
42 void tftp_send_ack(struct udp_pcb *pcb, const ip_addr_t *addr, u16_t port, u16_t blknum);
43 
44 static void
close_handle(int err)45 close_handle(int err)
46 {
47     tftp_state_t *pstate = &tftp_state;
48     sys_untimeout(tftp_tmr, NULL);
49     udp_remove(pstate->upcb);
50 
51     if (pstate->handle) {
52         pstate->ctx->close(pstate->handle);
53         pstate->handle = NULL;
54     }
55     //LWIP_DEBUGF(TFTP_DEBUG | LWIP_DBG_STATE, ("tftp: closing\n"));
56 
57     if (pstate->cb != NULL)
58         pstate->cb(err, err == 0 ? pstate->flen : -1);
59     memset(pstate, 0, sizeof(tftp_state_t));
60 }
61 
tftp_tmr(void * arg)62 static void tftp_tmr(void* arg)
63 {
64     tftp_state_t *pstate = &tftp_state;
65     tftp_state.tick++;
66 
67     if (tftp_state.handle == NULL) {
68         return;
69     }
70 
71     sys_timeout(TFTP_TIMER_MSECS, tftp_tmr, NULL);
72 
73     if ((pstate->tick - pstate->last_tick) > (TFTP_TIMEOUT_MSECS / TFTP_TIMER_MSECS)) {
74         if ((pstate->seq_expect > 1) && (pstate->retries < TFTP_MAX_RETRIES)) {
75             LWIP_DEBUGF(TFTP_DEBUG | LWIP_DBG_STATE, ("tftp: timeout, retrying...............\n"));
76             tftp_send_ack(pstate->upcb, &pstate->addr, pstate->port, pstate->seq_last);
77             pstate->retries++;
78         } else {
79             tftp_send_error(tftp_state.upcb, &tftp_state.addr, pstate->port,
80                     TFTP_ERROR_ILLEGAL_OPERATION, "wait packet timeout");
81             LWIP_DEBUGF(TFTP_DEBUG | LWIP_DBG_STATE, ("tftp: timeout\n"));
82             close_handle(-1);
83         }
84     }
85 }
86 
recv(void * arg,struct udp_pcb * upcb,struct pbuf * p,const ip_addr_t * addr,u16_t port)87 static void recv(void *arg, struct udp_pcb *upcb, struct pbuf *p, const ip_addr_t *addr, u16_t port)
88 {
89     tftp_state_t *pstate = &tftp_state;
90     if (pstate->seq_expect == 0 && ip_addr_cmp(&pstate->addr, addr)){
91         pstate->port = port;
92         pstate->seq_expect = 1;
93     }
94 
95     if ( port != pstate->port || !ip_addr_cmp(&pstate->addr, addr) ) {
96         tftp_send_error(pstate->upcb, addr, port,
97                 TFTP_ERROR_UNKNOWN_TRFR_ID, "port or addr illegal");
98         pbuf_free(p);
99         return;
100     }
101 
102     u16_t *sbuf = (u16_t *) p->payload;
103     pstate->last_tick = pstate->tick;
104     uint16_t opcode = PP_NTOHS(sbuf[0]);
105     uint16_t blknum, blklen;
106     int wlen = 0;
107     switch (opcode) {
108         case TFTP_DATA:
109             blknum = PP_NTOHS(sbuf[1]);
110             blklen = p->tot_len - TFTP_HEADER_LENGTH;
111             if (blknum < pstate->seq_expect) {
112              //   LWIP_DEBUGF(TFTP_DEBUG | LWIP_DBG_STATE,
113              //           ("received repeated block '%d', len='%u'\n", blknum, blklen));
114                 tftp_send_ack(pstate->upcb, &pstate->addr, port, blknum);
115                 pstate->seq_last = blknum;
116                 break;
117             }
118 
119             if (blknum > pstate->seq_expect) {
120                 LWIP_DEBUGF(TFTP_DEBUG | LWIP_DBG_STATE,
121                         ("received error block '%d', len='%u'\n", blknum, blklen ));
122                 tftp_send_error(pstate->upcb, addr, port,
123                         TFTP_ERROR_ILLEGAL_OPERATION, "seqno error");
124                 close_handle(-1);
125                 break;
126             }
127 
128             /*print download process based on 100KB.*/
129             if(pstate->flen/PRINT_BASE_SIZE > 0 && pstate->flen%PRINT_BASE_SIZE == 0)
130                 LWIP_DEBUGF(TFTP_DEBUG | LWIP_DBG_STATE,
131                         ("received total length='%uKB'\n", pstate->flen/1024));
132 
133             pbuf_header(p, -TFTP_HEADER_LENGTH);
134             wlen = pstate->ctx->write(pstate->handle, p);
135             if (wlen != blklen) {
136                 tftp_send_error(pstate->upcb, addr, port, TFTP_ERROR_DISK_FULL, "disk full");
137                 LWIP_DEBUGF(TFTP_DEBUG | LWIP_DBG_STATE, ("write block failed\n"));
138                 close_handle(-1);
139                 break;
140             }
141 
142             pstate->seq_last = blknum;
143             pstate->flen += blklen;
144             pstate->seq_expect++;
145             tftp_send_ack(pstate->upcb, &pstate->addr, port, blknum);
146 
147             if (blklen < 512) {
148                 pstate->time = aos_now_ms() - pstate->time;
149                 LWIP_DEBUGF(TFTP_DEBUG | LWIP_DBG_STATE,
150                         ("Total received: receive %u bytes in %u mS\n", pstate->flen, pstate->time));
151                 close_handle(0);
152                 break;
153             }
154             break;
155         case TFTP_ERROR:
156             LWIP_DEBUGF(TFTP_DEBUG | LWIP_DBG_STATE,
157                     ("sever return error '%d', msg: '%s'\n", PP_NTOHS(sbuf[1]), (char *)&sbuf[2]));
158             close_handle(-1);
159             break;
160         default:
161             break;
162     }
163     pbuf_free(p);
164 }
165 
tftp_fopen(const char * fname,const char * mode,u8_t write)166 static void* tftp_fopen(const char* fname, const char* mode, u8_t write)
167 {
168     FILE *fp = NULL;
169 
170     if (strncmp(mode, "netascii", 8) == 0) {
171         fp = fopen(fname, write == 0 ? "r" : "w");
172     } else if (strncmp(mode, "octet", 5) == 0) {
173         fp = fopen(fname, write == 0 ? "rb" : "wb");
174     }
175     return (void*)fp;
176 }
177 
tftp_fclose(void * handle)178 static void tftp_fclose(void* handle)
179 {
180     fclose((FILE*)handle);
181 }
182 
tftp_fread(void * handle,void * buf,int bytes)183 static int tftp_fread(void* handle, void* buf, int bytes)
184 {
185     size_t readbytes;
186     readbytes = fread(buf, 1, (size_t)bytes, (FILE*)handle);
187     return (int)readbytes;
188 }
189 
tftp_fwrite(void * handle,struct pbuf * p)190 static int tftp_fwrite(void* handle, struct pbuf* p)
191 {
192     char buff[512];
193     size_t writebytes = -1;
194 
195     pbuf_copy_partial(p, buff, p->tot_len, 0);
196 
197     writebytes = fwrite(buff, 1, p->tot_len, (FILE *)handle);
198 
199     return (int)writebytes;
200 }
201 
202 const tftp_context_t client_ctx = {
203     .open = tftp_fopen,
204     .close = tftp_fclose,
205     .read = tftp_fread,
206     .write = tftp_fwrite
207 };
208 
tftp_client_get(const ip_addr_t * paddr,const char * fname,const char * lfname,tftp_context_t * ctx,tftp_done_cb cb)209 int tftp_client_get(const ip_addr_t *paddr, const char *fname, const char *lfname,
210                     tftp_context_t *ctx, tftp_done_cb cb)
211 {
212     err_t ret;
213     tftp_state_t *pstate = &tftp_state;
214     char  *mode;
215 
216     if (tftp_binary_mode == 1) {
217         mode = "octet";
218     } else {
219         mode = "netascii";
220     }
221     pstate->time = aos_now_ms();
222     pstate->handle = ctx->open(lfname, mode, 1);
223     if (pstate->handle == NULL) {
224         LWIP_DEBUGF(TFTP_DEBUG | LWIP_DBG_STATE, ("error: open file '%s' failed\n", lfname));
225         return -1;
226     }
227 
228     struct udp_pcb *pcb = udp_new_ip_type(IPADDR_TYPE_ANY);
229     if (pcb == NULL) {
230         return ERR_MEM;
231     }
232 
233     uint16_t port = aos_rand() % 16384 + 49152;
234     ret = udp_bind(pcb, IP4_ADDR_ANY, port);
235     if (ret != ERR_OK) {
236         LWIP_DEBUGF(TFTP_DEBUG | LWIP_DBG_STATE, ("error: bind to port '%u' failed\n", port));
237         udp_remove(pcb);
238         return ret;
239     }
240 
241     /* send RRQ packet */
242     int pkt_len = 4 + strlen(mode) + strlen(fname);
243     struct pbuf *p = pbuf_alloc(PBUF_TRANSPORT, pkt_len, PBUF_RAM);
244     if (p == NULL) {
245         LWIP_DEBUGF(TFTP_DEBUG | LWIP_DBG_STATE, ("error: alloc pbuf failed\n"));
246         udp_remove(pcb);
247         return ERR_MEM;
248     }
249 
250     char *payload = (char *)p->payload;
251     memset(payload, 0, pkt_len);
252     payload[0] = 0x00;
253     payload[1] = TFTP_RRQ;
254     memcpy(&payload[2], fname, strlen(fname));
255     memcpy(&payload[3 + strlen(fname)], mode, strlen(mode));
256     pstate->port = tftp_port;
257     ret = udp_sendto(pcb, p, paddr, pstate->port);
258     if (ret != ERR_OK) {
259         LWIP_DEBUGF(TFTP_DEBUG | LWIP_DBG_STATE, ("error: send RRQ to server failed\n"));
260         pbuf_free(p);
261         udp_remove(pcb);
262         return ERR_MEM;
263     }
264 
265     udp_recv(pcb, recv, NULL);
266     pstate->tick = 0;
267     pstate->last_tick = 0;
268     sys_timeout(TFTP_TIMER_MSECS, tftp_tmr, NULL);
269     pstate->upcb = pcb;
270     pstate->ctx = ctx;
271     pstate->cb = cb;
272     memcpy(&pstate->addr, paddr, sizeof(*paddr));
273     pstate->seq_expect = 0;
274     pstate->flen = 0;
275     pbuf_free(p);
276     return 0;
277 }
278 
tftp_client_set_server_port(uint16_t port)279 void tftp_client_set_server_port(uint16_t port)
280 {
281     tftp_port = port;
282 }
283 
tftp_client_set_binary_mode(uint8_t binary_mode)284 void tftp_client_set_binary_mode(uint8_t binary_mode)
285 {
286     tftp_binary_mode = binary_mode;
287 }
288 
289