1 // SPDX-License-Identifier: GPL-2.0-only
2 // Copyright (C) 2019-2020 Arm Ltd.
3
4 #include <linux/compiler.h>
5 #include <linux/export.h>
6 #include <linux/kasan-checks.h>
7 #include <linux/kernel.h>
8
9 #include <net/checksum.h>
10
accumulate(u64 sum,u64 data)11 static u64 accumulate(u64 sum, u64 data)
12 {
13 sum += data;
14 if (sum < data)
15 sum += 1;
16 return sum;
17 }
18
19 /*
20 * We over-read the buffer and this makes KASAN unhappy. Instead, disable
21 * instrumentation and call kasan explicitly.
22 */
do_csum(const unsigned char * buff,int len)23 unsigned int __no_sanitize_address do_csum(const unsigned char *buff, int len)
24 {
25 unsigned int offset, shift, sum;
26 const u64 *ptr;
27 u64 data, sum64 = 0;
28
29 if (unlikely(len <= 0))
30 return 0;
31
32 offset = (unsigned long)buff & 7;
33 /*
34 * This is to all intents and purposes safe, since rounding down cannot
35 * result in a different page or cache line being accessed, and @buff
36 * should absolutely not be pointing to anything read-sensitive. We do,
37 * however, have to be careful not to piss off KASAN, which means using
38 * unchecked reads to accommodate the head and tail, for which we'll
39 * compensate with an explicit check up-front.
40 */
41 kasan_check_read(buff, len);
42 ptr = (u64 *)(buff - offset);
43 len = len + offset - 8;
44
45 /*
46 * Head: zero out any excess leading bytes. Shifting back by the same
47 * amount should be at least as fast as any other way of handling the
48 * odd/even alignment, and means we can ignore it until the very end.
49 */
50 shift = offset * 8;
51 data = *ptr++;
52 data = (data >> shift) << shift;
53
54 /*
55 * Body: straightforward aligned loads from here on (the paired loads
56 * underlying the quadword type still only need dword alignment). The
57 * main loop strictly excludes the tail, so the second loop will always
58 * run at least once.
59 */
60 while (unlikely(len > 64)) {
61 __uint128_t tmp1, tmp2, tmp3, tmp4;
62
63 tmp1 = *(__uint128_t *)ptr;
64 tmp2 = *(__uint128_t *)(ptr + 2);
65 tmp3 = *(__uint128_t *)(ptr + 4);
66 tmp4 = *(__uint128_t *)(ptr + 6);
67
68 len -= 64;
69 ptr += 8;
70
71 /* This is the "don't dump the carry flag into a GPR" idiom */
72 tmp1 += (tmp1 >> 64) | (tmp1 << 64);
73 tmp2 += (tmp2 >> 64) | (tmp2 << 64);
74 tmp3 += (tmp3 >> 64) | (tmp3 << 64);
75 tmp4 += (tmp4 >> 64) | (tmp4 << 64);
76 tmp1 = ((tmp1 >> 64) << 64) | (tmp2 >> 64);
77 tmp1 += (tmp1 >> 64) | (tmp1 << 64);
78 tmp3 = ((tmp3 >> 64) << 64) | (tmp4 >> 64);
79 tmp3 += (tmp3 >> 64) | (tmp3 << 64);
80 tmp1 = ((tmp1 >> 64) << 64) | (tmp3 >> 64);
81 tmp1 += (tmp1 >> 64) | (tmp1 << 64);
82 tmp1 = ((tmp1 >> 64) << 64) | sum64;
83 tmp1 += (tmp1 >> 64) | (tmp1 << 64);
84 sum64 = tmp1 >> 64;
85 }
86 while (len > 8) {
87 __uint128_t tmp;
88
89 sum64 = accumulate(sum64, data);
90 tmp = *(__uint128_t *)ptr;
91
92 len -= 16;
93 ptr += 2;
94
95 data = tmp >> 64;
96 sum64 = accumulate(sum64, tmp);
97 }
98 if (len > 0) {
99 sum64 = accumulate(sum64, data);
100 data = *ptr;
101 len -= 8;
102 }
103 /*
104 * Tail: zero any over-read bytes similarly to the head, again
105 * preserving odd/even alignment.
106 */
107 shift = len * -8;
108 data = (data << shift) >> shift;
109 sum64 = accumulate(sum64, data);
110
111 /* Finally, folding */
112 sum64 += (sum64 >> 32) | (sum64 << 32);
113 sum = sum64 >> 32;
114 sum += (sum >> 16) | (sum << 16);
115 if (offset & 1)
116 return (u16)swab32(sum);
117
118 return sum >> 16;
119 }
120
csum_ipv6_magic(const struct in6_addr * saddr,const struct in6_addr * daddr,__u32 len,__u8 proto,__wsum csum)121 __sum16 csum_ipv6_magic(const struct in6_addr *saddr,
122 const struct in6_addr *daddr,
123 __u32 len, __u8 proto, __wsum csum)
124 {
125 __uint128_t src, dst;
126 u64 sum = (__force u64)csum;
127
128 src = *(const __uint128_t *)saddr->s6_addr;
129 dst = *(const __uint128_t *)daddr->s6_addr;
130
131 sum += (__force u32)htonl(len);
132 sum += (u32)proto << 24;
133 src += (src >> 64) | (src << 64);
134 dst += (dst >> 64) | (dst << 64);
135
136 sum = accumulate(sum, src >> 64);
137 sum = accumulate(sum, dst >> 64);
138
139 sum += ((sum >> 32) | (sum << 32));
140 return csum_fold((__force __wsum)(sum >> 32));
141 }
142 EXPORT_SYMBOL(csum_ipv6_magic);
143