1 /*
2  * Copyright (c) 2023 Nordic Semiconductor ASA
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  */
6 
7 #include <zephyr/logging/log.h>
8 LOG_MODULE_REGISTER(net_sock_svc, CONFIG_NET_SOCKETS_LOG_LEVEL);
9 
10 #include <zephyr/kernel.h>
11 #include <zephyr/init.h>
12 #include <zephyr/net/socket_service.h>
13 #include <zephyr/zvfs/eventfd.h>
14 
15 static int init_socket_service(void);
16 
17 enum SOCKET_SERVICE_THREAD_STATUS {
18 	SOCKET_SERVICE_THREAD_UNINITIALIZED = 0,
19 	SOCKET_SERVICE_THREAD_FAILED,
20 	SOCKET_SERVICE_THREAD_STOPPED,
21 	SOCKET_SERVICE_THREAD_RUNNING,
22 };
23 static enum SOCKET_SERVICE_THREAD_STATUS thread_status;
24 
25 static K_MUTEX_DEFINE(lock);
26 static K_CONDVAR_DEFINE(wait_start);
27 
28 STRUCT_SECTION_START_EXTERN(net_socket_service_desc);
29 STRUCT_SECTION_END_EXTERN(net_socket_service_desc);
30 
31 static struct service {
32 	struct zsock_pollfd events[CONFIG_ZVFS_POLL_MAX];
33 	int count;
34 } ctx;
35 
36 #define get_idx(svc) (*(svc->idx))
37 
net_socket_service_foreach(net_socket_service_cb_t cb,void * user_data)38 void net_socket_service_foreach(net_socket_service_cb_t cb, void *user_data)
39 {
40 	STRUCT_SECTION_FOREACH(net_socket_service_desc, svc) {
41 		cb(svc, user_data);
42 	}
43 }
44 
cleanup_svc_events(const struct net_socket_service_desc * svc)45 static void cleanup_svc_events(const struct net_socket_service_desc *svc)
46 {
47 	for (int i = 0; i < svc->pev_len; i++) {
48 		svc->pev[i].event.fd = -1;
49 		svc->pev[i].event.events = 0;
50 	}
51 }
52 
z_impl_net_socket_service_register(const struct net_socket_service_desc * svc,struct zsock_pollfd * fds,int len,void * user_data)53 int z_impl_net_socket_service_register(const struct net_socket_service_desc *svc,
54 				       struct zsock_pollfd *fds, int len,
55 				       void *user_data)
56 {
57 	int i, ret = -ENOENT;
58 
59 	k_mutex_lock(&lock, K_FOREVER);
60 
61 	if (thread_status == SOCKET_SERVICE_THREAD_UNINITIALIZED) {
62 		(void)k_condvar_wait(&wait_start, &lock, K_FOREVER);
63 	} else if (thread_status != SOCKET_SERVICE_THREAD_RUNNING) {
64 		NET_ERR("Socket service thread not running, service %p register fails.", svc);
65 		ret = -EIO;
66 		goto out;
67 	}
68 
69 	if (STRUCT_SECTION_START(net_socket_service_desc) > svc ||
70 	    STRUCT_SECTION_END(net_socket_service_desc) <= svc) {
71 		goto out;
72 	}
73 
74 	cleanup_svc_events(svc);
75 
76 	if (fds != NULL) {
77 		if (len > svc->pev_len) {
78 			NET_DBG("Too many file descriptors, "
79 				"max is %d for service %p",
80 				svc->pev_len, svc);
81 			ret = -ENOMEM;
82 			goto out;
83 		}
84 
85 		for (i = 0; i < len; i++) {
86 			svc->pev[i].event = fds[i];
87 			svc->pev[i].user_data = user_data;
88 		}
89 	}
90 
91 	/* Tell the thread to re-read the variables */
92 	zvfs_eventfd_write(ctx.events[0].fd, 1);
93 	ret = 0;
94 
95 out:
96 	k_mutex_unlock(&lock);
97 
98 	return ret;
99 }
100 
find_svc_and_event(struct zsock_pollfd * pev,struct net_socket_service_event ** event)101 static struct net_socket_service_desc *find_svc_and_event(
102 	struct zsock_pollfd *pev,
103 	struct net_socket_service_event **event)
104 {
105 	STRUCT_SECTION_FOREACH(net_socket_service_desc, svc) {
106 		for (int i = 0; i < svc->pev_len; i++) {
107 			if (svc->pev[i].event.fd == pev->fd) {
108 				*event = &svc->pev[i];
109 				return svc;
110 			}
111 		}
112 	}
113 
114 	return NULL;
115 }
116 
117 /* We do not set the user callback to our work struct because we need to
118  * hook into the flow and restore the global poll array so that the next poll
119  * round will not notice it and call the callback again while we are
120  * servicing the callback.
121  */
net_socket_service_callback(struct net_socket_service_event * pev)122 void net_socket_service_callback(struct net_socket_service_event *pev)
123 {
124 	struct net_socket_service_event ev = *pev;
125 
126 	ev.callback(&ev);
127 }
128 
call_work(struct zsock_pollfd * pev,struct net_socket_service_event * event)129 static int call_work(struct zsock_pollfd *pev, struct net_socket_service_event *event)
130 {
131 	int ret = 0;
132 	int fd = pev->fd;
133 
134 	/* Mark the global fd non pollable so that we do not
135 	 * call the callback second time.
136 	 */
137 	pev->fd = -1;
138 
139 	/* Synchronous call */
140 	net_socket_service_callback(event);
141 
142 	/* Restore the fd so that new data can be re-triggered */
143 	pev->fd = fd;
144 
145 	return ret;
146 }
147 
trigger_work(struct zsock_pollfd * pev)148 static int trigger_work(struct zsock_pollfd *pev)
149 {
150 	struct net_socket_service_event *event;
151 	struct net_socket_service_desc *svc;
152 
153 	svc = find_svc_and_event(pev, &event);
154 	if (svc == NULL) {
155 		return -ENOENT;
156 	}
157 
158 	event->svc = svc;
159 
160 	/* Copy the triggered event to our event so that we know what
161 	 * was actually causing the event.
162 	 */
163 	event->event = *pev;
164 
165 	return call_work(pev, event);
166 }
167 
socket_service_thread(void * p1,void * p2,void * p3)168 static void socket_service_thread(void *p1, void *p2, void *p3)
169 {
170 	ARG_UNUSED(p1);
171 	ARG_UNUSED(p2);
172 	ARG_UNUSED(p3);
173 
174 	int ret, i, fd, count = 0;
175 	zvfs_eventfd_t value;
176 
177 	STRUCT_SECTION_COUNT(net_socket_service_desc, &ret);
178 	if (ret == 0) {
179 		NET_INFO("No socket services found, service disabled.");
180 		goto fail;
181 	}
182 
183 	/* Create contiguous poll event array to enable socket polling */
184 	STRUCT_SECTION_FOREACH(net_socket_service_desc, svc) {
185 		NET_DBG("Service %s has %d pollable sockets",
186 			COND_CODE_1(CONFIG_NET_SOCKETS_LOG_LEVEL_DBG,
187 				    (svc->owner), ("")),
188 			svc->pev_len);
189 		get_idx(svc) = count + 1;
190 		count += svc->pev_len;
191 	}
192 
193 	if ((count + 1) > ARRAY_SIZE(ctx.events)) {
194 		NET_ERR("You have %d services to monitor but "
195 			"%zd poll entries configured.",
196 			count + 1, ARRAY_SIZE(ctx.events));
197 		NET_ERR("Please increase value of %s to at least %d",
198 			"CONFIG_ZVFS_POLL_MAX", count + 1);
199 		goto fail;
200 	}
201 
202 	NET_DBG("Monitoring %d socket entries", count);
203 
204 	ctx.count = count + 1;
205 
206 	/* Create an zvfs_eventfd that can be used to trigger events during polling */
207 	fd = zvfs_eventfd(0, 0);
208 	if (fd < 0) {
209 		fd = -errno;
210 		NET_ERR("zvfs_eventfd failed (%d)", fd);
211 		goto out;
212 	}
213 
214 	thread_status = SOCKET_SERVICE_THREAD_RUNNING;
215 	k_condvar_broadcast(&wait_start);
216 
217 	ctx.events[0].fd = fd;
218 	ctx.events[0].events = ZSOCK_POLLIN;
219 
220 restart:
221 	i = 1;
222 
223 	k_mutex_lock(&lock, K_FOREVER);
224 
225 	/* Copy individual events to the big array */
226 	STRUCT_SECTION_FOREACH(net_socket_service_desc, svc) {
227 		for (int j = 0; j < svc->pev_len; j++) {
228 			ctx.events[get_idx(svc) + j] = svc->pev[j].event;
229 		}
230 	}
231 
232 	k_mutex_unlock(&lock);
233 
234 	while (true) {
235 		ret = zsock_poll(ctx.events, count + 1, -1);
236 		if (ret < 0) {
237 			ret = -errno;
238 			NET_ERR("poll failed (%d)", ret);
239 			goto out;
240 		}
241 
242 		if (ret == 0) {
243 			/* should not happen because timeout is -1 */
244 			break;
245 		}
246 
247 		/* Process work here */
248 		for (i = 1; i < (count + 1); i++) {
249 			if (ctx.events[i].fd < 0) {
250 				continue;
251 			}
252 
253 			if (ctx.events[i].revents > 0) {
254 				ret = trigger_work(&ctx.events[i]);
255 				if (ret < 0) {
256 					NET_DBG("Triggering work failed (%d)", ret);
257 					goto restart;
258 				}
259 			}
260 		}
261 
262 		/* Relocate after trigger work so the work gets done before restarting */
263 		if (ret > 0 && ctx.events[0].revents) {
264 			zvfs_eventfd_read(ctx.events[0].fd, &value);
265 			ctx.events[0].revents = 0;
266 			NET_DBG("Received restart event.");
267 			goto restart;
268 		}
269 	}
270 
271 out:
272 	NET_DBG("Socket service thread stopped");
273 	thread_status = SOCKET_SERVICE_THREAD_STOPPED;
274 
275 	return;
276 
277 fail:
278 	thread_status = SOCKET_SERVICE_THREAD_FAILED;
279 	k_condvar_broadcast(&wait_start);
280 }
281 
init_socket_service(void)282 static int init_socket_service(void)
283 {
284 	k_tid_t ssm;
285 	static struct k_thread service_thread;
286 
287 	static K_THREAD_STACK_DEFINE(service_thread_stack,
288 				     CONFIG_NET_SOCKETS_SERVICE_STACK_SIZE);
289 
290 	ssm = k_thread_create(&service_thread,
291 			      service_thread_stack,
292 			      K_THREAD_STACK_SIZEOF(service_thread_stack),
293 			      (k_thread_entry_t)socket_service_thread, NULL, NULL, NULL,
294 			      CLAMP(CONFIG_NET_SOCKETS_SERVICE_THREAD_PRIO,
295 				    K_HIGHEST_APPLICATION_THREAD_PRIO,
296 				    K_LOWEST_APPLICATION_THREAD_PRIO), 0, K_NO_WAIT);
297 
298 	k_thread_name_set(ssm, "net_socket_service");
299 
300 	return 0;
301 }
302 
socket_service_init(void)303 void socket_service_init(void)
304 {
305 	(void)init_socket_service();
306 }
307