1 /** @file
2  * @brief Websocket client API
3  *
4  * An API for applications to setup a websocket connections.
5  */
6 
7 /*
8  * Copyright (c) 2019 Intel Corporation
9  *
10  * SPDX-License-Identifier: Apache-2.0
11  */
12 
13 #include <zephyr/logging/log.h>
14 LOG_MODULE_REGISTER(net_websocket, CONFIG_NET_WEBSOCKET_LOG_LEVEL);
15 
16 #include <zephyr/kernel.h>
17 #include <strings.h>
18 #include <errno.h>
19 #include <stdbool.h>
20 #include <stdlib.h>
21 
22 #include <zephyr/sys/fdtable.h>
23 #include <zephyr/net/net_core.h>
24 #include <zephyr/net/net_ip.h>
25 #if defined(CONFIG_POSIX_API)
26 #include <zephyr/posix/unistd.h>
27 #include <zephyr/posix/sys/socket.h>
28 #else
29 #include <zephyr/net/socket.h>
30 #endif
31 #include <zephyr/net/http/client.h>
32 #include <zephyr/net/websocket.h>
33 
34 #include <zephyr/random/random.h>
35 #include <zephyr/sys/byteorder.h>
36 #include <zephyr/sys/base64.h>
37 #include <mbedtls/sha1.h>
38 
39 #include "net_private.h"
40 #include "sockets_internal.h"
41 #include "websocket_internal.h"
42 
43 /* If you want to see the data that is being sent or received,
44  * then you can enable debugging and set the following variables to 1.
45  * This will print a lot of data so is not enabled by default.
46  */
47 #define HEXDUMP_SENT_PACKETS 0
48 #define HEXDUMP_RECV_PACKETS 0
49 
50 static struct websocket_context contexts[CONFIG_WEBSOCKET_MAX_CONTEXTS];
51 
52 static struct k_sem contexts_lock;
53 
54 static const struct socket_op_vtable websocket_fd_op_vtable;
55 
56 #if defined(CONFIG_NET_TEST)
57 int verify_sent_and_received_msg(struct msghdr *msg, bool split_msg);
58 #endif
59 
opcode2str(enum websocket_opcode opcode)60 static const char *opcode2str(enum websocket_opcode opcode)
61 {
62 	switch (opcode) {
63 	case WEBSOCKET_OPCODE_DATA_TEXT:
64 		return "TEXT";
65 	case WEBSOCKET_OPCODE_DATA_BINARY:
66 		return "BIN";
67 	case WEBSOCKET_OPCODE_CONTINUE:
68 		return "CONT";
69 	case WEBSOCKET_OPCODE_CLOSE:
70 		return "CLOSE";
71 	case WEBSOCKET_OPCODE_PING:
72 		return "PING";
73 	case WEBSOCKET_OPCODE_PONG:
74 		return "PONG";
75 	default:
76 		break;
77 	}
78 
79 	return NULL;
80 }
81 
websocket_context_ref(struct websocket_context * ctx)82 static int websocket_context_ref(struct websocket_context *ctx)
83 {
84 	int old_rc = atomic_inc(&ctx->refcount);
85 
86 	return old_rc + 1;
87 }
88 
websocket_context_unref(struct websocket_context * ctx)89 static int websocket_context_unref(struct websocket_context *ctx)
90 {
91 	int old_rc = atomic_dec(&ctx->refcount);
92 
93 	if (old_rc != 1) {
94 		return old_rc - 1;
95 	}
96 
97 	return 0;
98 }
99 
websocket_context_is_used(struct websocket_context * ctx)100 static inline bool websocket_context_is_used(struct websocket_context *ctx)
101 {
102 	return !!atomic_get(&ctx->refcount);
103 }
104 
websocket_get(void)105 static struct websocket_context *websocket_get(void)
106 {
107 	struct websocket_context *ctx = NULL;
108 	int i;
109 
110 	k_sem_take(&contexts_lock, K_FOREVER);
111 
112 	for (i = 0; i < ARRAY_SIZE(contexts); i++) {
113 		if (websocket_context_is_used(&contexts[i])) {
114 			continue;
115 		}
116 
117 		websocket_context_ref(&contexts[i]);
118 		ctx = &contexts[i];
119 		break;
120 	}
121 
122 	k_sem_give(&contexts_lock);
123 
124 	return ctx;
125 }
126 
websocket_find(int real_sock)127 static struct websocket_context *websocket_find(int real_sock)
128 {
129 	struct websocket_context *ctx = NULL;
130 	int i;
131 
132 	k_sem_take(&contexts_lock, K_FOREVER);
133 
134 	for (i = 0; i < ARRAY_SIZE(contexts); i++) {
135 		if (!websocket_context_is_used(&contexts[i])) {
136 			continue;
137 		}
138 
139 		if (contexts[i].real_sock != real_sock) {
140 			continue;
141 		}
142 
143 		ctx = &contexts[i];
144 		break;
145 	}
146 
147 	k_sem_give(&contexts_lock);
148 
149 	return ctx;
150 }
151 
response_cb(struct http_response * rsp,enum http_final_call final_data,void * user_data)152 static int response_cb(struct http_response *rsp,
153 		       enum http_final_call final_data,
154 		       void *user_data)
155 {
156 	struct websocket_context *ctx = user_data;
157 
158 	if (final_data == HTTP_DATA_MORE) {
159 		NET_DBG("[%p] Partial data received (%zd bytes)", ctx,
160 			rsp->data_len);
161 		ctx->all_received = false;
162 	} else if (final_data == HTTP_DATA_FINAL) {
163 		NET_DBG("[%p] All the data received (%zd bytes)", ctx,
164 			rsp->data_len);
165 		ctx->all_received = true;
166 	}
167 
168 	return 0;
169 }
170 
on_header_field(struct http_parser * parser,const char * at,size_t length)171 static int on_header_field(struct http_parser *parser, const char *at,
172 			   size_t length)
173 {
174 	struct http_request *req = CONTAINER_OF(parser,
175 						struct http_request,
176 						internal.parser);
177 	struct websocket_context *ctx = req->internal.user_data;
178 	const char *ws_accept_str = "Sec-WebSocket-Accept";
179 	uint16_t len;
180 
181 	len = strlen(ws_accept_str);
182 	if (length >= len && strncasecmp(at, ws_accept_str, len) == 0) {
183 		ctx->sec_accept_present = true;
184 	}
185 
186 	if (ctx->http_cb && ctx->http_cb->on_header_field) {
187 		ctx->http_cb->on_header_field(parser, at, length);
188 	}
189 
190 	return 0;
191 }
192 
193 #define MAX_SEC_ACCEPT_LEN 32
194 
on_header_value(struct http_parser * parser,const char * at,size_t length)195 static int on_header_value(struct http_parser *parser, const char *at,
196 			   size_t length)
197 {
198 	struct http_request *req = CONTAINER_OF(parser,
199 						struct http_request,
200 						internal.parser);
201 	struct websocket_context *ctx = req->internal.user_data;
202 	char str[MAX_SEC_ACCEPT_LEN];
203 
204 	if (ctx->sec_accept_present) {
205 		int ret;
206 		size_t olen;
207 
208 		ctx->sec_accept_ok = false;
209 		ctx->sec_accept_present = false;
210 
211 		ret = base64_encode(str, sizeof(str) - 1, &olen,
212 				    ctx->sec_accept_key,
213 				    WS_SHA1_OUTPUT_LEN);
214 		if (ret == 0) {
215 			if (strncmp(at, str, length)) {
216 				NET_DBG("[%p] Security keys do not match "
217 					"%s vs %s", ctx, str, at);
218 			} else {
219 				ctx->sec_accept_ok = true;
220 			}
221 		}
222 	}
223 
224 	if (ctx->http_cb && ctx->http_cb->on_header_value) {
225 		ctx->http_cb->on_header_value(parser, at, length);
226 	}
227 
228 	return 0;
229 }
230 
websocket_connect(int sock,struct websocket_request * wreq,int32_t timeout,void * user_data)231 int websocket_connect(int sock, struct websocket_request *wreq,
232 		      int32_t timeout, void *user_data)
233 {
234 	/* This is the expected Sec-WebSocket-Accept key. We are storing a
235 	 * pointer to this in ctx but the value is only used for the duration
236 	 * of this function call so there is no issue even if this variable
237 	 * is allocated from stack.
238 	 */
239 	uint8_t sec_accept_key[WS_SHA1_OUTPUT_LEN];
240 	struct http_parser_settings http_parser_settings;
241 	struct websocket_context *ctx;
242 	struct http_request req;
243 	int ret, fd, key_len;
244 	size_t olen;
245 	char key_accept[MAX_SEC_ACCEPT_LEN + sizeof(WS_MAGIC)];
246 	uint32_t rnd_value = sys_rand32_get();
247 	char sec_ws_key[] =
248 		"Sec-WebSocket-Key: 0123456789012345678901==\r\n";
249 	char *headers[] = {
250 		sec_ws_key,
251 		"Upgrade: websocket\r\n",
252 		"Connection: Upgrade\r\n",
253 		"Sec-WebSocket-Version: 13\r\n",
254 		NULL
255 	};
256 
257 	fd = -1;
258 
259 	if (sock < 0 || wreq == NULL || wreq->host == NULL ||
260 	    wreq->url == NULL) {
261 		return -EINVAL;
262 	}
263 
264 	ctx = websocket_find(sock);
265 	if (ctx) {
266 		NET_DBG("[%p] Websocket for sock %d already exists!", ctx,
267 			sock);
268 		return -EEXIST;
269 	}
270 
271 	ctx = websocket_get();
272 	if (!ctx) {
273 		return -ENOENT;
274 	}
275 
276 	ctx->real_sock = sock;
277 	ctx->recv_buf.buf = wreq->tmp_buf;
278 	ctx->recv_buf.size = wreq->tmp_buf_len;
279 	ctx->sec_accept_key = sec_accept_key;
280 	ctx->http_cb = wreq->http_cb;
281 	ctx->is_client = 1;
282 
283 	mbedtls_sha1((const unsigned char *)&rnd_value, sizeof(rnd_value),
284 			 sec_accept_key);
285 
286 	ret = base64_encode(sec_ws_key + sizeof("Sec-Websocket-Key: ") - 1,
287 			    sizeof(sec_ws_key) -
288 					sizeof("Sec-Websocket-Key: "),
289 			    &olen, sec_accept_key,
290 			    /* We are only interested in 16 first bytes so
291 			     * subtract 4 from the SHA-1 length
292 			     */
293 			    sizeof(sec_accept_key) - 4);
294 	if (ret) {
295 		NET_DBG("[%p] Cannot encode base64 (%d)", ctx, ret);
296 		goto out;
297 	}
298 
299 	if ((olen + sizeof("Sec-Websocket-Key: ") + 2) > sizeof(sec_ws_key)) {
300 		NET_DBG("[%p] Too long message (%zd > %zd)", ctx,
301 			olen + sizeof("Sec-Websocket-Key: ") + 2,
302 			sizeof(sec_ws_key));
303 		ret = -EMSGSIZE;
304 		goto out;
305 	}
306 
307 	memcpy(sec_ws_key + sizeof("Sec-Websocket-Key: ") - 1 + olen,
308 	       HTTP_CRLF, sizeof(HTTP_CRLF));
309 
310 	memset(&req, 0, sizeof(req));
311 
312 	req.method = HTTP_GET;
313 	req.url = wreq->url;
314 	req.host = wreq->host;
315 	req.protocol = "HTTP/1.1";
316 	req.header_fields = (const char **)headers;
317 	req.optional_headers_cb = wreq->optional_headers_cb;
318 	req.optional_headers = wreq->optional_headers;
319 	req.response = response_cb;
320 	req.http_cb = &http_parser_settings;
321 	req.recv_buf = wreq->tmp_buf;
322 	req.recv_buf_len = wreq->tmp_buf_len;
323 
324 	/* We need to catch the Sec-WebSocket-Accept field in order to verify
325 	 * that it contains the stuff that we sent in Sec-WebSocket-Key field
326 	 * so setup HTTP callbacks so that we will get the needed fields.
327 	 */
328 	if (ctx->http_cb) {
329 		memcpy(&http_parser_settings, ctx->http_cb,
330 		       sizeof(http_parser_settings));
331 	} else {
332 		memset(&http_parser_settings, 0, sizeof(http_parser_settings));
333 	}
334 
335 	http_parser_settings.on_header_field = on_header_field;
336 	http_parser_settings.on_header_value = on_header_value;
337 
338 	/* Pre-calculate the expected Sec-Websocket-Accept field */
339 	key_len = MIN(sizeof(key_accept) - 1, olen);
340 	strncpy(key_accept, sec_ws_key + sizeof("Sec-Websocket-Key: ") - 1,
341 		key_len);
342 
343 	olen = MIN(sizeof(key_accept) - 1 - key_len, sizeof(WS_MAGIC) - 1);
344 	strncpy(key_accept + key_len, WS_MAGIC, olen);
345 
346 	/* This SHA-1 value is then checked when we receive the response */
347 	mbedtls_sha1(key_accept, olen + key_len, sec_accept_key);
348 
349 	ret = http_client_req(sock, &req, timeout, ctx);
350 	if (ret < 0) {
351 		NET_DBG("[%p] Cannot connect to Websocket host %s", ctx,
352 			wreq->host);
353 		ret = -ECONNABORTED;
354 		goto out;
355 	}
356 
357 	if (!(ctx->all_received && ctx->sec_accept_ok)) {
358 		NET_DBG("[%p] WS handshake failed (%d/%d)", ctx,
359 			ctx->all_received, ctx->sec_accept_ok);
360 		ret = -ECONNABORTED;
361 		goto out;
362 	}
363 
364 	ctx->user_data = user_data;
365 
366 	fd = zvfs_reserve_fd();
367 	if (fd < 0) {
368 		ret = -ENOSPC;
369 		goto out;
370 	}
371 
372 	ctx->sock = fd;
373 	zvfs_finalize_typed_fd(fd, ctx, (const struct fd_op_vtable *)&websocket_fd_op_vtable,
374 			    ZVFS_MODE_IFSOCK);
375 
376 	/* Call the user specified callback and if it accepts the connection
377 	 * then continue.
378 	 */
379 	if (wreq->cb) {
380 		ret = wreq->cb(fd, &req, user_data);
381 		if (ret < 0) {
382 			NET_DBG("[%p] Connection aborted (%d)", ctx, ret);
383 			goto out;
384 		}
385 	}
386 
387 	NET_DBG("[%p] WS connection to peer established (fd %d)", ctx, fd);
388 
389 	/* We will re-use the temp buffer in receive function. If there were
390 	 * any leftover data from HTTP headers processing, we need to reflect
391 	 * this in the count variable.
392 	 */
393 	ctx->recv_buf.count = req.data_len;
394 
395 	/* Init parser FSM */
396 	ctx->parser_state = WEBSOCKET_PARSER_STATE_OPCODE;
397 
398 	(void)sock_obj_core_alloc_find(ctx->real_sock, fd, SOCK_STREAM);
399 
400 	return fd;
401 
402 out:
403 	if (fd >= 0) {
404 		(void)zsock_close(fd);
405 	}
406 
407 	websocket_context_unref(ctx);
408 	return ret;
409 }
410 
websocket_disconnect(int ws_sock)411 int websocket_disconnect(int ws_sock)
412 {
413 	return zsock_close(ws_sock);
414 }
415 
websocket_interal_disconnect(struct websocket_context * ctx)416 static int websocket_interal_disconnect(struct websocket_context *ctx)
417 {
418 	int ret;
419 
420 	if (ctx == NULL) {
421 		return -ENOENT;
422 	}
423 
424 	NET_DBG("[%p] Disconnecting", ctx);
425 
426 	ret = websocket_send_msg(ctx->sock, NULL, 0, WEBSOCKET_OPCODE_CLOSE,
427 				 ctx->is_client, true, SYS_FOREVER_MS);
428 	if (ret < 0) {
429 		NET_DBG("[%p] Failed to send close message (err %d).", ctx, ret);
430 	}
431 
432 	(void)sock_obj_core_dealloc(ctx->sock);
433 
434 	websocket_context_unref(ctx);
435 
436 	return ret;
437 }
438 
websocket_close_vmeth(void * obj)439 static int websocket_close_vmeth(void *obj)
440 {
441 	struct websocket_context *ctx = obj;
442 	int ret;
443 
444 	ret = websocket_interal_disconnect(ctx);
445 	if (ret < 0) {
446 		/* Ignore error if we are not connected */
447 		if (ret != -ENOTCONN) {
448 			NET_DBG("[%p] Cannot close (%d)", obj, ret);
449 
450 			errno = -ret;
451 			return -1;
452 		}
453 
454 		ret = 0;
455 	}
456 
457 	return ret;
458 }
459 
websocket_poll_offload(struct zsock_pollfd * fds,int nfds,int timeout)460 static inline int websocket_poll_offload(struct zsock_pollfd *fds, int nfds,
461 					 int timeout)
462 {
463 	int fd_backup[CONFIG_ZVFS_POLL_MAX];
464 	const struct fd_op_vtable *vtable;
465 	void *ctx;
466 	int ret = 0;
467 	int i;
468 
469 	/* Overwrite websocket file descriptors with underlying ones. */
470 	for (i = 0; i < nfds; i++) {
471 		fd_backup[i] = fds[i].fd;
472 
473 		ctx = zvfs_get_fd_obj(fds[i].fd,
474 				   (const struct fd_op_vtable *)
475 						     &websocket_fd_op_vtable,
476 				   0);
477 		if (ctx == NULL) {
478 			continue;
479 		}
480 
481 		fds[i].fd = ((struct websocket_context *)ctx)->real_sock;
482 	}
483 
484 	/* Get offloaded sockets vtable. */
485 	ctx = zvfs_get_fd_obj_and_vtable(fds[0].fd,
486 				      (const struct fd_op_vtable **)&vtable,
487 				      NULL);
488 	if (ctx == NULL) {
489 		errno = EINVAL;
490 		ret = -1;
491 		goto exit;
492 	}
493 
494 	ret = zvfs_fdtable_call_ioctl(vtable, ctx, ZFD_IOCTL_POLL_OFFLOAD,
495 				   fds, nfds, timeout);
496 
497 exit:
498 	/* Restore original fds. */
499 	for (i = 0; i < nfds; i++) {
500 		fds[i].fd = fd_backup[i];
501 	}
502 
503 	return ret;
504 }
505 
websocket_ioctl_vmeth(void * obj,unsigned int request,va_list args)506 static int websocket_ioctl_vmeth(void *obj, unsigned int request, va_list args)
507 {
508 	struct websocket_context *ctx = obj;
509 
510 	switch (request) {
511 	case ZFD_IOCTL_POLL_OFFLOAD: {
512 		struct zsock_pollfd *fds;
513 		int nfds;
514 		int timeout;
515 
516 		fds = va_arg(args, struct zsock_pollfd *);
517 		nfds = va_arg(args, int);
518 		timeout = va_arg(args, int);
519 
520 		return websocket_poll_offload(fds, nfds, timeout);
521 	}
522 
523 	case ZFD_IOCTL_SET_LOCK:
524 		/* Ignore, don't want to overwrite underlying socket lock. */
525 		return 0;
526 
527 	default: {
528 		const struct fd_op_vtable *vtable;
529 		void *core_obj;
530 
531 		core_obj = zvfs_get_fd_obj_and_vtable(
532 				ctx->real_sock,
533 				(const struct fd_op_vtable **)&vtable,
534 				NULL);
535 		if (core_obj == NULL) {
536 			errno = EBADF;
537 			return -1;
538 		}
539 
540 		/* Pass the call to the core socket implementation. */
541 		return vtable->ioctl(core_obj, request, args);
542 	}
543 	}
544 
545 	return 0;
546 }
547 
548 #if !defined(CONFIG_NET_TEST)
sendmsg_all(int sock,const struct msghdr * message,int flags,const k_timepoint_t req_end_timepoint)549 static int sendmsg_all(int sock, const struct msghdr *message, int flags,
550 			const k_timepoint_t req_end_timepoint)
551 {
552 	int ret, i;
553 	size_t offset = 0;
554 	size_t total_len = 0;
555 
556 	for (i = 0; i < message->msg_iovlen; i++) {
557 		total_len += message->msg_iov[i].iov_len;
558 	}
559 
560 	while (offset < total_len) {
561 		ret = zsock_sendmsg(sock, message, flags);
562 
563 		if ((ret == 0) || (ret < 0 && errno == EAGAIN)) {
564 			struct zsock_pollfd pfd;
565 			int pollres;
566 			k_ticks_t req_timeout_ticks =
567 				sys_timepoint_timeout(req_end_timepoint).ticks;
568 			int req_timeout_ms = k_ticks_to_ms_floor32(req_timeout_ticks);
569 
570 			pfd.fd = sock;
571 			pfd.events = ZSOCK_POLLOUT;
572 			pollres = zsock_poll(&pfd, 1, req_timeout_ms);
573 			if (pollres == 0) {
574 				return -ETIMEDOUT;
575 			} else if (pollres > 0) {
576 				continue;
577 			} else {
578 				return -errno;
579 			}
580 		} else if (ret < 0) {
581 			return -errno;
582 		}
583 
584 		offset += ret;
585 		if (offset >= total_len) {
586 			break;
587 		}
588 
589 		/* Update msghdr for the next iteration. */
590 		for (i = 0; i < message->msg_iovlen; i++) {
591 			if (ret < message->msg_iov[i].iov_len) {
592 				message->msg_iov[i].iov_len -= ret;
593 				message->msg_iov[i].iov_base =
594 					(uint8_t *)message->msg_iov[i].iov_base + ret;
595 				break;
596 			}
597 
598 			ret -= message->msg_iov[i].iov_len;
599 			message->msg_iov[i].iov_len = 0;
600 		}
601 	}
602 
603 	return total_len;
604 }
605 #endif /* !defined(CONFIG_NET_TEST) */
606 
websocket_prepare_and_send(struct websocket_context * ctx,uint8_t * header,size_t header_len,uint8_t * payload,size_t payload_len,int32_t timeout)607 static int websocket_prepare_and_send(struct websocket_context *ctx,
608 				      uint8_t *header, size_t header_len,
609 				      uint8_t *payload, size_t payload_len,
610 				      int32_t timeout)
611 {
612 	struct iovec io_vector[2];
613 	struct msghdr msg;
614 
615 	io_vector[0].iov_base = header;
616 	io_vector[0].iov_len = header_len;
617 	io_vector[1].iov_base = payload;
618 	io_vector[1].iov_len = payload_len;
619 
620 	memset(&msg, 0, sizeof(msg));
621 
622 	msg.msg_iov = io_vector;
623 	msg.msg_iovlen = ARRAY_SIZE(io_vector);
624 
625 	if (HEXDUMP_SENT_PACKETS) {
626 		LOG_HEXDUMP_DBG(header, header_len, "Header");
627 		if ((payload != NULL) && (payload_len > 0)) {
628 			LOG_HEXDUMP_DBG(payload, payload_len, "Payload");
629 		} else {
630 			LOG_DBG("No payload");
631 		}
632 	}
633 
634 #if defined(CONFIG_NET_TEST)
635 	/* Simulate a case where the payload is split to two. The unit test
636 	 * does not set mask bit in this case.
637 	 */
638 	return verify_sent_and_received_msg(&msg, !(header[1] & BIT(7)));
639 #else
640 	k_timeout_t tout = K_FOREVER;
641 
642 	if (timeout != SYS_FOREVER_MS) {
643 		tout = K_MSEC(timeout);
644 	}
645 
646 	k_timeout_t req_timeout = K_MSEC(timeout);
647 	k_timepoint_t req_end_timepoint = sys_timepoint_calc(req_timeout);
648 
649 	return sendmsg_all(ctx->real_sock, &msg,
650 			   K_TIMEOUT_EQ(tout, K_NO_WAIT) ? ZSOCK_MSG_DONTWAIT : 0,
651 			   req_end_timepoint);
652 #endif /* CONFIG_NET_TEST */
653 }
654 
websocket_send_msg(int ws_sock,const uint8_t * payload,size_t payload_len,enum websocket_opcode opcode,bool mask,bool final,int32_t timeout)655 int websocket_send_msg(int ws_sock, const uint8_t *payload, size_t payload_len,
656 		       enum websocket_opcode opcode, bool mask, bool final,
657 		       int32_t timeout)
658 {
659 	struct websocket_context *ctx;
660 	uint8_t header[MAX_HEADER_LEN], hdr_len = 2;
661 	uint8_t *data_to_send = (uint8_t *)payload;
662 	int ret;
663 
664 	if (opcode != WEBSOCKET_OPCODE_DATA_TEXT &&
665 	    opcode != WEBSOCKET_OPCODE_DATA_BINARY &&
666 	    opcode != WEBSOCKET_OPCODE_CONTINUE &&
667 	    opcode != WEBSOCKET_OPCODE_CLOSE &&
668 	    opcode != WEBSOCKET_OPCODE_PING &&
669 	    opcode != WEBSOCKET_OPCODE_PONG) {
670 		return -EINVAL;
671 	}
672 
673 	ctx = zvfs_get_fd_obj(ws_sock, NULL, 0);
674 	if (ctx == NULL) {
675 		return -EBADF;
676 	}
677 
678 #if !defined(CONFIG_NET_TEST)
679 	/* Websocket unit test does not use context from pool but allocates
680 	 * its own, hence skip the check.
681 	 */
682 
683 	if (!PART_OF_ARRAY(contexts, ctx)) {
684 		return -ENOENT;
685 	}
686 #endif /* !defined(CONFIG_NET_TEST) */
687 
688 	NET_DBG("[%p] Len %zd %s/%d/%s", ctx, payload_len, opcode2str(opcode),
689 		mask, final ? "final" : "more");
690 
691 	memset(header, 0, sizeof(header));
692 
693 	/* Is this the last packet? */
694 	header[0] = final ? BIT(7) : 0;
695 
696 	/* Text, binary, ping, pong or close ? */
697 	header[0] |= opcode;
698 
699 	/* Masking */
700 	header[1] = mask ? BIT(7) : 0;
701 
702 	if (payload_len < 126) {
703 		header[1] |= payload_len;
704 	} else if (payload_len < 65536) {
705 		header[1] |= 126;
706 		header[2] = payload_len >> 8;
707 		header[3] = payload_len;
708 		hdr_len += 2;
709 	} else {
710 		header[1] |= 127;
711 		header[2] = 0;
712 		header[3] = 0;
713 		header[4] = 0;
714 		header[5] = 0;
715 		header[6] = payload_len >> 24;
716 		header[7] = payload_len >> 16;
717 		header[8] = payload_len >> 8;
718 		header[9] = payload_len;
719 		hdr_len += 8;
720 	}
721 
722 	/* Add masking value if needed */
723 	if (mask) {
724 		int i;
725 
726 		ctx->masking_value = sys_rand32_get();
727 
728 		header[hdr_len++] |= ctx->masking_value >> 24;
729 		header[hdr_len++] |= ctx->masking_value >> 16;
730 		header[hdr_len++] |= ctx->masking_value >> 8;
731 		header[hdr_len++] |= ctx->masking_value;
732 
733 		if ((payload != NULL) && (payload_len > 0)) {
734 			data_to_send = k_malloc(payload_len);
735 			if (!data_to_send) {
736 				return -ENOMEM;
737 			}
738 
739 			memcpy(data_to_send, payload, payload_len);
740 
741 			for (i = 0; i < payload_len; i++) {
742 				data_to_send[i] ^= ctx->masking_value >> (8 * (3 - i % 4));
743 			}
744 		}
745 	}
746 
747 	ret = websocket_prepare_and_send(ctx, header, hdr_len,
748 					 data_to_send, payload_len, timeout);
749 	if (ret < 0) {
750 		NET_DBG("Cannot send ws msg (%d)", -errno);
751 		goto quit;
752 	}
753 
754 quit:
755 	if (data_to_send != payload) {
756 		k_free(data_to_send);
757 	}
758 
759 	/* Do no math with 0 and error codes */
760 	if (ret <= 0) {
761 		return ret;
762 	}
763 
764 	return ret - hdr_len;
765 }
766 
websocket_opcode2flag(uint8_t data)767 static uint32_t websocket_opcode2flag(uint8_t data)
768 {
769 	switch (data & 0x0f) {
770 	case WEBSOCKET_OPCODE_DATA_TEXT:
771 		return WEBSOCKET_FLAG_TEXT;
772 	case WEBSOCKET_OPCODE_DATA_BINARY:
773 		return WEBSOCKET_FLAG_BINARY;
774 	case WEBSOCKET_OPCODE_CLOSE:
775 		return WEBSOCKET_FLAG_CLOSE;
776 	case WEBSOCKET_OPCODE_PING:
777 		return WEBSOCKET_FLAG_PING;
778 	case WEBSOCKET_OPCODE_PONG:
779 		return WEBSOCKET_FLAG_PONG;
780 	default:
781 		break;
782 	}
783 	return 0;
784 }
785 
websocket_parse(struct websocket_context * ctx,struct websocket_buffer * payload)786 static int websocket_parse(struct websocket_context *ctx, struct websocket_buffer *payload)
787 {
788 	int len;
789 	uint8_t data;
790 	size_t parsed_count = 0;
791 
792 	do {
793 		if (parsed_count >= ctx->recv_buf.count) {
794 			return parsed_count;
795 		}
796 		if (ctx->parser_state != WEBSOCKET_PARSER_STATE_PAYLOAD) {
797 			data = ctx->recv_buf.buf[parsed_count++];
798 
799 			switch (ctx->parser_state) {
800 			case WEBSOCKET_PARSER_STATE_OPCODE:
801 				ctx->message_type = websocket_opcode2flag(data);
802 				if ((data & 0x80) != 0) {
803 					ctx->message_type |= WEBSOCKET_FLAG_FINAL;
804 				}
805 				ctx->parser_state = WEBSOCKET_PARSER_STATE_LENGTH;
806 				break;
807 			case WEBSOCKET_PARSER_STATE_LENGTH:
808 				ctx->masked = (data & 0x80) != 0;
809 				len = data & 0x7f;
810 				if (len < 126) {
811 					ctx->message_len = len;
812 					if (ctx->masked) {
813 						ctx->masking_value = 0;
814 						ctx->parser_remaining = 4;
815 						ctx->parser_state = WEBSOCKET_PARSER_STATE_MASK;
816 					} else {
817 						ctx->parser_remaining = ctx->message_len;
818 						ctx->parser_state =
819 							(ctx->parser_remaining == 0)
820 								? WEBSOCKET_PARSER_STATE_OPCODE
821 								: WEBSOCKET_PARSER_STATE_PAYLOAD;
822 					}
823 				} else {
824 					ctx->message_len = 0;
825 					ctx->parser_remaining = (len < 127) ? 2 : 8;
826 					ctx->parser_state = WEBSOCKET_PARSER_STATE_EXT_LEN;
827 				}
828 				break;
829 			case WEBSOCKET_PARSER_STATE_EXT_LEN:
830 				ctx->parser_remaining--;
831 				ctx->message_len |= ((uint64_t)data << (ctx->parser_remaining * 8));
832 				if (ctx->parser_remaining == 0) {
833 					if (ctx->masked) {
834 						ctx->masking_value = 0;
835 						ctx->parser_remaining = 4;
836 						ctx->parser_state = WEBSOCKET_PARSER_STATE_MASK;
837 					} else {
838 						ctx->parser_remaining = ctx->message_len;
839 						ctx->parser_state = WEBSOCKET_PARSER_STATE_PAYLOAD;
840 					}
841 				}
842 				break;
843 			case WEBSOCKET_PARSER_STATE_MASK:
844 				ctx->parser_remaining--;
845 				ctx->masking_value |=
846 					(uint32_t)((uint64_t)data << (ctx->parser_remaining * 8));
847 				if (ctx->parser_remaining == 0) {
848 					if (ctx->message_len == 0) {
849 						ctx->parser_remaining = 0;
850 						ctx->parser_state = WEBSOCKET_PARSER_STATE_OPCODE;
851 					} else {
852 						ctx->parser_remaining = ctx->message_len;
853 						ctx->parser_state = WEBSOCKET_PARSER_STATE_PAYLOAD;
854 					}
855 				}
856 				break;
857 			default:
858 				return -EFAULT;
859 			}
860 
861 #if (LOG_LEVEL >= LOG_LEVEL_DBG)
862 			if ((ctx->parser_state == WEBSOCKET_PARSER_STATE_PAYLOAD) ||
863 			    ((ctx->parser_state == WEBSOCKET_PARSER_STATE_OPCODE) &&
864 			     (ctx->message_len == 0))) {
865 				NET_DBG("[%p] %smasked, mask 0x%08x, type 0x%02x, msg %zd", ctx,
866 					ctx->masked ? "" : "un",
867 					ctx->masked ? ctx->masking_value : 0, ctx->message_type,
868 					(size_t)ctx->message_len);
869 			}
870 #endif
871 		} else {
872 			size_t remaining_in_recv_buf = ctx->recv_buf.count - parsed_count;
873 			size_t payload_in_recv_buf =
874 				MIN(remaining_in_recv_buf, ctx->parser_remaining);
875 			size_t free_in_payload_buf = payload->size - payload->count;
876 			size_t ready_to_copy = MIN(payload_in_recv_buf, free_in_payload_buf);
877 
878 			if (free_in_payload_buf == 0) {
879 				break;
880 			}
881 
882 			memcpy(&payload->buf[payload->count], &ctx->recv_buf.buf[parsed_count],
883 			       ready_to_copy);
884 			parsed_count += ready_to_copy;
885 			payload->count += ready_to_copy;
886 			ctx->parser_remaining -= ready_to_copy;
887 			if (ctx->parser_remaining == 0) {
888 				ctx->parser_remaining = 0;
889 				ctx->parser_state = WEBSOCKET_PARSER_STATE_OPCODE;
890 			}
891 		}
892 
893 	} while (ctx->parser_state != WEBSOCKET_PARSER_STATE_OPCODE);
894 
895 	return parsed_count;
896 }
897 
898 #if !defined(CONFIG_NET_TEST)
wait_rx(int sock,int timeout)899 static int wait_rx(int sock, int timeout)
900 {
901 	struct zsock_pollfd fds = {
902 		.fd = sock,
903 		.events = ZSOCK_POLLIN,
904 	};
905 	int ret;
906 
907 	ret = zsock_poll(&fds, 1, timeout);
908 	if (ret < 0) {
909 		return ret;
910 	}
911 
912 	if (ret == 0) {
913 		/* Timeout */
914 		return -EAGAIN;
915 	}
916 
917 	if (fds.revents & ZSOCK_POLLNVAL) {
918 		return -EBADF;
919 	}
920 
921 	if (fds.revents & ZSOCK_POLLERR) {
922 		return -EIO;
923 	}
924 
925 	return 0;
926 }
927 
timeout_to_ms(k_timeout_t * timeout)928 static int timeout_to_ms(k_timeout_t *timeout)
929 {
930 	if (K_TIMEOUT_EQ(*timeout, K_NO_WAIT)) {
931 		return 0;
932 	} else if (K_TIMEOUT_EQ(*timeout, K_FOREVER)) {
933 		return SYS_FOREVER_MS;
934 	} else {
935 		return k_ticks_to_ms_floor32(timeout->ticks);
936 	}
937 }
938 
939 #endif /* !defined(CONFIG_NET_TEST) */
940 
websocket_recv_msg(int ws_sock,uint8_t * buf,size_t buf_len,uint32_t * message_type,uint64_t * remaining,int32_t timeout)941 int websocket_recv_msg(int ws_sock, uint8_t *buf, size_t buf_len,
942 		       uint32_t *message_type, uint64_t *remaining, int32_t timeout)
943 {
944 	struct websocket_context *ctx;
945 	int ret;
946 	k_timepoint_t end;
947 	k_timeout_t tout = K_FOREVER;
948 	struct websocket_buffer payload = {.buf = buf, .size = buf_len, .count = 0};
949 
950 	if (timeout != SYS_FOREVER_MS) {
951 		tout = K_MSEC(timeout);
952 	}
953 
954 	if ((buf == NULL) || (buf_len == 0)) {
955 		return -EINVAL;
956 	}
957 
958 	end = sys_timepoint_calc(tout);
959 
960 #if defined(CONFIG_NET_TEST)
961 	struct test_data *test_data = zvfs_get_fd_obj(ws_sock, NULL, 0);
962 
963 	if (test_data == NULL) {
964 		return -EBADF;
965 	}
966 
967 	ctx = test_data->ctx;
968 #else
969 	ctx = zvfs_get_fd_obj(ws_sock, NULL, 0);
970 	if (ctx == NULL) {
971 		return -EBADF;
972 	}
973 
974 	if (!PART_OF_ARRAY(contexts, ctx)) {
975 		return -ENOENT;
976 	}
977 #endif /* CONFIG_NET_TEST */
978 
979 	do {
980 		size_t parsed_count;
981 
982 		if (ctx->recv_buf.count == 0) {
983 #if defined(CONFIG_NET_TEST)
984 			size_t input_len = MIN(ctx->recv_buf.size,
985 					       test_data->input_len - test_data->input_pos);
986 
987 			if (input_len > 0) {
988 				memcpy(ctx->recv_buf.buf,
989 				       &test_data->input_buf[test_data->input_pos], input_len);
990 				test_data->input_pos += input_len;
991 				ret = input_len;
992 			} else {
993 				/* emulate timeout */
994 				ret = -EAGAIN;
995 			}
996 #else
997 			tout = sys_timepoint_timeout(end);
998 
999 			ret = wait_rx(ctx->real_sock, timeout_to_ms(&tout));
1000 			if (ret == 0) {
1001 				ret = zsock_recv(ctx->real_sock, ctx->recv_buf.buf,
1002 						 ctx->recv_buf.size, ZSOCK_MSG_DONTWAIT);
1003 				if (ret < 0) {
1004 					ret = -errno;
1005 				}
1006 			}
1007 #endif /* CONFIG_NET_TEST */
1008 
1009 			if (ret < 0) {
1010 				if ((ret == -EAGAIN) && (payload.count > 0)) {
1011 					/* go to unmasking */
1012 					break;
1013 				}
1014 				return ret;
1015 			}
1016 
1017 			if (ret == 0) {
1018 				/* Socket closed */
1019 				return -ENOTCONN;
1020 			}
1021 
1022 			ctx->recv_buf.count = ret;
1023 
1024 			NET_DBG("[%p] Received %d bytes", ctx, ret);
1025 		}
1026 
1027 		ret = websocket_parse(ctx, &payload);
1028 		if (ret < 0) {
1029 			return ret;
1030 		}
1031 		parsed_count = ret;
1032 
1033 		if ((ctx->parser_state == WEBSOCKET_PARSER_STATE_OPCODE) ||
1034 		    (payload.count >= payload.size)) {
1035 			if (remaining != NULL) {
1036 				*remaining = ctx->parser_remaining;
1037 			}
1038 			if (message_type != NULL) {
1039 				*message_type = ctx->message_type;
1040 			}
1041 
1042 			size_t left = ctx->recv_buf.count - parsed_count;
1043 
1044 			if (left > 0) {
1045 				memmove(ctx->recv_buf.buf, &ctx->recv_buf.buf[parsed_count], left);
1046 			}
1047 			ctx->recv_buf.count = left;
1048 			break;
1049 		}
1050 
1051 		ctx->recv_buf.count -= parsed_count;
1052 
1053 	} while (true);
1054 
1055 	/* Unmask the data */
1056 	if (ctx->masked) {
1057 		uint8_t *mask_as_bytes = (uint8_t *)&ctx->masking_value;
1058 		size_t data_buf_offset = ctx->message_len - ctx->parser_remaining - payload.count;
1059 
1060 		for (size_t i = 0; i < payload.count; i++) {
1061 			size_t m = data_buf_offset % 4;
1062 
1063 			payload.buf[i] ^= mask_as_bytes[3 - m];
1064 			data_buf_offset++;
1065 		}
1066 	}
1067 
1068 	return payload.count;
1069 }
1070 
websocket_send(struct websocket_context * ctx,const uint8_t * buf,size_t buf_len,int32_t timeout)1071 static int websocket_send(struct websocket_context *ctx, const uint8_t *buf,
1072 			  size_t buf_len, int32_t timeout)
1073 {
1074 	int ret;
1075 
1076 	NET_DBG("[%p] Sending %zd bytes", ctx, buf_len);
1077 
1078 	ret = websocket_send_msg(ctx->sock, buf, buf_len, WEBSOCKET_OPCODE_DATA_TEXT,
1079 				 ctx->is_client, true, timeout);
1080 	if (ret < 0) {
1081 		errno = -ret;
1082 		return -1;
1083 	}
1084 
1085 	NET_DBG("[%p] Sent %d bytes", ctx, ret);
1086 
1087 	sock_obj_core_update_send_stats(ctx->sock, ret);
1088 
1089 	return ret;
1090 }
1091 
websocket_recv(struct websocket_context * ctx,uint8_t * buf,size_t buf_len,int32_t timeout)1092 static int websocket_recv(struct websocket_context *ctx, uint8_t *buf,
1093 			  size_t buf_len, int32_t timeout)
1094 {
1095 	uint32_t message_type;
1096 	uint64_t remaining;
1097 	int ret;
1098 
1099 	NET_DBG("[%p] Waiting data, buf len %zd bytes", ctx, buf_len);
1100 
1101 	/* TODO: add support for recvmsg() so that we could return the
1102 	 *       websocket specific information in ancillary data.
1103 	 */
1104 	ret = websocket_recv_msg(ctx->sock, buf, buf_len, &message_type,
1105 				 &remaining, timeout);
1106 	if (ret < 0) {
1107 		if (ret == -ENOTCONN) {
1108 			ret = 0;
1109 		} else {
1110 			errno = -ret;
1111 			return -1;
1112 		}
1113 	}
1114 
1115 	NET_DBG("[%p] Received %d bytes", ctx, ret);
1116 
1117 	sock_obj_core_update_recv_stats(ctx->sock, ret);
1118 
1119 	return ret;
1120 }
1121 
websocket_read_vmeth(void * obj,void * buffer,size_t count)1122 static ssize_t websocket_read_vmeth(void *obj, void *buffer, size_t count)
1123 {
1124 	return (ssize_t)websocket_recv(obj, buffer, count, SYS_FOREVER_MS);
1125 }
1126 
websocket_write_vmeth(void * obj,const void * buffer,size_t count)1127 static ssize_t websocket_write_vmeth(void *obj, const void *buffer,
1128 				     size_t count)
1129 {
1130 	return (ssize_t)websocket_send(obj, buffer, count, SYS_FOREVER_MS);
1131 }
1132 
websocket_sendto_ctx(void * obj,const void * buf,size_t len,int flags,const struct sockaddr * dest_addr,socklen_t addrlen)1133 static ssize_t websocket_sendto_ctx(void *obj, const void *buf, size_t len,
1134 				    int flags,
1135 				    const struct sockaddr *dest_addr,
1136 				    socklen_t addrlen)
1137 {
1138 	struct websocket_context *ctx = obj;
1139 	int32_t timeout = SYS_FOREVER_MS;
1140 
1141 	if (flags & ZSOCK_MSG_DONTWAIT) {
1142 		timeout = 0;
1143 	}
1144 
1145 	ARG_UNUSED(dest_addr);
1146 	ARG_UNUSED(addrlen);
1147 
1148 	return (ssize_t)websocket_send(ctx, buf, len, timeout);
1149 }
1150 
websocket_recvfrom_ctx(void * obj,void * buf,size_t max_len,int flags,struct sockaddr * src_addr,socklen_t * addrlen)1151 static ssize_t websocket_recvfrom_ctx(void *obj, void *buf, size_t max_len,
1152 				      int flags, struct sockaddr *src_addr,
1153 				      socklen_t *addrlen)
1154 {
1155 	struct websocket_context *ctx = obj;
1156 	int32_t timeout = SYS_FOREVER_MS;
1157 
1158 	if (flags & ZSOCK_MSG_DONTWAIT) {
1159 		timeout = 0;
1160 	}
1161 
1162 	ARG_UNUSED(src_addr);
1163 	ARG_UNUSED(addrlen);
1164 
1165 	return (ssize_t)websocket_recv(ctx, buf, max_len, timeout);
1166 }
1167 
websocket_register(int sock,uint8_t * recv_buf,size_t recv_buf_len)1168 int websocket_register(int sock, uint8_t *recv_buf, size_t recv_buf_len)
1169 {
1170 	struct websocket_context *ctx;
1171 	int ret, fd;
1172 
1173 	if (sock < 0) {
1174 		return -EINVAL;
1175 	}
1176 
1177 	ctx = websocket_find(sock);
1178 	if (ctx) {
1179 		NET_DBG("[%p] Websocket for sock %d already exists!", ctx, sock);
1180 		return -EEXIST;
1181 	}
1182 
1183 	ctx = websocket_get();
1184 	if (!ctx) {
1185 		return -ENOENT;
1186 	}
1187 
1188 	ctx->real_sock = sock;
1189 	ctx->recv_buf.buf = recv_buf;
1190 	ctx->recv_buf.size = recv_buf_len;
1191 	ctx->is_client = 0;
1192 
1193 	fd = zvfs_reserve_fd();
1194 	if (fd < 0) {
1195 		ret = -ENOSPC;
1196 		goto out;
1197 	}
1198 
1199 	ctx->sock = fd;
1200 	zvfs_finalize_typed_fd(fd, ctx, (const struct fd_op_vtable *)&websocket_fd_op_vtable,
1201 			    ZVFS_MODE_IFSOCK);
1202 
1203 	NET_DBG("[%p] WS connection to peer established (fd %d)", ctx, fd);
1204 
1205 	ctx->recv_buf.count = 0;
1206 	ctx->parser_state = WEBSOCKET_PARSER_STATE_OPCODE;
1207 
1208 	(void)sock_obj_core_alloc_find(ctx->real_sock, fd, SOCK_STREAM);
1209 
1210 	return fd;
1211 
1212 out:
1213 	websocket_context_unref(ctx);
1214 
1215 	return ret;
1216 }
1217 
websocket_search(int sock)1218 static struct websocket_context *websocket_search(int sock)
1219 {
1220 	struct websocket_context *ctx = NULL;
1221 	int i;
1222 
1223 	k_sem_take(&contexts_lock, K_FOREVER);
1224 
1225 	for (i = 0; i < ARRAY_SIZE(contexts); i++) {
1226 		if (!websocket_context_is_used(&contexts[i])) {
1227 			continue;
1228 		}
1229 
1230 		if (contexts[i].sock != sock) {
1231 			continue;
1232 		}
1233 
1234 		ctx = &contexts[i];
1235 		break;
1236 	}
1237 
1238 	k_sem_give(&contexts_lock);
1239 
1240 	return ctx;
1241 }
1242 
websocket_unregister(int sock)1243 int websocket_unregister(int sock)
1244 {
1245 	struct websocket_context *ctx;
1246 
1247 	if (sock < 0) {
1248 		return -EINVAL;
1249 	}
1250 
1251 	ctx = websocket_search(sock);
1252 	if (ctx == NULL) {
1253 		NET_DBG("[%p] Real socket for websocket sock %d not found!", ctx, sock);
1254 		return -ENOENT;
1255 	}
1256 
1257 	if (ctx->real_sock < 0) {
1258 		return -EALREADY;
1259 	}
1260 
1261 	(void)zsock_close(sock);
1262 	(void)zsock_close(ctx->real_sock);
1263 
1264 	ctx->real_sock = -1;
1265 	ctx->sock = -1;
1266 
1267 	return 0;
1268 }
1269 
1270 static const struct socket_op_vtable websocket_fd_op_vtable = {
1271 	.fd_vtable = {
1272 		.read = websocket_read_vmeth,
1273 		.write = websocket_write_vmeth,
1274 		.close = websocket_close_vmeth,
1275 		.ioctl = websocket_ioctl_vmeth,
1276 	},
1277 	.sendto = websocket_sendto_ctx,
1278 	.recvfrom = websocket_recvfrom_ctx,
1279 };
1280 
websocket_context_foreach(websocket_context_cb_t cb,void * user_data)1281 void websocket_context_foreach(websocket_context_cb_t cb, void *user_data)
1282 {
1283 	int i;
1284 
1285 	k_sem_take(&contexts_lock, K_FOREVER);
1286 
1287 	for (i = 0; i < ARRAY_SIZE(contexts); i++) {
1288 		if (!websocket_context_is_used(&contexts[i])) {
1289 			continue;
1290 		}
1291 
1292 		k_mutex_lock(&contexts[i].lock, K_FOREVER);
1293 
1294 		cb(&contexts[i], user_data);
1295 
1296 		k_mutex_unlock(&contexts[i].lock);
1297 	}
1298 
1299 	k_sem_give(&contexts_lock);
1300 }
1301 
websocket_init(void)1302 void websocket_init(void)
1303 {
1304 	k_sem_init(&contexts_lock, 1, K_SEM_MAX_LIMIT);
1305 }
1306