1 // SPDX-License-Identifier: BSD-2-Clause
2 /*
3  * Copyright (c) 2019 Huawei Technologies Co., Ltd
4  */
5 /*
6  * SM3 Hash algorithm
7  * thanks to Xyssl
8  * author:goldboar
9  * email:goldboar@163.com
10  * 2011-10-26
11  */
12 
13 #include <string.h>
14 #include <string_ext.h>
15 
16 #include "sm3.h"
17 
18 #define GET_UINT32_BE(n, b, i)				\
19 	do {						\
20 		(n) = ((uint32_t)(b)[(i)] << 24)     |	\
21 		      ((uint32_t)(b)[(i) + 1] << 16) |	\
22 		      ((uint32_t)(b)[(i) + 2] <<  8) |	\
23 		      ((uint32_t)(b)[(i) + 3]);		\
24 	} while (0)
25 
26 #define PUT_UINT32_BE(n, b, i)				\
27 	do {						\
28 		(b)[(i)] = (uint8_t)((n) >> 24);	\
29 		(b)[(i) + 1] = (uint8_t)((n) >> 16);	\
30 		(b)[(i) + 2] = (uint8_t)((n) >>  8);	\
31 		(b)[(i) + 3] = (uint8_t)((n));		\
32 	} while (0)
33 
sm3_init(struct sm3_context * ctx)34 void sm3_init(struct sm3_context *ctx)
35 {
36 	ctx->total[0] = 0;
37 	ctx->total[1] = 0;
38 
39 	ctx->state[0] = 0x7380166F;
40 	ctx->state[1] = 0x4914B2B9;
41 	ctx->state[2] = 0x172442D7;
42 	ctx->state[3] = 0xDA8A0600;
43 	ctx->state[4] = 0xA96F30BC;
44 	ctx->state[5] = 0x163138AA;
45 	ctx->state[6] = 0xE38DEE4D;
46 	ctx->state[7] = 0xB0FB0E4E;
47 }
48 
sm3_process(struct sm3_context * ctx,const uint8_t data[64])49 static void sm3_process(struct sm3_context *ctx, const uint8_t data[64])
50 {
51 	uint32_t SS1, SS2, TT1, TT2, W[68], W1[64];
52 	uint32_t A, B, C, D, E, F, G, H;
53 	uint32_t T[64];
54 	uint32_t Temp1, Temp2, Temp3, Temp4, Temp5;
55 	int j;
56 
57 	for (j = 0; j < 16; j++)
58 		T[j] = 0x79CC4519;
59 	for (j = 16; j < 64; j++)
60 		T[j] = 0x7A879D8A;
61 
62 	GET_UINT32_BE(W[0], data,  0);
63 	GET_UINT32_BE(W[1], data,  4);
64 	GET_UINT32_BE(W[2], data,  8);
65 	GET_UINT32_BE(W[3], data, 12);
66 	GET_UINT32_BE(W[4], data, 16);
67 	GET_UINT32_BE(W[5], data, 20);
68 	GET_UINT32_BE(W[6], data, 24);
69 	GET_UINT32_BE(W[7], data, 28);
70 	GET_UINT32_BE(W[8], data, 32);
71 	GET_UINT32_BE(W[9], data, 36);
72 	GET_UINT32_BE(W[10], data, 40);
73 	GET_UINT32_BE(W[11], data, 44);
74 	GET_UINT32_BE(W[12], data, 48);
75 	GET_UINT32_BE(W[13], data, 52);
76 	GET_UINT32_BE(W[14], data, 56);
77 	GET_UINT32_BE(W[15], data, 60);
78 
79 #define FF0(x, y, z)	((x) ^ (y) ^ (z))
80 #define FF1(x, y, z)	(((x) & (y)) | ((x) & (z)) | ((y) & (z)))
81 
82 #define GG0(x, y, z)	((x) ^ (y) ^ (z))
83 #define GG1(x, y, z)	(((x) & (y)) | ((~(x)) & (z)))
84 
85 #define SHL(x, n)	((x) << (n))
86 #define ROTL(x, n)	(SHL((x), (n) & 0x1F) | ((x) >> (32 - ((n) & 0x1F))))
87 
88 #define P0(x)	((x) ^ ROTL((x), 9) ^ ROTL((x), 17))
89 #define P1(x)	((x) ^ ROTL((x), 15) ^ ROTL((x), 23))
90 
91 	for (j = 16; j < 68; j++) {
92 		/*
93 		 * W[j] = P1( W[j-16] ^ W[j-9] ^ ROTL(W[j-3],15)) ^
94 		 *        ROTL(W[j - 13],7 ) ^ W[j-6];
95 		 */
96 
97 		Temp1 = W[j - 16] ^ W[j - 9];
98 		Temp2 = ROTL(W[j - 3], 15);
99 		Temp3 = Temp1 ^ Temp2;
100 		Temp4 = P1(Temp3);
101 		Temp5 =  ROTL(W[j - 13], 7) ^ W[j - 6];
102 		W[j] = Temp4 ^ Temp5;
103 	}
104 
105 	for (j =  0; j < 64; j++)
106 		W1[j] = W[j] ^ W[j + 4];
107 
108 	A = ctx->state[0];
109 	B = ctx->state[1];
110 	C = ctx->state[2];
111 	D = ctx->state[3];
112 	E = ctx->state[4];
113 	F = ctx->state[5];
114 	G = ctx->state[6];
115 	H = ctx->state[7];
116 
117 	for (j = 0; j < 16; j++) {
118 		SS1 = ROTL(ROTL(A, 12) + E + ROTL(T[j], j), 7);
119 		SS2 = SS1 ^ ROTL(A, 12);
120 		TT1 = FF0(A, B, C) + D + SS2 + W1[j];
121 		TT2 = GG0(E, F, G) + H + SS1 + W[j];
122 		D = C;
123 		C = ROTL(B, 9);
124 		B = A;
125 		A = TT1;
126 		H = G;
127 		G = ROTL(F, 19);
128 		F = E;
129 		E = P0(TT2);
130 	}
131 
132 	for (j = 16; j < 64; j++) {
133 		SS1 = ROTL(ROTL(A, 12) + E + ROTL(T[j], j), 7);
134 		SS2 = SS1 ^ ROTL(A, 12);
135 		TT1 = FF1(A, B, C) + D + SS2 + W1[j];
136 		TT2 = GG1(E, F, G) + H + SS1 + W[j];
137 		D = C;
138 		C = ROTL(B, 9);
139 		B = A;
140 		A = TT1;
141 		H = G;
142 		G = ROTL(F, 19);
143 		F = E;
144 		E = P0(TT2);
145 	}
146 
147 	ctx->state[0] ^= A;
148 	ctx->state[1] ^= B;
149 	ctx->state[2] ^= C;
150 	ctx->state[3] ^= D;
151 	ctx->state[4] ^= E;
152 	ctx->state[5] ^= F;
153 	ctx->state[6] ^= G;
154 	ctx->state[7] ^= H;
155 }
156 
sm3_update(struct sm3_context * ctx,const uint8_t * input,size_t ilen)157 void sm3_update(struct sm3_context *ctx, const uint8_t *input, size_t ilen)
158 {
159 	size_t fill;
160 	size_t left;
161 
162 	if (!ilen)
163 		return;
164 
165 	left = ctx->total[0] & 0x3F;
166 	fill = 64 - left;
167 
168 	ctx->total[0] += ilen;
169 
170 	if (ctx->total[0] < ilen)
171 		ctx->total[1]++;
172 
173 	if (left && ilen >= fill) {
174 		memcpy(ctx->buffer + left, input, fill);
175 		sm3_process(ctx, ctx->buffer);
176 		input += fill;
177 		ilen -= fill;
178 		left = 0;
179 	}
180 
181 	while (ilen >= 64) {
182 		sm3_process(ctx, input);
183 		input += 64;
184 		ilen -= 64;
185 	}
186 
187 	if (ilen > 0)
188 		memcpy(ctx->buffer + left, input, ilen);
189 }
190 
191 static const uint8_t sm3_padding[64] = {
192 	0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
193 	0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
194 	0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
195 	0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
196 };
197 
sm3_final(struct sm3_context * ctx,uint8_t output[32])198 void sm3_final(struct sm3_context *ctx, uint8_t output[32])
199 {
200 	uint32_t last, padn;
201 	uint32_t high, low;
202 	uint8_t msglen[8];
203 
204 	high = (ctx->total[0] >> 29) | (ctx->total[1] <<  3);
205 	low  = ctx->total[0] << 3;
206 
207 	PUT_UINT32_BE(high, msglen, 0);
208 	PUT_UINT32_BE(low,  msglen, 4);
209 
210 	last = ctx->total[0] & 0x3F;
211 	padn = (last < 56) ? (56 - last) : (120 - last);
212 
213 	sm3_update(ctx, sm3_padding, padn);
214 	sm3_update(ctx, msglen, 8);
215 
216 	PUT_UINT32_BE(ctx->state[0], output,  0);
217 	PUT_UINT32_BE(ctx->state[1], output,  4);
218 	PUT_UINT32_BE(ctx->state[2], output,  8);
219 	PUT_UINT32_BE(ctx->state[3], output, 12);
220 	PUT_UINT32_BE(ctx->state[4], output, 16);
221 	PUT_UINT32_BE(ctx->state[5], output, 20);
222 	PUT_UINT32_BE(ctx->state[6], output, 24);
223 	PUT_UINT32_BE(ctx->state[7], output, 28);
224 }
225 
sm3(const uint8_t * input,size_t ilen,uint8_t output[32])226 void sm3(const uint8_t *input, size_t ilen, uint8_t output[32])
227 {
228 	struct sm3_context ctx = { };
229 
230 	sm3_init(&ctx);
231 	sm3_update(&ctx, input, ilen);
232 	sm3_final(&ctx, output);
233 
234 	memzero_explicit(&ctx, sizeof(ctx));
235 }
236 
sm3_hmac_init(struct sm3_context * ctx,const uint8_t * key,size_t keylen)237 void sm3_hmac_init(struct sm3_context *ctx, const uint8_t *key, size_t keylen)
238 {
239 	size_t i;
240 	uint8_t sum[32];
241 
242 	if (keylen > 64) {
243 		sm3(key, keylen, sum);
244 		keylen = 32;
245 		key = sum;
246 	}
247 
248 	memset(ctx->ipad, 0x36, 64);
249 	memset(ctx->opad, 0x5C, 64);
250 
251 	for (i = 0; i < keylen; i++) {
252 		ctx->ipad[i] ^= key[i];
253 		ctx->opad[i] ^= key[i];
254 	}
255 
256 	sm3_init(ctx);
257 	sm3_update(ctx, ctx->ipad, 64);
258 
259 	memzero_explicit(sum, sizeof(sum));
260 }
261 
sm3_hmac_update(struct sm3_context * ctx,const uint8_t * input,size_t ilen)262 void sm3_hmac_update(struct sm3_context *ctx, const uint8_t *input, size_t ilen)
263 {
264 	sm3_update(ctx, input, ilen);
265 }
266 
sm3_hmac_final(struct sm3_context * ctx,uint8_t output[32])267 void sm3_hmac_final(struct sm3_context *ctx, uint8_t output[32])
268 {
269 	uint8_t tmpbuf[32];
270 
271 	sm3_final(ctx, tmpbuf);
272 	sm3_init(ctx);
273 	sm3_update(ctx, ctx->opad, 64);
274 	sm3_update(ctx, tmpbuf, 32);
275 	sm3_final(ctx, output);
276 
277 	memzero_explicit(tmpbuf, sizeof(tmpbuf));
278 }
279 
sm3_hmac(const uint8_t * key,size_t keylen,const uint8_t * input,size_t ilen,uint8_t output[32])280 void sm3_hmac(const uint8_t *key, size_t keylen, const uint8_t *input,
281 	      size_t ilen, uint8_t output[32])
282 {
283 	struct sm3_context ctx;
284 
285 	sm3_hmac_init(&ctx, key, keylen);
286 	sm3_hmac_update(&ctx, input, ilen);
287 	sm3_hmac_final(&ctx, output);
288 
289 	memzero_explicit(&ctx, sizeof(ctx));
290 }
291