1 /** @file
2 * @brief Websocket client API
3 *
4 * An API for applications to setup a websocket connections.
5 */
6
7 /*
8 * Copyright (c) 2019 Intel Corporation
9 *
10 * SPDX-License-Identifier: Apache-2.0
11 */
12
13 #include <zephyr/logging/log.h>
14 LOG_MODULE_REGISTER(net_websocket, CONFIG_NET_WEBSOCKET_LOG_LEVEL);
15
16 #include <zephyr/kernel.h>
17 #include <strings.h>
18 #include <errno.h>
19 #include <stdbool.h>
20 #include <stdlib.h>
21
22 #include <zephyr/sys/fdtable.h>
23 #include <zephyr/net/net_core.h>
24 #include <zephyr/net/net_ip.h>
25 #if defined(CONFIG_POSIX_API)
26 #include <zephyr/posix/unistd.h>
27 #include <zephyr/posix/sys/socket.h>
28 #else
29 #include <zephyr/net/socket.h>
30 #endif
31 #include <zephyr/net/http/client.h>
32 #include <zephyr/net/websocket.h>
33
34 #include <zephyr/random/random.h>
35 #include <zephyr/sys/byteorder.h>
36 #include <zephyr/sys/base64.h>
37 #include <mbedtls/sha1.h>
38
39 #include "net_private.h"
40 #include "sockets_internal.h"
41 #include "websocket_internal.h"
42
43 /* If you want to see the data that is being sent or received,
44 * then you can enable debugging and set the following variables to 1.
45 * This will print a lot of data so is not enabled by default.
46 */
47 #define HEXDUMP_SENT_PACKETS 0
48 #define HEXDUMP_RECV_PACKETS 0
49
50 static struct websocket_context contexts[CONFIG_WEBSOCKET_MAX_CONTEXTS];
51
52 static struct k_sem contexts_lock;
53
54 static const struct socket_op_vtable websocket_fd_op_vtable;
55
56 #if defined(CONFIG_NET_TEST)
57 int verify_sent_and_received_msg(struct msghdr *msg, bool split_msg);
58 #endif
59
opcode2str(enum websocket_opcode opcode)60 static const char *opcode2str(enum websocket_opcode opcode)
61 {
62 switch (opcode) {
63 case WEBSOCKET_OPCODE_DATA_TEXT:
64 return "TEXT";
65 case WEBSOCKET_OPCODE_DATA_BINARY:
66 return "BIN";
67 case WEBSOCKET_OPCODE_CONTINUE:
68 return "CONT";
69 case WEBSOCKET_OPCODE_CLOSE:
70 return "CLOSE";
71 case WEBSOCKET_OPCODE_PING:
72 return "PING";
73 case WEBSOCKET_OPCODE_PONG:
74 return "PONG";
75 default:
76 break;
77 }
78
79 return NULL;
80 }
81
websocket_context_ref(struct websocket_context * ctx)82 static int websocket_context_ref(struct websocket_context *ctx)
83 {
84 int old_rc = atomic_inc(&ctx->refcount);
85
86 return old_rc + 1;
87 }
88
websocket_context_unref(struct websocket_context * ctx)89 static int websocket_context_unref(struct websocket_context *ctx)
90 {
91 int old_rc = atomic_dec(&ctx->refcount);
92
93 if (old_rc != 1) {
94 return old_rc - 1;
95 }
96
97 return 0;
98 }
99
websocket_context_is_used(struct websocket_context * ctx)100 static inline bool websocket_context_is_used(struct websocket_context *ctx)
101 {
102 return !!atomic_get(&ctx->refcount);
103 }
104
websocket_get(void)105 static struct websocket_context *websocket_get(void)
106 {
107 struct websocket_context *ctx = NULL;
108 int i;
109
110 k_sem_take(&contexts_lock, K_FOREVER);
111
112 for (i = 0; i < ARRAY_SIZE(contexts); i++) {
113 if (websocket_context_is_used(&contexts[i])) {
114 continue;
115 }
116
117 websocket_context_ref(&contexts[i]);
118 ctx = &contexts[i];
119 break;
120 }
121
122 k_sem_give(&contexts_lock);
123
124 return ctx;
125 }
126
websocket_find(int real_sock)127 static struct websocket_context *websocket_find(int real_sock)
128 {
129 struct websocket_context *ctx = NULL;
130 int i;
131
132 k_sem_take(&contexts_lock, K_FOREVER);
133
134 for (i = 0; i < ARRAY_SIZE(contexts); i++) {
135 if (!websocket_context_is_used(&contexts[i])) {
136 continue;
137 }
138
139 if (contexts[i].real_sock != real_sock) {
140 continue;
141 }
142
143 ctx = &contexts[i];
144 break;
145 }
146
147 k_sem_give(&contexts_lock);
148
149 return ctx;
150 }
151
response_cb(struct http_response * rsp,enum http_final_call final_data,void * user_data)152 static int response_cb(struct http_response *rsp,
153 enum http_final_call final_data,
154 void *user_data)
155 {
156 struct websocket_context *ctx = user_data;
157
158 if (final_data == HTTP_DATA_MORE) {
159 NET_DBG("[%p] Partial data received (%zd bytes)", ctx,
160 rsp->data_len);
161 ctx->all_received = false;
162 } else if (final_data == HTTP_DATA_FINAL) {
163 NET_DBG("[%p] All the data received (%zd bytes)", ctx,
164 rsp->data_len);
165 ctx->all_received = true;
166 }
167
168 return 0;
169 }
170
on_header_field(struct http_parser * parser,const char * at,size_t length)171 static int on_header_field(struct http_parser *parser, const char *at,
172 size_t length)
173 {
174 struct http_request *req = CONTAINER_OF(parser,
175 struct http_request,
176 internal.parser);
177 struct websocket_context *ctx = req->internal.user_data;
178 const char *ws_accept_str = "Sec-WebSocket-Accept";
179 uint16_t len;
180
181 len = strlen(ws_accept_str);
182 if (length >= len && strncasecmp(at, ws_accept_str, len) == 0) {
183 ctx->sec_accept_present = true;
184 }
185
186 if (ctx->http_cb && ctx->http_cb->on_header_field) {
187 ctx->http_cb->on_header_field(parser, at, length);
188 }
189
190 return 0;
191 }
192
193 #define MAX_SEC_ACCEPT_LEN 32
194
on_header_value(struct http_parser * parser,const char * at,size_t length)195 static int on_header_value(struct http_parser *parser, const char *at,
196 size_t length)
197 {
198 struct http_request *req = CONTAINER_OF(parser,
199 struct http_request,
200 internal.parser);
201 struct websocket_context *ctx = req->internal.user_data;
202 char str[MAX_SEC_ACCEPT_LEN];
203
204 if (ctx->sec_accept_present) {
205 int ret;
206 size_t olen;
207
208 ctx->sec_accept_ok = false;
209 ctx->sec_accept_present = false;
210
211 ret = base64_encode(str, sizeof(str) - 1, &olen,
212 ctx->sec_accept_key,
213 WS_SHA1_OUTPUT_LEN);
214 if (ret == 0) {
215 if (strncmp(at, str, length)) {
216 NET_DBG("[%p] Security keys do not match "
217 "%s vs %s", ctx, str, at);
218 } else {
219 ctx->sec_accept_ok = true;
220 }
221 }
222 }
223
224 if (ctx->http_cb && ctx->http_cb->on_header_value) {
225 ctx->http_cb->on_header_value(parser, at, length);
226 }
227
228 return 0;
229 }
230
websocket_connect(int sock,struct websocket_request * wreq,int32_t timeout,void * user_data)231 int websocket_connect(int sock, struct websocket_request *wreq,
232 int32_t timeout, void *user_data)
233 {
234 /* This is the expected Sec-WebSocket-Accept key. We are storing a
235 * pointer to this in ctx but the value is only used for the duration
236 * of this function call so there is no issue even if this variable
237 * is allocated from stack.
238 */
239 uint8_t sec_accept_key[WS_SHA1_OUTPUT_LEN];
240 struct http_parser_settings http_parser_settings;
241 struct websocket_context *ctx;
242 struct http_request req;
243 int ret, fd, key_len;
244 size_t olen;
245 char key_accept[MAX_SEC_ACCEPT_LEN + sizeof(WS_MAGIC)];
246 uint32_t rnd_value = sys_rand32_get();
247 char sec_ws_key[] =
248 "Sec-WebSocket-Key: 0123456789012345678901==\r\n";
249 char *headers[] = {
250 sec_ws_key,
251 "Upgrade: websocket\r\n",
252 "Connection: Upgrade\r\n",
253 "Sec-WebSocket-Version: 13\r\n",
254 NULL
255 };
256
257 fd = -1;
258
259 if (sock < 0 || wreq == NULL || wreq->host == NULL ||
260 wreq->url == NULL) {
261 return -EINVAL;
262 }
263
264 ctx = websocket_find(sock);
265 if (ctx) {
266 NET_DBG("[%p] Websocket for sock %d already exists!", ctx,
267 sock);
268 return -EEXIST;
269 }
270
271 ctx = websocket_get();
272 if (!ctx) {
273 return -ENOENT;
274 }
275
276 ctx->real_sock = sock;
277 ctx->recv_buf.buf = wreq->tmp_buf;
278 ctx->recv_buf.size = wreq->tmp_buf_len;
279 ctx->sec_accept_key = sec_accept_key;
280 ctx->http_cb = wreq->http_cb;
281 ctx->is_client = 1;
282
283 mbedtls_sha1((const unsigned char *)&rnd_value, sizeof(rnd_value),
284 sec_accept_key);
285
286 ret = base64_encode(sec_ws_key + sizeof("Sec-Websocket-Key: ") - 1,
287 sizeof(sec_ws_key) -
288 sizeof("Sec-Websocket-Key: "),
289 &olen, sec_accept_key,
290 /* We are only interested in 16 first bytes so
291 * subtract 4 from the SHA-1 length
292 */
293 sizeof(sec_accept_key) - 4);
294 if (ret) {
295 NET_DBG("[%p] Cannot encode base64 (%d)", ctx, ret);
296 goto out;
297 }
298
299 if ((olen + sizeof("Sec-Websocket-Key: ") + 2) > sizeof(sec_ws_key)) {
300 NET_DBG("[%p] Too long message (%zd > %zd)", ctx,
301 olen + sizeof("Sec-Websocket-Key: ") + 2,
302 sizeof(sec_ws_key));
303 ret = -EMSGSIZE;
304 goto out;
305 }
306
307 memcpy(sec_ws_key + sizeof("Sec-Websocket-Key: ") - 1 + olen,
308 HTTP_CRLF, sizeof(HTTP_CRLF));
309
310 memset(&req, 0, sizeof(req));
311
312 req.method = HTTP_GET;
313 req.url = wreq->url;
314 req.host = wreq->host;
315 req.protocol = "HTTP/1.1";
316 req.header_fields = (const char **)headers;
317 req.optional_headers_cb = wreq->optional_headers_cb;
318 req.optional_headers = wreq->optional_headers;
319 req.response = response_cb;
320 req.http_cb = &http_parser_settings;
321 req.recv_buf = wreq->tmp_buf;
322 req.recv_buf_len = wreq->tmp_buf_len;
323
324 /* We need to catch the Sec-WebSocket-Accept field in order to verify
325 * that it contains the stuff that we sent in Sec-WebSocket-Key field
326 * so setup HTTP callbacks so that we will get the needed fields.
327 */
328 if (ctx->http_cb) {
329 memcpy(&http_parser_settings, ctx->http_cb,
330 sizeof(http_parser_settings));
331 } else {
332 memset(&http_parser_settings, 0, sizeof(http_parser_settings));
333 }
334
335 http_parser_settings.on_header_field = on_header_field;
336 http_parser_settings.on_header_value = on_header_value;
337
338 /* Pre-calculate the expected Sec-Websocket-Accept field */
339 key_len = MIN(sizeof(key_accept) - 1, olen);
340 strncpy(key_accept, sec_ws_key + sizeof("Sec-Websocket-Key: ") - 1,
341 key_len);
342
343 olen = MIN(sizeof(key_accept) - 1 - key_len, sizeof(WS_MAGIC) - 1);
344 strncpy(key_accept + key_len, WS_MAGIC, olen);
345
346 /* This SHA-1 value is then checked when we receive the response */
347 mbedtls_sha1(key_accept, olen + key_len, sec_accept_key);
348
349 ret = http_client_req(sock, &req, timeout, ctx);
350 if (ret < 0) {
351 NET_DBG("[%p] Cannot connect to Websocket host %s", ctx,
352 wreq->host);
353 ret = -ECONNABORTED;
354 goto out;
355 }
356
357 if (!(ctx->all_received && ctx->sec_accept_ok)) {
358 NET_DBG("[%p] WS handshake failed (%d/%d)", ctx,
359 ctx->all_received, ctx->sec_accept_ok);
360 ret = -ECONNABORTED;
361 goto out;
362 }
363
364 ctx->user_data = user_data;
365
366 fd = zvfs_reserve_fd();
367 if (fd < 0) {
368 ret = -ENOSPC;
369 goto out;
370 }
371
372 ctx->sock = fd;
373 zvfs_finalize_typed_fd(fd, ctx, (const struct fd_op_vtable *)&websocket_fd_op_vtable,
374 ZVFS_MODE_IFSOCK);
375
376 /* Call the user specified callback and if it accepts the connection
377 * then continue.
378 */
379 if (wreq->cb) {
380 ret = wreq->cb(fd, &req, user_data);
381 if (ret < 0) {
382 NET_DBG("[%p] Connection aborted (%d)", ctx, ret);
383 goto out;
384 }
385 }
386
387 NET_DBG("[%p] WS connection to peer established (fd %d)", ctx, fd);
388
389 /* We will re-use the temp buffer in receive function. If there were
390 * any leftover data from HTTP headers processing, we need to reflect
391 * this in the count variable.
392 */
393 ctx->recv_buf.count = req.data_len;
394
395 /* Init parser FSM */
396 ctx->parser_state = WEBSOCKET_PARSER_STATE_OPCODE;
397
398 (void)sock_obj_core_alloc_find(ctx->real_sock, fd, SOCK_STREAM);
399
400 return fd;
401
402 out:
403 if (fd >= 0) {
404 (void)zsock_close(fd);
405 }
406
407 websocket_context_unref(ctx);
408 return ret;
409 }
410
websocket_disconnect(int ws_sock)411 int websocket_disconnect(int ws_sock)
412 {
413 return zsock_close(ws_sock);
414 }
415
websocket_interal_disconnect(struct websocket_context * ctx)416 static int websocket_interal_disconnect(struct websocket_context *ctx)
417 {
418 int ret;
419
420 if (ctx == NULL) {
421 return -ENOENT;
422 }
423
424 NET_DBG("[%p] Disconnecting", ctx);
425
426 ret = websocket_send_msg(ctx->sock, NULL, 0, WEBSOCKET_OPCODE_CLOSE,
427 ctx->is_client, true, SYS_FOREVER_MS);
428 if (ret < 0) {
429 NET_DBG("[%p] Failed to send close message (err %d).", ctx, ret);
430 }
431
432 (void)sock_obj_core_dealloc(ctx->sock);
433
434 websocket_context_unref(ctx);
435
436 return ret;
437 }
438
websocket_close_vmeth(void * obj)439 static int websocket_close_vmeth(void *obj)
440 {
441 struct websocket_context *ctx = obj;
442 int ret;
443
444 ret = websocket_interal_disconnect(ctx);
445 if (ret < 0) {
446 /* Ignore error if we are not connected */
447 if (ret != -ENOTCONN) {
448 NET_DBG("[%p] Cannot close (%d)", obj, ret);
449
450 errno = -ret;
451 return -1;
452 }
453
454 ret = 0;
455 }
456
457 return ret;
458 }
459
websocket_poll_offload(struct zsock_pollfd * fds,int nfds,int timeout)460 static inline int websocket_poll_offload(struct zsock_pollfd *fds, int nfds,
461 int timeout)
462 {
463 int fd_backup[CONFIG_ZVFS_POLL_MAX];
464 const struct fd_op_vtable *vtable;
465 void *ctx;
466 int ret = 0;
467 int i;
468
469 /* Overwrite websocket file descriptors with underlying ones. */
470 for (i = 0; i < nfds; i++) {
471 fd_backup[i] = fds[i].fd;
472
473 ctx = zvfs_get_fd_obj(fds[i].fd,
474 (const struct fd_op_vtable *)
475 &websocket_fd_op_vtable,
476 0);
477 if (ctx == NULL) {
478 continue;
479 }
480
481 fds[i].fd = ((struct websocket_context *)ctx)->real_sock;
482 }
483
484 /* Get offloaded sockets vtable. */
485 ctx = zvfs_get_fd_obj_and_vtable(fds[0].fd,
486 (const struct fd_op_vtable **)&vtable,
487 NULL);
488 if (ctx == NULL) {
489 errno = EINVAL;
490 ret = -1;
491 goto exit;
492 }
493
494 ret = zvfs_fdtable_call_ioctl(vtable, ctx, ZFD_IOCTL_POLL_OFFLOAD,
495 fds, nfds, timeout);
496
497 exit:
498 /* Restore original fds. */
499 for (i = 0; i < nfds; i++) {
500 fds[i].fd = fd_backup[i];
501 }
502
503 return ret;
504 }
505
websocket_ioctl_vmeth(void * obj,unsigned int request,va_list args)506 static int websocket_ioctl_vmeth(void *obj, unsigned int request, va_list args)
507 {
508 struct websocket_context *ctx = obj;
509
510 switch (request) {
511 case ZFD_IOCTL_POLL_OFFLOAD: {
512 struct zsock_pollfd *fds;
513 int nfds;
514 int timeout;
515
516 fds = va_arg(args, struct zsock_pollfd *);
517 nfds = va_arg(args, int);
518 timeout = va_arg(args, int);
519
520 return websocket_poll_offload(fds, nfds, timeout);
521 }
522
523 case ZFD_IOCTL_SET_LOCK:
524 /* Ignore, don't want to overwrite underlying socket lock. */
525 return 0;
526
527 default: {
528 const struct fd_op_vtable *vtable;
529 void *core_obj;
530
531 core_obj = zvfs_get_fd_obj_and_vtable(
532 ctx->real_sock,
533 (const struct fd_op_vtable **)&vtable,
534 NULL);
535 if (core_obj == NULL) {
536 errno = EBADF;
537 return -1;
538 }
539
540 /* Pass the call to the core socket implementation. */
541 return vtable->ioctl(core_obj, request, args);
542 }
543 }
544
545 return 0;
546 }
547
548 #if !defined(CONFIG_NET_TEST)
sendmsg_all(int sock,const struct msghdr * message,int flags,const k_timepoint_t req_end_timepoint)549 static int sendmsg_all(int sock, const struct msghdr *message, int flags,
550 const k_timepoint_t req_end_timepoint)
551 {
552 int ret, i;
553 size_t offset = 0;
554 size_t total_len = 0;
555
556 for (i = 0; i < message->msg_iovlen; i++) {
557 total_len += message->msg_iov[i].iov_len;
558 }
559
560 while (offset < total_len) {
561 ret = zsock_sendmsg(sock, message, flags);
562
563 if ((ret == 0) || (ret < 0 && errno == EAGAIN)) {
564 struct zsock_pollfd pfd;
565 int pollres;
566 k_ticks_t req_timeout_ticks =
567 sys_timepoint_timeout(req_end_timepoint).ticks;
568 int req_timeout_ms = k_ticks_to_ms_floor32(req_timeout_ticks);
569
570 pfd.fd = sock;
571 pfd.events = ZSOCK_POLLOUT;
572 pollres = zsock_poll(&pfd, 1, req_timeout_ms);
573 if (pollres == 0) {
574 return -ETIMEDOUT;
575 } else if (pollres > 0) {
576 continue;
577 } else {
578 return -errno;
579 }
580 } else if (ret < 0) {
581 return -errno;
582 }
583
584 offset += ret;
585 if (offset >= total_len) {
586 break;
587 }
588
589 /* Update msghdr for the next iteration. */
590 for (i = 0; i < message->msg_iovlen; i++) {
591 if (ret < message->msg_iov[i].iov_len) {
592 message->msg_iov[i].iov_len -= ret;
593 message->msg_iov[i].iov_base =
594 (uint8_t *)message->msg_iov[i].iov_base + ret;
595 break;
596 }
597
598 ret -= message->msg_iov[i].iov_len;
599 message->msg_iov[i].iov_len = 0;
600 }
601 }
602
603 return total_len;
604 }
605 #endif /* !defined(CONFIG_NET_TEST) */
606
websocket_prepare_and_send(struct websocket_context * ctx,uint8_t * header,size_t header_len,uint8_t * payload,size_t payload_len,int32_t timeout)607 static int websocket_prepare_and_send(struct websocket_context *ctx,
608 uint8_t *header, size_t header_len,
609 uint8_t *payload, size_t payload_len,
610 int32_t timeout)
611 {
612 struct iovec io_vector[2];
613 struct msghdr msg;
614
615 io_vector[0].iov_base = header;
616 io_vector[0].iov_len = header_len;
617 io_vector[1].iov_base = payload;
618 io_vector[1].iov_len = payload_len;
619
620 memset(&msg, 0, sizeof(msg));
621
622 msg.msg_iov = io_vector;
623 msg.msg_iovlen = ARRAY_SIZE(io_vector);
624
625 if (HEXDUMP_SENT_PACKETS) {
626 LOG_HEXDUMP_DBG(header, header_len, "Header");
627 if ((payload != NULL) && (payload_len > 0)) {
628 LOG_HEXDUMP_DBG(payload, payload_len, "Payload");
629 } else {
630 LOG_DBG("No payload");
631 }
632 }
633
634 #if defined(CONFIG_NET_TEST)
635 /* Simulate a case where the payload is split to two. The unit test
636 * does not set mask bit in this case.
637 */
638 return verify_sent_and_received_msg(&msg, !(header[1] & BIT(7)));
639 #else
640 k_timeout_t tout = K_FOREVER;
641
642 if (timeout != SYS_FOREVER_MS) {
643 tout = K_MSEC(timeout);
644 }
645
646 k_timeout_t req_timeout = K_MSEC(timeout);
647 k_timepoint_t req_end_timepoint = sys_timepoint_calc(req_timeout);
648
649 return sendmsg_all(ctx->real_sock, &msg,
650 K_TIMEOUT_EQ(tout, K_NO_WAIT) ? ZSOCK_MSG_DONTWAIT : 0,
651 req_end_timepoint);
652 #endif /* CONFIG_NET_TEST */
653 }
654
websocket_send_msg(int ws_sock,const uint8_t * payload,size_t payload_len,enum websocket_opcode opcode,bool mask,bool final,int32_t timeout)655 int websocket_send_msg(int ws_sock, const uint8_t *payload, size_t payload_len,
656 enum websocket_opcode opcode, bool mask, bool final,
657 int32_t timeout)
658 {
659 struct websocket_context *ctx;
660 uint8_t header[MAX_HEADER_LEN], hdr_len = 2;
661 uint8_t *data_to_send = (uint8_t *)payload;
662 int ret;
663
664 if (opcode != WEBSOCKET_OPCODE_DATA_TEXT &&
665 opcode != WEBSOCKET_OPCODE_DATA_BINARY &&
666 opcode != WEBSOCKET_OPCODE_CONTINUE &&
667 opcode != WEBSOCKET_OPCODE_CLOSE &&
668 opcode != WEBSOCKET_OPCODE_PING &&
669 opcode != WEBSOCKET_OPCODE_PONG) {
670 return -EINVAL;
671 }
672
673 ctx = zvfs_get_fd_obj(ws_sock, NULL, 0);
674 if (ctx == NULL) {
675 return -EBADF;
676 }
677
678 #if !defined(CONFIG_NET_TEST)
679 /* Websocket unit test does not use context from pool but allocates
680 * its own, hence skip the check.
681 */
682
683 if (!PART_OF_ARRAY(contexts, ctx)) {
684 return -ENOENT;
685 }
686 #endif /* !defined(CONFIG_NET_TEST) */
687
688 NET_DBG("[%p] Len %zd %s/%d/%s", ctx, payload_len, opcode2str(opcode),
689 mask, final ? "final" : "more");
690
691 memset(header, 0, sizeof(header));
692
693 /* Is this the last packet? */
694 header[0] = final ? BIT(7) : 0;
695
696 /* Text, binary, ping, pong or close ? */
697 header[0] |= opcode;
698
699 /* Masking */
700 header[1] = mask ? BIT(7) : 0;
701
702 if (payload_len < 126) {
703 header[1] |= payload_len;
704 } else if (payload_len < 65536) {
705 header[1] |= 126;
706 header[2] = payload_len >> 8;
707 header[3] = payload_len;
708 hdr_len += 2;
709 } else {
710 header[1] |= 127;
711 header[2] = 0;
712 header[3] = 0;
713 header[4] = 0;
714 header[5] = 0;
715 header[6] = payload_len >> 24;
716 header[7] = payload_len >> 16;
717 header[8] = payload_len >> 8;
718 header[9] = payload_len;
719 hdr_len += 8;
720 }
721
722 /* Add masking value if needed */
723 if (mask) {
724 int i;
725
726 ctx->masking_value = sys_rand32_get();
727
728 header[hdr_len++] |= ctx->masking_value >> 24;
729 header[hdr_len++] |= ctx->masking_value >> 16;
730 header[hdr_len++] |= ctx->masking_value >> 8;
731 header[hdr_len++] |= ctx->masking_value;
732
733 if ((payload != NULL) && (payload_len > 0)) {
734 data_to_send = k_malloc(payload_len);
735 if (!data_to_send) {
736 return -ENOMEM;
737 }
738
739 memcpy(data_to_send, payload, payload_len);
740
741 for (i = 0; i < payload_len; i++) {
742 data_to_send[i] ^= ctx->masking_value >> (8 * (3 - i % 4));
743 }
744 }
745 }
746
747 ret = websocket_prepare_and_send(ctx, header, hdr_len,
748 data_to_send, payload_len, timeout);
749 if (ret < 0) {
750 NET_DBG("Cannot send ws msg (%d)", -errno);
751 goto quit;
752 }
753
754 quit:
755 if (data_to_send != payload) {
756 k_free(data_to_send);
757 }
758
759 /* Do no math with 0 and error codes */
760 if (ret <= 0) {
761 return ret;
762 }
763
764 return ret - hdr_len;
765 }
766
websocket_opcode2flag(uint8_t data)767 static uint32_t websocket_opcode2flag(uint8_t data)
768 {
769 switch (data & 0x0f) {
770 case WEBSOCKET_OPCODE_DATA_TEXT:
771 return WEBSOCKET_FLAG_TEXT;
772 case WEBSOCKET_OPCODE_DATA_BINARY:
773 return WEBSOCKET_FLAG_BINARY;
774 case WEBSOCKET_OPCODE_CLOSE:
775 return WEBSOCKET_FLAG_CLOSE;
776 case WEBSOCKET_OPCODE_PING:
777 return WEBSOCKET_FLAG_PING;
778 case WEBSOCKET_OPCODE_PONG:
779 return WEBSOCKET_FLAG_PONG;
780 default:
781 break;
782 }
783 return 0;
784 }
785
websocket_parse(struct websocket_context * ctx,struct websocket_buffer * payload)786 static int websocket_parse(struct websocket_context *ctx, struct websocket_buffer *payload)
787 {
788 int len;
789 uint8_t data;
790 size_t parsed_count = 0;
791
792 do {
793 if (parsed_count >= ctx->recv_buf.count) {
794 return parsed_count;
795 }
796 if (ctx->parser_state != WEBSOCKET_PARSER_STATE_PAYLOAD) {
797 data = ctx->recv_buf.buf[parsed_count++];
798
799 switch (ctx->parser_state) {
800 case WEBSOCKET_PARSER_STATE_OPCODE:
801 ctx->message_type = websocket_opcode2flag(data);
802 if ((data & 0x80) != 0) {
803 ctx->message_type |= WEBSOCKET_FLAG_FINAL;
804 }
805 ctx->parser_state = WEBSOCKET_PARSER_STATE_LENGTH;
806 break;
807 case WEBSOCKET_PARSER_STATE_LENGTH:
808 ctx->masked = (data & 0x80) != 0;
809 len = data & 0x7f;
810 if (len < 126) {
811 ctx->message_len = len;
812 if (ctx->masked) {
813 ctx->masking_value = 0;
814 ctx->parser_remaining = 4;
815 ctx->parser_state = WEBSOCKET_PARSER_STATE_MASK;
816 } else {
817 ctx->parser_remaining = ctx->message_len;
818 ctx->parser_state =
819 (ctx->parser_remaining == 0)
820 ? WEBSOCKET_PARSER_STATE_OPCODE
821 : WEBSOCKET_PARSER_STATE_PAYLOAD;
822 }
823 } else {
824 ctx->message_len = 0;
825 ctx->parser_remaining = (len < 127) ? 2 : 8;
826 ctx->parser_state = WEBSOCKET_PARSER_STATE_EXT_LEN;
827 }
828 break;
829 case WEBSOCKET_PARSER_STATE_EXT_LEN:
830 ctx->parser_remaining--;
831 ctx->message_len |= ((uint64_t)data << (ctx->parser_remaining * 8));
832 if (ctx->parser_remaining == 0) {
833 if (ctx->masked) {
834 ctx->masking_value = 0;
835 ctx->parser_remaining = 4;
836 ctx->parser_state = WEBSOCKET_PARSER_STATE_MASK;
837 } else {
838 ctx->parser_remaining = ctx->message_len;
839 ctx->parser_state = WEBSOCKET_PARSER_STATE_PAYLOAD;
840 }
841 }
842 break;
843 case WEBSOCKET_PARSER_STATE_MASK:
844 ctx->parser_remaining--;
845 ctx->masking_value |=
846 (uint32_t)((uint64_t)data << (ctx->parser_remaining * 8));
847 if (ctx->parser_remaining == 0) {
848 if (ctx->message_len == 0) {
849 ctx->parser_remaining = 0;
850 ctx->parser_state = WEBSOCKET_PARSER_STATE_OPCODE;
851 } else {
852 ctx->parser_remaining = ctx->message_len;
853 ctx->parser_state = WEBSOCKET_PARSER_STATE_PAYLOAD;
854 }
855 }
856 break;
857 default:
858 return -EFAULT;
859 }
860
861 #if (LOG_LEVEL >= LOG_LEVEL_DBG)
862 if ((ctx->parser_state == WEBSOCKET_PARSER_STATE_PAYLOAD) ||
863 ((ctx->parser_state == WEBSOCKET_PARSER_STATE_OPCODE) &&
864 (ctx->message_len == 0))) {
865 NET_DBG("[%p] %smasked, mask 0x%08x, type 0x%02x, msg %zd", ctx,
866 ctx->masked ? "" : "un",
867 ctx->masked ? ctx->masking_value : 0, ctx->message_type,
868 (size_t)ctx->message_len);
869 }
870 #endif
871 } else {
872 size_t remaining_in_recv_buf = ctx->recv_buf.count - parsed_count;
873 size_t payload_in_recv_buf =
874 MIN(remaining_in_recv_buf, ctx->parser_remaining);
875 size_t free_in_payload_buf = payload->size - payload->count;
876 size_t ready_to_copy = MIN(payload_in_recv_buf, free_in_payload_buf);
877
878 if (free_in_payload_buf == 0) {
879 break;
880 }
881
882 memcpy(&payload->buf[payload->count], &ctx->recv_buf.buf[parsed_count],
883 ready_to_copy);
884 parsed_count += ready_to_copy;
885 payload->count += ready_to_copy;
886 ctx->parser_remaining -= ready_to_copy;
887 if (ctx->parser_remaining == 0) {
888 ctx->parser_remaining = 0;
889 ctx->parser_state = WEBSOCKET_PARSER_STATE_OPCODE;
890 }
891 }
892
893 } while (ctx->parser_state != WEBSOCKET_PARSER_STATE_OPCODE);
894
895 return parsed_count;
896 }
897
898 #if !defined(CONFIG_NET_TEST)
wait_rx(int sock,int timeout)899 static int wait_rx(int sock, int timeout)
900 {
901 struct zsock_pollfd fds = {
902 .fd = sock,
903 .events = ZSOCK_POLLIN,
904 };
905 int ret;
906
907 ret = zsock_poll(&fds, 1, timeout);
908 if (ret < 0) {
909 return ret;
910 }
911
912 if (ret == 0) {
913 /* Timeout */
914 return -EAGAIN;
915 }
916
917 if (fds.revents & ZSOCK_POLLNVAL) {
918 return -EBADF;
919 }
920
921 if (fds.revents & ZSOCK_POLLERR) {
922 return -EIO;
923 }
924
925 return 0;
926 }
927
timeout_to_ms(k_timeout_t * timeout)928 static int timeout_to_ms(k_timeout_t *timeout)
929 {
930 if (K_TIMEOUT_EQ(*timeout, K_NO_WAIT)) {
931 return 0;
932 } else if (K_TIMEOUT_EQ(*timeout, K_FOREVER)) {
933 return SYS_FOREVER_MS;
934 } else {
935 return k_ticks_to_ms_floor32(timeout->ticks);
936 }
937 }
938
939 #endif /* !defined(CONFIG_NET_TEST) */
940
websocket_recv_msg(int ws_sock,uint8_t * buf,size_t buf_len,uint32_t * message_type,uint64_t * remaining,int32_t timeout)941 int websocket_recv_msg(int ws_sock, uint8_t *buf, size_t buf_len,
942 uint32_t *message_type, uint64_t *remaining, int32_t timeout)
943 {
944 struct websocket_context *ctx;
945 int ret;
946 k_timepoint_t end;
947 k_timeout_t tout = K_FOREVER;
948 struct websocket_buffer payload = {.buf = buf, .size = buf_len, .count = 0};
949
950 if (timeout != SYS_FOREVER_MS) {
951 tout = K_MSEC(timeout);
952 }
953
954 if ((buf == NULL) || (buf_len == 0)) {
955 return -EINVAL;
956 }
957
958 end = sys_timepoint_calc(tout);
959
960 #if defined(CONFIG_NET_TEST)
961 struct test_data *test_data = zvfs_get_fd_obj(ws_sock, NULL, 0);
962
963 if (test_data == NULL) {
964 return -EBADF;
965 }
966
967 ctx = test_data->ctx;
968 #else
969 ctx = zvfs_get_fd_obj(ws_sock, NULL, 0);
970 if (ctx == NULL) {
971 return -EBADF;
972 }
973
974 if (!PART_OF_ARRAY(contexts, ctx)) {
975 return -ENOENT;
976 }
977 #endif /* CONFIG_NET_TEST */
978
979 do {
980 size_t parsed_count;
981
982 if (ctx->recv_buf.count == 0) {
983 #if defined(CONFIG_NET_TEST)
984 size_t input_len = MIN(ctx->recv_buf.size,
985 test_data->input_len - test_data->input_pos);
986
987 if (input_len > 0) {
988 memcpy(ctx->recv_buf.buf,
989 &test_data->input_buf[test_data->input_pos], input_len);
990 test_data->input_pos += input_len;
991 ret = input_len;
992 } else {
993 /* emulate timeout */
994 ret = -EAGAIN;
995 }
996 #else
997 tout = sys_timepoint_timeout(end);
998
999 ret = wait_rx(ctx->real_sock, timeout_to_ms(&tout));
1000 if (ret == 0) {
1001 ret = zsock_recv(ctx->real_sock, ctx->recv_buf.buf,
1002 ctx->recv_buf.size, ZSOCK_MSG_DONTWAIT);
1003 if (ret < 0) {
1004 ret = -errno;
1005 }
1006 }
1007 #endif /* CONFIG_NET_TEST */
1008
1009 if (ret < 0) {
1010 if ((ret == -EAGAIN) && (payload.count > 0)) {
1011 /* go to unmasking */
1012 break;
1013 }
1014 return ret;
1015 }
1016
1017 if (ret == 0) {
1018 /* Socket closed */
1019 return -ENOTCONN;
1020 }
1021
1022 ctx->recv_buf.count = ret;
1023
1024 NET_DBG("[%p] Received %d bytes", ctx, ret);
1025 }
1026
1027 ret = websocket_parse(ctx, &payload);
1028 if (ret < 0) {
1029 return ret;
1030 }
1031 parsed_count = ret;
1032
1033 if ((ctx->parser_state == WEBSOCKET_PARSER_STATE_OPCODE) ||
1034 (payload.count >= payload.size)) {
1035 if (remaining != NULL) {
1036 *remaining = ctx->parser_remaining;
1037 }
1038 if (message_type != NULL) {
1039 *message_type = ctx->message_type;
1040 }
1041
1042 size_t left = ctx->recv_buf.count - parsed_count;
1043
1044 if (left > 0) {
1045 memmove(ctx->recv_buf.buf, &ctx->recv_buf.buf[parsed_count], left);
1046 }
1047 ctx->recv_buf.count = left;
1048 break;
1049 }
1050
1051 ctx->recv_buf.count -= parsed_count;
1052
1053 } while (true);
1054
1055 /* Unmask the data */
1056 if (ctx->masked) {
1057 uint8_t *mask_as_bytes = (uint8_t *)&ctx->masking_value;
1058 size_t data_buf_offset = ctx->message_len - ctx->parser_remaining - payload.count;
1059
1060 for (size_t i = 0; i < payload.count; i++) {
1061 size_t m = data_buf_offset % 4;
1062
1063 payload.buf[i] ^= mask_as_bytes[3 - m];
1064 data_buf_offset++;
1065 }
1066 }
1067
1068 return payload.count;
1069 }
1070
websocket_send(struct websocket_context * ctx,const uint8_t * buf,size_t buf_len,int32_t timeout)1071 static int websocket_send(struct websocket_context *ctx, const uint8_t *buf,
1072 size_t buf_len, int32_t timeout)
1073 {
1074 int ret;
1075
1076 NET_DBG("[%p] Sending %zd bytes", ctx, buf_len);
1077
1078 ret = websocket_send_msg(ctx->sock, buf, buf_len, WEBSOCKET_OPCODE_DATA_TEXT,
1079 ctx->is_client, true, timeout);
1080 if (ret < 0) {
1081 errno = -ret;
1082 return -1;
1083 }
1084
1085 NET_DBG("[%p] Sent %d bytes", ctx, ret);
1086
1087 sock_obj_core_update_send_stats(ctx->sock, ret);
1088
1089 return ret;
1090 }
1091
websocket_recv(struct websocket_context * ctx,uint8_t * buf,size_t buf_len,int32_t timeout)1092 static int websocket_recv(struct websocket_context *ctx, uint8_t *buf,
1093 size_t buf_len, int32_t timeout)
1094 {
1095 uint32_t message_type;
1096 uint64_t remaining;
1097 int ret;
1098
1099 NET_DBG("[%p] Waiting data, buf len %zd bytes", ctx, buf_len);
1100
1101 /* TODO: add support for recvmsg() so that we could return the
1102 * websocket specific information in ancillary data.
1103 */
1104 ret = websocket_recv_msg(ctx->sock, buf, buf_len, &message_type,
1105 &remaining, timeout);
1106 if (ret < 0) {
1107 if (ret == -ENOTCONN) {
1108 ret = 0;
1109 } else {
1110 errno = -ret;
1111 return -1;
1112 }
1113 }
1114
1115 NET_DBG("[%p] Received %d bytes", ctx, ret);
1116
1117 sock_obj_core_update_recv_stats(ctx->sock, ret);
1118
1119 return ret;
1120 }
1121
websocket_read_vmeth(void * obj,void * buffer,size_t count)1122 static ssize_t websocket_read_vmeth(void *obj, void *buffer, size_t count)
1123 {
1124 return (ssize_t)websocket_recv(obj, buffer, count, SYS_FOREVER_MS);
1125 }
1126
websocket_write_vmeth(void * obj,const void * buffer,size_t count)1127 static ssize_t websocket_write_vmeth(void *obj, const void *buffer,
1128 size_t count)
1129 {
1130 return (ssize_t)websocket_send(obj, buffer, count, SYS_FOREVER_MS);
1131 }
1132
websocket_sendto_ctx(void * obj,const void * buf,size_t len,int flags,const struct sockaddr * dest_addr,socklen_t addrlen)1133 static ssize_t websocket_sendto_ctx(void *obj, const void *buf, size_t len,
1134 int flags,
1135 const struct sockaddr *dest_addr,
1136 socklen_t addrlen)
1137 {
1138 struct websocket_context *ctx = obj;
1139 int32_t timeout = SYS_FOREVER_MS;
1140
1141 if (flags & ZSOCK_MSG_DONTWAIT) {
1142 timeout = 0;
1143 }
1144
1145 ARG_UNUSED(dest_addr);
1146 ARG_UNUSED(addrlen);
1147
1148 return (ssize_t)websocket_send(ctx, buf, len, timeout);
1149 }
1150
websocket_recvfrom_ctx(void * obj,void * buf,size_t max_len,int flags,struct sockaddr * src_addr,socklen_t * addrlen)1151 static ssize_t websocket_recvfrom_ctx(void *obj, void *buf, size_t max_len,
1152 int flags, struct sockaddr *src_addr,
1153 socklen_t *addrlen)
1154 {
1155 struct websocket_context *ctx = obj;
1156 int32_t timeout = SYS_FOREVER_MS;
1157
1158 if (flags & ZSOCK_MSG_DONTWAIT) {
1159 timeout = 0;
1160 }
1161
1162 ARG_UNUSED(src_addr);
1163 ARG_UNUSED(addrlen);
1164
1165 return (ssize_t)websocket_recv(ctx, buf, max_len, timeout);
1166 }
1167
websocket_register(int sock,uint8_t * recv_buf,size_t recv_buf_len)1168 int websocket_register(int sock, uint8_t *recv_buf, size_t recv_buf_len)
1169 {
1170 struct websocket_context *ctx;
1171 int ret, fd;
1172
1173 if (sock < 0) {
1174 return -EINVAL;
1175 }
1176
1177 ctx = websocket_find(sock);
1178 if (ctx) {
1179 NET_DBG("[%p] Websocket for sock %d already exists!", ctx, sock);
1180 return -EEXIST;
1181 }
1182
1183 ctx = websocket_get();
1184 if (!ctx) {
1185 return -ENOENT;
1186 }
1187
1188 ctx->real_sock = sock;
1189 ctx->recv_buf.buf = recv_buf;
1190 ctx->recv_buf.size = recv_buf_len;
1191 ctx->is_client = 0;
1192
1193 fd = zvfs_reserve_fd();
1194 if (fd < 0) {
1195 ret = -ENOSPC;
1196 goto out;
1197 }
1198
1199 ctx->sock = fd;
1200 zvfs_finalize_typed_fd(fd, ctx, (const struct fd_op_vtable *)&websocket_fd_op_vtable,
1201 ZVFS_MODE_IFSOCK);
1202
1203 NET_DBG("[%p] WS connection to peer established (fd %d)", ctx, fd);
1204
1205 ctx->recv_buf.count = 0;
1206 ctx->parser_state = WEBSOCKET_PARSER_STATE_OPCODE;
1207
1208 (void)sock_obj_core_alloc_find(ctx->real_sock, fd, SOCK_STREAM);
1209
1210 return fd;
1211
1212 out:
1213 websocket_context_unref(ctx);
1214
1215 return ret;
1216 }
1217
websocket_search(int sock)1218 static struct websocket_context *websocket_search(int sock)
1219 {
1220 struct websocket_context *ctx = NULL;
1221 int i;
1222
1223 k_sem_take(&contexts_lock, K_FOREVER);
1224
1225 for (i = 0; i < ARRAY_SIZE(contexts); i++) {
1226 if (!websocket_context_is_used(&contexts[i])) {
1227 continue;
1228 }
1229
1230 if (contexts[i].sock != sock) {
1231 continue;
1232 }
1233
1234 ctx = &contexts[i];
1235 break;
1236 }
1237
1238 k_sem_give(&contexts_lock);
1239
1240 return ctx;
1241 }
1242
websocket_unregister(int sock)1243 int websocket_unregister(int sock)
1244 {
1245 struct websocket_context *ctx;
1246
1247 if (sock < 0) {
1248 return -EINVAL;
1249 }
1250
1251 ctx = websocket_search(sock);
1252 if (ctx == NULL) {
1253 NET_DBG("[%p] Real socket for websocket sock %d not found!", ctx, sock);
1254 return -ENOENT;
1255 }
1256
1257 if (ctx->real_sock < 0) {
1258 return -EALREADY;
1259 }
1260
1261 (void)zsock_close(sock);
1262 (void)zsock_close(ctx->real_sock);
1263
1264 ctx->real_sock = -1;
1265 ctx->sock = -1;
1266
1267 return 0;
1268 }
1269
1270 static const struct socket_op_vtable websocket_fd_op_vtable = {
1271 .fd_vtable = {
1272 .read = websocket_read_vmeth,
1273 .write = websocket_write_vmeth,
1274 .close = websocket_close_vmeth,
1275 .ioctl = websocket_ioctl_vmeth,
1276 },
1277 .sendto = websocket_sendto_ctx,
1278 .recvfrom = websocket_recvfrom_ctx,
1279 };
1280
websocket_context_foreach(websocket_context_cb_t cb,void * user_data)1281 void websocket_context_foreach(websocket_context_cb_t cb, void *user_data)
1282 {
1283 int i;
1284
1285 k_sem_take(&contexts_lock, K_FOREVER);
1286
1287 for (i = 0; i < ARRAY_SIZE(contexts); i++) {
1288 if (!websocket_context_is_used(&contexts[i])) {
1289 continue;
1290 }
1291
1292 k_mutex_lock(&contexts[i].lock, K_FOREVER);
1293
1294 cb(&contexts[i], user_data);
1295
1296 k_mutex_unlock(&contexts[i].lock);
1297 }
1298
1299 k_sem_give(&contexts_lock);
1300 }
1301
websocket_init(void)1302 void websocket_init(void)
1303 {
1304 k_sem_init(&contexts_lock, 1, K_SEM_MAX_LIMIT);
1305 }
1306