1 // SPDX-License-Identifier: BSD-2-Clause
2 /*
3  * Copyright (c) 2019 Huawei Technologies Co., Ltd
4  */
5 /*
6  * SM3 Hash algorithm
7  * thanks to Xyssl
8  * author:goldboar
9  * email:goldboar@163.com
10  * 2011-10-26
11  */
12 
13 #include <compiler.h>
14 #include <crypto/crypto_accel.h>
15 #include <string_ext.h>
16 #include <string.h>
17 
18 #include "sm3.h"
19 
20 #define SM3_BLOCK_SIZE	64
21 
22 #define GET_UINT32_BE(n, b, i)				\
23 	do {						\
24 		(n) = ((uint32_t)(b)[(i)] << 24)     |	\
25 		      ((uint32_t)(b)[(i) + 1] << 16) |	\
26 		      ((uint32_t)(b)[(i) + 2] <<  8) |	\
27 		      ((uint32_t)(b)[(i) + 3]);		\
28 	} while (0)
29 
30 #define PUT_UINT32_BE(n, b, i)				\
31 	do {						\
32 		(b)[(i)] = (uint8_t)((n) >> 24);	\
33 		(b)[(i) + 1] = (uint8_t)((n) >> 16);	\
34 		(b)[(i) + 2] = (uint8_t)((n) >>  8);	\
35 		(b)[(i) + 3] = (uint8_t)((n));		\
36 	} while (0)
37 
sm3_init(struct sm3_context * ctx)38 void sm3_init(struct sm3_context *ctx)
39 {
40 	ctx->total[0] = 0;
41 	ctx->total[1] = 0;
42 
43 	ctx->state[0] = 0x7380166F;
44 	ctx->state[1] = 0x4914B2B9;
45 	ctx->state[2] = 0x172442D7;
46 	ctx->state[3] = 0xDA8A0600;
47 	ctx->state[4] = 0xA96F30BC;
48 	ctx->state[5] = 0x163138AA;
49 	ctx->state[6] = 0xE38DEE4D;
50 	ctx->state[7] = 0xB0FB0E4E;
51 }
52 
sm3_process(struct sm3_context * ctx,const uint8_t data[64])53 static void __maybe_unused sm3_process(struct sm3_context *ctx,
54 				       const uint8_t data[64])
55 {
56 	uint32_t SS1, SS2, TT1, TT2, W[68], W1[64];
57 	uint32_t A, B, C, D, E, F, G, H;
58 	uint32_t T[64];
59 	uint32_t Temp1, Temp2, Temp3, Temp4, Temp5;
60 	int j;
61 
62 	for (j = 0; j < 16; j++)
63 		T[j] = 0x79CC4519;
64 	for (j = 16; j < 64; j++)
65 		T[j] = 0x7A879D8A;
66 
67 	GET_UINT32_BE(W[0], data,  0);
68 	GET_UINT32_BE(W[1], data,  4);
69 	GET_UINT32_BE(W[2], data,  8);
70 	GET_UINT32_BE(W[3], data, 12);
71 	GET_UINT32_BE(W[4], data, 16);
72 	GET_UINT32_BE(W[5], data, 20);
73 	GET_UINT32_BE(W[6], data, 24);
74 	GET_UINT32_BE(W[7], data, 28);
75 	GET_UINT32_BE(W[8], data, 32);
76 	GET_UINT32_BE(W[9], data, 36);
77 	GET_UINT32_BE(W[10], data, 40);
78 	GET_UINT32_BE(W[11], data, 44);
79 	GET_UINT32_BE(W[12], data, 48);
80 	GET_UINT32_BE(W[13], data, 52);
81 	GET_UINT32_BE(W[14], data, 56);
82 	GET_UINT32_BE(W[15], data, 60);
83 
84 #define FF0(x, y, z)	((x) ^ (y) ^ (z))
85 #define FF1(x, y, z)	(((x) & (y)) | ((x) & (z)) | ((y) & (z)))
86 
87 #define GG0(x, y, z)	((x) ^ (y) ^ (z))
88 #define GG1(x, y, z)	(((x) & (y)) | ((~(x)) & (z)))
89 
90 #define SHL(x, n)	((x) << (n))
91 #define ROTL(x, n)	(SHL((x), (n) & 0x1F) | ((x) >> (32 - ((n) & 0x1F))))
92 
93 #define P0(x)	((x) ^ ROTL((x), 9) ^ ROTL((x), 17))
94 #define P1(x)	((x) ^ ROTL((x), 15) ^ ROTL((x), 23))
95 
96 	for (j = 16; j < 68; j++) {
97 		/*
98 		 * W[j] = P1( W[j-16] ^ W[j-9] ^ ROTL(W[j-3],15)) ^
99 		 *        ROTL(W[j - 13],7 ) ^ W[j-6];
100 		 */
101 
102 		Temp1 = W[j - 16] ^ W[j - 9];
103 		Temp2 = ROTL(W[j - 3], 15);
104 		Temp3 = Temp1 ^ Temp2;
105 		Temp4 = P1(Temp3);
106 		Temp5 =  ROTL(W[j - 13], 7) ^ W[j - 6];
107 		W[j] = Temp4 ^ Temp5;
108 	}
109 
110 	for (j =  0; j < 64; j++)
111 		W1[j] = W[j] ^ W[j + 4];
112 
113 	A = ctx->state[0];
114 	B = ctx->state[1];
115 	C = ctx->state[2];
116 	D = ctx->state[3];
117 	E = ctx->state[4];
118 	F = ctx->state[5];
119 	G = ctx->state[6];
120 	H = ctx->state[7];
121 
122 	for (j = 0; j < 16; j++) {
123 		SS1 = ROTL(ROTL(A, 12) + E + ROTL(T[j], j), 7);
124 		SS2 = SS1 ^ ROTL(A, 12);
125 		TT1 = FF0(A, B, C) + D + SS2 + W1[j];
126 		TT2 = GG0(E, F, G) + H + SS1 + W[j];
127 		D = C;
128 		C = ROTL(B, 9);
129 		B = A;
130 		A = TT1;
131 		H = G;
132 		G = ROTL(F, 19);
133 		F = E;
134 		E = P0(TT2);
135 	}
136 
137 	for (j = 16; j < 64; j++) {
138 		SS1 = ROTL(ROTL(A, 12) + E + ROTL(T[j], j), 7);
139 		SS2 = SS1 ^ ROTL(A, 12);
140 		TT1 = FF1(A, B, C) + D + SS2 + W1[j];
141 		TT2 = GG1(E, F, G) + H + SS1 + W[j];
142 		D = C;
143 		C = ROTL(B, 9);
144 		B = A;
145 		A = TT1;
146 		H = G;
147 		G = ROTL(F, 19);
148 		F = E;
149 		E = P0(TT2);
150 	}
151 
152 	ctx->state[0] ^= A;
153 	ctx->state[1] ^= B;
154 	ctx->state[2] ^= C;
155 	ctx->state[3] ^= D;
156 	ctx->state[4] ^= E;
157 	ctx->state[5] ^= F;
158 	ctx->state[6] ^= G;
159 	ctx->state[7] ^= H;
160 }
161 
sm3_process_blocks(struct sm3_context * ctx,const uint8_t * input,unsigned int block_count)162 static void sm3_process_blocks(struct sm3_context *ctx, const uint8_t *input,
163 			       unsigned int block_count)
164 {
165 #ifdef CFG_CRYPTO_SM3_ARM_CE
166 	if (block_count)
167 		crypto_accel_sm3_compress(ctx->state, input, block_count);
168 #else
169 	unsigned int n = 0;
170 
171 	for (n = 0; n < block_count; n++)
172 		sm3_process(ctx, input + n * SM3_BLOCK_SIZE);
173 #endif
174 }
175 
sm3_update(struct sm3_context * ctx,const uint8_t * input,size_t ilen)176 void sm3_update(struct sm3_context *ctx, const uint8_t *input, size_t ilen)
177 {
178 	unsigned int block_count = 0;
179 	size_t fill = 0;
180 	size_t left = 0;
181 
182 	if (!ilen)
183 		return;
184 
185 	left = ctx->total[0] & 0x3F;
186 	fill = 64 - left;
187 
188 	ctx->total[0] += ilen;
189 
190 	if (ctx->total[0] < ilen)
191 		ctx->total[1]++;
192 
193 	if (left && ilen >= fill) {
194 		memcpy(ctx->buffer + left, input, fill);
195 		sm3_process_blocks(ctx, ctx->buffer, 1);
196 		input += fill;
197 		ilen -= fill;
198 		left = 0;
199 	}
200 
201 	block_count = ilen / SM3_BLOCK_SIZE;
202 	sm3_process_blocks(ctx, input, block_count);
203 	ilen -= block_count * SM3_BLOCK_SIZE;
204 	input += block_count * SM3_BLOCK_SIZE;
205 
206 	if (ilen > 0)
207 		memcpy(ctx->buffer + left, input, ilen);
208 }
209 
210 static const uint8_t sm3_padding[64] = {
211 	0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
212 	0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
213 	0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
214 	0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
215 };
216 
sm3_final(struct sm3_context * ctx,uint8_t output[32])217 void sm3_final(struct sm3_context *ctx, uint8_t output[32])
218 {
219 	uint32_t last, padn;
220 	uint32_t high, low;
221 	uint8_t msglen[8];
222 
223 	high = (ctx->total[0] >> 29) | (ctx->total[1] <<  3);
224 	low  = ctx->total[0] << 3;
225 
226 	PUT_UINT32_BE(high, msglen, 0);
227 	PUT_UINT32_BE(low,  msglen, 4);
228 
229 	last = ctx->total[0] & 0x3F;
230 	padn = (last < 56) ? (56 - last) : (120 - last);
231 
232 	sm3_update(ctx, sm3_padding, padn);
233 	sm3_update(ctx, msglen, 8);
234 
235 	PUT_UINT32_BE(ctx->state[0], output,  0);
236 	PUT_UINT32_BE(ctx->state[1], output,  4);
237 	PUT_UINT32_BE(ctx->state[2], output,  8);
238 	PUT_UINT32_BE(ctx->state[3], output, 12);
239 	PUT_UINT32_BE(ctx->state[4], output, 16);
240 	PUT_UINT32_BE(ctx->state[5], output, 20);
241 	PUT_UINT32_BE(ctx->state[6], output, 24);
242 	PUT_UINT32_BE(ctx->state[7], output, 28);
243 }
244 
sm3(const uint8_t * input,size_t ilen,uint8_t output[32])245 void sm3(const uint8_t *input, size_t ilen, uint8_t output[32])
246 {
247 	struct sm3_context ctx = { };
248 
249 	sm3_init(&ctx);
250 	sm3_update(&ctx, input, ilen);
251 	sm3_final(&ctx, output);
252 
253 	memzero_explicit(&ctx, sizeof(ctx));
254 }
255 
sm3_hmac_init(struct sm3_context * ctx,const uint8_t * key,size_t keylen)256 void sm3_hmac_init(struct sm3_context *ctx, const uint8_t *key, size_t keylen)
257 {
258 	size_t i;
259 	uint8_t sum[32];
260 
261 	if (keylen > 64) {
262 		sm3(key, keylen, sum);
263 		keylen = 32;
264 		key = sum;
265 	}
266 
267 	memset(ctx->ipad, 0x36, 64);
268 	memset(ctx->opad, 0x5C, 64);
269 
270 	for (i = 0; i < keylen; i++) {
271 		ctx->ipad[i] ^= key[i];
272 		ctx->opad[i] ^= key[i];
273 	}
274 
275 	sm3_init(ctx);
276 	sm3_update(ctx, ctx->ipad, 64);
277 
278 	memzero_explicit(sum, sizeof(sum));
279 }
280 
sm3_hmac_update(struct sm3_context * ctx,const uint8_t * input,size_t ilen)281 void sm3_hmac_update(struct sm3_context *ctx, const uint8_t *input, size_t ilen)
282 {
283 	sm3_update(ctx, input, ilen);
284 }
285 
sm3_hmac_final(struct sm3_context * ctx,uint8_t output[32])286 void sm3_hmac_final(struct sm3_context *ctx, uint8_t output[32])
287 {
288 	uint8_t tmpbuf[32];
289 
290 	sm3_final(ctx, tmpbuf);
291 	sm3_init(ctx);
292 	sm3_update(ctx, ctx->opad, 64);
293 	sm3_update(ctx, tmpbuf, 32);
294 	sm3_final(ctx, output);
295 
296 	memzero_explicit(tmpbuf, sizeof(tmpbuf));
297 }
298 
sm3_hmac(const uint8_t * key,size_t keylen,const uint8_t * input,size_t ilen,uint8_t output[32])299 void sm3_hmac(const uint8_t *key, size_t keylen, const uint8_t *input,
300 	      size_t ilen, uint8_t output[32])
301 {
302 	struct sm3_context ctx;
303 
304 	sm3_hmac_init(&ctx, key, keylen);
305 	sm3_hmac_update(&ctx, input, ilen);
306 	sm3_hmac_final(&ctx, output);
307 
308 	memzero_explicit(&ctx, sizeof(ctx));
309 }
310