1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright Amazon.com Inc. or its affiliates. */
3
4 #include "vmlinux.h"
5
6 #include <bpf/bpf_helpers.h>
7 #include <bpf/bpf_endian.h>
8 #include "bpf_tracing_net.h"
9 #include "bpf_kfuncs.h"
10 #include "test_siphash.h"
11 #include "test_tcp_custom_syncookie.h"
12 #include "bpf_misc.h"
13
14 #define MAX_PACKET_OFF 0xffff
15
16 /* Hash is calculated for each client and split into ISN and TS.
17 *
18 * MSB LSB
19 * ISN: | 31 ... 8 | 7 6 | 5 | 4 | 3 2 1 0 |
20 * | Hash_1 | MSS | ECN | SACK | WScale |
21 *
22 * TS: | 31 ... 8 | 7 ... 0 |
23 * | Random | Hash_2 |
24 */
25 #define COOKIE_BITS 8
26 #define COOKIE_MASK (((__u32)1 << COOKIE_BITS) - 1)
27
28 enum {
29 /* 0xf is invalid thus means that SYN did not have WScale. */
30 BPF_SYNCOOKIE_WSCALE_MASK = (1 << 4) - 1,
31 BPF_SYNCOOKIE_SACK = (1 << 4),
32 BPF_SYNCOOKIE_ECN = (1 << 5),
33 };
34
35 #define MSS_LOCAL_IPV4 65495
36 #define MSS_LOCAL_IPV6 65476
37
38 const __u16 msstab4[] = {
39 536,
40 1300,
41 1460,
42 MSS_LOCAL_IPV4,
43 };
44
45 const __u16 msstab6[] = {
46 1280 - 60, /* IPV6_MIN_MTU - 60 */
47 1480 - 60,
48 9000 - 60,
49 MSS_LOCAL_IPV6,
50 };
51
52 static siphash_key_t test_key_siphash = {
53 { 0x0706050403020100ULL, 0x0f0e0d0c0b0a0908ULL }
54 };
55
56 struct tcp_syncookie {
57 struct __sk_buff *skb;
58 void *data;
59 void *data_end;
60 struct ethhdr *eth;
61 struct iphdr *ipv4;
62 struct ipv6hdr *ipv6;
63 struct tcphdr *tcp;
64 __be32 *ptr32;
65 struct bpf_tcp_req_attrs attrs;
66 u32 off;
67 u32 cookie;
68 u64 first;
69 };
70
71 bool handled_syn, handled_ack;
72
tcp_load_headers(struct tcp_syncookie * ctx)73 static int tcp_load_headers(struct tcp_syncookie *ctx)
74 {
75 ctx->data = (void *)(long)ctx->skb->data;
76 ctx->data_end = (void *)(long)ctx->skb->data_end;
77 ctx->eth = (struct ethhdr *)(long)ctx->skb->data;
78
79 if (ctx->eth + 1 > ctx->data_end)
80 goto err;
81
82 switch (bpf_ntohs(ctx->eth->h_proto)) {
83 case ETH_P_IP:
84 ctx->ipv4 = (struct iphdr *)(ctx->eth + 1);
85
86 if (ctx->ipv4 + 1 > ctx->data_end)
87 goto err;
88
89 if (ctx->ipv4->ihl != sizeof(*ctx->ipv4) / 4)
90 goto err;
91
92 if (ctx->ipv4->version != 4)
93 goto err;
94
95 if (ctx->ipv4->protocol != IPPROTO_TCP)
96 goto err;
97
98 ctx->tcp = (struct tcphdr *)(ctx->ipv4 + 1);
99 break;
100 case ETH_P_IPV6:
101 ctx->ipv6 = (struct ipv6hdr *)(ctx->eth + 1);
102
103 if (ctx->ipv6 + 1 > ctx->data_end)
104 goto err;
105
106 if (ctx->ipv6->version != 6)
107 goto err;
108
109 if (ctx->ipv6->nexthdr != NEXTHDR_TCP)
110 goto err;
111
112 ctx->tcp = (struct tcphdr *)(ctx->ipv6 + 1);
113 break;
114 default:
115 goto err;
116 }
117
118 if (ctx->tcp + 1 > ctx->data_end)
119 goto err;
120
121 return 0;
122 err:
123 return -1;
124 }
125
tcp_reload_headers(struct tcp_syncookie * ctx)126 static int tcp_reload_headers(struct tcp_syncookie *ctx)
127 {
128 /* Without volatile,
129 * R3 32-bit pointer arithmetic prohibited
130 */
131 volatile u64 data_len = ctx->skb->data_end - ctx->skb->data;
132
133 if (ctx->tcp->doff < sizeof(*ctx->tcp) / 4)
134 goto err;
135
136 /* Needed to calculate csum and parse TCP options. */
137 if (bpf_skb_change_tail(ctx->skb, data_len + 60 - ctx->tcp->doff * 4, 0))
138 goto err;
139
140 ctx->data = (void *)(long)ctx->skb->data;
141 ctx->data_end = (void *)(long)ctx->skb->data_end;
142 ctx->eth = (struct ethhdr *)(long)ctx->skb->data;
143 if (ctx->ipv4) {
144 ctx->ipv4 = (struct iphdr *)(ctx->eth + 1);
145 ctx->ipv6 = NULL;
146 ctx->tcp = (struct tcphdr *)(ctx->ipv4 + 1);
147 } else {
148 ctx->ipv4 = NULL;
149 ctx->ipv6 = (struct ipv6hdr *)(ctx->eth + 1);
150 ctx->tcp = (struct tcphdr *)(ctx->ipv6 + 1);
151 }
152
153 if ((void *)ctx->tcp + 60 > ctx->data_end)
154 goto err;
155
156 return 0;
157 err:
158 return -1;
159 }
160
tcp_v4_csum(struct tcp_syncookie * ctx,__wsum csum)161 static __sum16 tcp_v4_csum(struct tcp_syncookie *ctx, __wsum csum)
162 {
163 return csum_tcpudp_magic(ctx->ipv4->saddr, ctx->ipv4->daddr,
164 ctx->tcp->doff * 4, IPPROTO_TCP, csum);
165 }
166
tcp_v6_csum(struct tcp_syncookie * ctx,__wsum csum)167 static __sum16 tcp_v6_csum(struct tcp_syncookie *ctx, __wsum csum)
168 {
169 return csum_ipv6_magic(&ctx->ipv6->saddr, &ctx->ipv6->daddr,
170 ctx->tcp->doff * 4, IPPROTO_TCP, csum);
171 }
172
tcp_validate_header(struct tcp_syncookie * ctx)173 static int tcp_validate_header(struct tcp_syncookie *ctx)
174 {
175 s64 csum;
176
177 if (tcp_reload_headers(ctx))
178 goto err;
179
180 csum = bpf_csum_diff(0, 0, (void *)ctx->tcp, ctx->tcp->doff * 4, 0);
181 if (csum < 0)
182 goto err;
183
184 if (ctx->ipv4) {
185 /* check tcp_v4_csum(csum) is 0 if not on lo. */
186
187 csum = bpf_csum_diff(0, 0, (void *)ctx->ipv4, ctx->ipv4->ihl * 4, 0);
188 if (csum < 0)
189 goto err;
190
191 if (csum_fold(csum) != 0)
192 goto err;
193 } else if (ctx->ipv6) {
194 /* check tcp_v6_csum(csum) is 0 if not on lo. */
195 }
196
197 return 0;
198 err:
199 return -1;
200 }
201
next(struct tcp_syncookie * ctx,__u32 sz)202 static __always_inline void *next(struct tcp_syncookie *ctx, __u32 sz)
203 {
204 __u64 off = ctx->off;
205 __u8 *data;
206
207 /* Verifier forbids access to packet when offset exceeds MAX_PACKET_OFF */
208 if (off > MAX_PACKET_OFF - sz)
209 return NULL;
210
211 data = ctx->data + off;
212 barrier_var(data);
213 if (data + sz >= ctx->data_end)
214 return NULL;
215
216 ctx->off += sz;
217 return data;
218 }
219
tcp_parse_option(__u32 index,struct tcp_syncookie * ctx)220 static int tcp_parse_option(__u32 index, struct tcp_syncookie *ctx)
221 {
222 __u8 *opcode, *opsize, *wscale;
223 __u32 *tsval, *tsecr;
224 __u16 *mss;
225 __u32 off;
226
227 off = ctx->off;
228 opcode = next(ctx, 1);
229 if (!opcode)
230 goto stop;
231
232 if (*opcode == TCPOPT_EOL)
233 goto stop;
234
235 if (*opcode == TCPOPT_NOP)
236 goto next;
237
238 opsize = next(ctx, 1);
239 if (!opsize)
240 goto stop;
241
242 if (*opsize < 2)
243 goto stop;
244
245 switch (*opcode) {
246 case TCPOPT_MSS:
247 mss = next(ctx, 2);
248 if (*opsize == TCPOLEN_MSS && ctx->tcp->syn && mss)
249 ctx->attrs.mss = get_unaligned_be16(mss);
250 break;
251 case TCPOPT_WINDOW:
252 wscale = next(ctx, 1);
253 if (*opsize == TCPOLEN_WINDOW && ctx->tcp->syn && wscale) {
254 ctx->attrs.wscale_ok = 1;
255 ctx->attrs.snd_wscale = *wscale;
256 }
257 break;
258 case TCPOPT_TIMESTAMP:
259 tsval = next(ctx, 4);
260 tsecr = next(ctx, 4);
261 if (*opsize == TCPOLEN_TIMESTAMP && tsval && tsecr) {
262 ctx->attrs.rcv_tsval = get_unaligned_be32(tsval);
263 ctx->attrs.rcv_tsecr = get_unaligned_be32(tsecr);
264
265 if (ctx->tcp->syn && ctx->attrs.rcv_tsecr)
266 ctx->attrs.tstamp_ok = 0;
267 else
268 ctx->attrs.tstamp_ok = 1;
269 }
270 break;
271 case TCPOPT_SACK_PERM:
272 if (*opsize == TCPOLEN_SACK_PERM && ctx->tcp->syn)
273 ctx->attrs.sack_ok = 1;
274 break;
275 }
276
277 ctx->off = off + *opsize;
278 next:
279 return 0;
280 stop:
281 return 1;
282 }
283
tcp_parse_options(struct tcp_syncookie * ctx)284 static void tcp_parse_options(struct tcp_syncookie *ctx)
285 {
286 ctx->off = (__u8 *)(ctx->tcp + 1) - (__u8 *)ctx->data,
287
288 bpf_loop(40, tcp_parse_option, ctx, 0);
289 }
290
tcp_validate_sysctl(struct tcp_syncookie * ctx)291 static int tcp_validate_sysctl(struct tcp_syncookie *ctx)
292 {
293 if ((ctx->ipv4 && ctx->attrs.mss != MSS_LOCAL_IPV4) ||
294 (ctx->ipv6 && ctx->attrs.mss != MSS_LOCAL_IPV6))
295 goto err;
296
297 if (!ctx->attrs.wscale_ok ||
298 !ctx->attrs.snd_wscale ||
299 ctx->attrs.snd_wscale >= BPF_SYNCOOKIE_WSCALE_MASK)
300 goto err;
301
302 if (!ctx->attrs.tstamp_ok)
303 goto err;
304
305 if (!ctx->attrs.sack_ok)
306 goto err;
307
308 if (!ctx->tcp->ece || !ctx->tcp->cwr)
309 goto err;
310
311 return 0;
312 err:
313 return -1;
314 }
315
tcp_prepare_cookie(struct tcp_syncookie * ctx)316 static void tcp_prepare_cookie(struct tcp_syncookie *ctx)
317 {
318 u32 seq = bpf_ntohl(ctx->tcp->seq);
319 u64 first = 0, second;
320 int mssind = 0;
321 u32 hash;
322
323 if (ctx->ipv4) {
324 for (mssind = ARRAY_SIZE(msstab4) - 1; mssind; mssind--)
325 if (ctx->attrs.mss >= msstab4[mssind])
326 break;
327
328 ctx->attrs.mss = msstab4[mssind];
329
330 first = (u64)ctx->ipv4->saddr << 32 | ctx->ipv4->daddr;
331 } else if (ctx->ipv6) {
332 for (mssind = ARRAY_SIZE(msstab6) - 1; mssind; mssind--)
333 if (ctx->attrs.mss >= msstab6[mssind])
334 break;
335
336 ctx->attrs.mss = msstab6[mssind];
337
338 first = (u64)ctx->ipv6->saddr.in6_u.u6_addr8[0] << 32 |
339 ctx->ipv6->daddr.in6_u.u6_addr32[0];
340 }
341
342 second = (u64)seq << 32 | ctx->tcp->source << 16 | ctx->tcp->dest;
343 hash = siphash_2u64(first, second, &test_key_siphash);
344
345 if (ctx->attrs.tstamp_ok) {
346 ctx->attrs.rcv_tsecr = bpf_get_prandom_u32();
347 ctx->attrs.rcv_tsecr &= ~COOKIE_MASK;
348 ctx->attrs.rcv_tsecr |= hash & COOKIE_MASK;
349 }
350
351 hash &= ~COOKIE_MASK;
352 hash |= mssind << 6;
353
354 if (ctx->attrs.wscale_ok)
355 hash |= ctx->attrs.snd_wscale & BPF_SYNCOOKIE_WSCALE_MASK;
356
357 if (ctx->attrs.sack_ok)
358 hash |= BPF_SYNCOOKIE_SACK;
359
360 if (ctx->attrs.tstamp_ok && ctx->tcp->ece && ctx->tcp->cwr)
361 hash |= BPF_SYNCOOKIE_ECN;
362
363 ctx->cookie = hash;
364 }
365
tcp_write_options(struct tcp_syncookie * ctx)366 static void tcp_write_options(struct tcp_syncookie *ctx)
367 {
368 ctx->ptr32 = (__be32 *)(ctx->tcp + 1);
369
370 *ctx->ptr32++ = bpf_htonl(TCPOPT_MSS << 24 | TCPOLEN_MSS << 16 |
371 ctx->attrs.mss);
372
373 if (ctx->attrs.wscale_ok)
374 *ctx->ptr32++ = bpf_htonl(TCPOPT_NOP << 24 |
375 TCPOPT_WINDOW << 16 |
376 TCPOLEN_WINDOW << 8 |
377 ctx->attrs.snd_wscale);
378
379 if (ctx->attrs.tstamp_ok) {
380 if (ctx->attrs.sack_ok)
381 *ctx->ptr32++ = bpf_htonl(TCPOPT_SACK_PERM << 24 |
382 TCPOLEN_SACK_PERM << 16 |
383 TCPOPT_TIMESTAMP << 8 |
384 TCPOLEN_TIMESTAMP);
385 else
386 *ctx->ptr32++ = bpf_htonl(TCPOPT_NOP << 24 |
387 TCPOPT_NOP << 16 |
388 TCPOPT_TIMESTAMP << 8 |
389 TCPOLEN_TIMESTAMP);
390
391 *ctx->ptr32++ = bpf_htonl(ctx->attrs.rcv_tsecr);
392 *ctx->ptr32++ = bpf_htonl(ctx->attrs.rcv_tsval);
393 } else if (ctx->attrs.sack_ok) {
394 *ctx->ptr32++ = bpf_htonl(TCPOPT_NOP << 24 |
395 TCPOPT_NOP << 16 |
396 TCPOPT_SACK_PERM << 8 |
397 TCPOLEN_SACK_PERM);
398 }
399 }
400
tcp_handle_syn(struct tcp_syncookie * ctx)401 static int tcp_handle_syn(struct tcp_syncookie *ctx)
402 {
403 s64 csum;
404
405 if (tcp_validate_header(ctx))
406 goto err;
407
408 tcp_parse_options(ctx);
409
410 if (tcp_validate_sysctl(ctx))
411 goto err;
412
413 tcp_prepare_cookie(ctx);
414 tcp_write_options(ctx);
415
416 swap(ctx->tcp->source, ctx->tcp->dest);
417 ctx->tcp->check = 0;
418 ctx->tcp->ack_seq = bpf_htonl(bpf_ntohl(ctx->tcp->seq) + 1);
419 ctx->tcp->seq = bpf_htonl(ctx->cookie);
420 ctx->tcp->doff = ((long)ctx->ptr32 - (long)ctx->tcp) >> 2;
421 ctx->tcp->ack = 1;
422 if (!ctx->attrs.tstamp_ok || !ctx->tcp->ece || !ctx->tcp->cwr)
423 ctx->tcp->ece = 0;
424 ctx->tcp->cwr = 0;
425
426 csum = bpf_csum_diff(0, 0, (void *)ctx->tcp, ctx->tcp->doff * 4, 0);
427 if (csum < 0)
428 goto err;
429
430 if (ctx->ipv4) {
431 swap(ctx->ipv4->saddr, ctx->ipv4->daddr);
432 ctx->tcp->check = tcp_v4_csum(ctx, csum);
433
434 ctx->ipv4->check = 0;
435 ctx->ipv4->tos = 0;
436 ctx->ipv4->tot_len = bpf_htons((long)ctx->ptr32 - (long)ctx->ipv4);
437 ctx->ipv4->id = 0;
438 ctx->ipv4->ttl = 64;
439
440 csum = bpf_csum_diff(0, 0, (void *)ctx->ipv4, sizeof(*ctx->ipv4), 0);
441 if (csum < 0)
442 goto err;
443
444 ctx->ipv4->check = csum_fold(csum);
445 } else if (ctx->ipv6) {
446 swap(ctx->ipv6->saddr, ctx->ipv6->daddr);
447 ctx->tcp->check = tcp_v6_csum(ctx, csum);
448
449 *(__be32 *)ctx->ipv6 = bpf_htonl(0x60000000);
450 ctx->ipv6->payload_len = bpf_htons((long)ctx->ptr32 - (long)ctx->tcp);
451 ctx->ipv6->hop_limit = 64;
452 }
453
454 swap_array(ctx->eth->h_source, ctx->eth->h_dest);
455
456 if (bpf_skb_change_tail(ctx->skb, (long)ctx->ptr32 - (long)ctx->eth, 0))
457 goto err;
458
459 return bpf_redirect(ctx->skb->ifindex, 0);
460 err:
461 return TC_ACT_SHOT;
462 }
463
tcp_validate_cookie(struct tcp_syncookie * ctx)464 static int tcp_validate_cookie(struct tcp_syncookie *ctx)
465 {
466 u32 cookie = bpf_ntohl(ctx->tcp->ack_seq) - 1;
467 u32 seq = bpf_ntohl(ctx->tcp->seq) - 1;
468 u64 first = 0, second;
469 int mssind;
470 u32 hash;
471
472 if (ctx->ipv4)
473 first = (u64)ctx->ipv4->saddr << 32 | ctx->ipv4->daddr;
474 else if (ctx->ipv6)
475 first = (u64)ctx->ipv6->saddr.in6_u.u6_addr8[0] << 32 |
476 ctx->ipv6->daddr.in6_u.u6_addr32[0];
477
478 second = (u64)seq << 32 | ctx->tcp->source << 16 | ctx->tcp->dest;
479 hash = siphash_2u64(first, second, &test_key_siphash);
480
481 if (ctx->attrs.tstamp_ok)
482 hash -= ctx->attrs.rcv_tsecr & COOKIE_MASK;
483 else
484 hash &= ~COOKIE_MASK;
485
486 hash -= cookie & ~COOKIE_MASK;
487 if (hash)
488 goto err;
489
490 mssind = (cookie & (3 << 6)) >> 6;
491 if (ctx->ipv4)
492 ctx->attrs.mss = msstab4[mssind];
493 else
494 ctx->attrs.mss = msstab6[mssind];
495
496 ctx->attrs.snd_wscale = cookie & BPF_SYNCOOKIE_WSCALE_MASK;
497 ctx->attrs.rcv_wscale = ctx->attrs.snd_wscale;
498 ctx->attrs.wscale_ok = ctx->attrs.snd_wscale == BPF_SYNCOOKIE_WSCALE_MASK;
499 ctx->attrs.sack_ok = cookie & BPF_SYNCOOKIE_SACK;
500 ctx->attrs.ecn_ok = cookie & BPF_SYNCOOKIE_ECN;
501
502 return 0;
503 err:
504 return -1;
505 }
506
tcp_handle_ack(struct tcp_syncookie * ctx)507 static int tcp_handle_ack(struct tcp_syncookie *ctx)
508 {
509 struct bpf_sock_tuple tuple;
510 struct bpf_sock *skc;
511 int ret = TC_ACT_OK;
512 struct sock *sk;
513 u32 tuple_size;
514
515 if (ctx->ipv4) {
516 tuple.ipv4.saddr = ctx->ipv4->saddr;
517 tuple.ipv4.daddr = ctx->ipv4->daddr;
518 tuple.ipv4.sport = ctx->tcp->source;
519 tuple.ipv4.dport = ctx->tcp->dest;
520 tuple_size = sizeof(tuple.ipv4);
521 } else if (ctx->ipv6) {
522 __builtin_memcpy(tuple.ipv6.saddr, &ctx->ipv6->saddr, sizeof(tuple.ipv6.saddr));
523 __builtin_memcpy(tuple.ipv6.daddr, &ctx->ipv6->daddr, sizeof(tuple.ipv6.daddr));
524 tuple.ipv6.sport = ctx->tcp->source;
525 tuple.ipv6.dport = ctx->tcp->dest;
526 tuple_size = sizeof(tuple.ipv6);
527 } else {
528 goto out;
529 }
530
531 skc = bpf_skc_lookup_tcp(ctx->skb, &tuple, tuple_size, -1, 0);
532 if (!skc)
533 goto out;
534
535 if (skc->state != TCP_LISTEN)
536 goto release;
537
538 sk = (struct sock *)bpf_skc_to_tcp_sock(skc);
539 if (!sk)
540 goto err;
541
542 if (tcp_validate_header(ctx))
543 goto err;
544
545 tcp_parse_options(ctx);
546
547 if (tcp_validate_cookie(ctx))
548 goto err;
549
550 ret = bpf_sk_assign_tcp_reqsk(ctx->skb, sk, &ctx->attrs, sizeof(ctx->attrs));
551 if (ret < 0)
552 goto err;
553
554 release:
555 bpf_sk_release(skc);
556 out:
557 return ret;
558
559 err:
560 ret = TC_ACT_SHOT;
561 goto release;
562 }
563
564 SEC("tc")
tcp_custom_syncookie(struct __sk_buff * skb)565 int tcp_custom_syncookie(struct __sk_buff *skb)
566 {
567 struct tcp_syncookie ctx = {
568 .skb = skb,
569 };
570
571 if (tcp_load_headers(&ctx))
572 return TC_ACT_OK;
573
574 if (ctx.tcp->rst)
575 return TC_ACT_OK;
576
577 if (ctx.tcp->syn) {
578 if (ctx.tcp->ack)
579 return TC_ACT_OK;
580
581 handled_syn = true;
582
583 return tcp_handle_syn(&ctx);
584 }
585
586 handled_ack = true;
587
588 return tcp_handle_ack(&ctx);
589 }
590
591 char _license[] SEC("license") = "GPL";
592