1 // SPDX-License-Identifier: GPL-2.0
2 // Copyright (c) 2020 Cloudflare
3 /*
4  * Tests for sockmap/sockhash holding kTLS sockets.
5  */
6 #include <error.h>
7 #include <netinet/tcp.h>
8 #include <linux/tls.h>
9 #include "test_progs.h"
10 #include "sockmap_helpers.h"
11 #include "test_skmsg_load_helpers.skel.h"
12 #include "test_sockmap_ktls.skel.h"
13 
14 #define MAX_TEST_NAME 80
15 #define TCP_ULP 31
16 
init_ktls_pairs(int c,int p)17 static int init_ktls_pairs(int c, int p)
18 {
19 	int err;
20 	struct tls12_crypto_info_aes_gcm_128 crypto_rx;
21 	struct tls12_crypto_info_aes_gcm_128 crypto_tx;
22 
23 	err = setsockopt(c, IPPROTO_TCP, TCP_ULP, "tls", strlen("tls"));
24 	if (!ASSERT_OK(err, "setsockopt(TCP_ULP)"))
25 		goto out;
26 
27 	err = setsockopt(p, IPPROTO_TCP, TCP_ULP, "tls", strlen("tls"));
28 	if (!ASSERT_OK(err, "setsockopt(TCP_ULP)"))
29 		goto out;
30 
31 	memset(&crypto_rx, 0, sizeof(crypto_rx));
32 	memset(&crypto_tx, 0, sizeof(crypto_tx));
33 	crypto_rx.info.version = TLS_1_2_VERSION;
34 	crypto_tx.info.version = TLS_1_2_VERSION;
35 	crypto_rx.info.cipher_type = TLS_CIPHER_AES_GCM_128;
36 	crypto_tx.info.cipher_type = TLS_CIPHER_AES_GCM_128;
37 
38 	err = setsockopt(c, SOL_TLS, TLS_TX, &crypto_tx, sizeof(crypto_tx));
39 	if (!ASSERT_OK(err, "setsockopt(TLS_TX)"))
40 		goto out;
41 
42 	err = setsockopt(p, SOL_TLS, TLS_RX, &crypto_rx, sizeof(crypto_rx));
43 	if (!ASSERT_OK(err, "setsockopt(TLS_RX)"))
44 		goto out;
45 	return 0;
46 out:
47 	return -1;
48 }
49 
create_ktls_pairs(int family,int sotype,int * c,int * p)50 static int create_ktls_pairs(int family, int sotype, int *c, int *p)
51 {
52 	int err;
53 
54 	err = create_pair(family, sotype, c, p);
55 	if (!ASSERT_OK(err, "create_pair()"))
56 		return -1;
57 
58 	err = init_ktls_pairs(*c, *p);
59 	if (!ASSERT_OK(err, "init_ktls_pairs(c, p)"))
60 		return -1;
61 	return 0;
62 }
63 
test_sockmap_ktls_update_fails_when_sock_has_ulp(int family,int map)64 static void test_sockmap_ktls_update_fails_when_sock_has_ulp(int family, int map)
65 {
66 	struct sockaddr_storage addr = {};
67 	socklen_t len = sizeof(addr);
68 	struct sockaddr_in6 *v6;
69 	struct sockaddr_in *v4;
70 	int err, s, zero = 0;
71 
72 	switch (family) {
73 	case AF_INET:
74 		v4 = (struct sockaddr_in *)&addr;
75 		v4->sin_family = AF_INET;
76 		break;
77 	case AF_INET6:
78 		v6 = (struct sockaddr_in6 *)&addr;
79 		v6->sin6_family = AF_INET6;
80 		break;
81 	default:
82 		PRINT_FAIL("unsupported socket family %d", family);
83 		return;
84 	}
85 
86 	s = socket(family, SOCK_STREAM, 0);
87 	if (!ASSERT_GE(s, 0, "socket"))
88 		return;
89 
90 	err = bind(s, (struct sockaddr *)&addr, len);
91 	if (!ASSERT_OK(err, "bind"))
92 		goto close;
93 
94 	err = getsockname(s, (struct sockaddr *)&addr, &len);
95 	if (!ASSERT_OK(err, "getsockname"))
96 		goto close;
97 
98 	err = connect(s, (struct sockaddr *)&addr, len);
99 	if (!ASSERT_OK(err, "connect"))
100 		goto close;
101 
102 	/* save sk->sk_prot and set it to tls_prots */
103 	err = setsockopt(s, IPPROTO_TCP, TCP_ULP, "tls", strlen("tls"));
104 	if (!ASSERT_OK(err, "setsockopt(TCP_ULP)"))
105 		goto close;
106 
107 	/* sockmap update should not affect saved sk_prot */
108 	err = bpf_map_update_elem(map, &zero, &s, BPF_ANY);
109 	if (!ASSERT_ERR(err, "sockmap update elem"))
110 		goto close;
111 
112 	/* call sk->sk_prot->setsockopt to dispatch to saved sk_prot */
113 	err = setsockopt(s, IPPROTO_TCP, TCP_NODELAY, &zero, sizeof(zero));
114 	ASSERT_OK(err, "setsockopt(TCP_NODELAY)");
115 
116 close:
117 	close(s);
118 }
119 
fmt_test_name(const char * subtest_name,int family,enum bpf_map_type map_type)120 static const char *fmt_test_name(const char *subtest_name, int family,
121 				 enum bpf_map_type map_type)
122 {
123 	const char *map_type_str = BPF_MAP_TYPE_SOCKMAP ? "SOCKMAP" : "SOCKHASH";
124 	const char *family_str = AF_INET ? "IPv4" : "IPv6";
125 	static char test_name[MAX_TEST_NAME];
126 
127 	snprintf(test_name, MAX_TEST_NAME,
128 		 "sockmap_ktls %s %s %s",
129 		 subtest_name, family_str, map_type_str);
130 
131 	return test_name;
132 }
133 
test_sockmap_ktls_offload(int family,int sotype)134 static void test_sockmap_ktls_offload(int family, int sotype)
135 {
136 	int err;
137 	int c = 0, p = 0, sent, recvd;
138 	char msg[12] = "hello world\0";
139 	char rcv[13];
140 
141 	err = create_ktls_pairs(family, sotype, &c, &p);
142 	if (!ASSERT_OK(err, "create_ktls_pairs()"))
143 		goto out;
144 
145 	sent = send(c, msg, sizeof(msg), 0);
146 	if (!ASSERT_OK(err, "send(msg)"))
147 		goto out;
148 
149 	recvd = recv(p, rcv, sizeof(rcv), 0);
150 	if (!ASSERT_OK(err, "recv(msg)") ||
151 	    !ASSERT_EQ(recvd, sent, "length mismatch"))
152 		goto out;
153 
154 	ASSERT_OK(memcmp(msg, rcv, sizeof(msg)), "data mismatch");
155 
156 out:
157 	if (c)
158 		close(c);
159 	if (p)
160 		close(p);
161 }
162 
test_sockmap_ktls_tx_cork(int family,int sotype,bool push)163 static void test_sockmap_ktls_tx_cork(int family, int sotype, bool push)
164 {
165 	int err, off;
166 	int i, j;
167 	int start_push = 0, push_len = 0;
168 	int c = 0, p = 0, one = 1, sent, recvd;
169 	int prog_fd, map_fd;
170 	char msg[12] = "hello world\0";
171 	char rcv[20] = {0};
172 	struct test_sockmap_ktls *skel;
173 
174 	skel = test_sockmap_ktls__open_and_load();
175 	if (!ASSERT_TRUE(skel, "open ktls skel"))
176 		return;
177 
178 	err = create_pair(family, sotype, &c, &p);
179 	if (!ASSERT_OK(err, "create_pair()"))
180 		goto out;
181 
182 	prog_fd = bpf_program__fd(skel->progs.prog_sk_policy);
183 	map_fd = bpf_map__fd(skel->maps.sock_map);
184 
185 	err = bpf_prog_attach(prog_fd, map_fd, BPF_SK_MSG_VERDICT, 0);
186 	if (!ASSERT_OK(err, "bpf_prog_attach sk msg"))
187 		goto out;
188 
189 	err = bpf_map_update_elem(map_fd, &one, &c, BPF_NOEXIST);
190 	if (!ASSERT_OK(err, "bpf_map_update_elem(c)"))
191 		goto out;
192 
193 	err = init_ktls_pairs(c, p);
194 	if (!ASSERT_OK(err, "init_ktls_pairs(c, p)"))
195 		goto out;
196 
197 	skel->bss->cork_byte = sizeof(msg);
198 	if (push) {
199 		start_push = 1;
200 		push_len = 2;
201 	}
202 	skel->bss->push_start = start_push;
203 	skel->bss->push_end = push_len;
204 
205 	off = sizeof(msg) / 2;
206 	sent = send(c, msg, off, 0);
207 	if (!ASSERT_EQ(sent, off, "send(msg)"))
208 		goto out;
209 
210 	recvd = recv_timeout(p, rcv, sizeof(rcv), MSG_DONTWAIT, 1);
211 	if (!ASSERT_EQ(-1, recvd, "expected no data"))
212 		goto out;
213 
214 	/* send remaining msg */
215 	sent = send(c, msg + off, sizeof(msg) - off, 0);
216 	if (!ASSERT_EQ(sent, sizeof(msg) - off, "send remaining data"))
217 		goto out;
218 
219 	recvd = recv_timeout(p, rcv, sizeof(rcv), MSG_DONTWAIT, 1);
220 	if (!ASSERT_OK(err, "recv(msg)") ||
221 	    !ASSERT_EQ(recvd, sizeof(msg) + push_len, "check length mismatch"))
222 		goto out;
223 
224 	for (i = 0, j = 0; i < recvd;) {
225 		/* skip checking the data that has been pushed in */
226 		if (i >= start_push && i <= start_push + push_len - 1) {
227 			i++;
228 			continue;
229 		}
230 		if (!ASSERT_EQ(rcv[i], msg[j], "data mismatch"))
231 			goto out;
232 		i++;
233 		j++;
234 	}
235 out:
236 	if (c)
237 		close(c);
238 	if (p)
239 		close(p);
240 	test_sockmap_ktls__destroy(skel);
241 }
242 
test_sockmap_ktls_tx_no_buf(int family,int sotype,bool push)243 static void test_sockmap_ktls_tx_no_buf(int family, int sotype, bool push)
244 {
245 	int c = -1, p = -1, one = 1, two = 2;
246 	struct test_sockmap_ktls *skel;
247 	unsigned char *data = NULL;
248 	struct msghdr msg = {0};
249 	struct iovec iov[2];
250 	int prog_fd, map_fd;
251 	int txrx_buf = 1024;
252 	int iov_length = 8192;
253 	int err;
254 
255 	skel = test_sockmap_ktls__open_and_load();
256 	if (!ASSERT_TRUE(skel, "open ktls skel"))
257 		return;
258 
259 	err = create_pair(family, sotype, &c, &p);
260 	if (!ASSERT_OK(err, "create_pair()"))
261 		goto out;
262 
263 	err = setsockopt(c, SOL_SOCKET, SO_RCVBUFFORCE, &txrx_buf, sizeof(int));
264 	err |= setsockopt(p, SOL_SOCKET, SO_SNDBUFFORCE, &txrx_buf, sizeof(int));
265 	if (!ASSERT_OK(err, "set buf limit"))
266 		goto out;
267 
268 	prog_fd = bpf_program__fd(skel->progs.prog_sk_policy_redir);
269 	map_fd = bpf_map__fd(skel->maps.sock_map);
270 
271 	err = bpf_prog_attach(prog_fd, map_fd, BPF_SK_MSG_VERDICT, 0);
272 	if (!ASSERT_OK(err, "bpf_prog_attach sk msg"))
273 		goto out;
274 
275 	err = bpf_map_update_elem(map_fd, &one, &c, BPF_NOEXIST);
276 	if (!ASSERT_OK(err, "bpf_map_update_elem(c)"))
277 		goto out;
278 
279 	err = bpf_map_update_elem(map_fd, &two, &p, BPF_NOEXIST);
280 	if (!ASSERT_OK(err, "bpf_map_update_elem(p)"))
281 		goto out;
282 
283 	skel->bss->apply_bytes = 1024;
284 
285 	err = init_ktls_pairs(c, p);
286 	if (!ASSERT_OK(err, "init_ktls_pairs(c, p)"))
287 		goto out;
288 
289 	data = calloc(iov_length, sizeof(char));
290 	if (!data)
291 		goto out;
292 
293 	iov[0].iov_base = data;
294 	iov[0].iov_len = iov_length;
295 	iov[1].iov_base = data;
296 	iov[1].iov_len = iov_length;
297 	msg.msg_iov = iov;
298 	msg.msg_iovlen = 2;
299 
300 	for (;;) {
301 		err = sendmsg(c, &msg, MSG_DONTWAIT);
302 		if (err <= 0)
303 			break;
304 	}
305 
306 out:
307 	if (data)
308 		free(data);
309 	if (c != -1)
310 		close(c);
311 	if (p != -1)
312 		close(p);
313 
314 	test_sockmap_ktls__destroy(skel);
315 }
316 
test_sockmap_ktls_tx_pop(int family,int sotype)317 static void test_sockmap_ktls_tx_pop(int family, int sotype)
318 {
319 	char msg[37] = "0123456789abcdefghijklmnopqrstuvwxyz\0";
320 	int c = 0, p = 0, one = 1, sent, recvd;
321 	struct test_sockmap_ktls *skel;
322 	int prog_fd, map_fd;
323 	char rcv[50] = {0};
324 	int err;
325 	int i, m, r;
326 
327 	skel = test_sockmap_ktls__open_and_load();
328 	if (!ASSERT_TRUE(skel, "open ktls skel"))
329 		return;
330 
331 	err = create_pair(family, sotype, &c, &p);
332 	if (!ASSERT_OK(err, "create_pair()"))
333 		goto out;
334 
335 	prog_fd = bpf_program__fd(skel->progs.prog_sk_policy);
336 	map_fd = bpf_map__fd(skel->maps.sock_map);
337 
338 	err = bpf_prog_attach(prog_fd, map_fd, BPF_SK_MSG_VERDICT, 0);
339 	if (!ASSERT_OK(err, "bpf_prog_attach sk msg"))
340 		goto out;
341 
342 	err = bpf_map_update_elem(map_fd, &one, &c, BPF_NOEXIST);
343 	if (!ASSERT_OK(err, "bpf_map_update_elem(c)"))
344 		goto out;
345 
346 	err = init_ktls_pairs(c, p);
347 	if (!ASSERT_OK(err, "init_ktls_pairs(c, p)"))
348 		goto out;
349 
350 	struct {
351 		int	pop_start;
352 		int	pop_len;
353 	} pop_policy[] = {
354 		/* trim the start */
355 		{0, 2},
356 		{0, 10},
357 		{1, 2},
358 		{1, 10},
359 		/* trim the end */
360 		{35, 2},
361 		/* New entries should be added before this line */
362 		{-1, -1},
363 	};
364 
365 	i = 0;
366 	while (pop_policy[i].pop_start >= 0) {
367 		skel->bss->pop_start = pop_policy[i].pop_start;
368 		skel->bss->pop_end =  pop_policy[i].pop_len;
369 
370 		sent = send(c, msg, sizeof(msg), 0);
371 		if (!ASSERT_EQ(sent, sizeof(msg), "send(msg)"))
372 			goto out;
373 
374 		recvd = recv_timeout(p, rcv, sizeof(rcv), MSG_DONTWAIT, 1);
375 		if (!ASSERT_EQ(recvd, sizeof(msg) - pop_policy[i].pop_len, "pop len mismatch"))
376 			goto out;
377 
378 		/* verify the data
379 		 * msg: 0123456789a bcdefghij klmnopqrstuvwxyz
380 		 *                  |       |
381 		 *                  popped data
382 		 */
383 		for (m = 0, r = 0; m < sizeof(msg);) {
384 			/* skip checking the data that has been popped */
385 			if (m >= pop_policy[i].pop_start &&
386 			    m <= pop_policy[i].pop_start + pop_policy[i].pop_len - 1) {
387 				m++;
388 				continue;
389 			}
390 
391 			if (!ASSERT_EQ(msg[m], rcv[r], "data mismatch"))
392 				goto out;
393 			m++;
394 			r++;
395 		}
396 		i++;
397 	}
398 out:
399 	if (c)
400 		close(c);
401 	if (p)
402 		close(p);
403 	test_sockmap_ktls__destroy(skel);
404 }
405 
run_tests(int family,enum bpf_map_type map_type)406 static void run_tests(int family, enum bpf_map_type map_type)
407 {
408 	int map;
409 
410 	map = bpf_map_create(map_type, NULL, sizeof(int), sizeof(int), 1, NULL);
411 	if (!ASSERT_GE(map, 0, "bpf_map_create"))
412 		return;
413 
414 	if (test__start_subtest(fmt_test_name("update_fails_when_sock_has_ulp", family, map_type)))
415 		test_sockmap_ktls_update_fails_when_sock_has_ulp(family, map);
416 
417 	close(map);
418 }
419 
run_ktls_test(int family,int sotype)420 static void run_ktls_test(int family, int sotype)
421 {
422 	if (test__start_subtest("tls simple offload"))
423 		test_sockmap_ktls_offload(family, sotype);
424 	if (test__start_subtest("tls tx cork"))
425 		test_sockmap_ktls_tx_cork(family, sotype, false);
426 	if (test__start_subtest("tls tx cork with push"))
427 		test_sockmap_ktls_tx_cork(family, sotype, true);
428 	if (test__start_subtest("tls tx egress with no buf"))
429 		test_sockmap_ktls_tx_no_buf(family, sotype, true);
430 	if (test__start_subtest("tls tx with pop"))
431 		test_sockmap_ktls_tx_pop(family, sotype);
432 }
433 
test_sockmap_ktls(void)434 void test_sockmap_ktls(void)
435 {
436 	run_tests(AF_INET, BPF_MAP_TYPE_SOCKMAP);
437 	run_tests(AF_INET, BPF_MAP_TYPE_SOCKHASH);
438 	run_tests(AF_INET6, BPF_MAP_TYPE_SOCKMAP);
439 	run_tests(AF_INET6, BPF_MAP_TYPE_SOCKHASH);
440 	run_ktls_test(AF_INET, SOCK_STREAM);
441 	run_ktls_test(AF_INET6, SOCK_STREAM);
442 }
443