1 // SPDX-License-Identifier: BSD-2-Clause
2 /*
3  * Copyright (c) 2018, Linaro Limited
4  */
5 
6 #include <crypto/crypto.h>
7 #include <kernel/panic.h>
8 #include <mbedtls/bignum.h>
9 #include <mempool.h>
10 #include <stdlib.h>
11 #include <string.h>
12 #include <tomcrypt_private.h>
13 #include <tomcrypt_mp.h>
14 #include <util.h>
15 
16 #if defined(_CFG_CORE_LTC_PAGER)
17 #include <mm/core_mmu.h>
18 #include <mm/tee_pager.h>
19 #endif
20 
21 /* Size needed for xtest to pass reliably on both ARM32 and ARM64 */
22 #define MPI_MEMPOOL_SIZE	(46 * 1024)
23 
24 /* From mbedtls/library/bignum.c */
25 #define ciL		(sizeof(mbedtls_mpi_uint))	/* chars in limb  */
26 #define biL		(ciL << 3)			/* bits  in limb  */
27 #define BITS_TO_LIMBS(i)	((i) / biL + ((i) % biL != 0))
28 
29 #if defined(_CFG_CORE_LTC_PAGER)
30 /* allocate pageable_zi vmem for mp scratch memory pool */
get_mp_scratch_memory_pool(void)31 static struct mempool *get_mp_scratch_memory_pool(void)
32 {
33 	size_t size;
34 	void *data;
35 
36 	size = ROUNDUP(MPI_MEMPOOL_SIZE, SMALL_PAGE_SIZE);
37 	data = tee_pager_alloc(size);
38 	if (!data)
39 		panic();
40 
41 	return mempool_alloc_pool(data, size, tee_pager_release_phys);
42 }
43 #else /* _CFG_CORE_LTC_PAGER */
get_mp_scratch_memory_pool(void)44 static struct mempool *get_mp_scratch_memory_pool(void)
45 {
46 	static uint8_t data[MPI_MEMPOOL_SIZE] __aligned(MEMPOOL_ALIGN);
47 
48 	return mempool_alloc_pool(data, sizeof(data), NULL);
49 }
50 #endif
51 
init_mp_tomcrypt(void)52 void init_mp_tomcrypt(void)
53 {
54 	struct mempool *p = get_mp_scratch_memory_pool();
55 
56 	if (!p)
57 		panic();
58 	mbedtls_mpi_mempool = p;
59 	assert(!mempool_default);
60 	mempool_default = p;
61 }
62 
init(void ** a)63 static int init(void **a)
64 {
65 	mbedtls_mpi *bn = mempool_alloc(mbedtls_mpi_mempool, sizeof(*bn));
66 
67 	if (!bn)
68 		return CRYPT_MEM;
69 
70 	mbedtls_mpi_init_mempool(bn);
71 	*a = bn;
72 	return CRYPT_OK;
73 }
74 
init_size(int size_bits __unused,void ** a)75 static int init_size(int size_bits __unused, void **a)
76 {
77 	return init(a);
78 }
79 
deinit(void * a)80 static void deinit(void *a)
81 {
82 	mbedtls_mpi_free((mbedtls_mpi *)a);
83 	mempool_free(mbedtls_mpi_mempool, a);
84 }
85 
neg(void * a,void * b)86 static int neg(void *a, void *b)
87 {
88 	if (mbedtls_mpi_copy(b, a))
89 		return CRYPT_MEM;
90 	((mbedtls_mpi *)b)->s *= -1;
91 	return CRYPT_OK;
92 }
93 
copy(void * a,void * b)94 static int copy(void *a, void *b)
95 {
96 	if (mbedtls_mpi_copy(b, a))
97 		return CRYPT_MEM;
98 	return CRYPT_OK;
99 }
100 
init_copy(void ** a,void * b)101 static int init_copy(void **a, void *b)
102 {
103 	if (init(a) != CRYPT_OK) {
104 		return CRYPT_MEM;
105 	}
106 	return copy(b, *a);
107 }
108 
109 /* ---- trivial ---- */
set_int(void * a,ltc_mp_digit b)110 static int set_int(void *a, ltc_mp_digit b)
111 {
112 	uint32_t b32 = b;
113 
114 	if (b32 != b)
115 		return CRYPT_INVALID_ARG;
116 
117 	mbedtls_mpi_uint p = b32;
118 	mbedtls_mpi bn = { .s = 1, .n = 1, .p = &p };
119 
120 	if (mbedtls_mpi_copy(a, &bn))
121 		return CRYPT_MEM;
122 	return CRYPT_OK;
123 }
124 
get_int(void * a)125 static unsigned long get_int(void *a)
126 {
127 	mbedtls_mpi *bn = a;
128 
129 	if (!bn->n)
130 		return 0;
131 
132 	return bn->p[bn->n - 1];
133 }
134 
get_digit(void * a,int n)135 static ltc_mp_digit get_digit(void *a, int n)
136 {
137 	mbedtls_mpi *bn = a;
138 
139 	COMPILE_TIME_ASSERT(sizeof(ltc_mp_digit) >= sizeof(mbedtls_mpi_uint));
140 
141 	if (n < 0 || (size_t)n >= bn->n)
142 		return 0;
143 
144 	return bn->p[n];
145 }
146 
get_digit_count(void * a)147 static int get_digit_count(void *a)
148 {
149 	return ROUNDUP_DIV(mbedtls_mpi_size(a), sizeof(mbedtls_mpi_uint));
150 }
151 
compare(void * a,void * b)152 static int compare(void *a, void *b)
153 {
154 	int ret = mbedtls_mpi_cmp_mpi(a, b);
155 
156 	if (ret < 0)
157 		return LTC_MP_LT;
158 
159 	if (ret > 0)
160 		return LTC_MP_GT;
161 
162 	return LTC_MP_EQ;
163 }
164 
compare_d(void * a,ltc_mp_digit b)165 static int compare_d(void *a, ltc_mp_digit b)
166 {
167 	unsigned long v = b;
168 	unsigned int shift = 31;
169 	uint32_t mask = BIT(shift) - 1;
170 	mbedtls_mpi bn;
171 
172 	mbedtls_mpi_init_mempool(&bn);
173 	while (true) {
174 		mbedtls_mpi_add_int(&bn, &bn, v & mask);
175 		v >>= shift;
176 		if (!v)
177 			break;
178 		mbedtls_mpi_shift_l(&bn, shift);
179 	}
180 
181 	int ret = compare(a, &bn);
182 
183 	mbedtls_mpi_free(&bn);
184 
185 	return ret;
186 }
187 
count_bits(void * a)188 static int count_bits(void *a)
189 {
190 	return mbedtls_mpi_bitlen(a);
191 }
192 
count_lsb_bits(void * a)193 static int count_lsb_bits(void *a)
194 {
195 	return mbedtls_mpi_lsb(a);
196 }
197 
198 
twoexpt(void * a,int n)199 static int twoexpt(void *a, int n)
200 {
201 	if (mbedtls_mpi_set_bit(a, n, 1))
202 		return CRYPT_MEM;
203 
204 	return CRYPT_OK;
205 }
206 
207 /* ---- conversions ---- */
208 
209 /* read ascii string */
read_radix(void * a,const char * b,int radix)210 static int read_radix(void *a, const char *b, int radix)
211 {
212 	int res = mbedtls_mpi_read_string(a, radix, b);
213 
214 	if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED)
215 		return CRYPT_MEM;
216 	if (res)
217 		return CRYPT_ERROR;
218 
219 	return CRYPT_OK;
220 }
221 
222 /* write one */
write_radix(void * a,char * b,int radix)223 static int write_radix(void *a, char *b, int radix)
224 {
225 	size_t ol = SIZE_MAX;
226 	int res = mbedtls_mpi_write_string(a, radix, b, ol, &ol);
227 
228 	if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED)
229 		return CRYPT_MEM;
230 	if (res)
231 		return CRYPT_ERROR;
232 
233 	return CRYPT_OK;
234 }
235 
236 /* get size as unsigned char string */
unsigned_size(void * a)237 static unsigned long unsigned_size(void *a)
238 {
239 	return mbedtls_mpi_size(a);
240 }
241 
242 /* store */
unsigned_write(void * a,unsigned char * b)243 static int unsigned_write(void *a, unsigned char *b)
244 {
245 	int res = mbedtls_mpi_write_binary(a, b, unsigned_size(a));
246 
247 	if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED)
248 		return CRYPT_MEM;
249 	if (res)
250 		return CRYPT_ERROR;
251 
252 	return CRYPT_OK;
253 }
254 
255 /* read */
unsigned_read(void * a,unsigned char * b,unsigned long len)256 static int unsigned_read(void *a, unsigned char *b, unsigned long len)
257 {
258 	int res = mbedtls_mpi_read_binary(a, b, len);
259 
260 	if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED)
261 		return CRYPT_MEM;
262 	if (res)
263 		return CRYPT_ERROR;
264 
265 	return CRYPT_OK;
266 }
267 
268 /* add */
add(void * a,void * b,void * c)269 static int add(void *a, void *b, void *c)
270 {
271 	if (mbedtls_mpi_add_mpi(c, a, b))
272 		return CRYPT_MEM;
273 
274 	return CRYPT_OK;
275 }
276 
addi(void * a,ltc_mp_digit b,void * c)277 static int addi(void *a, ltc_mp_digit b, void *c)
278 {
279 	uint32_t b32 = b;
280 
281 	if (b32 != b)
282 		return CRYPT_INVALID_ARG;
283 
284 	mbedtls_mpi_uint p = b32;
285 	mbedtls_mpi bn = { .s = 1, .n = 1, .p = &p };
286 
287 	return add(a, &bn, c);
288 }
289 
290 /* sub */
sub(void * a,void * b,void * c)291 static int sub(void *a, void *b, void *c)
292 {
293 	if (mbedtls_mpi_sub_mpi(c, a, b))
294 		return CRYPT_MEM;
295 
296 	return CRYPT_OK;
297 }
298 
subi(void * a,ltc_mp_digit b,void * c)299 static int subi(void *a, ltc_mp_digit b, void *c)
300 {
301 	uint32_t b32 = b;
302 
303 	if (b32 != b)
304 		return CRYPT_INVALID_ARG;
305 
306 	mbedtls_mpi_uint p = b32;
307 	mbedtls_mpi bn = { .s = 1, .n = 1, .p = &p };
308 
309 	return sub(a, &bn, c);
310 }
311 
312 /* mul */
mul(void * a,void * b,void * c)313 static int mul(void *a, void *b, void *c)
314 {
315 	if (mbedtls_mpi_mul_mpi(c, a, b))
316 		return CRYPT_MEM;
317 
318 	return CRYPT_OK;
319 }
320 
muli(void * a,ltc_mp_digit b,void * c)321 static int muli(void *a, ltc_mp_digit b, void *c)
322 {
323 	if (b > (unsigned long) UINT32_MAX)
324 		return CRYPT_INVALID_ARG;
325 
326 	if (mbedtls_mpi_mul_int(c, a, b))
327 		return CRYPT_MEM;
328 
329 	return CRYPT_OK;
330 }
331 
332 /* sqr */
sqr(void * a,void * b)333 static int sqr(void *a, void *b)
334 {
335 	return mul(a, a, b);
336 }
337 
338 /* div */
divide(void * a,void * b,void * c,void * d)339 static int divide(void *a, void *b, void *c, void *d)
340 {
341 	int res = mbedtls_mpi_div_mpi(c, d, a, b);
342 
343 	if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED)
344 		return CRYPT_MEM;
345 	if (res)
346 		return CRYPT_ERROR;
347 
348 	return CRYPT_OK;
349 }
350 
div_2(void * a,void * b)351 static int div_2(void *a, void *b)
352 {
353 	if (mbedtls_mpi_copy(b, a))
354 		return CRYPT_MEM;
355 
356 	if (mbedtls_mpi_shift_r(b, 1))
357 		return CRYPT_MEM;
358 
359 	return CRYPT_OK;
360 }
361 
362 /* modi */
modi(void * a,ltc_mp_digit b,ltc_mp_digit * c)363 static int modi(void *a, ltc_mp_digit b, ltc_mp_digit *c)
364 {
365 	mbedtls_mpi bn_b;
366 	mbedtls_mpi bn_c;
367 	int res = 0;
368 
369 	mbedtls_mpi_init_mempool(&bn_b);
370 	mbedtls_mpi_init_mempool(&bn_c);
371 
372 	res = set_int(&bn_b, b);
373 	if (res)
374 		return res;
375 
376 	res = mbedtls_mpi_mod_mpi(&bn_c, &bn_b, a);
377 	if (!res)
378 		*c = get_int(&bn_c);
379 
380 	mbedtls_mpi_free(&bn_b);
381 	mbedtls_mpi_free(&bn_c);
382 
383 	if (res)
384 		return CRYPT_MEM;
385 
386 	return CRYPT_OK;
387 }
388 
389 /* gcd */
gcd(void * a,void * b,void * c)390 static int gcd(void *a, void *b, void *c)
391 {
392 	if (mbedtls_mpi_gcd(c, a, b))
393 		return CRYPT_MEM;
394 
395 	return CRYPT_OK;
396 }
397 
398 /* lcm */
lcm(void * a,void * b,void * c)399 static int lcm(void *a, void *b, void *c)
400 {
401 	int res = CRYPT_MEM;
402 	mbedtls_mpi tmp;
403 
404 	mbedtls_mpi_init_mempool(&tmp);
405 	if (mbedtls_mpi_mul_mpi(&tmp, a, b))
406 		goto out;
407 
408 	if (mbedtls_mpi_gcd(c, a, b))
409 		goto out;
410 
411 	/* We use the following equality: gcd(a, b) * lcm(a, b) = a * b */
412 	res = divide(&tmp, c, c, NULL);
413 out:
414 	mbedtls_mpi_free(&tmp);
415 	return res;
416 }
417 
mod(void * a,void * b,void * c)418 static int mod(void *a, void *b, void *c)
419 {
420 	int res = mbedtls_mpi_mod_mpi(c, a, b);
421 
422 	if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED)
423 		return CRYPT_MEM;
424 	if (res)
425 		return CRYPT_ERROR;
426 
427 	return CRYPT_OK;
428 }
429 
addmod(void * a,void * b,void * c,void * d)430 static int addmod(void *a, void *b, void *c, void *d)
431 {
432 	int res = add(a, b, d);
433 
434 	if (res)
435 		return res;
436 
437 	return mod(d, c, d);
438 }
439 
submod(void * a,void * b,void * c,void * d)440 static int submod(void *a, void *b, void *c, void *d)
441 {
442 	int res = sub(a, b, d);
443 
444 	if (res)
445 		return res;
446 
447 	return mod(d, c, d);
448 }
449 
mulmod(void * a,void * b,void * c,void * d)450 static int mulmod(void *a, void *b, void *c, void *d)
451 {
452 	int res;
453 	mbedtls_mpi ta;
454 	mbedtls_mpi tb;
455 
456 	mbedtls_mpi_init_mempool(&ta);
457 	mbedtls_mpi_init_mempool(&tb);
458 
459 	res = mod(a, c, &ta);
460 	if (res)
461 		goto out;
462 	res = mod(b, c, &tb);
463 	if (res)
464 		goto out;
465 	res = mul(&ta, &tb, d);
466 	if (res)
467 		goto out;
468 	res = mod(d, c, d);
469 out:
470 	mbedtls_mpi_free(&ta);
471 	mbedtls_mpi_free(&tb);
472 	return res;
473 }
474 
sqrmod(void * a,void * b,void * c)475 static int sqrmod(void *a, void *b, void *c)
476 {
477 	return mulmod(a, a, b, c);
478 }
479 
480 /* invmod */
invmod(void * a,void * b,void * c)481 static int invmod(void *a, void *b, void *c)
482 {
483 	int res = mbedtls_mpi_inv_mod(c, a, b);
484 
485 	if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED)
486 		return CRYPT_MEM;
487 	if (res)
488 		return CRYPT_ERROR;
489 
490 	return CRYPT_OK;
491 }
492 
493 
494 /* setup */
montgomery_setup(void * a,void ** b)495 static int montgomery_setup(void *a, void **b)
496 {
497 	*b = mempool_alloc(mbedtls_mpi_mempool, sizeof(mbedtls_mpi_uint));
498 	if (!*b)
499 		return CRYPT_MEM;
500 
501 	mbedtls_mpi_montg_init(*b, a);
502 
503 	return CRYPT_OK;
504 }
505 
506 /* get normalization value */
montgomery_normalization(void * a,void * b)507 static int montgomery_normalization(void *a, void *b)
508 {
509 	size_t c = ROUNDUP(mbedtls_mpi_size(b), sizeof(mbedtls_mpi_uint)) * 8;
510 
511 	if (mbedtls_mpi_lset(a, 1))
512 		return CRYPT_MEM;
513 	if (mbedtls_mpi_shift_l(a, c))
514 		return CRYPT_MEM;
515 	if (mbedtls_mpi_mod_mpi(a, a, b))
516 		return CRYPT_MEM;
517 
518 	return CRYPT_OK;
519 }
520 
521 /* reduce */
montgomery_reduce(void * a,void * b,void * c)522 static int montgomery_reduce(void *a, void *b, void *c)
523 {
524 	mbedtls_mpi A;
525 	mbedtls_mpi *N = b;
526 	mbedtls_mpi_uint *mm = c;
527 	mbedtls_mpi T;
528 	int ret = CRYPT_MEM;
529 
530 	mbedtls_mpi_init_mempool(&T);
531 	mbedtls_mpi_init_mempool(&A);
532 
533 	if (mbedtls_mpi_grow(&T, (N->n + 1) * 2))
534 		goto out;
535 
536 	if (mbedtls_mpi_cmp_mpi(a, N) > 0) {
537 		if (mbedtls_mpi_mod_mpi(&A, a, N))
538 			goto out;
539 	} else {
540 		if (mbedtls_mpi_copy(&A, a))
541 			goto out;
542 	}
543 
544 	if (mbedtls_mpi_grow(&A, N->n + 1))
545 		goto out;
546 
547 	mbedtls_mpi_montred(&A, N, *mm, &T);
548 
549 	if (mbedtls_mpi_copy(a, &A))
550 		goto out;
551 
552 	ret = CRYPT_OK;
553 out:
554 	mbedtls_mpi_free(&A);
555 	mbedtls_mpi_free(&T);
556 
557 	return ret;
558 }
559 
560 /* clean up */
montgomery_deinit(void * a)561 static void montgomery_deinit(void *a)
562 {
563 	mempool_free(mbedtls_mpi_mempool, a);
564 }
565 
566 /*
567  * This function calculates:
568  *  d = a^b mod c
569  *
570  * @a: base
571  * @b: exponent
572  * @c: modulus
573  * @d: destination
574  */
exptmod(void * a,void * b,void * c,void * d)575 static int exptmod(void *a, void *b, void *c, void *d)
576 {
577 	int res;
578 
579 	if (d == a || d == b || d == c) {
580 		mbedtls_mpi dest;
581 
582 		mbedtls_mpi_init_mempool(&dest);
583 		res = mbedtls_mpi_exp_mod(&dest, a, b, c, NULL);
584 		if (!res)
585 			res = mbedtls_mpi_copy(d, &dest);
586 		mbedtls_mpi_free(&dest);
587 	} else {
588 		res = mbedtls_mpi_exp_mod(d, a, b, c, NULL);
589 	}
590 
591 	if (res)
592 		return CRYPT_MEM;
593 	else
594 		return CRYPT_OK;
595 }
596 
rng_read(void * ignored __unused,unsigned char * buf,size_t blen)597 static int rng_read(void *ignored __unused, unsigned char *buf, size_t blen)
598 {
599 	if (crypto_rng_read(buf, blen))
600 		return MBEDTLS_ERR_MPI_FILE_IO_ERROR;
601 	return 0;
602 }
603 
isprime(void * a,int b,int * c)604 static int isprime(void *a, int b, int *c)
605 {
606 	int res = mbedtls_mpi_is_prime_ext(a, b, rng_read, NULL);
607 
608 	if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED)
609 		return CRYPT_MEM;
610 
611 	if (res)
612 		*c = LTC_MP_NO;
613 	else
614 		*c = LTC_MP_YES;
615 
616 	return CRYPT_OK;
617 }
618 
mpi_rand(void * a,int size)619 static int mpi_rand(void *a, int size)
620 {
621 	if (mbedtls_mpi_fill_random(a, size, rng_read, NULL))
622 		return CRYPT_MEM;
623 
624 	return CRYPT_OK;
625 }
626 
627 ltc_math_descriptor ltc_mp = {
628 	.name = "MPI",
629 	.bits_per_digit = sizeof(mbedtls_mpi_uint) * 8,
630 
631 	.init = init,
632 	.init_size = init_size,
633 	.init_copy = init_copy,
634 	.deinit = deinit,
635 
636 	.neg = neg,
637 	.copy = copy,
638 
639 	.set_int = set_int,
640 	.get_int = get_int,
641 	.get_digit = get_digit,
642 	.get_digit_count = get_digit_count,
643 	.compare = compare,
644 	.compare_d = compare_d,
645 	.count_bits = count_bits,
646 	.count_lsb_bits = count_lsb_bits,
647 	.twoexpt = twoexpt,
648 
649 	.read_radix = read_radix,
650 	.write_radix = write_radix,
651 	.unsigned_size = unsigned_size,
652 	.unsigned_write = unsigned_write,
653 	.unsigned_read = unsigned_read,
654 
655 	.add = add,
656 	.addi = addi,
657 	.sub = sub,
658 	.subi = subi,
659 	.mul = mul,
660 	.muli = muli,
661 	.sqr = sqr,
662 	.mpdiv = divide,
663 	.div_2 = div_2,
664 	.modi = modi,
665 	.gcd = gcd,
666 	.lcm = lcm,
667 
668 	.mulmod = mulmod,
669 	.sqrmod = sqrmod,
670 	.invmod = invmod,
671 
672 	.montgomery_setup = montgomery_setup,
673 	.montgomery_normalization = montgomery_normalization,
674 	.montgomery_reduce = montgomery_reduce,
675 	.montgomery_deinit = montgomery_deinit,
676 
677 	.exptmod = exptmod,
678 	.isprime = isprime,
679 
680 #ifdef LTC_MECC
681 #ifdef LTC_MECC_FP
682 	.ecc_ptmul = ltc_ecc_fp_mulmod,
683 #else
684 	.ecc_ptmul = ltc_ecc_mulmod,
685 #endif /* LTC_MECC_FP */
686 	.ecc_ptadd = ltc_ecc_projective_add_point,
687 	.ecc_ptdbl = ltc_ecc_projective_dbl_point,
688 	.ecc_map = ltc_ecc_map,
689 #ifdef LTC_ECC_SHAMIR
690 #ifdef LTC_MECC_FP
691 	.ecc_mul2add = ltc_ecc_fp_mul2add,
692 #else
693 	.ecc_mul2add = ltc_ecc_mul2add,
694 #endif /* LTC_MECC_FP */
695 #endif /* LTC_ECC_SHAMIR */
696 #endif /* LTC_MECC */
697 
698 #ifdef LTC_MRSA
699 	.rsa_keygen = rsa_make_key,
700 	.rsa_me = rsa_exptmod,
701 #endif
702 	.addmod = addmod,
703 	.submod = submod,
704 	.rand = mpi_rand,
705 
706 };
707 
crypto_bignum_num_bytes(struct bignum * a)708 size_t crypto_bignum_num_bytes(struct bignum *a)
709 {
710 	return mbedtls_mpi_size((mbedtls_mpi *)a);
711 }
712 
crypto_bignum_num_bits(struct bignum * a)713 size_t crypto_bignum_num_bits(struct bignum *a)
714 {
715 	return mbedtls_mpi_bitlen((mbedtls_mpi *)a);
716 }
717 
crypto_bignum_compare(struct bignum * a,struct bignum * b)718 int32_t crypto_bignum_compare(struct bignum *a, struct bignum *b)
719 {
720 	return mbedtls_mpi_cmp_mpi((mbedtls_mpi *)a, (mbedtls_mpi *)b);
721 }
722 
crypto_bignum_bn2bin(const struct bignum * from,uint8_t * to)723 void crypto_bignum_bn2bin(const struct bignum *from, uint8_t *to)
724 {
725 	const mbedtls_mpi *f = (const mbedtls_mpi *)from;
726 	int rc __maybe_unused = 0;
727 
728 	rc = mbedtls_mpi_write_binary(f, (void *)to, mbedtls_mpi_size(f));
729 	assert(!rc);
730 }
731 
crypto_bignum_bin2bn(const uint8_t * from,size_t fromsize,struct bignum * to)732 TEE_Result crypto_bignum_bin2bn(const uint8_t *from, size_t fromsize,
733 			 struct bignum *to)
734 {
735 	if (mbedtls_mpi_read_binary((mbedtls_mpi *)to, (const void *)from,
736 				    fromsize))
737 		return TEE_ERROR_BAD_PARAMETERS;
738 	return TEE_SUCCESS;
739 }
740 
crypto_bignum_copy(struct bignum * to,const struct bignum * from)741 void crypto_bignum_copy(struct bignum *to, const struct bignum *from)
742 {
743 	int rc __maybe_unused = 0;
744 
745 	rc = mbedtls_mpi_copy((mbedtls_mpi *)to, (const mbedtls_mpi *)from);
746 	assert(!rc);
747 }
748 
crypto_bignum_allocate(size_t size_bits)749 struct bignum *crypto_bignum_allocate(size_t size_bits)
750 {
751 	mbedtls_mpi *bn = malloc(sizeof(*bn));
752 
753 	if (!bn)
754 		return NULL;
755 
756 	mbedtls_mpi_init(bn);
757 	if (mbedtls_mpi_grow(bn, BITS_TO_LIMBS(size_bits))) {
758 		free(bn);
759 		return NULL;
760 	}
761 
762 	return (struct bignum *)bn;
763 }
764 
crypto_bignum_free(struct bignum ** s)765 void crypto_bignum_free(struct bignum **s)
766 {
767 	assert(s);
768 
769 	mbedtls_mpi_free((mbedtls_mpi *)*s);
770 	free(*s);
771 	*s = NULL;
772 }
773 
crypto_bignum_clear(struct bignum * s)774 void crypto_bignum_clear(struct bignum *s)
775 {
776 	mbedtls_mpi *bn = (mbedtls_mpi *)s;
777 
778 	bn->s = 1;
779 	if (bn->p)
780 		memset(bn->p, 0, sizeof(*bn->p) * bn->n);
781 }
782