1 /* SPDX-License-Identifier: GPL-2.0 */
2
3 #ifndef __SOCKET_HELPERS__
4 #define __SOCKET_HELPERS__
5
6 #include <sys/un.h>
7 #include <linux/vm_sockets.h>
8
9 /* include/linux/net.h */
10 #define SOCK_TYPE_MASK 0xf
11
12 #define IO_TIMEOUT_SEC 30
13 #define MAX_STRERR_LEN 256
14
15 /* workaround for older vm_sockets.h */
16 #ifndef VMADDR_CID_LOCAL
17 #define VMADDR_CID_LOCAL 1
18 #endif
19
20 /* include/linux/cleanup.h */
21 #define __get_and_null(p, nullvalue) \
22 ({ \
23 __auto_type __ptr = &(p); \
24 __auto_type __val = *__ptr; \
25 *__ptr = nullvalue; \
26 __val; \
27 })
28
29 #define take_fd(fd) __get_and_null(fd, -EBADF)
30
31 /* Wrappers that fail the test on error and report it. */
32
33 #define _FAIL(errnum, fmt...) \
34 ({ \
35 error_at_line(0, (errnum), __func__, __LINE__, fmt); \
36 CHECK_FAIL(true); \
37 })
38 #define FAIL(fmt...) _FAIL(0, fmt)
39 #define FAIL_ERRNO(fmt...) _FAIL(errno, fmt)
40 #define FAIL_LIBBPF(err, msg) \
41 ({ \
42 char __buf[MAX_STRERR_LEN]; \
43 libbpf_strerror((err), __buf, sizeof(__buf)); \
44 FAIL("%s: %s", (msg), __buf); \
45 })
46
47
48 #define xaccept_nonblock(fd, addr, len) \
49 ({ \
50 int __ret = \
51 accept_timeout((fd), (addr), (len), IO_TIMEOUT_SEC); \
52 if (__ret == -1) \
53 FAIL_ERRNO("accept"); \
54 __ret; \
55 })
56
57 #define xbind(fd, addr, len) \
58 ({ \
59 int __ret = bind((fd), (addr), (len)); \
60 if (__ret == -1) \
61 FAIL_ERRNO("bind"); \
62 __ret; \
63 })
64
65 #define xclose(fd) \
66 ({ \
67 int __ret = close((fd)); \
68 if (__ret == -1) \
69 FAIL_ERRNO("close"); \
70 __ret; \
71 })
72
73 #define xconnect(fd, addr, len) \
74 ({ \
75 int __ret = connect((fd), (addr), (len)); \
76 if (__ret == -1) \
77 FAIL_ERRNO("connect"); \
78 __ret; \
79 })
80
81 #define xgetsockname(fd, addr, len) \
82 ({ \
83 int __ret = getsockname((fd), (addr), (len)); \
84 if (__ret == -1) \
85 FAIL_ERRNO("getsockname"); \
86 __ret; \
87 })
88
89 #define xgetsockopt(fd, level, name, val, len) \
90 ({ \
91 int __ret = getsockopt((fd), (level), (name), (val), (len)); \
92 if (__ret == -1) \
93 FAIL_ERRNO("getsockopt(" #name ")"); \
94 __ret; \
95 })
96
97 #define xlisten(fd, backlog) \
98 ({ \
99 int __ret = listen((fd), (backlog)); \
100 if (__ret == -1) \
101 FAIL_ERRNO("listen"); \
102 __ret; \
103 })
104
105 #define xsetsockopt(fd, level, name, val, len) \
106 ({ \
107 int __ret = setsockopt((fd), (level), (name), (val), (len)); \
108 if (__ret == -1) \
109 FAIL_ERRNO("setsockopt(" #name ")"); \
110 __ret; \
111 })
112
113 #define xsend(fd, buf, len, flags) \
114 ({ \
115 ssize_t __ret = send((fd), (buf), (len), (flags)); \
116 if (__ret == -1) \
117 FAIL_ERRNO("send"); \
118 __ret; \
119 })
120
121 #define xrecv_nonblock(fd, buf, len, flags) \
122 ({ \
123 ssize_t __ret = recv_timeout((fd), (buf), (len), (flags), \
124 IO_TIMEOUT_SEC); \
125 if (__ret == -1) \
126 FAIL_ERRNO("recv"); \
127 __ret; \
128 })
129
130 #define xsocket(family, sotype, flags) \
131 ({ \
132 int __ret = socket(family, sotype, flags); \
133 if (__ret == -1) \
134 FAIL_ERRNO("socket"); \
135 __ret; \
136 })
137
close_fd(int * fd)138 static inline void close_fd(int *fd)
139 {
140 if (*fd >= 0)
141 xclose(*fd);
142 }
143
144 #define __close_fd __attribute__((cleanup(close_fd)))
145
sockaddr(struct sockaddr_storage * ss)146 static inline struct sockaddr *sockaddr(struct sockaddr_storage *ss)
147 {
148 return (struct sockaddr *)ss;
149 }
150
init_addr_loopback4(struct sockaddr_storage * ss,socklen_t * len)151 static inline void init_addr_loopback4(struct sockaddr_storage *ss,
152 socklen_t *len)
153 {
154 struct sockaddr_in *addr4 = memset(ss, 0, sizeof(*ss));
155
156 addr4->sin_family = AF_INET;
157 addr4->sin_port = 0;
158 addr4->sin_addr.s_addr = htonl(INADDR_LOOPBACK);
159 *len = sizeof(*addr4);
160 }
161
init_addr_loopback6(struct sockaddr_storage * ss,socklen_t * len)162 static inline void init_addr_loopback6(struct sockaddr_storage *ss,
163 socklen_t *len)
164 {
165 struct sockaddr_in6 *addr6 = memset(ss, 0, sizeof(*ss));
166
167 addr6->sin6_family = AF_INET6;
168 addr6->sin6_port = 0;
169 addr6->sin6_addr = in6addr_loopback;
170 *len = sizeof(*addr6);
171 }
172
init_addr_loopback_unix(struct sockaddr_storage * ss,socklen_t * len)173 static inline void init_addr_loopback_unix(struct sockaddr_storage *ss,
174 socklen_t *len)
175 {
176 struct sockaddr_un *addr = memset(ss, 0, sizeof(*ss));
177
178 addr->sun_family = AF_UNIX;
179 *len = sizeof(sa_family_t);
180 }
181
init_addr_loopback_vsock(struct sockaddr_storage * ss,socklen_t * len)182 static inline void init_addr_loopback_vsock(struct sockaddr_storage *ss,
183 socklen_t *len)
184 {
185 struct sockaddr_vm *addr = memset(ss, 0, sizeof(*ss));
186
187 addr->svm_family = AF_VSOCK;
188 addr->svm_port = VMADDR_PORT_ANY;
189 addr->svm_cid = VMADDR_CID_LOCAL;
190 *len = sizeof(*addr);
191 }
192
init_addr_loopback(int family,struct sockaddr_storage * ss,socklen_t * len)193 static inline void init_addr_loopback(int family, struct sockaddr_storage *ss,
194 socklen_t *len)
195 {
196 switch (family) {
197 case AF_INET:
198 init_addr_loopback4(ss, len);
199 return;
200 case AF_INET6:
201 init_addr_loopback6(ss, len);
202 return;
203 case AF_UNIX:
204 init_addr_loopback_unix(ss, len);
205 return;
206 case AF_VSOCK:
207 init_addr_loopback_vsock(ss, len);
208 return;
209 default:
210 FAIL("unsupported address family %d", family);
211 }
212 }
213
enable_reuseport(int s,int progfd)214 static inline int enable_reuseport(int s, int progfd)
215 {
216 int err, one = 1;
217
218 err = xsetsockopt(s, SOL_SOCKET, SO_REUSEPORT, &one, sizeof(one));
219 if (err)
220 return -1;
221 err = xsetsockopt(s, SOL_SOCKET, SO_ATTACH_REUSEPORT_EBPF, &progfd,
222 sizeof(progfd));
223 if (err)
224 return -1;
225
226 return 0;
227 }
228
socket_loopback_reuseport(int family,int sotype,int progfd)229 static inline int socket_loopback_reuseport(int family, int sotype, int progfd)
230 {
231 struct sockaddr_storage addr;
232 socklen_t len = 0;
233 int err, s;
234
235 init_addr_loopback(family, &addr, &len);
236
237 s = xsocket(family, sotype, 0);
238 if (s == -1)
239 return -1;
240
241 if (progfd >= 0)
242 enable_reuseport(s, progfd);
243
244 err = xbind(s, sockaddr(&addr), len);
245 if (err)
246 goto close;
247
248 if (sotype & SOCK_DGRAM)
249 return s;
250
251 err = xlisten(s, SOMAXCONN);
252 if (err)
253 goto close;
254
255 return s;
256 close:
257 xclose(s);
258 return -1;
259 }
260
socket_loopback(int family,int sotype)261 static inline int socket_loopback(int family, int sotype)
262 {
263 return socket_loopback_reuseport(family, sotype, -1);
264 }
265
poll_connect(int fd,unsigned int timeout_sec)266 static inline int poll_connect(int fd, unsigned int timeout_sec)
267 {
268 struct timeval timeout = { .tv_sec = timeout_sec };
269 fd_set wfds;
270 int r, eval;
271 socklen_t esize = sizeof(eval);
272
273 FD_ZERO(&wfds);
274 FD_SET(fd, &wfds);
275
276 r = select(fd + 1, NULL, &wfds, NULL, &timeout);
277 if (r == 0)
278 errno = ETIME;
279 if (r != 1)
280 return -1;
281
282 if (getsockopt(fd, SOL_SOCKET, SO_ERROR, &eval, &esize) < 0)
283 return -1;
284 if (eval != 0) {
285 errno = eval;
286 return -1;
287 }
288
289 return 0;
290 }
291
poll_read(int fd,unsigned int timeout_sec)292 static inline int poll_read(int fd, unsigned int timeout_sec)
293 {
294 struct timeval timeout = { .tv_sec = timeout_sec };
295 fd_set rfds;
296 int r;
297
298 FD_ZERO(&rfds);
299 FD_SET(fd, &rfds);
300
301 r = select(fd + 1, &rfds, NULL, NULL, &timeout);
302 if (r == 0)
303 errno = ETIME;
304
305 return r == 1 ? 0 : -1;
306 }
307
accept_timeout(int fd,struct sockaddr * addr,socklen_t * len,unsigned int timeout_sec)308 static inline int accept_timeout(int fd, struct sockaddr *addr, socklen_t *len,
309 unsigned int timeout_sec)
310 {
311 if (poll_read(fd, timeout_sec))
312 return -1;
313
314 return accept(fd, addr, len);
315 }
316
recv_timeout(int fd,void * buf,size_t len,int flags,unsigned int timeout_sec)317 static inline int recv_timeout(int fd, void *buf, size_t len, int flags,
318 unsigned int timeout_sec)
319 {
320 if (poll_read(fd, timeout_sec))
321 return -1;
322
323 return recv(fd, buf, len, flags);
324 }
325
326
create_pair(int family,int sotype,int * p0,int * p1)327 static inline int create_pair(int family, int sotype, int *p0, int *p1)
328 {
329 __close_fd int s, c = -1, p = -1;
330 struct sockaddr_storage addr;
331 socklen_t len;
332 int err;
333
334 s = socket_loopback(family, sotype);
335 if (s < 0)
336 return s;
337
338 c = xsocket(family, sotype, 0);
339 if (c < 0)
340 return c;
341
342 init_addr_loopback(family, &addr, &len);
343 err = xbind(c, sockaddr(&addr), len);
344 if (err)
345 return err;
346
347 len = sizeof(addr);
348 err = xgetsockname(s, sockaddr(&addr), &len);
349 if (err)
350 return err;
351
352 err = connect(c, sockaddr(&addr), len);
353 if (err) {
354 if (errno != EINPROGRESS) {
355 FAIL_ERRNO("connect");
356 return err;
357 }
358
359 err = poll_connect(c, IO_TIMEOUT_SEC);
360 if (err) {
361 FAIL_ERRNO("poll_connect");
362 return err;
363 }
364 }
365
366 switch (sotype & SOCK_TYPE_MASK) {
367 case SOCK_DGRAM:
368 err = xgetsockname(c, sockaddr(&addr), &len);
369 if (err)
370 return err;
371
372 err = xconnect(s, sockaddr(&addr), len);
373 if (err)
374 return err;
375
376 *p0 = take_fd(s);
377 break;
378 case SOCK_STREAM:
379 case SOCK_SEQPACKET:
380 p = xaccept_nonblock(s, NULL, NULL);
381 if (p < 0)
382 return p;
383
384 *p0 = take_fd(p);
385 break;
386 default:
387 FAIL("Unsupported socket type %#x", sotype);
388 return -EOPNOTSUPP;
389 }
390
391 *p1 = take_fd(c);
392 return 0;
393 }
394
create_socket_pairs(int family,int sotype,int * c0,int * c1,int * p0,int * p1)395 static inline int create_socket_pairs(int family, int sotype, int *c0, int *c1,
396 int *p0, int *p1)
397 {
398 int err;
399
400 err = create_pair(family, sotype, c0, p0);
401 if (err)
402 return err;
403
404 err = create_pair(family, sotype, c1, p1);
405 if (err) {
406 close(*c0);
407 close(*p0);
408 }
409
410 return err;
411 }
412
socket_kind_to_str(int sock_fd)413 static inline const char *socket_kind_to_str(int sock_fd)
414 {
415 socklen_t opt_len;
416 int domain, type;
417
418 opt_len = sizeof(domain);
419 if (getsockopt(sock_fd, SOL_SOCKET, SO_DOMAIN, &domain, &opt_len))
420 FAIL_ERRNO("getsockopt(SO_DOMAIN)");
421
422 opt_len = sizeof(type);
423 if (getsockopt(sock_fd, SOL_SOCKET, SO_TYPE, &type, &opt_len))
424 FAIL_ERRNO("getsockopt(SO_TYPE)");
425
426 switch (domain) {
427 case AF_INET:
428 switch (type) {
429 case SOCK_STREAM:
430 return "tcp4";
431 case SOCK_DGRAM:
432 return "udp4";
433 }
434 break;
435 case AF_INET6:
436 switch (type) {
437 case SOCK_STREAM:
438 return "tcp6";
439 case SOCK_DGRAM:
440 return "udp6";
441 }
442 break;
443 case AF_UNIX:
444 switch (type) {
445 case SOCK_STREAM:
446 return "u_str";
447 case SOCK_DGRAM:
448 return "u_dgr";
449 case SOCK_SEQPACKET:
450 return "u_seq";
451 }
452 break;
453 case AF_VSOCK:
454 switch (type) {
455 case SOCK_STREAM:
456 return "v_str";
457 case SOCK_DGRAM:
458 return "v_dgr";
459 case SOCK_SEQPACKET:
460 return "v_seq";
461 }
462 break;
463 }
464
465 return "???";
466 }
467
468 #endif // __SOCKET_HELPERS__
469