1 // SPDX-License-Identifier: BSD-2-Clause
2 /*
3  * Copyright (C) 2018, ARM Limited
4  * Copyright (C) 2019, Linaro Limited
5  */
6 
7 #include <assert.h>
8 #include <crypto/crypto.h>
9 #include <crypto/crypto_impl.h>
10 #include <mbedtls/ctr_drbg.h>
11 #include <mbedtls/entropy.h>
12 #include <mbedtls/pk.h>
13 #include <mbedtls/pk_internal.h>
14 #include <stdlib.h>
15 #include <string.h>
16 #include <tee/tee_cryp_utl.h>
17 #include <utee_defines.h>
18 #include <fault_mitigation.h>
19 
20 #include "mbed_helpers.h"
21 
get_tee_result(int lmd_res)22 static TEE_Result get_tee_result(int lmd_res)
23 {
24 	switch (lmd_res) {
25 	case 0:
26 		return TEE_SUCCESS;
27 	case MBEDTLS_ERR_RSA_PRIVATE_FAILED +
28 		MBEDTLS_ERR_MPI_BAD_INPUT_DATA:
29 	case MBEDTLS_ERR_RSA_BAD_INPUT_DATA:
30 	case MBEDTLS_ERR_RSA_INVALID_PADDING:
31 	case MBEDTLS_ERR_PK_TYPE_MISMATCH:
32 		return TEE_ERROR_BAD_PARAMETERS;
33 	case MBEDTLS_ERR_RSA_OUTPUT_TOO_LARGE:
34 		return TEE_ERROR_SHORT_BUFFER;
35 	default:
36 		return TEE_ERROR_BAD_STATE;
37 	}
38 }
39 
tee_algo_to_mbedtls_hash_algo(uint32_t algo)40 static uint32_t tee_algo_to_mbedtls_hash_algo(uint32_t algo)
41 {
42 	switch (algo) {
43 #if defined(CFG_CRYPTO_SHA1)
44 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA1:
45 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA1:
46 	case TEE_ALG_RSAES_PKCS1_OAEP_MGF1_SHA1:
47 	case TEE_ALG_SHA1:
48 	case TEE_ALG_DSA_SHA1:
49 	case TEE_ALG_HMAC_SHA1:
50 		return MBEDTLS_MD_SHA1;
51 #endif
52 #if defined(CFG_CRYPTO_MD5)
53 	case TEE_ALG_RSASSA_PKCS1_V1_5_MD5:
54 	case TEE_ALG_MD5:
55 	case TEE_ALG_HMAC_MD5:
56 		return MBEDTLS_MD_MD5;
57 #endif
58 #if defined(CFG_CRYPTO_SHA224)
59 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA224:
60 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA224:
61 	case TEE_ALG_RSAES_PKCS1_OAEP_MGF1_SHA224:
62 	case TEE_ALG_SHA224:
63 	case TEE_ALG_DSA_SHA224:
64 	case TEE_ALG_HMAC_SHA224:
65 		return MBEDTLS_MD_SHA224;
66 #endif
67 #if defined(CFG_CRYPTO_SHA256)
68 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA256:
69 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA256:
70 	case TEE_ALG_RSAES_PKCS1_OAEP_MGF1_SHA256:
71 	case TEE_ALG_SHA256:
72 	case TEE_ALG_DSA_SHA256:
73 	case TEE_ALG_HMAC_SHA256:
74 		return MBEDTLS_MD_SHA256;
75 #endif
76 #if defined(CFG_CRYPTO_SHA384)
77 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA384:
78 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA384:
79 	case TEE_ALG_RSAES_PKCS1_OAEP_MGF1_SHA384:
80 	case TEE_ALG_SHA384:
81 	case TEE_ALG_HMAC_SHA384:
82 		return MBEDTLS_MD_SHA384;
83 #endif
84 #if defined(CFG_CRYPTO_SHA512)
85 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA512:
86 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA512:
87 	case TEE_ALG_RSAES_PKCS1_OAEP_MGF1_SHA512:
88 	case TEE_ALG_SHA512:
89 	case TEE_ALG_HMAC_SHA512:
90 		return MBEDTLS_MD_SHA512;
91 #endif
92 	default:
93 		return MBEDTLS_MD_NONE;
94 	}
95 }
96 
rsa_init_from_key_pair(mbedtls_rsa_context * rsa,struct rsa_keypair * key)97 static void rsa_init_from_key_pair(mbedtls_rsa_context *rsa,
98 				struct rsa_keypair *key)
99 {
100 	mbedtls_rsa_init(rsa, 0, 0);
101 
102 	rsa->E = *(mbedtls_mpi *)key->e;
103 	rsa->N = *(mbedtls_mpi *)key->n;
104 	rsa->D = *(mbedtls_mpi *)key->d;
105 	if (key->p && crypto_bignum_num_bytes(key->p)) {
106 		rsa->P = *(mbedtls_mpi *)key->p;
107 		rsa->Q = *(mbedtls_mpi *)key->q;
108 		rsa->QP = *(mbedtls_mpi *)key->qp;
109 		rsa->DP = *(mbedtls_mpi *)key->dp;
110 		rsa->DQ = *(mbedtls_mpi *)key->dq;
111 	}
112 	rsa->len = mbedtls_mpi_size(&rsa->N);
113 }
114 
mbd_rsa_free(mbedtls_rsa_context * rsa)115 static void mbd_rsa_free(mbedtls_rsa_context *rsa)
116 {
117 	/* Reset mpi to skip freeing here, those mpis will be freed with key */
118 	mbedtls_mpi_init(&rsa->E);
119 	mbedtls_mpi_init(&rsa->N);
120 	mbedtls_mpi_init(&rsa->D);
121 	if (mbedtls_mpi_size(&rsa->P)) {
122 		mbedtls_mpi_init(&rsa->P);
123 		mbedtls_mpi_init(&rsa->Q);
124 		mbedtls_mpi_init(&rsa->QP);
125 		mbedtls_mpi_init(&rsa->DP);
126 		mbedtls_mpi_init(&rsa->DQ);
127 	}
128 	mbedtls_rsa_free(rsa);
129 }
130 
131 TEE_Result crypto_acipher_alloc_rsa_keypair(struct rsa_keypair *s,
132 					    size_t key_size_bits)
133 __weak __alias("sw_crypto_acipher_alloc_rsa_keypair");
134 
sw_crypto_acipher_alloc_rsa_keypair(struct rsa_keypair * s,size_t key_size_bits)135 TEE_Result sw_crypto_acipher_alloc_rsa_keypair(struct rsa_keypair *s,
136 					       size_t key_size_bits)
137 {
138 	memset(s, 0, sizeof(*s));
139 	s->e = crypto_bignum_allocate(key_size_bits);
140 	if (!s->e)
141 		goto err;
142 	s->d = crypto_bignum_allocate(key_size_bits);
143 	if (!s->d)
144 		goto err;
145 	s->n = crypto_bignum_allocate(key_size_bits);
146 	if (!s->n)
147 		goto err;
148 	s->p = crypto_bignum_allocate(key_size_bits);
149 	if (!s->p)
150 		goto err;
151 	s->q = crypto_bignum_allocate(key_size_bits);
152 	if (!s->q)
153 		goto err;
154 	s->qp = crypto_bignum_allocate(key_size_bits);
155 	if (!s->qp)
156 		goto err;
157 	s->dp = crypto_bignum_allocate(key_size_bits);
158 	if (!s->dp)
159 		goto err;
160 	s->dq = crypto_bignum_allocate(key_size_bits);
161 	if (!s->dq)
162 		goto err;
163 
164 	return TEE_SUCCESS;
165 err:
166 	crypto_acipher_free_rsa_keypair(s);
167 	return TEE_ERROR_OUT_OF_MEMORY;
168 }
169 
170 TEE_Result crypto_acipher_alloc_rsa_public_key(struct rsa_public_key *s,
171 					       size_t key_size_bits)
172 __weak __alias("sw_crypto_acipher_alloc_rsa_public_key");
173 
sw_crypto_acipher_alloc_rsa_public_key(struct rsa_public_key * s,size_t key_size_bits)174 TEE_Result sw_crypto_acipher_alloc_rsa_public_key(struct rsa_public_key *s,
175 						  size_t key_size_bits)
176 {
177 	memset(s, 0, sizeof(*s));
178 	s->e = crypto_bignum_allocate(key_size_bits);
179 	if (!s->e)
180 		return TEE_ERROR_OUT_OF_MEMORY;
181 	s->n = crypto_bignum_allocate(key_size_bits);
182 	if (!s->n)
183 		goto err;
184 	return TEE_SUCCESS;
185 err:
186 	crypto_bignum_free(s->e);
187 	return TEE_ERROR_OUT_OF_MEMORY;
188 }
189 
190 void crypto_acipher_free_rsa_public_key(struct rsa_public_key *s)
191 __weak __alias("sw_crypto_acipher_free_rsa_public_key");
192 
sw_crypto_acipher_free_rsa_public_key(struct rsa_public_key * s)193 void sw_crypto_acipher_free_rsa_public_key(struct rsa_public_key *s)
194 {
195 	if (!s)
196 		return;
197 	crypto_bignum_free(s->n);
198 	crypto_bignum_free(s->e);
199 }
200 
201 void crypto_acipher_free_rsa_keypair(struct rsa_keypair *s)
202 __weak __alias("sw_crypto_acipher_free_rsa_keypair");
203 
sw_crypto_acipher_free_rsa_keypair(struct rsa_keypair * s)204 void sw_crypto_acipher_free_rsa_keypair(struct rsa_keypair *s)
205 {
206 	if (!s)
207 		return;
208 	crypto_bignum_free(s->e);
209 	crypto_bignum_free(s->d);
210 	crypto_bignum_free(s->n);
211 	crypto_bignum_free(s->p);
212 	crypto_bignum_free(s->q);
213 	crypto_bignum_free(s->qp);
214 	crypto_bignum_free(s->dp);
215 	crypto_bignum_free(s->dq);
216 }
217 
218 TEE_Result crypto_acipher_gen_rsa_key(struct rsa_keypair *key,
219 				      size_t key_size)
220 __weak __alias("sw_crypto_acipher_gen_rsa_key");
221 
sw_crypto_acipher_gen_rsa_key(struct rsa_keypair * key,size_t key_size)222 TEE_Result sw_crypto_acipher_gen_rsa_key(struct rsa_keypair *key,
223 					 size_t key_size)
224 {
225 	TEE_Result res = TEE_SUCCESS;
226 	mbedtls_rsa_context rsa;
227 	mbedtls_ctr_drbg_context rngctx;
228 	int lmd_res = 0;
229 	uint32_t e = 0;
230 
231 	mbedtls_ctr_drbg_init(&rngctx);
232 	if (mbedtls_ctr_drbg_seed(&rngctx, mbd_rand, NULL, NULL, 0))
233 		return TEE_ERROR_BAD_STATE;
234 
235 	memset(&rsa, 0, sizeof(rsa));
236 	mbedtls_rsa_init(&rsa, 0, 0);
237 
238 	/* get the public exponent */
239 	mbedtls_mpi_write_binary((mbedtls_mpi *)key->e,
240 				 (unsigned char *)&e, sizeof(uint32_t));
241 
242 	e = TEE_U32_FROM_BIG_ENDIAN(e);
243 	lmd_res = mbedtls_rsa_gen_key(&rsa, mbedtls_ctr_drbg_random, &rngctx,
244 				      key_size, (int)e);
245 	mbedtls_ctr_drbg_free(&rngctx);
246 	if (lmd_res != 0) {
247 		res = get_tee_result(lmd_res);
248 	} else if ((size_t)mbedtls_mpi_bitlen(&rsa.N) != key_size) {
249 		res = TEE_ERROR_BAD_PARAMETERS;
250 	} else {
251 		/* Copy the key */
252 		crypto_bignum_copy(key->e, (void *)&rsa.E);
253 		crypto_bignum_copy(key->d, (void *)&rsa.D);
254 		crypto_bignum_copy(key->n, (void *)&rsa.N);
255 		crypto_bignum_copy(key->p, (void *)&rsa.P);
256 
257 		crypto_bignum_copy(key->q, (void *)&rsa.Q);
258 		crypto_bignum_copy(key->qp, (void *)&rsa.QP);
259 		crypto_bignum_copy(key->dp, (void *)&rsa.DP);
260 		crypto_bignum_copy(key->dq, (void *)&rsa.DQ);
261 
262 		res = TEE_SUCCESS;
263 	}
264 
265 	mbedtls_rsa_free(&rsa);
266 
267 	return res;
268 }
269 
270 TEE_Result crypto_acipher_rsanopad_encrypt(struct rsa_public_key *key,
271 					   const uint8_t *src,
272 					   size_t src_len, uint8_t *dst,
273 					   size_t *dst_len)
274 __weak __alias("sw_crypto_acipher_rsanopad_encrypt");
275 
sw_crypto_acipher_rsanopad_encrypt(struct rsa_public_key * key,const uint8_t * src,size_t src_len,uint8_t * dst,size_t * dst_len)276 TEE_Result sw_crypto_acipher_rsanopad_encrypt(struct rsa_public_key *key,
277 					      const uint8_t *src,
278 					      size_t src_len, uint8_t *dst,
279 					      size_t *dst_len)
280 {
281 	TEE_Result res = TEE_SUCCESS;
282 	mbedtls_rsa_context rsa;
283 	int lmd_res = 0;
284 	uint8_t *buf = NULL;
285 	unsigned long blen = 0;
286 	unsigned long offset = 0;
287 
288 	memset(&rsa, 0, sizeof(rsa));
289 	mbedtls_rsa_init(&rsa, 0, 0);
290 
291 	rsa.E = *(mbedtls_mpi *)key->e;
292 	rsa.N = *(mbedtls_mpi *)key->n;
293 
294 	rsa.len = crypto_bignum_num_bytes((void *)&rsa.N);
295 
296 	blen = CFG_CORE_BIGNUM_MAX_BITS / 8;
297 	buf = malloc(blen);
298 	if (!buf) {
299 		res = TEE_ERROR_OUT_OF_MEMORY;
300 		goto out;
301 	}
302 
303 	memset(buf, 0, blen);
304 	memcpy(buf + rsa.len - src_len, src, src_len);
305 
306 	lmd_res = mbedtls_rsa_public(&rsa, buf, buf);
307 	if (lmd_res != 0) {
308 		FMSG("mbedtls_rsa_public() returned 0x%x", -lmd_res);
309 		res = get_tee_result(lmd_res);
310 		goto out;
311 	}
312 
313 	/* Remove the zero-padding (leave one zero if buff is all zeroes) */
314 	offset = 0;
315 	while ((offset < rsa.len - 1) && (buf[offset] == 0))
316 		offset++;
317 
318 	if (*dst_len < rsa.len - offset) {
319 		*dst_len = rsa.len - offset;
320 		res = TEE_ERROR_SHORT_BUFFER;
321 		goto out;
322 	}
323 	*dst_len = rsa.len - offset;
324 	memcpy(dst, buf + offset, *dst_len);
325 out:
326 	free(buf);
327 	/* Reset mpi to skip freeing here, those mpis will be freed with key */
328 	mbedtls_mpi_init(&rsa.E);
329 	mbedtls_mpi_init(&rsa.N);
330 	mbedtls_rsa_free(&rsa);
331 
332 	return res;
333 }
334 
335 TEE_Result crypto_acipher_rsanopad_decrypt(struct rsa_keypair *key,
336 					   const uint8_t *src,
337 					   size_t src_len, uint8_t *dst,
338 					   size_t *dst_len)
339 __weak __alias("sw_crypto_acipher_rsanopad_decrypt");
340 
sw_crypto_acipher_rsanopad_decrypt(struct rsa_keypair * key,const uint8_t * src,size_t src_len,uint8_t * dst,size_t * dst_len)341 TEE_Result sw_crypto_acipher_rsanopad_decrypt(struct rsa_keypair *key,
342 					      const uint8_t *src,
343 					      size_t src_len, uint8_t *dst,
344 					      size_t *dst_len)
345 {
346 	TEE_Result res = TEE_SUCCESS;
347 	mbedtls_rsa_context rsa;
348 	int lmd_res = 0;
349 	uint8_t *buf = NULL;
350 	unsigned long blen = 0;
351 	unsigned long offset = 0;
352 
353 	memset(&rsa, 0, sizeof(rsa));
354 	rsa_init_from_key_pair(&rsa, key);
355 
356 	blen = CFG_CORE_BIGNUM_MAX_BITS / 8;
357 	buf = malloc(blen);
358 	if (!buf) {
359 		res = TEE_ERROR_OUT_OF_MEMORY;
360 		goto out;
361 	}
362 
363 	memset(buf, 0, blen);
364 	memcpy(buf + rsa.len - src_len, src, src_len);
365 
366 	lmd_res = mbedtls_rsa_private(&rsa, NULL, NULL, buf, buf);
367 	if (lmd_res != 0) {
368 		FMSG("mbedtls_rsa_private() returned 0x%x", -lmd_res);
369 		res = get_tee_result(lmd_res);
370 		goto out;
371 	}
372 
373 	/* Remove the zero-padding (leave one zero if buff is all zeroes) */
374 	offset = 0;
375 	while ((offset < rsa.len - 1) && (buf[offset] == 0))
376 		offset++;
377 
378 	if (*dst_len < rsa.len - offset) {
379 		*dst_len = rsa.len - offset;
380 		res = TEE_ERROR_SHORT_BUFFER;
381 		goto out;
382 	}
383 	*dst_len = rsa.len - offset;
384 	memcpy(dst, (char *)buf + offset, *dst_len);
385 out:
386 	if (buf)
387 		free(buf);
388 	mbd_rsa_free(&rsa);
389 	return res;
390 }
391 
392 TEE_Result crypto_acipher_rsaes_decrypt(uint32_t algo,
393 					struct rsa_keypair *key,
394 					const uint8_t *label __unused,
395 					size_t label_len __unused,
396 					const uint8_t *src, size_t src_len,
397 					uint8_t *dst, size_t *dst_len)
398 __weak __alias("sw_crypto_acipher_rsaes_decrypt");
399 
sw_crypto_acipher_rsaes_decrypt(uint32_t algo,struct rsa_keypair * key,const uint8_t * label __unused,size_t label_len __unused,const uint8_t * src,size_t src_len,uint8_t * dst,size_t * dst_len)400 TEE_Result sw_crypto_acipher_rsaes_decrypt(uint32_t algo,
401 					   struct rsa_keypair *key,
402 					   const uint8_t *label __unused,
403 					   size_t label_len __unused,
404 					   const uint8_t *src, size_t src_len,
405 					   uint8_t *dst, size_t *dst_len)
406 {
407 	TEE_Result res = TEE_SUCCESS;
408 	int lmd_res = 0;
409 	int lmd_padding = 0;
410 	size_t blen = 0;
411 	size_t mod_size = 0;
412 	void *buf = NULL;
413 	mbedtls_rsa_context rsa;
414 	const mbedtls_pk_info_t *pk_info = NULL;
415 	uint32_t md_algo = MBEDTLS_MD_NONE;
416 
417 	memset(&rsa, 0, sizeof(rsa));
418 	rsa_init_from_key_pair(&rsa, key);
419 
420 	/*
421 	 * Use a temporary buffer since we don't know exactly how large
422 	 * the required size of the out buffer without doing a partial
423 	 * decrypt. We know the upper bound though.
424 	 */
425 	if (algo == TEE_ALG_RSAES_PKCS1_V1_5) {
426 		mod_size = crypto_bignum_num_bytes(key->n);
427 		blen = mod_size - 11;
428 		lmd_padding = MBEDTLS_RSA_PKCS_V15;
429 	} else {
430 		/* Decoded message is always shorter than encrypted message */
431 		blen = src_len;
432 		lmd_padding = MBEDTLS_RSA_PKCS_V21;
433 	}
434 
435 	buf = malloc(blen);
436 	if (!buf) {
437 		res = TEE_ERROR_OUT_OF_MEMORY;
438 		goto out;
439 	}
440 
441 	pk_info = mbedtls_pk_info_from_type(MBEDTLS_PK_RSA);
442 	if (!pk_info) {
443 		res = TEE_ERROR_NOT_SUPPORTED;
444 		goto out;
445 	}
446 
447 	/*
448 	 * TEE_ALG_RSAES_PKCS1_V1_5 is invalid in hash. But its hash algo will
449 	 * not be used in rsa, so skip it here.
450 	 */
451 	if (algo != TEE_ALG_RSAES_PKCS1_V1_5) {
452 		md_algo = tee_algo_to_mbedtls_hash_algo(algo);
453 		if (md_algo == MBEDTLS_MD_NONE) {
454 			res = TEE_ERROR_NOT_SUPPORTED;
455 			goto out;
456 		}
457 	}
458 
459 	mbedtls_rsa_set_padding(&rsa, lmd_padding, md_algo);
460 
461 	if (lmd_padding == MBEDTLS_RSA_PKCS_V15)
462 		lmd_res = pk_info->decrypt_func(&rsa, src, src_len, buf, &blen,
463 						blen, NULL, NULL);
464 	else
465 		lmd_res = pk_info->decrypt_func(&rsa, src, src_len, buf, &blen,
466 						blen, mbd_rand, NULL);
467 	if (lmd_res != 0) {
468 		FMSG("decrypt_func() returned 0x%x", -lmd_res);
469 		res = get_tee_result(lmd_res);
470 		goto out;
471 	}
472 
473 	if (*dst_len < blen) {
474 		*dst_len = blen;
475 		res = TEE_ERROR_SHORT_BUFFER;
476 		goto out;
477 	}
478 
479 	res = TEE_SUCCESS;
480 	*dst_len = blen;
481 	memcpy(dst, buf, blen);
482 out:
483 	if (buf)
484 		free(buf);
485 	mbd_rsa_free(&rsa);
486 	return res;
487 }
488 
489 TEE_Result crypto_acipher_rsaes_encrypt(uint32_t algo,
490 					struct rsa_public_key *key,
491 					const uint8_t *label __unused,
492 					size_t label_len __unused,
493 					const uint8_t *src, size_t src_len,
494 					uint8_t *dst, size_t *dst_len)
495 __weak __alias("sw_crypto_acipher_rsaes_encrypt");
496 
sw_crypto_acipher_rsaes_encrypt(uint32_t algo,struct rsa_public_key * key,const uint8_t * label __unused,size_t label_len __unused,const uint8_t * src,size_t src_len,uint8_t * dst,size_t * dst_len)497 TEE_Result sw_crypto_acipher_rsaes_encrypt(uint32_t algo,
498 					   struct rsa_public_key *key,
499 					   const uint8_t *label __unused,
500 					   size_t label_len __unused,
501 					   const uint8_t *src, size_t src_len,
502 					   uint8_t *dst, size_t *dst_len)
503 {
504 	TEE_Result res = TEE_SUCCESS;
505 	int lmd_res = 0;
506 	int lmd_padding = 0;
507 	size_t mod_size = 0;
508 	mbedtls_rsa_context rsa;
509 	const mbedtls_pk_info_t *pk_info = NULL;
510 	uint32_t md_algo = MBEDTLS_MD_NONE;
511 
512 	memset(&rsa, 0, sizeof(rsa));
513 	mbedtls_rsa_init(&rsa, 0, 0);
514 
515 	rsa.E = *(mbedtls_mpi *)key->e;
516 	rsa.N = *(mbedtls_mpi *)key->n;
517 
518 	mod_size = crypto_bignum_num_bytes(key->n);
519 	if (*dst_len < mod_size) {
520 		*dst_len = mod_size;
521 		res = TEE_ERROR_SHORT_BUFFER;
522 		goto out;
523 	}
524 	*dst_len = mod_size;
525 	rsa.len = mod_size;
526 
527 	if (algo == TEE_ALG_RSAES_PKCS1_V1_5)
528 		lmd_padding = MBEDTLS_RSA_PKCS_V15;
529 	else
530 		lmd_padding = MBEDTLS_RSA_PKCS_V21;
531 
532 	pk_info = mbedtls_pk_info_from_type(MBEDTLS_PK_RSA);
533 	if (!pk_info) {
534 		res = TEE_ERROR_NOT_SUPPORTED;
535 		goto out;
536 	}
537 
538 	/*
539 	 * TEE_ALG_RSAES_PKCS1_V1_5 is invalid in hash. But its hash algo will
540 	 * not be used in rsa, so skip it here.
541 	 */
542 	if (algo != TEE_ALG_RSAES_PKCS1_V1_5) {
543 		md_algo = tee_algo_to_mbedtls_hash_algo(algo);
544 		if (md_algo == MBEDTLS_MD_NONE) {
545 			res = TEE_ERROR_NOT_SUPPORTED;
546 			goto out;
547 		}
548 	}
549 
550 	mbedtls_rsa_set_padding(&rsa, lmd_padding, md_algo);
551 
552 	lmd_res = pk_info->encrypt_func(&rsa, src, src_len, dst, dst_len,
553 					*dst_len, mbd_rand, NULL);
554 	if (lmd_res != 0) {
555 		FMSG("encrypt_func() returned 0x%x", -lmd_res);
556 		res = get_tee_result(lmd_res);
557 		goto out;
558 	}
559 	res = TEE_SUCCESS;
560 out:
561 	/* Reset mpi to skip freeing here, those mpis will be freed with key */
562 	mbedtls_mpi_init(&rsa.E);
563 	mbedtls_mpi_init(&rsa.N);
564 	mbedtls_rsa_free(&rsa);
565 	return res;
566 }
567 
568 TEE_Result crypto_acipher_rsassa_sign(uint32_t algo, struct rsa_keypair *key,
569 				      int salt_len __unused,
570 				      const uint8_t *msg, size_t msg_len,
571 				      uint8_t *sig, size_t *sig_len)
572 __weak __alias("sw_crypto_acipher_rsassa_sign");
573 
sw_crypto_acipher_rsassa_sign(uint32_t algo,struct rsa_keypair * key,int salt_len __unused,const uint8_t * msg,size_t msg_len,uint8_t * sig,size_t * sig_len)574 TEE_Result sw_crypto_acipher_rsassa_sign(uint32_t algo, struct rsa_keypair *key,
575 					 int salt_len __unused,
576 					 const uint8_t *msg, size_t msg_len,
577 					 uint8_t *sig, size_t *sig_len)
578 {
579 	TEE_Result res = TEE_SUCCESS;
580 	int lmd_res = 0;
581 	int lmd_padding = 0;
582 	size_t mod_size = 0;
583 	size_t hash_size = 0;
584 	mbedtls_rsa_context rsa;
585 	const mbedtls_pk_info_t *pk_info = NULL;
586 	uint32_t md_algo = 0;
587 
588 	memset(&rsa, 0, sizeof(rsa));
589 	rsa_init_from_key_pair(&rsa, key);
590 
591 	switch (algo) {
592 	case TEE_ALG_RSASSA_PKCS1_V1_5_MD5:
593 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA1:
594 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA224:
595 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA256:
596 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA384:
597 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA512:
598 		lmd_padding = MBEDTLS_RSA_PKCS_V15;
599 		break;
600 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA1:
601 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA224:
602 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA256:
603 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA384:
604 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA512:
605 		lmd_padding = MBEDTLS_RSA_PKCS_V21;
606 		break;
607 	default:
608 		res = TEE_ERROR_BAD_PARAMETERS;
609 		goto err;
610 	}
611 
612 	res = tee_alg_get_digest_size(TEE_DIGEST_HASH_TO_ALGO(algo),
613 				      &hash_size);
614 	if (res != TEE_SUCCESS)
615 		goto err;
616 
617 	if (msg_len != hash_size) {
618 		res = TEE_ERROR_BAD_PARAMETERS;
619 		goto err;
620 	}
621 
622 	mod_size = crypto_bignum_num_bytes(key->n);
623 	if (*sig_len < mod_size) {
624 		*sig_len = mod_size;
625 		res = TEE_ERROR_SHORT_BUFFER;
626 		goto err;
627 	}
628 	rsa.len = mod_size;
629 
630 	md_algo = tee_algo_to_mbedtls_hash_algo(algo);
631 	if (md_algo == MBEDTLS_MD_NONE) {
632 		res = TEE_ERROR_NOT_SUPPORTED;
633 		goto err;
634 	}
635 
636 	pk_info = mbedtls_pk_info_from_type(MBEDTLS_PK_RSA);
637 	if (!pk_info) {
638 		res = TEE_ERROR_NOT_SUPPORTED;
639 		goto err;
640 	}
641 
642 	mbedtls_rsa_set_padding(&rsa, lmd_padding, md_algo);
643 
644 	if (lmd_padding == MBEDTLS_RSA_PKCS_V15)
645 		lmd_res = pk_info->sign_func(&rsa, md_algo, msg, msg_len, sig,
646 					     sig_len, NULL, NULL);
647 	else
648 		lmd_res = pk_info->sign_func(&rsa, md_algo, msg, msg_len, sig,
649 					     sig_len, mbd_rand, NULL);
650 	if (lmd_res != 0) {
651 		FMSG("sign_func failed, returned 0x%x", -lmd_res);
652 		res = get_tee_result(lmd_res);
653 		goto err;
654 	}
655 	res = TEE_SUCCESS;
656 err:
657 	mbd_rsa_free(&rsa);
658 	return res;
659 }
660 
661 TEE_Result crypto_acipher_rsassa_verify(uint32_t algo,
662 					struct rsa_public_key *key,
663 					int salt_len __unused,
664 					const uint8_t *msg,
665 					size_t msg_len, const uint8_t *sig,
666 					size_t sig_len)
667 __weak __alias("sw_crypto_acipher_rsassa_verify");
668 
sw_crypto_acipher_rsassa_verify(uint32_t algo,struct rsa_public_key * key,int salt_len __unused,const uint8_t * msg,size_t msg_len,const uint8_t * sig,size_t sig_len)669 TEE_Result sw_crypto_acipher_rsassa_verify(uint32_t algo,
670 					   struct rsa_public_key *key,
671 					   int salt_len __unused,
672 					   const uint8_t *msg,
673 					   size_t msg_len, const uint8_t *sig,
674 					   size_t sig_len)
675 {
676 	TEE_Result res = TEE_SUCCESS;
677 	int lmd_res = 0;
678 	int lmd_padding = 0;
679 	size_t hash_size = 0;
680 	size_t bigint_size = 0;
681 	mbedtls_rsa_context rsa;
682 	const mbedtls_pk_info_t *pk_info = NULL;
683 	uint32_t md_algo = 0;
684 	struct ftmn ftmn = { };
685 	unsigned long arg_hash = 0;
686 
687 	/*
688 	 * The caller expects to call crypto_acipher_rsassa_verify(),
689 	 * update the hash as needed.
690 	 */
691 	FTMN_CALLEE_SWAP_HASH(FTMN_FUNC_HASH("crypto_acipher_rsassa_verify"));
692 
693 	memset(&rsa, 0, sizeof(rsa));
694 	mbedtls_rsa_init(&rsa, 0, 0);
695 
696 	rsa.E = *(mbedtls_mpi *)key->e;
697 	rsa.N = *(mbedtls_mpi *)key->n;
698 
699 	res = tee_alg_get_digest_size(TEE_DIGEST_HASH_TO_ALGO(algo),
700 				      &hash_size);
701 	if (res != TEE_SUCCESS)
702 		goto err;
703 
704 	if (msg_len != hash_size) {
705 		res = TEE_ERROR_BAD_PARAMETERS;
706 		goto err;
707 	}
708 
709 	bigint_size = crypto_bignum_num_bytes(key->n);
710 	if (sig_len < bigint_size) {
711 		res = TEE_ERROR_SIGNATURE_INVALID;
712 		goto err;
713 	}
714 
715 	rsa.len = bigint_size;
716 
717 	switch (algo) {
718 	case TEE_ALG_RSASSA_PKCS1_V1_5_MD5:
719 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA1:
720 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA224:
721 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA256:
722 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA384:
723 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA512:
724 		arg_hash = FTMN_FUNC_HASH("mbedtls_rsa_rsassa_pkcs1_v15_verify");
725 		lmd_padding = MBEDTLS_RSA_PKCS_V15;
726 		break;
727 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA1:
728 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA224:
729 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA256:
730 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA384:
731 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA512:
732 		arg_hash = FTMN_FUNC_HASH("mbedtls_rsa_rsassa_pss_verify_ext");
733 		lmd_padding = MBEDTLS_RSA_PKCS_V21;
734 		break;
735 	default:
736 		res = TEE_ERROR_BAD_PARAMETERS;
737 		goto err;
738 	}
739 
740 	md_algo = tee_algo_to_mbedtls_hash_algo(algo);
741 	if (md_algo == MBEDTLS_MD_NONE) {
742 		res = TEE_ERROR_NOT_SUPPORTED;
743 		goto err;
744 	}
745 
746 	pk_info = mbedtls_pk_info_from_type(MBEDTLS_PK_RSA);
747 	if (!pk_info) {
748 		res = TEE_ERROR_NOT_SUPPORTED;
749 		goto err;
750 	}
751 
752 	mbedtls_rsa_set_padding(&rsa, lmd_padding, md_algo);
753 
754 	FTMN_PUSH_LINKED_CALL(&ftmn, arg_hash);
755 	lmd_res = pk_info->verify_func(&rsa, md_algo, msg, msg_len,
756 				       sig, sig_len);
757 	if (!lmd_res)
758 		FTMN_SET_CHECK_RES_FROM_CALL(&ftmn, FTMN_INCR0, lmd_res);
759 	FTMN_POP_LINKED_CALL(&ftmn);
760 	if (lmd_res != 0) {
761 		FMSG("verify_func failed, returned 0x%x", -lmd_res);
762 		res = TEE_ERROR_SIGNATURE_INVALID;
763 		goto err;
764 	}
765 	res = TEE_SUCCESS;
766 	goto out;
767 
768 err:
769 	FTMN_SET_CHECK_RES_NOT_ZERO(&ftmn, FTMN_INCR0, res);
770 out:
771 	FTMN_CALLEE_DONE_CHECK(&ftmn, FTMN_INCR0, FTMN_STEP_COUNT(1), res);
772 	/* Reset mpi to skip freeing here, those mpis will be freed with key */
773 	mbedtls_mpi_init(&rsa.E);
774 	mbedtls_mpi_init(&rsa.N);
775 	mbedtls_rsa_free(&rsa);
776 	return res;
777 }
778