1 // SPDX-License-Identifier: BSD-2-Clause
2 /*
3  * Copyright 2018-2020 NXP
4  *
5  * RSA Signature Software common implementation.
6  * Functions preparing and/or verifying the signature
7  * encoded string.
8  *
9  * PKCS #1 v2.1: RSA Cryptography Standard
10  * https://www.ietf.org/rfc/rfc3447.txt
11  */
12 #include <crypto/crypto.h>
13 #include <drvcrypt.h>
14 #include <drvcrypt_asn1_oid.h>
15 #include <drvcrypt_math.h>
16 #include <malloc.h>
17 #include <string.h>
18 #include <tee_api_defines_extensions.h>
19 #include <tee/tee_cryp_utl.h>
20 #include <utee_defines.h>
21 #include <util.h>
22 
23 #include "local.h"
24 
25 /*
26  * PKCS#1 V1.5 - Encode the message in Distinguished Encoding Rules
27  * (DER) format.
28  * Refer to EMSA-PKCS1-v1_5 chapter of the PKCS#1 v2.1
29  *
30  * @ssa_data  RSA data to encode
31  * @EM        [out] Encoded Message
32  */
emsa_pkcs1_v1_5_encode(struct drvcrypt_rsa_ssa * ssa_data,struct drvcrypt_buf * EM)33 static TEE_Result emsa_pkcs1_v1_5_encode(struct drvcrypt_rsa_ssa *ssa_data,
34 					 struct drvcrypt_buf *EM)
35 {
36 	const struct drvcrypt_oid *hash_oid = NULL;
37 	size_t ps_size = 0;
38 	uint8_t *buf = NULL;
39 
40 	hash_oid = drvcrypt_get_alg_hash_oid(ssa_data->hash_algo);
41 	if (!hash_oid)
42 		return TEE_ERROR_NOT_SUPPORTED;
43 
44 	/*
45 	 * Calculate the PS size
46 	 *  EM Size (modulus size) - 3 bytes - DigestInfo DER format size
47 	 */
48 	ps_size = ssa_data->key.n_size - 3;
49 	ps_size -= ssa_data->digest_size;
50 	ps_size -= 10 + hash_oid->asn1_length;
51 
52 	CRYPTO_TRACE("PS size = %zu (n %zu)", ps_size, ssa_data->key.n_size);
53 
54 	/*
55 	 * EM = 0x00 || 0x01 || PS || 0x00 || T
56 	 *
57 	 * where T represent the message DigestInfo in DER:
58 	 *    DigestInfo ::= SEQUENCE {
59 	 *                digestAlgorithm AlgorithmIdentifier,
60 	 *                digest OCTET STRING
61 	 *                }
62 	 *
63 	 * T  Length = digest length + oid length
64 	 * EM Length = T Length + 11 + PS Length
65 	 */
66 	buf = EM->data;
67 
68 	/* Set the EM first byte to 0x00 */
69 	*buf++ = 0x00;
70 
71 	/* Set the EM second byte to 0x01 */
72 	*buf++ = 0x01;
73 
74 	/* Fill PS with 0xFF */
75 	memset(buf, UINT8_MAX, ps_size);
76 	buf += ps_size;
77 
78 	/* Set the Byte after PS to 0x00 */
79 	*buf++ = 0x00;
80 
81 	/*
82 	 * Create the DigestInfo DER Sequence
83 	 *
84 	 *  DigestInfo ::= SEQUENCE {
85 	 *                digestAlgorithm AlgorithmIdentifier,
86 	 *                digest OCTET STRING
87 	 *                }
88 	 *
89 	 */
90 	/* SEQUENCE { */
91 	*buf++ = DRVCRYPT_ASN1_SEQUENCE | DRVCRYPT_ASN1_CONSTRUCTED;
92 	*buf++ = 0x08 + hash_oid->asn1_length + ssa_data->digest_size;
93 
94 	/* digestAlgorithm AlgorithmIdentifier */
95 	*buf++ = DRVCRYPT_ASN1_SEQUENCE | DRVCRYPT_ASN1_CONSTRUCTED;
96 	*buf++ = 0x04 + hash_oid->asn1_length;
97 	*buf++ = DRVCRYPT_ASN1_OID;
98 	*buf++ = hash_oid->asn1_length;
99 
100 	/* digest OCTET STRING */
101 	memcpy(buf, hash_oid->asn1, hash_oid->asn1_length);
102 	buf += hash_oid->asn1_length;
103 	*buf++ = DRVCRYPT_ASN1_NULL;
104 	*buf++ = 0x00;
105 	*buf++ = DRVCRYPT_ASN1_OCTET_STRING;
106 	*buf++ = ssa_data->digest_size;
107 	/* } */
108 
109 	memcpy(buf, ssa_data->message.data, ssa_data->digest_size);
110 
111 	CRYPTO_DUMPBUF("Encoded Message", EM->data, (size_t)EM->length);
112 
113 	return TEE_SUCCESS;
114 }
115 
116 /*
117  * PKCS#1 V1.5 - Encode the message in Distinguished Encoding Rules
118  * (DER) format.
119  * Refer to EMSA-PKCS1-v1_5 chapter of the PKCS#1 v2.1
120  *
121  * @ssa_data  RSA data to encode
122  * @EM        [out] Encoded Message
123  */
124 static TEE_Result
emsa_pkcs1_v1_5_encode_noasn1(struct drvcrypt_rsa_ssa * ssa_data,struct drvcrypt_buf * EM)125 emsa_pkcs1_v1_5_encode_noasn1(struct drvcrypt_rsa_ssa *ssa_data,
126 			      struct drvcrypt_buf *EM)
127 {
128 	size_t ps_size = 0;
129 	uint8_t *buf = NULL;
130 
131 	/*
132 	 * Calculate the PS size
133 	 *  EM Size (modulus size) - 3 bytes - Message Length
134 	 */
135 	ps_size = ssa_data->key.n_size - 3;
136 
137 	if (ps_size < ssa_data->message.length)
138 		return TEE_ERROR_BAD_PARAMETERS;
139 
140 	ps_size -= ssa_data->message.length;
141 
142 	CRYPTO_TRACE("PS size = %zu (n %zu)", ps_size, ssa_data->key.n_size);
143 
144 	/*
145 	 * EM = 0x00 || 0x01 || PS || 0x00 || T
146 	 *
147 	 * T  Length = message length
148 	 * EM Length = T Length + PS Length
149 	 */
150 	buf = EM->data;
151 
152 	/* Set the EM first byte to 0x00 */
153 	*buf++ = 0x00;
154 
155 	/* Set the EM second byte to 0x01 */
156 	*buf++ = 0x01;
157 
158 	/* Fill PS with 0xFF */
159 	memset(buf, UINT8_MAX, ps_size);
160 	buf += ps_size;
161 
162 	/* Set the Byte after PS to 0x00 */
163 	*buf++ = 0x00;
164 
165 	memcpy(buf, ssa_data->message.data, ssa_data->message.length);
166 
167 	CRYPTO_DUMPBUF("Encoded Message", EM->data, EM->length);
168 
169 	return TEE_SUCCESS;
170 }
171 
172 /*
173  * PKCS#1 V1.5 - Signature of RSA message and encodes the signature.
174  * Refer to RSASSA-PKCS1-v1_5 chapter of the PKCS#1 v2.1
175  *
176  * @ssa_data   [in/out] RSA data to sign / Signature
177  */
rsassa_pkcs1_v1_5_sign(struct drvcrypt_rsa_ssa * ssa_data)178 static TEE_Result rsassa_pkcs1_v1_5_sign(struct drvcrypt_rsa_ssa *ssa_data)
179 {
180 	TEE_Result ret = TEE_ERROR_BAD_PARAMETERS;
181 	struct drvcrypt_buf EM = { };
182 	struct drvcrypt_rsa_ed rsa_data = { };
183 	struct drvcrypt_rsa *rsa = NULL;
184 
185 	EM.length = ssa_data->key.n_size;
186 	EM.data = malloc(EM.length);
187 	if (!EM.data)
188 		return TEE_ERROR_OUT_OF_MEMORY;
189 
190 	/* Encode the Message */
191 	if (ssa_data->algo != TEE_ALG_RSASSA_PKCS1_V1_5)
192 		ret = emsa_pkcs1_v1_5_encode(ssa_data, &EM);
193 	else
194 		ret = emsa_pkcs1_v1_5_encode_noasn1(ssa_data, &EM);
195 
196 	if (ret != TEE_SUCCESS)
197 		goto out;
198 
199 	/*
200 	 * RSA Encrypt/Decrypt are doing the same operation except
201 	 * that decrypt takes a RSA Private key in parameter
202 	 */
203 	rsa_data.key.key = ssa_data->key.key;
204 	rsa_data.key.isprivate = true;
205 	rsa_data.key.n_size = ssa_data->key.n_size;
206 
207 	rsa = drvcrypt_get_ops(CRYPTO_RSA);
208 	if (!rsa) {
209 		ret = TEE_ERROR_NOT_IMPLEMENTED;
210 		goto out;
211 	}
212 
213 	/* Prepare the decryption data parameters */
214 	rsa_data.rsa_id = DRVCRYPT_RSASSA_PKCS_V1_5;
215 	rsa_data.message.data = ssa_data->signature.data;
216 	rsa_data.message.length = ssa_data->signature.length;
217 	rsa_data.cipher.data = EM.data;
218 	rsa_data.cipher.length = EM.length;
219 	rsa_data.hash_algo = ssa_data->hash_algo;
220 	rsa_data.algo = ssa_data->algo;
221 
222 	ret = rsa->decrypt(&rsa_data);
223 
224 	/* Set the message decrypted size */
225 	ssa_data->signature.length = rsa_data.message.length;
226 
227 out:
228 	free(EM.data);
229 
230 	return ret;
231 }
232 
233 /*
234  * PKCS#1 V1.5 - Verification of the RSA message's signature.
235  * Refer to RSASSA-PKCS1-v1_5 chapter of the PKCS#1 v2.1
236  *
237  * @ssa_data   [int/out] RSA data signed and encoded signature
238  */
rsassa_pkcs1_v1_5_verify(struct drvcrypt_rsa_ssa * ssa_data)239 static TEE_Result rsassa_pkcs1_v1_5_verify(struct drvcrypt_rsa_ssa *ssa_data)
240 {
241 	TEE_Result ret = TEE_ERROR_BAD_PARAMETERS;
242 	struct drvcrypt_buf EM = { };
243 	struct drvcrypt_buf EM_gen = { };
244 	struct drvcrypt_rsa_ed rsa_data = { };
245 	struct drvcrypt_rsa *rsa = NULL;
246 
247 	EM.length = ssa_data->key.n_size;
248 	EM.data = malloc(EM.length);
249 
250 	EM_gen.length = ssa_data->key.n_size;
251 	EM_gen.data = malloc(EM.length);
252 
253 	if (!EM.data || !EM_gen.data) {
254 		ret = TEE_ERROR_OUT_OF_MEMORY;
255 		goto end_verify;
256 	}
257 
258 	/*
259 	 * RSA Encrypt/Decrypt are doing the same operation except
260 	 * that the encrypt takes a RSA Public key in parameter
261 	 */
262 	rsa_data.key.key = ssa_data->key.key;
263 	rsa_data.key.isprivate = false;
264 	rsa_data.key.n_size = ssa_data->key.n_size;
265 
266 	rsa = drvcrypt_get_ops(CRYPTO_RSA);
267 	if (rsa) {
268 		/* Prepare the encryption data parameters */
269 		rsa_data.rsa_id = DRVCRYPT_RSASSA_PKCS_V1_5;
270 		rsa_data.message.data = ssa_data->signature.data;
271 		rsa_data.message.length = ssa_data->signature.length;
272 		rsa_data.cipher.data = EM.data;
273 		rsa_data.cipher.length = EM.length;
274 		rsa_data.hash_algo = ssa_data->hash_algo;
275 		ret = rsa->encrypt(&rsa_data);
276 
277 		/* Set the cipher size */
278 		EM.length = rsa_data.cipher.length;
279 	} else {
280 		ret = TEE_ERROR_NOT_IMPLEMENTED;
281 	}
282 
283 	if (ret != TEE_SUCCESS)
284 		goto end_verify;
285 
286 	/* Encode the Message */
287 	if (ssa_data->algo != TEE_ALG_RSASSA_PKCS1_V1_5)
288 		ret = emsa_pkcs1_v1_5_encode(ssa_data, &EM_gen);
289 	else
290 		ret = emsa_pkcs1_v1_5_encode_noasn1(ssa_data, &EM_gen);
291 
292 	if (ret != TEE_SUCCESS)
293 		goto end_verify;
294 
295 	/* Check if EM decrypted and EM re-generated are identical */
296 	ret = TEE_ERROR_SIGNATURE_INVALID;
297 	if (EM.length == EM_gen.length) {
298 		if (!memcmp(EM.data, EM_gen.data, EM.length))
299 			ret = TEE_SUCCESS;
300 	}
301 
302 end_verify:
303 	free(EM.data);
304 	free(EM_gen.data);
305 
306 	return ret;
307 }
308 
309 /*
310  * PSS - Encode the message using a Probabilistic Signature Scheme (PSS)
311  * Refer to EMSA-PSS (encoding) chapter of the PKCS#1 v2.1
312  *
313  * @ssa_data  RSA data to encode
314  * @emBits    EM size in bits
315  * @EM        [out] Encoded Message
316  */
emsa_pss_encode(struct drvcrypt_rsa_ssa * ssa_data,size_t emBits,struct drvcrypt_buf * EM)317 static TEE_Result emsa_pss_encode(struct drvcrypt_rsa_ssa *ssa_data,
318 				  size_t emBits, struct drvcrypt_buf *EM)
319 {
320 	TEE_Result ret = TEE_ERROR_GENERIC;
321 	struct drvcrypt_rsa_mgf mgf_data = { };
322 	struct drvcrypt_buf hash = { };
323 	struct drvcrypt_buf dbMask = { };
324 	struct drvcrypt_buf DB = { };
325 	size_t db_size = 0;
326 	size_t ps_size = 0;
327 	size_t msg_size = 0;
328 	uint8_t *buf = NULL;
329 	uint8_t *msg_db = NULL;
330 	uint8_t *salt = NULL;
331 	struct drvcrypt_mod_op mod_op = { };
332 
333 	/*
334 	 * Build EM = maskedDB || H || 0xbc
335 	 *
336 	 * where
337 	 *    maskedDB = DB xor dbMask
338 	 *       DB     = PS || 0x01 || salt
339 	 *       dbMask = MGF(H, emLen - hLen - 1)
340 	 *
341 	 *    H  = Hash(M')
342 	 *       M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt
343 	 *
344 	 * PS size = emLen - sLen - hLen - 2 (may be = 0)
345 	 */
346 
347 	/*
348 	 * Calculate the M' and DB size to allocate a temporary buffer
349 	 * used for both object
350 	 */
351 	ps_size = EM->length - ssa_data->digest_size - ssa_data->salt_len - 2;
352 	db_size = EM->length - ssa_data->digest_size - 1;
353 	msg_size = 8 + ssa_data->digest_size + ssa_data->salt_len;
354 
355 	CRYPTO_TRACE("PS Len = %zu, DB Len = %zu, M' Len = %zu", ps_size,
356 		     db_size, msg_size);
357 
358 	msg_db = malloc(MAX(db_size, msg_size));
359 	if (!msg_db)
360 		return TEE_ERROR_OUT_OF_MEMORY;
361 
362 	if (ssa_data->salt_len) {
363 		salt = malloc(ssa_data->salt_len);
364 
365 		if (!salt) {
366 			ret = TEE_ERROR_OUT_OF_MEMORY;
367 			goto end_pss_encode;
368 		}
369 	}
370 
371 	/*
372 	 * Step 4 and 5
373 	 * Generate the M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt
374 	 *
375 	 * where
376 	 *   mHash is the input message (already hash)
377 	 *   salt is a random number of salt_len (input data) can be empty
378 	 */
379 	buf = msg_db;
380 
381 	memset(buf, 0, 8);
382 	buf += 8;
383 
384 	memcpy(buf, ssa_data->message.data, ssa_data->message.length);
385 	buf += ssa_data->message.length;
386 
387 	/* Get salt random number if salt length not 0 */
388 	if (ssa_data->salt_len) {
389 		ret = crypto_rng_read(salt, ssa_data->salt_len);
390 		CRYPTO_TRACE("Get salt of %zu bytes (ret = 0x%08" PRIx32 ")",
391 			     ssa_data->salt_len, ret);
392 		if (ret != TEE_SUCCESS)
393 			goto end_pss_encode;
394 
395 		memcpy(buf, salt, ssa_data->salt_len);
396 	}
397 
398 	/*
399 	 * Step 6
400 	 * Hash the M' generated new message
401 	 * H = hash(M')
402 	 */
403 	hash.data = &EM->data[db_size];
404 	hash.length = ssa_data->digest_size;
405 
406 	ret = tee_hash_createdigest(ssa_data->hash_algo, msg_db, msg_size,
407 				    hash.data, hash.length);
408 
409 	CRYPTO_TRACE("H = hash(M') returned 0x%08" PRIx32, ret);
410 	if (ret != TEE_SUCCESS)
411 		goto end_pss_encode;
412 
413 	CRYPTO_DUMPBUF("H = hash(M')", hash.data, hash.length);
414 
415 	/*
416 	 * Step 7 and 8
417 	 *   DB = PS || 0x01 || salt
418 	 */
419 	buf = msg_db;
420 	if (ps_size)
421 		memset(buf, 0, ps_size);
422 	buf += ps_size;
423 	*buf++ = 0x01;
424 
425 	if (ssa_data->salt_len)
426 		memcpy(buf, salt, ssa_data->salt_len);
427 
428 	DB.data = msg_db;
429 	DB.length = db_size;
430 
431 	CRYPTO_DUMPBUF("DB", DB.data, DB.length);
432 
433 	/*
434 	 * Step 9
435 	 * Generate a Mask of the seed value
436 	 * dbMask = MGF(H, emLen - hLen - 1)
437 	 *
438 	 * Note: Will use the same buffer for the dbMask and maskedDB
439 	 *       maskedDB is in the EM output
440 	 */
441 	dbMask.data = EM->data;
442 	dbMask.length = db_size;
443 
444 	mgf_data.hash_algo = ssa_data->hash_algo;
445 	mgf_data.digest_size = ssa_data->digest_size;
446 	mgf_data.seed.data = hash.data;
447 	mgf_data.seed.length = hash.length;
448 	mgf_data.mask.data = dbMask.data;
449 	mgf_data.mask.length = dbMask.length;
450 	ret = ssa_data->mgf(&mgf_data);
451 
452 	CRYPTO_TRACE("dbMask = MGF(H, emLen - hLen - 1) returned 0x%08" PRIx32,
453 		     ret);
454 	if (ret != TEE_SUCCESS)
455 		goto end_pss_encode;
456 
457 	CRYPTO_DUMPBUF("dbMask", dbMask.data, dbMask.length);
458 
459 	/*
460 	 * Step 10
461 	 * maskedDB = DB xor dbMask
462 	 */
463 	mod_op.n.length = dbMask.length;
464 	mod_op.a.data = DB.data;
465 	mod_op.a.length = DB.length;
466 	mod_op.b.data = dbMask.data;
467 	mod_op.b.length = dbMask.length;
468 	mod_op.result.data = dbMask.data;
469 	mod_op.result.length = dbMask.length;
470 
471 	ret = drvcrypt_xor_mod_n(&mod_op);
472 	CRYPTO_TRACE("maskedDB = DB xor dbMask returned 0x%08" PRIx32, ret);
473 	if (ret != TEE_SUCCESS)
474 		goto end_pss_encode;
475 
476 	CRYPTO_DUMPBUF("maskedDB", dbMask.data, dbMask.length);
477 
478 	/*
479 	 * Step 11
480 	 * Set the leftmost 8emLen - emBits of the leftmost octet
481 	 * in maskedDB to 0'
482 	 */
483 	EM->data[0] &= (UINT8_MAX >> ((EM->length * 8) - emBits));
484 
485 	/*
486 	 * Step 12
487 	 * EM = maskedDB || H || 0xbc
488 	 */
489 	EM->data[EM->length - 1] = 0xbc;
490 
491 	CRYPTO_DUMPBUF("EM", EM->data, EM->length);
492 
493 	ret = TEE_SUCCESS;
494 end_pss_encode:
495 	free(msg_db);
496 	free(salt);
497 
498 	return ret;
499 }
500 
501 /*
502  * PSS - Verify the message using a Probabilistic Signature Scheme (PSS)
503  * Refer to EMSA-PSS (verification) chapter of the PKCS#1 v2.1
504  *
505  * @ssa_data  RSA data to encode
506  * @emBits    EM size in bits
507  * @EM        [out] Encoded Message
508  */
emsa_pss_verify(struct drvcrypt_rsa_ssa * ssa_data,size_t emBits,struct drvcrypt_buf * EM)509 static TEE_Result emsa_pss_verify(struct drvcrypt_rsa_ssa *ssa_data,
510 				  size_t emBits, struct drvcrypt_buf *EM)
511 {
512 	TEE_Result ret = TEE_ERROR_GENERIC;
513 	struct drvcrypt_rsa_mgf mgf_data = { };
514 	struct drvcrypt_buf hash = { };
515 	struct drvcrypt_buf hash_gen = { };
516 	size_t db_size = 0;
517 	size_t ps_size = 0;
518 	size_t msg_size = 0;
519 	uint8_t *msg_db = NULL;
520 	uint8_t *salt = NULL;
521 	uint8_t *buf = NULL;
522 	struct drvcrypt_mod_op mod_op = { };
523 
524 	/*
525 	 * EM = maskedDB || H || 0xbc
526 	 *
527 	 * where
528 	 *    maskedDB = DB xor dbMask
529 	 *       DB     = PS || 0x01 || salt
530 	 *       dbMask = MGF(H, emLen - hLen - 1)
531 	 *
532 	 *    H  = Hash(M')
533 	 *       M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt
534 	 *
535 	 * PS size = emLen - sLen - hLen - 2 (may be = 0)
536 	 */
537 
538 	/*
539 	 * Calculate the M' and DB size to allocate a temporary buffer
540 	 * used for both object
541 	 */
542 	ps_size = EM->length - ssa_data->digest_size - ssa_data->salt_len - 2;
543 	db_size = EM->length - ssa_data->digest_size - 1;
544 	msg_size = 8 + ssa_data->digest_size + ssa_data->salt_len;
545 
546 	CRYPTO_TRACE("PS Len = %zu, DB Len = %zu, M' Len = %zu", ps_size,
547 		     db_size, msg_size);
548 
549 	msg_db = malloc(MAX(db_size, msg_size));
550 	if (!msg_db)
551 		return TEE_ERROR_OUT_OF_MEMORY;
552 
553 	/*
554 	 * Step 4
555 	 * Check if rightmost octet of EM is 0xbc
556 	 */
557 	if (EM->data[EM->length - 1] != 0xbc) {
558 		CRYPTO_TRACE("rigthmost octet != 0xbc (0x%" PRIX8 ")",
559 			     EM->data[EM->length - 1]);
560 		ret = TEE_ERROR_SIGNATURE_INVALID;
561 		goto end_pss_verify;
562 	}
563 
564 	/*
565 	 * Step 6
566 	 * Check if the leftmost 8emLen - emBits of the leftmost octet
567 	 * in maskedDB are 0's
568 	 */
569 	if (EM->data[0] & ~(UINT8_MAX >> (EM->length * 8 - emBits))) {
570 		CRYPTO_TRACE("Error leftmost octet maskedDB not 0's");
571 		CRYPTO_TRACE("EM[0] = 0x%" PRIX8
572 			     " - EM Len = %zu, emBits = %zu",
573 			     EM->data[0], EM->length, emBits);
574 		ret = TEE_ERROR_SIGNATURE_INVALID;
575 		goto end_pss_verify;
576 	}
577 
578 	hash.data = &EM->data[db_size];
579 	hash.length = ssa_data->digest_size;
580 
581 	/*
582 	 * Step 7
583 	 * dbMask = MGF(H, emLen - hLen - 1)
584 	 *
585 	 * Note: Will use the same buffer for the dbMask and DB
586 	 */
587 	mgf_data.hash_algo = ssa_data->hash_algo;
588 	mgf_data.digest_size = ssa_data->digest_size;
589 	mgf_data.seed.data = hash.data;
590 	mgf_data.seed.length = hash.length;
591 	mgf_data.mask.data = msg_db;
592 	mgf_data.mask.length = db_size;
593 	ret = ssa_data->mgf(&mgf_data);
594 
595 	CRYPTO_TRACE("dbMask = MGF(H, emLen - hLen - 1) returned 0x%08" PRIx32,
596 		     ret);
597 	if (ret != TEE_SUCCESS)
598 		goto end_pss_verify;
599 
600 	CRYPTO_DUMPBUF("dbMask", msg_db, db_size);
601 
602 	/*
603 	 * Step 8
604 	 * DB = maskedDB xor dbMask
605 	 *
606 	 *
607 	 * Note: maskedDB is in the EM input
608 	 */
609 	mod_op.n.length = db_size;
610 	mod_op.a.data = EM->data;
611 	mod_op.a.length = db_size;
612 	mod_op.b.data = msg_db;
613 	mod_op.b.length = db_size;
614 	mod_op.result.data = msg_db;
615 	mod_op.result.length = db_size;
616 
617 	ret = drvcrypt_xor_mod_n(&mod_op);
618 	CRYPTO_TRACE("DB = maskedDB xor dbMask returned 0x%08" PRIx32, ret);
619 	if (ret != TEE_SUCCESS)
620 		goto end_pss_verify;
621 
622 	/*
623 	 * Step 9
624 	 * Set the leftmost 8emLen - emBits of the leftmost octet in
625 	 * DB to zero
626 	 */
627 	*msg_db &= UINT8_MAX >> (EM->length * 8 - emBits);
628 
629 	CRYPTO_DUMPBUF("DB", msg_db, db_size);
630 
631 	/*
632 	 * Step 10
633 	 * Expected to have
634 	 *       DB     = PS || 0x01 || salt
635 	 *
636 	 * PS must be 0
637 	 * PS size = emLen - sLen - hLen - 2 (may be = 0)
638 	 */
639 	buf = msg_db;
640 	while (buf < msg_db + ps_size) {
641 		if (*buf++ != 0) {
642 			ret = TEE_ERROR_SIGNATURE_INVALID;
643 			goto end_pss_verify;
644 		}
645 	}
646 
647 	if (*buf++ != 0x01) {
648 		ret = TEE_ERROR_SIGNATURE_INVALID;
649 		goto end_pss_verify;
650 	}
651 
652 	/*
653 	 * Step 11
654 	 * Get the salt value
655 	 */
656 	if (ssa_data->salt_len) {
657 		salt = malloc(ssa_data->salt_len);
658 		if (!salt) {
659 			ret = TEE_ERROR_OUT_OF_MEMORY;
660 			goto end_pss_verify;
661 		}
662 
663 		memcpy(salt, buf, ssa_data->salt_len);
664 	}
665 
666 	/*
667 	 * Step 12
668 	 * Generate the M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt
669 	 *
670 	 * where
671 	 *   mHash is the input message (already hash)
672 	 *   salt is a random number of salt_len (input data) can be empty
673 	 */
674 	buf = msg_db;
675 
676 	memset(buf, 0, 8);
677 	buf += 8;
678 
679 	memcpy(buf, ssa_data->message.data, ssa_data->message.length);
680 	buf += ssa_data->message.length;
681 
682 	if (ssa_data->salt_len)
683 		memcpy(buf, salt, ssa_data->salt_len);
684 
685 	/*
686 	 * Step 13
687 	 * Hash the M' generated new message
688 	 * H' = hash(M')
689 	 *
690 	 * Note: reuse the msg_db buffer as Hash result
691 	 */
692 	hash_gen.data = msg_db;
693 	hash_gen.length = ssa_data->digest_size;
694 
695 	ret = tee_hash_createdigest(ssa_data->hash_algo, msg_db, msg_size,
696 				    hash_gen.data, hash_gen.length);
697 
698 	CRYPTO_TRACE("H' = hash(M') returned 0x%08" PRIx32, ret);
699 	if (ret != TEE_SUCCESS)
700 		goto end_pss_verify;
701 
702 	CRYPTO_DUMPBUF("H' = hash(M')", hash_gen.data, hash_gen.length);
703 
704 	/*
705 	 * Step 14
706 	 * Compare H and H'
707 	 */
708 	ret = TEE_ERROR_SIGNATURE_INVALID;
709 	if (!memcmp(hash_gen.data, hash.data, hash_gen.length))
710 		ret = TEE_SUCCESS;
711 
712 end_pss_verify:
713 	free(msg_db);
714 	free(salt);
715 
716 	return ret;
717 }
718 
719 /*
720  * PSS - Signature of RSA message and encodes the signature.
721  * Refer to RSASSA-PSS chapter of the PKCS#1 v2.1
722  *
723  * @ssa_data   [in/out] RSA data to sign / Signature
724  */
rsassa_pss_sign(struct drvcrypt_rsa_ssa * ssa_data)725 static TEE_Result rsassa_pss_sign(struct drvcrypt_rsa_ssa *ssa_data)
726 {
727 	TEE_Result ret = TEE_ERROR_GENERIC;
728 	struct rsa_keypair *key = NULL;
729 	struct drvcrypt_buf EM = { };
730 	size_t modBits = 0;
731 	struct drvcrypt_rsa_ed rsa_data = { };
732 	struct drvcrypt_rsa *rsa = NULL;
733 
734 	key = ssa_data->key.key;
735 
736 	/* Get modulus length in bits */
737 	modBits = crypto_bignum_num_bits(key->n);
738 	if (modBits <= 0)
739 		return TEE_ERROR_BAD_PARAMETERS;
740 
741 	/*
742 	 * EM Length = (modBits - 1) / 8
743 	 * if (modBits - 1) is not divisible by 8, one more byte is needed
744 	 */
745 	modBits--;
746 	EM.length = ROUNDUP(modBits, 8) / 8;
747 
748 	if (EM.length < ssa_data->digest_size + ssa_data->salt_len + 2)
749 		return TEE_ERROR_BAD_PARAMETERS;
750 
751 	EM.data = malloc(EM.length);
752 	if (!EM.data)
753 		return TEE_ERROR_OUT_OF_MEMORY;
754 
755 	CRYPTO_TRACE("modBits = %zu, hence EM Length = %zu", modBits + 1,
756 		     EM.length);
757 
758 	/* Encode the Message */
759 	ret = emsa_pss_encode(ssa_data, modBits, &EM);
760 	CRYPTO_TRACE("EMSA PSS Encode returned 0x%08" PRIx32, ret);
761 
762 	/*
763 	 * RSA Encrypt/Decrypt are doing the same operation
764 	 * expect that the decrypt takes a RSA Private key in parameter
765 	 */
766 	if (ret == TEE_SUCCESS) {
767 		rsa_data.key.key = ssa_data->key.key;
768 		rsa_data.key.isprivate = true;
769 		rsa_data.key.n_size = ssa_data->key.n_size;
770 
771 		rsa = drvcrypt_get_ops(CRYPTO_RSA);
772 		if (rsa) {
773 			/* Prepare the decryption data parameters */
774 			rsa_data.rsa_id = DRVCRYPT_RSASSA_PSS;
775 			rsa_data.message.data = ssa_data->signature.data;
776 			rsa_data.message.length = ssa_data->signature.length;
777 			rsa_data.cipher.data = EM.data;
778 			rsa_data.cipher.length = EM.length;
779 			rsa_data.algo = ssa_data->algo;
780 
781 			ret = rsa->decrypt(&rsa_data);
782 
783 			/* Set the message decrypted size */
784 			ssa_data->signature.length = rsa_data.message.length;
785 		} else {
786 			ret = TEE_ERROR_NOT_IMPLEMENTED;
787 		}
788 	}
789 	free(EM.data);
790 
791 	return ret;
792 }
793 
794 /*
795  * PSS - Signature verification of RSA message.
796  * Refer to RSASSA-PSS chapter of the PKCS#1 v2.1
797  *
798  * @ssa_data   [in/out] RSA Signature vs. message to verify
799  */
rsassa_pss_verify(struct drvcrypt_rsa_ssa * ssa_data)800 static TEE_Result rsassa_pss_verify(struct drvcrypt_rsa_ssa *ssa_data)
801 {
802 	TEE_Result ret = TEE_ERROR_GENERIC;
803 	struct rsa_public_key *key = NULL;
804 	struct drvcrypt_buf EM = { };
805 	size_t modBits = 0;
806 	struct drvcrypt_rsa_ed rsa_data = { };
807 	struct drvcrypt_rsa *rsa = NULL;
808 
809 	key = ssa_data->key.key;
810 
811 	/* Get modulus length in bits */
812 	modBits = crypto_bignum_num_bits(key->n);
813 	if (modBits <= 0)
814 		return TEE_ERROR_BAD_PARAMETERS;
815 
816 	/*
817 	 * EM Length = (modBits - 1) / 8
818 	 * if (modBits - 1) is not divisible by 8, one more byte is needed
819 	 */
820 	modBits--;
821 	EM.length = ROUNDUP(modBits, 8) / 8;
822 
823 	if (EM.length < ssa_data->digest_size + ssa_data->salt_len + 2)
824 		return TEE_ERROR_BAD_PARAMETERS;
825 
826 	EM.data = malloc(EM.length);
827 	if (!EM.data)
828 		return TEE_ERROR_OUT_OF_MEMORY;
829 
830 	CRYPTO_TRACE("modBits = %zu, hence EM Length = %zu", modBits + 1,
831 		     EM.length);
832 
833 	/*
834 	 * RSA Encrypt/Decrypt are doing the same operation
835 	 * expect that the encrypt takes a RSA Public key in parameter
836 	 */
837 	rsa_data.key.key = ssa_data->key.key;
838 	rsa_data.key.isprivate = false;
839 	rsa_data.key.n_size = ssa_data->key.n_size;
840 
841 	rsa = drvcrypt_get_ops(CRYPTO_RSA);
842 	if (rsa) {
843 		/* Prepare the encryption data parameters */
844 		rsa_data.rsa_id = DRVCRYPT_RSASSA_PSS;
845 		rsa_data.message.data = ssa_data->signature.data;
846 		rsa_data.message.length = ssa_data->signature.length;
847 		rsa_data.cipher.data = EM.data;
848 		rsa_data.cipher.length = EM.length;
849 		rsa_data.algo = ssa_data->algo;
850 
851 		ret = rsa->encrypt(&rsa_data);
852 
853 		/* Set the cipher size */
854 		EM.length = rsa_data.cipher.length;
855 	} else {
856 		ret = TEE_ERROR_NOT_IMPLEMENTED;
857 		goto end_pss_verify;
858 	}
859 
860 	if (ret == TEE_SUCCESS) {
861 		/* Verify the Message */
862 		ret = emsa_pss_verify(ssa_data, modBits, &EM);
863 		CRYPTO_TRACE("EMSA PSS Verify returned 0x%08" PRIx32, ret);
864 	} else {
865 		CRYPTO_TRACE("RSA NO PAD returned 0x%08" PRIx32, ret);
866 		ret = TEE_ERROR_SIGNATURE_INVALID;
867 	}
868 
869 end_pss_verify:
870 	free(EM.data);
871 
872 	return ret;
873 }
874 
drvcrypt_rsassa_sign(struct drvcrypt_rsa_ssa * ssa_data)875 TEE_Result drvcrypt_rsassa_sign(struct drvcrypt_rsa_ssa *ssa_data)
876 {
877 	switch (ssa_data->algo) {
878 	case TEE_ALG_RSASSA_PKCS1_V1_5:
879 	case TEE_ALG_RSASSA_PKCS1_V1_5_MD5:
880 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA1:
881 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA224:
882 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA256:
883 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA384:
884 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA512:
885 		return rsassa_pkcs1_v1_5_sign(ssa_data);
886 
887 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA1:
888 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA224:
889 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA256:
890 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA384:
891 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA512:
892 		return rsassa_pss_sign(ssa_data);
893 
894 	default:
895 		break;
896 	}
897 
898 	return TEE_ERROR_BAD_PARAMETERS;
899 }
900 
drvcrypt_rsassa_verify(struct drvcrypt_rsa_ssa * ssa_data)901 TEE_Result drvcrypt_rsassa_verify(struct drvcrypt_rsa_ssa *ssa_data)
902 {
903 	switch (ssa_data->algo) {
904 	case TEE_ALG_RSASSA_PKCS1_V1_5:
905 	case TEE_ALG_RSASSA_PKCS1_V1_5_MD5:
906 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA1:
907 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA224:
908 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA256:
909 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA384:
910 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA512:
911 		return rsassa_pkcs1_v1_5_verify(ssa_data);
912 
913 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA1:
914 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA224:
915 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA256:
916 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA384:
917 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA512:
918 		return rsassa_pss_verify(ssa_data);
919 
920 	default:
921 		break;
922 	}
923 
924 	return TEE_ERROR_BAD_PARAMETERS;
925 }
926