1 // SPDX-License-Identifier: BSD-2-Clause
2 /*
3  * Copyright (C) 2018, ARM Limited
4  * Copyright (C) 2019, Linaro Limited
5  */
6 
7 #include <assert.h>
8 #include <crypto/crypto.h>
9 #include <crypto/crypto_impl.h>
10 #include <mbedtls/ctr_drbg.h>
11 #include <mbedtls/entropy.h>
12 #include <mbedtls/pk.h>
13 #include <stdlib.h>
14 #include <string.h>
15 #include <tee/tee_cryp_utl.h>
16 #include <utee_defines.h>
17 #include <fault_mitigation.h>
18 
19 #include "mbed_helpers.h"
20 #include "../mbedtls/library/pk_wrap.h"
21 #include "../mbedtls/library/rsa_alt_helpers.h"
22 
get_tee_result(int lmd_res)23 static TEE_Result get_tee_result(int lmd_res)
24 {
25 	switch (lmd_res) {
26 	case 0:
27 		return TEE_SUCCESS;
28 	case MBEDTLS_ERR_RSA_PRIVATE_FAILED +
29 		MBEDTLS_ERR_MPI_BAD_INPUT_DATA:
30 	case MBEDTLS_ERR_RSA_BAD_INPUT_DATA:
31 	case MBEDTLS_ERR_RSA_INVALID_PADDING:
32 	case MBEDTLS_ERR_PK_TYPE_MISMATCH:
33 		return TEE_ERROR_BAD_PARAMETERS;
34 	case MBEDTLS_ERR_RSA_OUTPUT_TOO_LARGE:
35 		return TEE_ERROR_SHORT_BUFFER;
36 	default:
37 		return TEE_ERROR_BAD_STATE;
38 	}
39 }
40 
tee_algo_to_mbedtls_hash_algo(uint32_t algo)41 static uint32_t tee_algo_to_mbedtls_hash_algo(uint32_t algo)
42 {
43 	switch (algo) {
44 #if defined(CFG_CRYPTO_SHA1)
45 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA1:
46 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA1:
47 	case TEE_ALG_RSAES_PKCS1_OAEP_MGF1_SHA1:
48 	case TEE_ALG_SHA1:
49 	case TEE_ALG_DSA_SHA1:
50 	case TEE_ALG_HMAC_SHA1:
51 		return MBEDTLS_MD_SHA1;
52 #endif
53 #if defined(CFG_CRYPTO_MD5)
54 	case TEE_ALG_RSASSA_PKCS1_V1_5_MD5:
55 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_MD5:
56 	case TEE_ALG_RSAES_PKCS1_OAEP_MGF1_MD5:
57 	case TEE_ALG_MD5:
58 	case TEE_ALG_HMAC_MD5:
59 		return MBEDTLS_MD_MD5;
60 #endif
61 #if defined(CFG_CRYPTO_SHA224)
62 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA224:
63 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA224:
64 	case TEE_ALG_RSAES_PKCS1_OAEP_MGF1_SHA224:
65 	case TEE_ALG_SHA224:
66 	case TEE_ALG_DSA_SHA224:
67 	case TEE_ALG_HMAC_SHA224:
68 		return MBEDTLS_MD_SHA224;
69 #endif
70 #if defined(CFG_CRYPTO_SHA256)
71 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA256:
72 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA256:
73 	case TEE_ALG_RSAES_PKCS1_OAEP_MGF1_SHA256:
74 	case TEE_ALG_SHA256:
75 	case TEE_ALG_DSA_SHA256:
76 	case TEE_ALG_HMAC_SHA256:
77 		return MBEDTLS_MD_SHA256;
78 #endif
79 #if defined(CFG_CRYPTO_SHA384)
80 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA384:
81 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA384:
82 	case TEE_ALG_RSAES_PKCS1_OAEP_MGF1_SHA384:
83 	case TEE_ALG_SHA384:
84 	case TEE_ALG_HMAC_SHA384:
85 		return MBEDTLS_MD_SHA384;
86 #endif
87 #if defined(CFG_CRYPTO_SHA512)
88 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA512:
89 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA512:
90 	case TEE_ALG_RSAES_PKCS1_OAEP_MGF1_SHA512:
91 	case TEE_ALG_SHA512:
92 	case TEE_ALG_HMAC_SHA512:
93 		return MBEDTLS_MD_SHA512;
94 #endif
95 	default:
96 		return MBEDTLS_MD_NONE;
97 	}
98 }
99 
rsa_complete_from_key_pair(mbedtls_rsa_context * rsa,struct rsa_keypair * key)100 static TEE_Result rsa_complete_from_key_pair(mbedtls_rsa_context *rsa,
101 						      struct rsa_keypair *key)
102 {
103 	int lmd_res = 0;
104 
105 	rsa->E = *(mbedtls_mpi *)key->e;
106 	rsa->N = *(mbedtls_mpi *)key->n;
107 	rsa->D = *(mbedtls_mpi *)key->d;
108 	rsa->len = mbedtls_mpi_size(&rsa->N);
109 
110 	if (key->p && crypto_bignum_num_bytes(key->p)) {
111 		rsa->P = *(mbedtls_mpi *)key->p;
112 		rsa->Q = *(mbedtls_mpi *)key->q;
113 		rsa->QP = *(mbedtls_mpi *)key->qp;
114 		rsa->DP = *(mbedtls_mpi *)key->dp;
115 		rsa->DQ = *(mbedtls_mpi *)key->dq;
116 	} else {
117 		mbedtls_mpi_init_mempool(&rsa->P);
118 		mbedtls_mpi_init_mempool(&rsa->Q);
119 		mbedtls_mpi_init_mempool(&rsa->QP);
120 		mbedtls_mpi_init_mempool(&rsa->DP);
121 		mbedtls_mpi_init_mempool(&rsa->DQ);
122 
123 		lmd_res = mbedtls_rsa_deduce_primes(&rsa->N, &rsa->E, &rsa->D,
124 						    &rsa->P, &rsa->Q);
125 		if (lmd_res) {
126 			DMSG("mbedtls_rsa_deduce_primes() returned 0x%x",
127 			     -lmd_res);
128 			goto err;
129 		}
130 
131 		lmd_res = mbedtls_rsa_deduce_crt(&rsa->P, &rsa->Q, &rsa->D,
132 						 &rsa->DP, &rsa->DQ, &rsa->QP);
133 		if (lmd_res) {
134 			DMSG("mbedtls_rsa_deduce_crt() returned 0x%x",
135 			     -lmd_res);
136 			goto err;
137 		}
138 	}
139 
140 	return TEE_SUCCESS;
141 err:
142 	mbedtls_mpi_free(&rsa->P);
143 	mbedtls_mpi_free(&rsa->Q);
144 	mbedtls_mpi_free(&rsa->QP);
145 	mbedtls_mpi_free(&rsa->DP);
146 	mbedtls_mpi_free(&rsa->DQ);
147 
148 	return get_tee_result(lmd_res);
149 }
150 
rsa_init_and_complete_from_key_pair(mbedtls_rsa_context * rsa,struct rsa_keypair * key)151 static TEE_Result rsa_init_and_complete_from_key_pair(mbedtls_rsa_context *rsa,
152 						      struct rsa_keypair *key)
153 {
154 	mbedtls_rsa_init(rsa);
155 
156 	return rsa_complete_from_key_pair(rsa, key);
157 }
158 
mbd_rsa_free(mbedtls_rsa_context * rsa,struct rsa_keypair * key)159 static void mbd_rsa_free(mbedtls_rsa_context *rsa, struct rsa_keypair *key)
160 {
161 	/*
162 	 * The mpi's in @rsa are initialized from @key, but the primes and
163 	 * CRT part are generated if @key doesn't have them. When freeing
164 	 * we should only free the generated mpi's, the ones copied are
165 	 * reset instead.
166 	 */
167 	mbedtls_mpi_init(&rsa->E);
168 	mbedtls_mpi_init(&rsa->N);
169 	mbedtls_mpi_init(&rsa->D);
170 	if (key->p && crypto_bignum_num_bytes(key->p)) {
171 		mbedtls_mpi_init(&rsa->P);
172 		mbedtls_mpi_init(&rsa->Q);
173 		mbedtls_mpi_init(&rsa->QP);
174 		mbedtls_mpi_init(&rsa->DP);
175 		mbedtls_mpi_init(&rsa->DQ);
176 	}
177 	mbedtls_rsa_free(rsa);
178 }
179 
mbd_pk_free(mbedtls_pk_context * ctx,struct rsa_keypair * key)180 static void mbd_pk_free(mbedtls_pk_context *ctx, struct rsa_keypair *key)
181 {
182 	mbedtls_rsa_context *rsa = ctx->pk_ctx;
183 
184 	/*
185 	 * Executing mbedtls_rsa_free twice is fine, as it does nothing if its
186 	 * argument is NULL.
187 	 */
188 	mbd_rsa_free(rsa, key);
189 	mbedtls_pk_free(ctx);
190 }
191 
192 TEE_Result crypto_acipher_alloc_rsa_keypair(struct rsa_keypair *s,
193 					    size_t key_size_bits)
194 __weak __alias("sw_crypto_acipher_alloc_rsa_keypair");
195 
sw_crypto_acipher_alloc_rsa_keypair(struct rsa_keypair * s,size_t key_size_bits)196 TEE_Result sw_crypto_acipher_alloc_rsa_keypair(struct rsa_keypair *s,
197 					       size_t key_size_bits)
198 {
199 	memset(s, 0, sizeof(*s));
200 	s->e = crypto_bignum_allocate(key_size_bits);
201 	if (!s->e)
202 		goto err;
203 	s->d = crypto_bignum_allocate(key_size_bits);
204 	if (!s->d)
205 		goto err;
206 	s->n = crypto_bignum_allocate(key_size_bits);
207 	if (!s->n)
208 		goto err;
209 	s->p = crypto_bignum_allocate(key_size_bits);
210 	if (!s->p)
211 		goto err;
212 	s->q = crypto_bignum_allocate(key_size_bits);
213 	if (!s->q)
214 		goto err;
215 	s->qp = crypto_bignum_allocate(key_size_bits);
216 	if (!s->qp)
217 		goto err;
218 	s->dp = crypto_bignum_allocate(key_size_bits);
219 	if (!s->dp)
220 		goto err;
221 	s->dq = crypto_bignum_allocate(key_size_bits);
222 	if (!s->dq)
223 		goto err;
224 
225 	return TEE_SUCCESS;
226 err:
227 	crypto_acipher_free_rsa_keypair(s);
228 	return TEE_ERROR_OUT_OF_MEMORY;
229 }
230 
231 TEE_Result crypto_acipher_alloc_rsa_public_key(struct rsa_public_key *s,
232 					       size_t key_size_bits)
233 __weak __alias("sw_crypto_acipher_alloc_rsa_public_key");
234 
sw_crypto_acipher_alloc_rsa_public_key(struct rsa_public_key * s,size_t key_size_bits)235 TEE_Result sw_crypto_acipher_alloc_rsa_public_key(struct rsa_public_key *s,
236 						  size_t key_size_bits)
237 {
238 	memset(s, 0, sizeof(*s));
239 	s->e = crypto_bignum_allocate(key_size_bits);
240 	if (!s->e)
241 		return TEE_ERROR_OUT_OF_MEMORY;
242 	s->n = crypto_bignum_allocate(key_size_bits);
243 	if (!s->n)
244 		goto err;
245 	return TEE_SUCCESS;
246 err:
247 	crypto_bignum_free(&s->e);
248 	return TEE_ERROR_OUT_OF_MEMORY;
249 }
250 
251 void crypto_acipher_free_rsa_public_key(struct rsa_public_key *s)
252 __weak __alias("sw_crypto_acipher_free_rsa_public_key");
253 
sw_crypto_acipher_free_rsa_public_key(struct rsa_public_key * s)254 void sw_crypto_acipher_free_rsa_public_key(struct rsa_public_key *s)
255 {
256 	if (!s)
257 		return;
258 	crypto_bignum_free(&s->n);
259 	crypto_bignum_free(&s->e);
260 }
261 
262 void crypto_acipher_free_rsa_keypair(struct rsa_keypair *s)
263 __weak __alias("sw_crypto_acipher_free_rsa_keypair");
264 
sw_crypto_acipher_free_rsa_keypair(struct rsa_keypair * s)265 void sw_crypto_acipher_free_rsa_keypair(struct rsa_keypair *s)
266 {
267 	if (!s)
268 		return;
269 	crypto_bignum_free(&s->e);
270 	crypto_bignum_free(&s->d);
271 	crypto_bignum_free(&s->n);
272 	crypto_bignum_free(&s->p);
273 	crypto_bignum_free(&s->q);
274 	crypto_bignum_free(&s->qp);
275 	crypto_bignum_free(&s->dp);
276 	crypto_bignum_free(&s->dq);
277 }
278 
279 TEE_Result crypto_acipher_gen_rsa_key(struct rsa_keypair *key,
280 				      size_t key_size)
281 __weak __alias("sw_crypto_acipher_gen_rsa_key");
282 
sw_crypto_acipher_gen_rsa_key(struct rsa_keypair * key,size_t key_size)283 TEE_Result sw_crypto_acipher_gen_rsa_key(struct rsa_keypair *key,
284 					 size_t key_size)
285 {
286 	TEE_Result res = TEE_SUCCESS;
287 	mbedtls_rsa_context rsa;
288 	mbedtls_ctr_drbg_context rngctx;
289 	int lmd_res = 0;
290 	uint32_t e = 0;
291 
292 	mbedtls_ctr_drbg_init(&rngctx);
293 	if (mbedtls_ctr_drbg_seed(&rngctx, mbd_rand, NULL, NULL, 0))
294 		return TEE_ERROR_BAD_STATE;
295 
296 	memset(&rsa, 0, sizeof(rsa));
297 	mbedtls_rsa_init(&rsa);
298 
299 	/* get the public exponent */
300 	mbedtls_mpi_write_binary((mbedtls_mpi *)key->e,
301 				 (unsigned char *)&e, sizeof(uint32_t));
302 
303 	e = TEE_U32_FROM_BIG_ENDIAN(e);
304 	lmd_res = mbedtls_rsa_gen_key(&rsa, mbedtls_ctr_drbg_random, &rngctx,
305 				      key_size, (int)e);
306 	mbedtls_ctr_drbg_free(&rngctx);
307 	if (lmd_res != 0) {
308 		res = get_tee_result(lmd_res);
309 	} else if ((size_t)mbedtls_mpi_bitlen(&rsa.N) != key_size) {
310 		res = TEE_ERROR_BAD_PARAMETERS;
311 	} else {
312 		/* Copy the key */
313 		crypto_bignum_copy(key->e, (void *)&rsa.E);
314 		crypto_bignum_copy(key->d, (void *)&rsa.D);
315 		crypto_bignum_copy(key->n, (void *)&rsa.N);
316 		crypto_bignum_copy(key->p, (void *)&rsa.P);
317 
318 		crypto_bignum_copy(key->q, (void *)&rsa.Q);
319 		crypto_bignum_copy(key->qp, (void *)&rsa.QP);
320 		crypto_bignum_copy(key->dp, (void *)&rsa.DP);
321 		crypto_bignum_copy(key->dq, (void *)&rsa.DQ);
322 
323 		res = TEE_SUCCESS;
324 	}
325 
326 	mbedtls_rsa_free(&rsa);
327 
328 	return res;
329 }
330 
331 TEE_Result crypto_acipher_rsanopad_encrypt(struct rsa_public_key *key,
332 					   const uint8_t *src,
333 					   size_t src_len, uint8_t *dst,
334 					   size_t *dst_len)
335 __weak __alias("sw_crypto_acipher_rsanopad_encrypt");
336 
sw_crypto_acipher_rsanopad_encrypt(struct rsa_public_key * key,const uint8_t * src,size_t src_len,uint8_t * dst,size_t * dst_len)337 TEE_Result sw_crypto_acipher_rsanopad_encrypt(struct rsa_public_key *key,
338 					      const uint8_t *src,
339 					      size_t src_len, uint8_t *dst,
340 					      size_t *dst_len)
341 {
342 	TEE_Result res = TEE_SUCCESS;
343 	mbedtls_rsa_context rsa;
344 	int lmd_res = 0;
345 	uint8_t *buf = NULL;
346 	unsigned long blen = 0;
347 	unsigned long offset = 0;
348 
349 	memset(&rsa, 0, sizeof(rsa));
350 	mbedtls_rsa_init(&rsa);
351 
352 	rsa.E = *(mbedtls_mpi *)key->e;
353 	rsa.N = *(mbedtls_mpi *)key->n;
354 
355 	rsa.len = crypto_bignum_num_bytes((void *)&rsa.N);
356 
357 	blen = CFG_CORE_BIGNUM_MAX_BITS / 8;
358 	buf = malloc(blen);
359 	if (!buf) {
360 		res = TEE_ERROR_OUT_OF_MEMORY;
361 		goto out;
362 	}
363 
364 	memset(buf, 0, blen);
365 	memcpy(buf + rsa.len - src_len, src, src_len);
366 
367 	lmd_res = mbedtls_rsa_public(&rsa, buf, buf);
368 	if (lmd_res != 0) {
369 		FMSG("mbedtls_rsa_public() returned 0x%x", -lmd_res);
370 		res = get_tee_result(lmd_res);
371 		goto out;
372 	}
373 
374 	/* Remove the zero-padding (leave one zero if buff is all zeroes) */
375 	offset = 0;
376 	while ((offset < rsa.len - 1) && (buf[offset] == 0))
377 		offset++;
378 
379 	if (*dst_len < rsa.len - offset) {
380 		*dst_len = rsa.len - offset;
381 		res = TEE_ERROR_SHORT_BUFFER;
382 		goto out;
383 	}
384 	*dst_len = rsa.len - offset;
385 	memcpy(dst, buf + offset, *dst_len);
386 out:
387 	free(buf);
388 	/* Reset mpi to skip freeing here, those mpis will be freed with key */
389 	mbedtls_mpi_init(&rsa.E);
390 	mbedtls_mpi_init(&rsa.N);
391 	mbedtls_rsa_free(&rsa);
392 
393 	return res;
394 }
395 
396 TEE_Result crypto_acipher_rsanopad_decrypt(struct rsa_keypair *key,
397 					   const uint8_t *src,
398 					   size_t src_len, uint8_t *dst,
399 					   size_t *dst_len)
400 __weak __alias("sw_crypto_acipher_rsanopad_decrypt");
401 
sw_crypto_acipher_rsanopad_decrypt(struct rsa_keypair * key,const uint8_t * src,size_t src_len,uint8_t * dst,size_t * dst_len)402 TEE_Result sw_crypto_acipher_rsanopad_decrypt(struct rsa_keypair *key,
403 					      const uint8_t *src,
404 					      size_t src_len, uint8_t *dst,
405 					      size_t *dst_len)
406 {
407 	TEE_Result res = TEE_SUCCESS;
408 	mbedtls_rsa_context rsa = { };
409 	int lmd_res = 0;
410 	uint8_t *buf = NULL;
411 	unsigned long blen = 0;
412 	unsigned long offset = 0;
413 
414 	res = rsa_init_and_complete_from_key_pair(&rsa, key);
415 	if (res)
416 		return res;
417 
418 	blen = CFG_CORE_BIGNUM_MAX_BITS / 8;
419 	buf = malloc(blen);
420 	if (!buf) {
421 		res = TEE_ERROR_OUT_OF_MEMORY;
422 		goto out;
423 	}
424 
425 	memset(buf, 0, blen);
426 	memcpy(buf + rsa.len - src_len, src, src_len);
427 
428 	lmd_res = mbedtls_rsa_private(&rsa, mbd_rand, NULL, buf, buf);
429 	if (lmd_res != 0) {
430 		FMSG("mbedtls_rsa_private() returned 0x%x", -lmd_res);
431 		res = get_tee_result(lmd_res);
432 		goto out;
433 	}
434 
435 	/* Remove the zero-padding (leave one zero if buff is all zeroes) */
436 	offset = 0;
437 	while ((offset < rsa.len - 1) && (buf[offset] == 0))
438 		offset++;
439 
440 	if (*dst_len < rsa.len - offset) {
441 		*dst_len = rsa.len - offset;
442 		res = TEE_ERROR_SHORT_BUFFER;
443 		goto out;
444 	}
445 	*dst_len = rsa.len - offset;
446 	memcpy(dst, (char *)buf + offset, *dst_len);
447 out:
448 	if (buf)
449 		free(buf);
450 	mbd_rsa_free(&rsa, key);
451 	return res;
452 }
453 
454 TEE_Result crypto_acipher_rsaes_decrypt(uint32_t algo,
455 					struct rsa_keypair *key,
456 					const uint8_t *label __unused,
457 					size_t label_len __unused,
458 					uint32_t mgf_algo,
459 					const uint8_t *src, size_t src_len,
460 					uint8_t *dst, size_t *dst_len)
461 __weak __alias("sw_crypto_acipher_rsaes_decrypt");
462 
sw_crypto_acipher_rsaes_decrypt(uint32_t algo,struct rsa_keypair * key,const uint8_t * label __unused,size_t label_len __unused,uint32_t mgf_algo,const uint8_t * src,size_t src_len,uint8_t * dst,size_t * dst_len)463 TEE_Result sw_crypto_acipher_rsaes_decrypt(uint32_t algo,
464 					   struct rsa_keypair *key,
465 					   const uint8_t *label __unused,
466 					   size_t label_len __unused,
467 					   uint32_t mgf_algo,
468 					   const uint8_t *src, size_t src_len,
469 					   uint8_t *dst, size_t *dst_len)
470 {
471 	TEE_Result res = TEE_SUCCESS;
472 	int lmd_res = 0;
473 	int lmd_padding = 0;
474 	size_t blen = 0;
475 	size_t mod_size = 0;
476 	void *buf = NULL;
477 	mbedtls_pk_context ctx = { };
478 	mbedtls_rsa_context *rsa = NULL;
479 	const mbedtls_pk_info_t *pk_info = NULL;
480 	uint32_t md_algo = MBEDTLS_MD_NONE;
481 
482 	pk_info = mbedtls_pk_info_from_type(MBEDTLS_PK_RSA);
483 	if (!pk_info) {
484 		return TEE_ERROR_NOT_SUPPORTED;
485 	}
486 
487 	mbedtls_pk_init(&ctx);
488 	res = mbedtls_pk_setup(&ctx, pk_info);
489 	if (res != 0) {
490 		goto out;
491 	}
492 
493 	rsa = ctx.pk_ctx;
494 	res = rsa_complete_from_key_pair(rsa, key);
495 	if (res)
496 		return res;
497 
498 	/*
499 	 * Use a temporary buffer since we don't know exactly how large
500 	 * the required size of the out buffer without doing a partial
501 	 * decrypt. We know the upper bound though.
502 	 */
503 	if (algo == TEE_ALG_RSAES_PKCS1_V1_5) {
504 		mod_size = crypto_bignum_num_bytes(key->n);
505 		blen = mod_size - 11;
506 		lmd_padding = MBEDTLS_RSA_PKCS_V15;
507 	} else {
508 		/* Decoded message is always shorter than encrypted message */
509 		blen = src_len;
510 		lmd_padding = MBEDTLS_RSA_PKCS_V21;
511 	}
512 
513 	buf = malloc(blen);
514 	if (!buf) {
515 		res = TEE_ERROR_OUT_OF_MEMORY;
516 		goto out;
517 	}
518 
519 	/*
520 	 * TEE_ALG_RSAES_PKCS1_V1_5 is invalid in hash. But its hash algo will
521 	 * not be used in rsa, so skip it here.
522 	 */
523 	if (algo != TEE_ALG_RSAES_PKCS1_V1_5) {
524 		md_algo = tee_algo_to_mbedtls_hash_algo(algo);
525 		if (md_algo == MBEDTLS_MD_NONE) {
526 			res = TEE_ERROR_NOT_SUPPORTED;
527 			goto out;
528 		}
529 		if (md_algo != tee_algo_to_mbedtls_hash_algo(mgf_algo)) {
530 			DMSG("Using a different MGF1 algorithm is not supported");
531 			res = TEE_ERROR_NOT_SUPPORTED;
532 			goto out;
533 		}
534 	}
535 
536 	mbedtls_rsa_set_padding(rsa, lmd_padding, md_algo);
537 
538 	lmd_res = pk_info->decrypt_func(&ctx, src, src_len, buf, &blen,
539 					blen, mbd_rand, NULL);
540 	if (lmd_res != 0) {
541 		FMSG("decrypt_func() returned 0x%x", -lmd_res);
542 		res = get_tee_result(lmd_res);
543 		goto out;
544 	}
545 
546 	if (*dst_len < blen) {
547 		*dst_len = blen;
548 		res = TEE_ERROR_SHORT_BUFFER;
549 		goto out;
550 	}
551 
552 	res = TEE_SUCCESS;
553 	*dst_len = blen;
554 	memcpy(dst, buf, blen);
555 out:
556 	if (buf)
557 		free(buf);
558 	mbd_pk_free(&ctx, key);
559 	return res;
560 }
561 
562 TEE_Result crypto_acipher_rsaes_encrypt(uint32_t algo,
563 					struct rsa_public_key *key,
564 					const uint8_t *label __unused,
565 					size_t label_len __unused,
566 					uint32_t mgf_algo,
567 					const uint8_t *src, size_t src_len,
568 					uint8_t *dst, size_t *dst_len)
569 __weak __alias("sw_crypto_acipher_rsaes_encrypt");
570 
sw_crypto_acipher_rsaes_encrypt(uint32_t algo,struct rsa_public_key * key,const uint8_t * label __unused,size_t label_len __unused,uint32_t mgf_algo,const uint8_t * src,size_t src_len,uint8_t * dst,size_t * dst_len)571 TEE_Result sw_crypto_acipher_rsaes_encrypt(uint32_t algo,
572 					   struct rsa_public_key *key,
573 					   const uint8_t *label __unused,
574 					   size_t label_len __unused,
575 					   uint32_t mgf_algo,
576 					   const uint8_t *src, size_t src_len,
577 					   uint8_t *dst, size_t *dst_len)
578 {
579 	TEE_Result res = TEE_SUCCESS;
580 	int lmd_res = 0;
581 	int lmd_padding = 0;
582 	size_t mod_size = 0;
583 	mbedtls_pk_context ctx = { };
584 	mbedtls_rsa_context *rsa = NULL;
585 	const mbedtls_pk_info_t *pk_info = NULL;
586 	uint32_t md_algo = MBEDTLS_MD_NONE;
587 
588 	memset(&ctx, 0, sizeof(ctx));
589 
590 	pk_info = mbedtls_pk_info_from_type(MBEDTLS_PK_RSA);
591 	if (!pk_info) {
592 		return TEE_ERROR_NOT_SUPPORTED;
593 	}
594 
595 	mbedtls_pk_init(&ctx);
596 	res = mbedtls_pk_setup(&ctx, pk_info);
597 	if (res != 0) {
598 		goto out;
599 	}
600 
601 	rsa = ctx.pk_ctx;
602 
603 	rsa->E = *(mbedtls_mpi *)key->e;
604 	rsa->N = *(mbedtls_mpi *)key->n;
605 
606 	mod_size = crypto_bignum_num_bytes(key->n);
607 	if (*dst_len < mod_size) {
608 		*dst_len = mod_size;
609 		res = TEE_ERROR_SHORT_BUFFER;
610 		goto out;
611 	}
612 	*dst_len = mod_size;
613 	rsa->len = mod_size;
614 
615 	if (algo == TEE_ALG_RSAES_PKCS1_V1_5)
616 		lmd_padding = MBEDTLS_RSA_PKCS_V15;
617 	else
618 		lmd_padding = MBEDTLS_RSA_PKCS_V21;
619 
620 	/*
621 	 * TEE_ALG_RSAES_PKCS1_V1_5 is invalid in hash. But its hash algo will
622 	 * not be used in rsa, so skip it here.
623 	 */
624 	if (algo != TEE_ALG_RSAES_PKCS1_V1_5) {
625 		md_algo = tee_algo_to_mbedtls_hash_algo(algo);
626 		/* Using a different MGF1 algorithm is not supported. */
627 		if (md_algo == MBEDTLS_MD_NONE ||
628 		    md_algo != tee_algo_to_mbedtls_hash_algo(mgf_algo)) {
629 			res = TEE_ERROR_NOT_SUPPORTED;
630 			goto out;
631 		}
632 	}
633 
634 	mbedtls_rsa_set_padding(rsa, lmd_padding, md_algo);
635 
636 	lmd_res = pk_info->encrypt_func(&ctx, src, src_len, dst, dst_len,
637 					*dst_len, mbd_rand, NULL);
638 	if (lmd_res != 0) {
639 		FMSG("encrypt_func() returned 0x%x", -lmd_res);
640 		res = get_tee_result(lmd_res);
641 		goto out;
642 	}
643 	res = TEE_SUCCESS;
644 out:
645 	/* Reset mpi to skip freeing here, those mpis will be freed with key */
646 	mbedtls_mpi_init(&rsa->E);
647 	mbedtls_mpi_init(&rsa->N);
648 	mbedtls_pk_free(&ctx);
649 	return res;
650 }
651 
652 TEE_Result crypto_acipher_rsassa_sign(uint32_t algo, struct rsa_keypair *key,
653 				      int salt_len __unused,
654 				      const uint8_t *msg, size_t msg_len,
655 				      uint8_t *sig, size_t *sig_len)
656 __weak __alias("sw_crypto_acipher_rsassa_sign");
657 
sw_crypto_acipher_rsassa_sign(uint32_t algo,struct rsa_keypair * key,int salt_len __unused,const uint8_t * msg,size_t msg_len,uint8_t * sig,size_t * sig_len)658 TEE_Result sw_crypto_acipher_rsassa_sign(uint32_t algo, struct rsa_keypair *key,
659 					 int salt_len __unused,
660 					 const uint8_t *msg, size_t msg_len,
661 					 uint8_t *sig, size_t *sig_len)
662 {
663 	TEE_Result res = TEE_SUCCESS;
664 	int lmd_res = 0;
665 	int lmd_padding = 0;
666 	size_t mod_size = 0;
667 	size_t hash_size = 0;
668 	mbedtls_pk_context ctx = { };
669 	mbedtls_rsa_context *rsa = NULL;
670 	const mbedtls_pk_info_t *pk_info = NULL;
671 	uint32_t md_algo = 0;
672 
673 	memset(&ctx, 0, sizeof(ctx));
674 
675 	pk_info = mbedtls_pk_info_from_type(MBEDTLS_PK_RSA);
676 	if (!pk_info) {
677 		return TEE_ERROR_NOT_SUPPORTED;
678 	}
679 
680 	mbedtls_pk_init(&ctx);
681 	res = mbedtls_pk_setup(&ctx, pk_info);
682 	if (res != 0) {
683 		goto err;
684 	}
685 
686 	rsa = ctx.pk_ctx;
687 	res = rsa_complete_from_key_pair(rsa, key);
688 	if (res)
689 		return res;
690 
691 	switch (algo) {
692 	case TEE_ALG_RSASSA_PKCS1_V1_5_MD5:
693 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA1:
694 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA224:
695 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA256:
696 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA384:
697 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA512:
698 		lmd_padding = MBEDTLS_RSA_PKCS_V15;
699 		break;
700 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_MD5:
701 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA1:
702 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA224:
703 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA256:
704 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA384:
705 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA512:
706 		lmd_padding = MBEDTLS_RSA_PKCS_V21;
707 		break;
708 	default:
709 		res = TEE_ERROR_BAD_PARAMETERS;
710 		goto err;
711 	}
712 
713 	res = tee_alg_get_digest_size(TEE_DIGEST_HASH_TO_ALGO(algo),
714 				      &hash_size);
715 	if (res != TEE_SUCCESS)
716 		goto err;
717 
718 	if (msg_len != hash_size) {
719 		res = TEE_ERROR_BAD_PARAMETERS;
720 		goto err;
721 	}
722 
723 	mod_size = crypto_bignum_num_bytes(key->n);
724 	if (*sig_len < mod_size) {
725 		*sig_len = mod_size;
726 		res = TEE_ERROR_SHORT_BUFFER;
727 		goto err;
728 	}
729 	rsa->len = mod_size;
730 
731 	md_algo = tee_algo_to_mbedtls_hash_algo(algo);
732 	if (md_algo == MBEDTLS_MD_NONE) {
733 		res = TEE_ERROR_NOT_SUPPORTED;
734 		goto err;
735 	}
736 
737 
738 	mbedtls_rsa_set_padding(rsa, lmd_padding, md_algo);
739 
740 	lmd_res = pk_info->sign_func(&ctx, md_algo, msg, msg_len, sig,
741 				     *sig_len, sig_len, mbd_rand, NULL);
742 	if (lmd_res != 0) {
743 		FMSG("sign_func failed, returned 0x%x", -lmd_res);
744 		res = get_tee_result(lmd_res);
745 		goto err;
746 	}
747 	res = TEE_SUCCESS;
748 err:
749 	mbd_pk_free(&ctx, key);
750 	return res;
751 }
752 
753 TEE_Result crypto_acipher_rsassa_verify(uint32_t algo,
754 					struct rsa_public_key *key,
755 					int salt_len __unused,
756 					const uint8_t *msg,
757 					size_t msg_len, const uint8_t *sig,
758 					size_t sig_len)
759 __weak __alias("sw_crypto_acipher_rsassa_verify");
760 
sw_crypto_acipher_rsassa_verify(uint32_t algo,struct rsa_public_key * key,int salt_len __unused,const uint8_t * msg,size_t msg_len,const uint8_t * sig,size_t sig_len)761 TEE_Result sw_crypto_acipher_rsassa_verify(uint32_t algo,
762 					   struct rsa_public_key *key,
763 					   int salt_len __unused,
764 					   const uint8_t *msg,
765 					   size_t msg_len, const uint8_t *sig,
766 					   size_t sig_len)
767 {
768 	TEE_Result res = TEE_SUCCESS;
769 	int lmd_res = 0;
770 	int lmd_padding = 0;
771 	size_t hash_size = 0;
772 	size_t bigint_size = 0;
773 	mbedtls_pk_context ctx = { };
774 	mbedtls_rsa_context *rsa = NULL;
775 	const mbedtls_pk_info_t *pk_info = NULL;
776 	uint32_t md_algo = 0;
777 	struct ftmn ftmn = { };
778 	unsigned long arg_hash = 0;
779 
780 	/*
781 	 * The caller expects to call crypto_acipher_rsassa_verify(),
782 	 * update the hash as needed.
783 	 */
784 	FTMN_CALLEE_SWAP_HASH(FTMN_FUNC_HASH("crypto_acipher_rsassa_verify"));
785 
786 	memset(&ctx, 0, sizeof(ctx));
787 
788 	pk_info = mbedtls_pk_info_from_type(MBEDTLS_PK_RSA);
789 	if (!pk_info) {
790 		return TEE_ERROR_NOT_SUPPORTED;
791 	}
792 
793 	mbedtls_pk_init(&ctx);
794 	res = mbedtls_pk_setup(&ctx, pk_info);
795 	if (res != 0) {
796 		goto err;
797 	}
798 
799 	rsa = ctx.pk_ctx;
800 
801 	rsa->E = *(mbedtls_mpi *)key->e;
802 	rsa->N = *(mbedtls_mpi *)key->n;
803 
804 	res = tee_alg_get_digest_size(TEE_DIGEST_HASH_TO_ALGO(algo),
805 				      &hash_size);
806 	if (res != TEE_SUCCESS)
807 		goto err;
808 
809 	if (msg_len != hash_size) {
810 		res = TEE_ERROR_BAD_PARAMETERS;
811 		goto err;
812 	}
813 
814 	bigint_size = crypto_bignum_num_bytes(key->n);
815 	if (sig_len < bigint_size) {
816 		res = TEE_ERROR_SIGNATURE_INVALID;
817 		goto err;
818 	}
819 
820 	rsa->len = bigint_size;
821 
822 	switch (algo) {
823 	case TEE_ALG_RSASSA_PKCS1_V1_5_MD5:
824 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA1:
825 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA224:
826 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA256:
827 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA384:
828 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA512:
829 		arg_hash = FTMN_FUNC_HASH("mbedtls_rsa_rsassa_pkcs1_v15_verify");
830 		lmd_padding = MBEDTLS_RSA_PKCS_V15;
831 		break;
832 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_MD5:
833 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA1:
834 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA224:
835 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA256:
836 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA384:
837 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA512:
838 		arg_hash = FTMN_FUNC_HASH("mbedtls_rsa_rsassa_pss_verify_ext");
839 		lmd_padding = MBEDTLS_RSA_PKCS_V21;
840 		break;
841 	default:
842 		res = TEE_ERROR_BAD_PARAMETERS;
843 		goto err;
844 	}
845 
846 	md_algo = tee_algo_to_mbedtls_hash_algo(algo);
847 	if (md_algo == MBEDTLS_MD_NONE) {
848 		res = TEE_ERROR_NOT_SUPPORTED;
849 		goto err;
850 	}
851 
852 	mbedtls_rsa_set_padding(rsa, lmd_padding, md_algo);
853 
854 	FTMN_PUSH_LINKED_CALL(&ftmn, arg_hash);
855 	lmd_res = pk_info->verify_func(&ctx, md_algo, msg, msg_len,
856 	                               sig, sig_len);
857 	if (!lmd_res)
858 		FTMN_SET_CHECK_RES_FROM_CALL(&ftmn, FTMN_INCR0, lmd_res);
859 	FTMN_POP_LINKED_CALL(&ftmn);
860 	if (lmd_res != 0) {
861 		FMSG("verify_func failed, returned 0x%x", -lmd_res);
862 		res = TEE_ERROR_SIGNATURE_INVALID;
863 		goto err;
864 	}
865 	res = TEE_SUCCESS;
866 	goto out;
867 
868 err:
869 	FTMN_SET_CHECK_RES_NOT_ZERO(&ftmn, FTMN_INCR0, res);
870 out:
871 	FTMN_CALLEE_DONE_CHECK(&ftmn, FTMN_INCR0, FTMN_STEP_COUNT(1), res);
872 	/* Reset mpi to skip freeing here, those mpis will be freed with key */
873 	mbedtls_mpi_init(&rsa->E);
874 	mbedtls_mpi_init(&rsa->N);
875 	mbedtls_pk_free(&ctx);
876 	return res;
877 }
878