1 // SPDX-License-Identifier: GPL-2.0-only
2 // Copyright (C) 2019-2020 Arm Ltd.
3
4 #include <linux/compiler.h>
5 #include <linux/kasan-checks.h>
6 #include <linux/kernel.h>
7
8 #include <net/checksum.h>
9
10 /* Looks dumb, but generates nice-ish code */
accumulate(u64 sum,u64 data)11 static u64 accumulate(u64 sum, u64 data)
12 {
13 __uint128_t tmp = (__uint128_t)sum + data;
14 return tmp + (tmp >> 64);
15 }
16
17 /*
18 * We over-read the buffer and this makes KASAN unhappy. Instead, disable
19 * instrumentation and call kasan explicitly.
20 */
do_csum(const unsigned char * buff,int len)21 unsigned int __no_sanitize_address do_csum(const unsigned char *buff, int len)
22 {
23 unsigned int offset, shift, sum;
24 const u64 *ptr;
25 u64 data, sum64 = 0;
26
27 if (unlikely(len == 0))
28 return 0;
29
30 offset = (unsigned long)buff & 7;
31 /*
32 * This is to all intents and purposes safe, since rounding down cannot
33 * result in a different page or cache line being accessed, and @buff
34 * should absolutely not be pointing to anything read-sensitive. We do,
35 * however, have to be careful not to piss off KASAN, which means using
36 * unchecked reads to accommodate the head and tail, for which we'll
37 * compensate with an explicit check up-front.
38 */
39 kasan_check_read(buff, len);
40 ptr = (u64 *)(buff - offset);
41 len = len + offset - 8;
42
43 /*
44 * Head: zero out any excess leading bytes. Shifting back by the same
45 * amount should be at least as fast as any other way of handling the
46 * odd/even alignment, and means we can ignore it until the very end.
47 */
48 shift = offset * 8;
49 data = *ptr++;
50 #ifdef __LITTLE_ENDIAN
51 data = (data >> shift) << shift;
52 #else
53 data = (data << shift) >> shift;
54 #endif
55
56 /*
57 * Body: straightforward aligned loads from here on (the paired loads
58 * underlying the quadword type still only need dword alignment). The
59 * main loop strictly excludes the tail, so the second loop will always
60 * run at least once.
61 */
62 while (unlikely(len > 64)) {
63 __uint128_t tmp1, tmp2, tmp3, tmp4;
64
65 tmp1 = *(__uint128_t *)ptr;
66 tmp2 = *(__uint128_t *)(ptr + 2);
67 tmp3 = *(__uint128_t *)(ptr + 4);
68 tmp4 = *(__uint128_t *)(ptr + 6);
69
70 len -= 64;
71 ptr += 8;
72
73 /* This is the "don't dump the carry flag into a GPR" idiom */
74 tmp1 += (tmp1 >> 64) | (tmp1 << 64);
75 tmp2 += (tmp2 >> 64) | (tmp2 << 64);
76 tmp3 += (tmp3 >> 64) | (tmp3 << 64);
77 tmp4 += (tmp4 >> 64) | (tmp4 << 64);
78 tmp1 = ((tmp1 >> 64) << 64) | (tmp2 >> 64);
79 tmp1 += (tmp1 >> 64) | (tmp1 << 64);
80 tmp3 = ((tmp3 >> 64) << 64) | (tmp4 >> 64);
81 tmp3 += (tmp3 >> 64) | (tmp3 << 64);
82 tmp1 = ((tmp1 >> 64) << 64) | (tmp3 >> 64);
83 tmp1 += (tmp1 >> 64) | (tmp1 << 64);
84 tmp1 = ((tmp1 >> 64) << 64) | sum64;
85 tmp1 += (tmp1 >> 64) | (tmp1 << 64);
86 sum64 = tmp1 >> 64;
87 }
88 while (len > 8) {
89 __uint128_t tmp;
90
91 sum64 = accumulate(sum64, data);
92 tmp = *(__uint128_t *)ptr;
93
94 len -= 16;
95 ptr += 2;
96
97 #ifdef __LITTLE_ENDIAN
98 data = tmp >> 64;
99 sum64 = accumulate(sum64, tmp);
100 #else
101 data = tmp;
102 sum64 = accumulate(sum64, tmp >> 64);
103 #endif
104 }
105 if (len > 0) {
106 sum64 = accumulate(sum64, data);
107 data = *ptr;
108 len -= 8;
109 }
110 /*
111 * Tail: zero any over-read bytes similarly to the head, again
112 * preserving odd/even alignment.
113 */
114 shift = len * -8;
115 #ifdef __LITTLE_ENDIAN
116 data = (data << shift) >> shift;
117 #else
118 data = (data >> shift) << shift;
119 #endif
120 sum64 = accumulate(sum64, data);
121
122 /* Finally, folding */
123 sum64 += (sum64 >> 32) | (sum64 << 32);
124 sum = sum64 >> 32;
125 sum += (sum >> 16) | (sum << 16);
126 if (offset & 1)
127 return (u16)swab32(sum);
128
129 return sum >> 16;
130 }
131
csum_ipv6_magic(const struct in6_addr * saddr,const struct in6_addr * daddr,__u32 len,__u8 proto,__wsum csum)132 __sum16 csum_ipv6_magic(const struct in6_addr *saddr,
133 const struct in6_addr *daddr,
134 __u32 len, __u8 proto, __wsum csum)
135 {
136 __uint128_t src, dst;
137 u64 sum = (__force u64)csum;
138
139 src = *(const __uint128_t *)saddr->s6_addr;
140 dst = *(const __uint128_t *)daddr->s6_addr;
141
142 sum += (__force u32)htonl(len);
143 #ifdef __LITTLE_ENDIAN
144 sum += (u32)proto << 24;
145 #else
146 sum += proto;
147 #endif
148 src += (src >> 64) | (src << 64);
149 dst += (dst >> 64) | (dst << 64);
150
151 sum = accumulate(sum, src >> 64);
152 sum = accumulate(sum, dst >> 64);
153
154 sum += ((sum >> 32) | (sum << 32));
155 return csum_fold((__force __wsum)(sum >> 32));
156 }
157 EXPORT_SYMBOL(csum_ipv6_magic);
158