1 // SPDX-License-Identifier: BSD-2-Clause
2 /*
3  * Copyright (c) 2019-2021 Huawei Technologies Co., Ltd
4  */
5 
6 #include <crypto/crypto.h>
7 #include <crypto/sm2-kdf.h>
8 #include <mbedtls/bignum.h>
9 #include <mbedtls/ecp.h>
10 #include <stdlib.h>
11 #include <string.h>
12 #include <string_ext.h>
13 #include <tee_api_types.h>
14 #include <util.h>
15 #include <utee_defines.h>
16 
17 #include "mbed_helpers.h"
18 #include "sm2-pke.h"
19 
20 /* SM2 uses 256 bit unsigned integers in big endian format */
21 #define SM2_INT_SIZE_BYTES 32
22 
23 static TEE_Result
sm2_uncompressed_bytes_to_point(const mbedtls_ecp_group * grp,mbedtls_ecp_point * p,const uint8_t * x1y1,size_t max_size,size_t * consumed)24 sm2_uncompressed_bytes_to_point(const mbedtls_ecp_group *grp,
25 				mbedtls_ecp_point *p, const uint8_t *x1y1,
26 				size_t max_size, size_t *consumed)
27 {
28 	uint8_t *ptr = (uint8_t *)x1y1;
29 	int mres = 0;
30 
31 	if (max_size < (size_t)(2 * SM2_INT_SIZE_BYTES))
32 		return TEE_ERROR_BAD_PARAMETERS;
33 
34 	mres = mbedtls_mpi_read_binary(&p->X, ptr, SM2_INT_SIZE_BYTES);
35 	if (mres)
36 		return TEE_ERROR_BAD_PARAMETERS;
37 
38 	ptr += SM2_INT_SIZE_BYTES;
39 
40 	mres = mbedtls_mpi_read_binary(&p->Y, ptr, SM2_INT_SIZE_BYTES);
41 	if (mres)
42 		return TEE_ERROR_BAD_PARAMETERS;
43 
44 	mres = mbedtls_mpi_lset(&p->Z, 1);
45 	if (mres)
46 		return TEE_ERROR_BAD_PARAMETERS;
47 
48 	mres = mbedtls_ecp_check_pubkey(grp, p);
49 	if (mres)
50 		return TEE_ERROR_BAD_PARAMETERS;
51 
52 	*consumed = 2 * SM2_INT_SIZE_BYTES + 1; /* PC */
53 
54 	return TEE_SUCCESS;
55 }
56 
57 /*
58  * GM/T 0003.1‒2012 Part 1 Section 4.2.9
59  * Conversion of a byte string @buf to a point @p. Makes sure @p is on the curve
60  * defined by domain parameters @dp.
61  * Note: only the uncompressed form is supported. Uncompressed and hybrid forms
62  * are TBD.
63  */
sm2_bytes_to_point(const mbedtls_ecp_group * grp,mbedtls_ecp_point * p,const uint8_t * buf,size_t max_size,size_t * consumed)64 static TEE_Result sm2_bytes_to_point(const mbedtls_ecp_group *grp,
65 				     mbedtls_ecp_point *p, const uint8_t *buf,
66 				     size_t max_size, size_t *consumed)
67 {
68 	uint8_t PC = 0;
69 
70 	if (!max_size)
71 		return TEE_ERROR_BAD_PARAMETERS;
72 
73 	PC = buf[0];
74 
75 	switch (PC) {
76 	case 0x02:
77 	case 0x03:
78 		/* Compressed form */
79 		return TEE_ERROR_NOT_SUPPORTED;
80 	case 0x04:
81 		/* Uncompressed form */
82 		return sm2_uncompressed_bytes_to_point(grp, p, buf + 1,
83 						       max_size - 1, consumed);
84 	case 0x06:
85 	case 0x07:
86 		/* Hybrid form */
87 		return TEE_ERROR_NOT_SUPPORTED;
88 	default:
89 		return TEE_ERROR_BAD_PARAMETERS;
90 	}
91 
92 	return TEE_ERROR_GENERIC;
93 }
94 
is_zero(const uint8_t * buf,size_t size)95 static bool is_zero(const uint8_t *buf, size_t size)
96 {
97 	uint8_t v = 0;
98 	size_t i = 0;
99 
100 	for (i = 0; i < size; i++)
101 		v |= buf[i];
102 
103 	return !v;
104 }
105 
106 /*
107  * GM/T 0003.1‒2012 Part 4 Section 7.1
108  * Decryption algorithm
109  */
sm2_mbedtls_pke_decrypt(struct ecc_keypair * key,const uint8_t * src,size_t src_len,uint8_t * dst,size_t * dst_len)110 TEE_Result sm2_mbedtls_pke_decrypt(struct ecc_keypair *key, const uint8_t *src,
111 				   size_t src_len, uint8_t *dst,
112 				   size_t *dst_len)
113 {
114 	TEE_Result res = TEE_SUCCESS;
115 	uint8_t x2y2[64] = { };
116 	mbedtls_ecp_point C1 = { };
117 	size_t C1_len = 0;
118 	mbedtls_ecp_point x2y2p = { };
119 	mbedtls_ecp_group grp = { };
120 	void *ctx = NULL;
121 	int mres = 0;
122 	uint8_t *t = NULL;
123 	size_t C2_len = 0;
124 	size_t i = 0;
125 	size_t out_len = 0;
126 	uint8_t *eom = NULL;
127 	uint8_t u[TEE_SM3_HASH_SIZE] = { };
128 
129 	/*
130 	 * Input buffer src is (C1 || C2 || C3)
131 	 * - C1 represents a point (should be on the curve)
132 	 * - C2 is the encrypted message
133 	 * - C3 is a SM3 hash
134 	 */
135 
136 	mbedtls_ecp_point_init(&C1);
137 	mbedtls_ecp_point_init(&x2y2p);
138 
139 	mbedtls_ecp_group_init(&grp);
140 	mres = mbedtls_ecp_group_load(&grp, MBEDTLS_ECP_DP_SM2);
141 	if (mres) {
142 		res = TEE_ERROR_GENERIC;
143 		goto out;
144 	}
145 
146 	/* Step B1: read and validate point C1 from encrypted message */
147 
148 	res = sm2_bytes_to_point(&grp, &C1, src, src_len, &C1_len);
149 	if (res)
150 		goto out;
151 
152 	/*
153 	 * Step B2: S = [h]C1, the cofactor h is 1 for SM2 so S == C1.
154 	 * The fact that S is on the curve has already been checked in
155 	 * sm2_bytes_to_point().
156 	 */
157 
158 	/* Step B3: (x2, y2) = [dB]C1 */
159 
160 	mres = mbedtls_ecp_mul(&grp, &x2y2p, (mbedtls_mpi *)key->d, &C1,
161 			       mbd_rand, NULL);
162 	if (mres) {
163 		res = TEE_ERROR_BAD_STATE;
164 		goto out;
165 	}
166 
167 	if (mbedtls_mpi_size(&x2y2p.X) > SM2_INT_SIZE_BYTES ||
168 	    mbedtls_mpi_size(&x2y2p.Y) > SM2_INT_SIZE_BYTES) {
169 		res = TEE_ERROR_BAD_STATE;
170 		goto out;
171 	}
172 
173 	mres = mbedtls_mpi_write_binary(&x2y2p.X, x2y2, SM2_INT_SIZE_BYTES);
174 	if (mres) {
175 		res = TEE_ERROR_BAD_STATE;
176 		goto out;
177 	}
178 	mres = mbedtls_mpi_write_binary(&x2y2p.Y, x2y2 + SM2_INT_SIZE_BYTES,
179 					SM2_INT_SIZE_BYTES);
180 	if (mres) {
181 		res = TEE_ERROR_BAD_STATE;
182 		goto out;
183 	}
184 
185 	/* Step B4: t = KDF(x2 || y2, klen) */
186 
187 	/* C = C1 || C2 || C3 */
188 	if (src_len <= C1_len + TEE_SM3_HASH_SIZE) {
189 		res = TEE_ERROR_BAD_PARAMETERS;
190 		goto out;
191 	}
192 
193 	C2_len = src_len - C1_len - TEE_SM3_HASH_SIZE;
194 
195 	t = calloc(1, C2_len);
196 	if (!t) {
197 		res = TEE_ERROR_OUT_OF_MEMORY;
198 		goto out;
199 	}
200 
201 	res = sm2_kdf(x2y2, sizeof(x2y2), t, C2_len);
202 	if (res)
203 		goto out;
204 
205 	if (is_zero(t, C2_len)) {
206 		res = TEE_ERROR_CIPHERTEXT_INVALID;
207 		goto out;
208 	}
209 
210 	/* Step B5: get C2 from C and compute Mprime = C2 (+) t */
211 
212 	out_len = MIN(*dst_len, C2_len);
213 	for (i = 0; i < out_len; i++)
214 		dst[i] = src[C1_len + i] ^ t[i];
215 	*dst_len = out_len;
216 	if (out_len < C2_len) {
217 		eom = calloc(1, C2_len - out_len);
218 		if (!eom) {
219 			res = TEE_ERROR_OUT_OF_MEMORY;
220 			goto out;
221 		}
222 		for (i = out_len; i < C2_len; i++)
223 		       eom[i - out_len] = src[C1_len + i] ^ t[i];
224 	}
225 
226 	/* Step B6: compute u = Hash(x2 || M' || y2) and compare with C3 */
227 
228 	res = crypto_hash_alloc_ctx(&ctx, TEE_ALG_SM3);
229 	if (res)
230 		goto out;
231 	res = crypto_hash_init(ctx);
232 	if (res)
233 		goto out;
234 	res = crypto_hash_update(ctx, x2y2, SM2_INT_SIZE_BYTES);
235 	if (res)
236 		goto out;
237 	res = crypto_hash_update(ctx, dst, out_len);
238 	if (res)
239 		goto out;
240 	if (out_len < C2_len) {
241 		res = crypto_hash_update(ctx, eom, C2_len - out_len);
242 		if (res)
243 			goto out;
244 	}
245 	res = crypto_hash_update(ctx, x2y2 + SM2_INT_SIZE_BYTES,
246 				 SM2_INT_SIZE_BYTES);
247 	if (res)
248 		goto out;
249 	res = crypto_hash_final(ctx, u, sizeof(u));
250 	if (res)
251 		goto out;
252 
253 	if (consttime_memcmp(u, src + C1_len + C2_len, TEE_SM3_HASH_SIZE)) {
254 		res = TEE_ERROR_CIPHERTEXT_INVALID;
255 		goto out;
256 	}
257 out:
258 	free(eom);
259 	free(t);
260 	crypto_hash_free_ctx(ctx);
261 	mbedtls_ecp_point_free(&C1);
262 	mbedtls_ecp_point_free(&x2y2p);
263 	mbedtls_ecp_group_free(&grp);
264 	return res;
265 }
266 
267 /*
268  * GM/T 0003.1‒2012 Part 1 Section 4.2.8
269  * Conversion of point @p to a byte string @buf (uncompressed form).
270  */
sm2_point_to_bytes(uint8_t * buf,size_t * size,const mbedtls_ecp_point * p)271 static TEE_Result sm2_point_to_bytes(uint8_t *buf, size_t *size,
272 				     const mbedtls_ecp_point *p)
273 {
274 	size_t xsize = mbedtls_mpi_size(&p->X);
275 	size_t ysize = mbedtls_mpi_size(&p->Y);
276 	size_t sz = 2 * SM2_INT_SIZE_BYTES + 1;
277 	int mres = 0;
278 
279 	if (xsize > SM2_INT_SIZE_BYTES || ysize > SM2_INT_SIZE_BYTES ||
280 	    *size < sz)
281 		return TEE_ERROR_BAD_STATE;
282 
283 	memset(buf, 0, sz);
284 	buf[0] = 0x04;  /* Uncompressed form indicator */
285 	mres = mbedtls_mpi_write_binary(&p->X, buf + 1, SM2_INT_SIZE_BYTES);
286 	if (mres)
287 		return TEE_ERROR_BAD_STATE;
288 	mres = mbedtls_mpi_write_binary(&p->Y, buf + 1 + SM2_INT_SIZE_BYTES,
289 					SM2_INT_SIZE_BYTES);
290 	if (mres)
291 		return TEE_ERROR_BAD_STATE;
292 
293 	*size = sz;
294 
295 	return TEE_SUCCESS;
296 }
297 
298 /*
299  * GM/T 0003.1‒2012 Part 4 Section 6.1
300  * Encryption algorithm
301  */
sm2_mbedtls_pke_encrypt(struct ecc_public_key * key,const uint8_t * src,size_t src_len,uint8_t * dst,size_t * dst_len)302 TEE_Result sm2_mbedtls_pke_encrypt(struct ecc_public_key *key,
303 				   const uint8_t *src, size_t src_len,
304 				   uint8_t *dst, size_t *dst_len)
305 {
306 	TEE_Result res = TEE_SUCCESS;
307 	mbedtls_ecp_group grp = { };
308 	mbedtls_ecp_point x2y2p = { };
309 	mbedtls_ecp_point PB = { };
310 	mbedtls_ecp_point C1 = { };
311 	uint8_t x2y2[64] = { };
312 	uint8_t *t = NULL;
313 	int mres = 0;
314 	mbedtls_mpi k = { };
315 	size_t C1_len = 0;
316 	void *ctx = NULL;
317 	size_t i = 0;
318 
319 	mbedtls_mpi_init(&k);
320 
321 	mbedtls_ecp_point_init(&x2y2p);
322 	mbedtls_ecp_point_init(&PB);
323 	mbedtls_ecp_point_init(&C1);
324 
325 	mbedtls_ecp_group_init(&grp);
326 	mres = mbedtls_ecp_group_load(&grp, MBEDTLS_ECP_DP_SM2);
327 	if (mres) {
328 		res = TEE_ERROR_GENERIC;
329 		goto out;
330 	}
331 
332 	/* Step A1: generate random number 1 <= k < n */
333 
334 	res = mbed_gen_random_upto(&k, &grp.N);
335 	if (res)
336 		goto out;
337 
338 	/* Step A2: compute C1 = [k]G */
339 
340 	mres = mbedtls_ecp_mul(&grp, &C1, &k, &grp.G, mbd_rand, NULL);
341 	if (mres) {
342 		res = TEE_ERROR_BAD_STATE;
343 		goto out;
344 	}
345 
346 	/*
347 	 * Step A3: compute S = [h]PB and check for infinity.
348 	 * The cofactor h is 1 for SM2 so S == PB, nothing to do.
349 	 */
350 
351 	/* Step A4: compute (x2, y2) = [k]PB */
352 
353 	mbedtls_mpi_copy(&PB.X, (mbedtls_mpi *)key->x);
354 	mbedtls_mpi_copy(&PB.Y, (mbedtls_mpi *)key->y);
355 	mbedtls_mpi_lset(&PB.Z, 1);
356 
357 	mres = mbedtls_ecp_mul(&grp, &x2y2p, &k, &PB, mbd_rand, NULL);
358 	if (mres) {
359 		res = TEE_ERROR_BAD_STATE;
360 		goto out;
361 	}
362 
363 	if (mbedtls_mpi_size(&x2y2p.X) > SM2_INT_SIZE_BYTES ||
364 	    mbedtls_mpi_size(&x2y2p.Y) > SM2_INT_SIZE_BYTES) {
365 		res = TEE_ERROR_BAD_STATE;
366 		goto out;
367 	}
368 
369 	mres = mbedtls_mpi_write_binary(&x2y2p.X, x2y2, SM2_INT_SIZE_BYTES);
370 	if (mres) {
371 		res = TEE_ERROR_BAD_STATE;
372 		goto out;
373 	}
374 	mres = mbedtls_mpi_write_binary(&x2y2p.Y, x2y2 + SM2_INT_SIZE_BYTES,
375 					SM2_INT_SIZE_BYTES);
376 	if (mres) {
377 		res = TEE_ERROR_BAD_STATE;
378 		goto out;
379 	}
380 
381 	/* Step A5: compute t = KDF(x2 || y2, klen) */
382 
383 	t = calloc(1, src_len);
384 	if (!t) {
385 		res = TEE_ERROR_OUT_OF_MEMORY;
386 		goto out;
387 	}
388 
389 	res = sm2_kdf(x2y2, sizeof(x2y2), t, src_len);
390 	if (res)
391 		goto out;
392 
393 	if (is_zero(t, src_len)) {
394 		res = TEE_ERROR_CIPHERTEXT_INVALID;
395 		goto out;
396 	}
397 
398 	/*
399 	 * Steps A6, A7, A8:
400 	 * Compute C2 = M (+) t
401 	 * Compute C3 = Hash(x2 || M || y2)
402 	 * Output C = C1 || C2 || C3
403 	 */
404 
405 	/* C1 */
406 	C1_len = *dst_len;
407 	res = sm2_point_to_bytes(dst, &C1_len, &C1);
408 	if (res)
409 		goto out;
410 
411 	if (*dst_len < C1_len + src_len + TEE_SM3_HASH_SIZE) {
412 		*dst_len = C1_len + src_len + TEE_SM3_HASH_SIZE;
413 		res = TEE_ERROR_SHORT_BUFFER;
414 		goto out;
415 	}
416 
417 	/* C2 */
418 	for (i = 0; i < src_len; i++)
419 		dst[i + C1_len] = src[i] ^ t[i];
420 
421 	/* C3 */
422         res = crypto_hash_alloc_ctx(&ctx, TEE_ALG_SM3);
423         if (res)
424                 goto out;
425         res = crypto_hash_init(ctx);
426         if (res)
427                 goto out;
428         res = crypto_hash_update(ctx, x2y2, SM2_INT_SIZE_BYTES);
429         if (res)
430                 goto out;
431         res = crypto_hash_update(ctx, src, src_len);
432         if (res)
433                 goto out;
434         res = crypto_hash_update(ctx, x2y2 + SM2_INT_SIZE_BYTES,
435 				 SM2_INT_SIZE_BYTES);
436         if (res)
437                 goto out;
438         res = crypto_hash_final(ctx, dst + C1_len + src_len, TEE_SM3_HASH_SIZE);
439         if (res)
440                 goto out;
441 
442 	*dst_len = C1_len + src_len + TEE_SM3_HASH_SIZE;
443 out:
444 	crypto_hash_free_ctx(ctx);
445 	free(t);
446 	mbedtls_ecp_point_free(&x2y2p);
447 	mbedtls_ecp_point_free(&PB);
448 	mbedtls_ecp_point_free(&C1);
449 	mbedtls_ecp_group_free(&grp);
450 	mbedtls_mpi_free(&k);
451 	return res;
452 }
453