1 // SPDX-License-Identifier: BSD-2-Clause
2 /*
3  * Copyright (C) 2018, ARM Limited
4  * Copyright (C) 2019, Linaro Limited
5  */
6 
7 #include <assert.h>
8 #include <crypto/crypto.h>
9 #include <mbedtls/ctr_drbg.h>
10 #include <mbedtls/entropy.h>
11 #include <mbedtls/pk.h>
12 #include <mbedtls/pk_internal.h>
13 #include <stdlib.h>
14 #include <string.h>
15 #include <tee/tee_cryp_utl.h>
16 #include <utee_defines.h>
17 
18 #include "mbed_helpers.h"
19 
get_tee_result(int lmd_res)20 static TEE_Result get_tee_result(int lmd_res)
21 {
22 	switch (lmd_res) {
23 	case 0:
24 		return TEE_SUCCESS;
25 	case MBEDTLS_ERR_RSA_PRIVATE_FAILED +
26 		MBEDTLS_ERR_MPI_BAD_INPUT_DATA:
27 	case MBEDTLS_ERR_RSA_BAD_INPUT_DATA:
28 	case MBEDTLS_ERR_RSA_INVALID_PADDING:
29 	case MBEDTLS_ERR_PK_TYPE_MISMATCH:
30 		return TEE_ERROR_BAD_PARAMETERS;
31 	case MBEDTLS_ERR_RSA_OUTPUT_TOO_LARGE:
32 		return TEE_ERROR_SHORT_BUFFER;
33 	default:
34 		return TEE_ERROR_BAD_STATE;
35 	}
36 }
37 
tee_algo_to_mbedtls_hash_algo(uint32_t algo)38 static uint32_t tee_algo_to_mbedtls_hash_algo(uint32_t algo)
39 {
40 	switch (algo) {
41 #if defined(CFG_CRYPTO_SHA1)
42 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA1:
43 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA1:
44 	case TEE_ALG_RSAES_PKCS1_OAEP_MGF1_SHA1:
45 	case TEE_ALG_SHA1:
46 	case TEE_ALG_DSA_SHA1:
47 	case TEE_ALG_HMAC_SHA1:
48 		return MBEDTLS_MD_SHA1;
49 #endif
50 #if defined(CFG_CRYPTO_MD5)
51 	case TEE_ALG_RSASSA_PKCS1_V1_5_MD5:
52 	case TEE_ALG_MD5:
53 	case TEE_ALG_HMAC_MD5:
54 		return MBEDTLS_MD_MD5;
55 #endif
56 #if defined(CFG_CRYPTO_SHA224)
57 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA224:
58 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA224:
59 	case TEE_ALG_RSAES_PKCS1_OAEP_MGF1_SHA224:
60 	case TEE_ALG_SHA224:
61 	case TEE_ALG_DSA_SHA224:
62 	case TEE_ALG_HMAC_SHA224:
63 		return MBEDTLS_MD_SHA224;
64 #endif
65 #if defined(CFG_CRYPTO_SHA256)
66 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA256:
67 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA256:
68 	case TEE_ALG_RSAES_PKCS1_OAEP_MGF1_SHA256:
69 	case TEE_ALG_SHA256:
70 	case TEE_ALG_DSA_SHA256:
71 	case TEE_ALG_HMAC_SHA256:
72 		return MBEDTLS_MD_SHA256;
73 #endif
74 #if defined(CFG_CRYPTO_SHA384)
75 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA384:
76 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA384:
77 	case TEE_ALG_RSAES_PKCS1_OAEP_MGF1_SHA384:
78 	case TEE_ALG_SHA384:
79 	case TEE_ALG_HMAC_SHA384:
80 		return MBEDTLS_MD_SHA384;
81 #endif
82 #if defined(CFG_CRYPTO_SHA512)
83 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA512:
84 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA512:
85 	case TEE_ALG_RSAES_PKCS1_OAEP_MGF1_SHA512:
86 	case TEE_ALG_SHA512:
87 	case TEE_ALG_HMAC_SHA512:
88 		return MBEDTLS_MD_SHA512;
89 #endif
90 	default:
91 		return MBEDTLS_MD_NONE;
92 	}
93 }
94 
rsa_init_from_key_pair(mbedtls_rsa_context * rsa,struct rsa_keypair * key)95 static void rsa_init_from_key_pair(mbedtls_rsa_context *rsa,
96 				struct rsa_keypair *key)
97 {
98 	mbedtls_rsa_init(rsa, 0, 0);
99 
100 	rsa->E = *(mbedtls_mpi *)key->e;
101 	rsa->N = *(mbedtls_mpi *)key->n;
102 	rsa->D = *(mbedtls_mpi *)key->d;
103 	if (key->p && crypto_bignum_num_bytes(key->p)) {
104 		rsa->P = *(mbedtls_mpi *)key->p;
105 		rsa->Q = *(mbedtls_mpi *)key->q;
106 		rsa->QP = *(mbedtls_mpi *)key->qp;
107 		rsa->DP = *(mbedtls_mpi *)key->dp;
108 		rsa->DQ = *(mbedtls_mpi *)key->dq;
109 	}
110 	rsa->len = mbedtls_mpi_size(&rsa->N);
111 }
112 
mbd_rsa_free(mbedtls_rsa_context * rsa)113 static void mbd_rsa_free(mbedtls_rsa_context *rsa)
114 {
115 	/* Reset mpi to skip freeing here, those mpis will be freed with key */
116 	mbedtls_mpi_init(&rsa->E);
117 	mbedtls_mpi_init(&rsa->N);
118 	mbedtls_mpi_init(&rsa->D);
119 	if (mbedtls_mpi_size(&rsa->P)) {
120 		mbedtls_mpi_init(&rsa->P);
121 		mbedtls_mpi_init(&rsa->Q);
122 		mbedtls_mpi_init(&rsa->QP);
123 		mbedtls_mpi_init(&rsa->DP);
124 		mbedtls_mpi_init(&rsa->DQ);
125 	}
126 	mbedtls_rsa_free(rsa);
127 }
128 
crypto_acipher_alloc_rsa_keypair(struct rsa_keypair * s,size_t key_size_bits)129 TEE_Result crypto_acipher_alloc_rsa_keypair(struct rsa_keypair *s,
130 					    size_t key_size_bits)
131 {
132 	memset(s, 0, sizeof(*s));
133 	s->e = crypto_bignum_allocate(key_size_bits);
134 	if (!s->e)
135 		goto err;
136 	s->d = crypto_bignum_allocate(key_size_bits);
137 	if (!s->d)
138 		goto err;
139 	s->n = crypto_bignum_allocate(key_size_bits);
140 	if (!s->n)
141 		goto err;
142 	s->p = crypto_bignum_allocate(key_size_bits);
143 	if (!s->p)
144 		goto err;
145 	s->q = crypto_bignum_allocate(key_size_bits);
146 	if (!s->q)
147 		goto err;
148 	s->qp = crypto_bignum_allocate(key_size_bits);
149 	if (!s->qp)
150 		goto err;
151 	s->dp = crypto_bignum_allocate(key_size_bits);
152 	if (!s->dp)
153 		goto err;
154 	s->dq = crypto_bignum_allocate(key_size_bits);
155 	if (!s->dq)
156 		goto err;
157 
158 	return TEE_SUCCESS;
159 err:
160 	crypto_acipher_free_rsa_keypair(s);
161 	return TEE_ERROR_OUT_OF_MEMORY;
162 }
163 
crypto_acipher_alloc_rsa_public_key(struct rsa_public_key * s,size_t key_size_bits)164 TEE_Result crypto_acipher_alloc_rsa_public_key(struct rsa_public_key *s,
165 					       size_t key_size_bits)
166 {
167 	memset(s, 0, sizeof(*s));
168 	s->e = crypto_bignum_allocate(key_size_bits);
169 	if (!s->e)
170 		return TEE_ERROR_OUT_OF_MEMORY;
171 	s->n = crypto_bignum_allocate(key_size_bits);
172 	if (!s->n)
173 		goto err;
174 	return TEE_SUCCESS;
175 err:
176 	crypto_bignum_free(s->e);
177 	return TEE_ERROR_OUT_OF_MEMORY;
178 }
179 
crypto_acipher_free_rsa_public_key(struct rsa_public_key * s)180 void crypto_acipher_free_rsa_public_key(struct rsa_public_key *s)
181 {
182 	if (!s)
183 		return;
184 	crypto_bignum_free(s->n);
185 	crypto_bignum_free(s->e);
186 }
187 
crypto_acipher_free_rsa_keypair(struct rsa_keypair * s)188 void crypto_acipher_free_rsa_keypair(struct rsa_keypair *s)
189 {
190 	if (!s)
191 		return;
192 	crypto_bignum_free(s->e);
193 	crypto_bignum_free(s->d);
194 	crypto_bignum_free(s->n);
195 	crypto_bignum_free(s->p);
196 	crypto_bignum_free(s->q);
197 	crypto_bignum_free(s->qp);
198 	crypto_bignum_free(s->dp);
199 	crypto_bignum_free(s->dq);
200 }
201 
crypto_acipher_gen_rsa_key(struct rsa_keypair * key,size_t key_size)202 TEE_Result crypto_acipher_gen_rsa_key(struct rsa_keypair *key, size_t key_size)
203 {
204 	TEE_Result res = TEE_SUCCESS;
205 	mbedtls_rsa_context rsa;
206 	int lmd_res = 0;
207 	uint32_t e = 0;
208 
209 	memset(&rsa, 0, sizeof(rsa));
210 	mbedtls_rsa_init(&rsa, 0, 0);
211 
212 	/* get the public exponent */
213 	mbedtls_mpi_write_binary((mbedtls_mpi *)key->e,
214 				 (unsigned char *)&e, sizeof(uint32_t));
215 
216 	e = TEE_U32_FROM_BIG_ENDIAN(e);
217 	lmd_res = mbedtls_rsa_gen_key(&rsa, mbd_rand, NULL, key_size, (int)e);
218 	if (lmd_res != 0) {
219 		res = get_tee_result(lmd_res);
220 	} else if ((size_t)mbedtls_mpi_bitlen(&rsa.N) != key_size) {
221 		res = TEE_ERROR_BAD_PARAMETERS;
222 	} else {
223 		/* Copy the key */
224 		crypto_bignum_copy(key->e, (void *)&rsa.E);
225 		crypto_bignum_copy(key->d, (void *)&rsa.D);
226 		crypto_bignum_copy(key->n, (void *)&rsa.N);
227 		crypto_bignum_copy(key->p, (void *)&rsa.P);
228 
229 		crypto_bignum_copy(key->q, (void *)&rsa.Q);
230 		crypto_bignum_copy(key->qp, (void *)&rsa.QP);
231 		crypto_bignum_copy(key->dp, (void *)&rsa.DP);
232 		crypto_bignum_copy(key->dq, (void *)&rsa.DQ);
233 
234 		res = TEE_SUCCESS;
235 	}
236 
237 	mbedtls_rsa_free(&rsa);
238 
239 	return res;
240 }
241 
crypto_acipher_rsanopad_encrypt(struct rsa_public_key * key,const uint8_t * src,size_t src_len,uint8_t * dst,size_t * dst_len)242 TEE_Result crypto_acipher_rsanopad_encrypt(struct rsa_public_key *key,
243 					   const uint8_t *src, size_t src_len,
244 					   uint8_t *dst, size_t *dst_len)
245 {
246 	TEE_Result res = TEE_SUCCESS;
247 	mbedtls_rsa_context rsa;
248 	int lmd_res = 0;
249 	uint8_t *buf = NULL;
250 	unsigned long blen = 0;
251 	unsigned long offset = 0;
252 
253 	memset(&rsa, 0, sizeof(rsa));
254 	mbedtls_rsa_init(&rsa, 0, 0);
255 
256 	rsa.E = *(mbedtls_mpi *)key->e;
257 	rsa.N = *(mbedtls_mpi *)key->n;
258 
259 	rsa.len = crypto_bignum_num_bytes((void *)&rsa.N);
260 
261 	blen = CFG_CORE_BIGNUM_MAX_BITS / 8;
262 	buf = malloc(blen);
263 	if (!buf) {
264 		res = TEE_ERROR_OUT_OF_MEMORY;
265 		goto out;
266 	}
267 
268 	memset(buf, 0, blen);
269 	memcpy(buf + rsa.len - src_len, src, src_len);
270 
271 	lmd_res = mbedtls_rsa_public(&rsa, buf, buf);
272 	if (lmd_res != 0) {
273 		FMSG("mbedtls_rsa_public() returned 0x%x", -lmd_res);
274 		res = get_tee_result(lmd_res);
275 		goto out;
276 	}
277 
278 	/* Remove the zero-padding (leave one zero if buff is all zeroes) */
279 	offset = 0;
280 	while ((offset < rsa.len - 1) && (buf[offset] == 0))
281 		offset++;
282 
283 	if (*dst_len < rsa.len - offset) {
284 		*dst_len = rsa.len - offset;
285 		res = TEE_ERROR_SHORT_BUFFER;
286 		goto out;
287 	}
288 	*dst_len = rsa.len - offset;
289 	memcpy(dst, buf + offset, *dst_len);
290 out:
291 	free(buf);
292 	/* Reset mpi to skip freeing here, those mpis will be freed with key */
293 	mbedtls_mpi_init(&rsa.E);
294 	mbedtls_mpi_init(&rsa.N);
295 	mbedtls_rsa_free(&rsa);
296 
297 	return res;
298 }
299 
crypto_acipher_rsanopad_decrypt(struct rsa_keypair * key,const uint8_t * src,size_t src_len,uint8_t * dst,size_t * dst_len)300 TEE_Result crypto_acipher_rsanopad_decrypt(struct rsa_keypair *key,
301 					   const uint8_t *src, size_t src_len,
302 					   uint8_t *dst, size_t *dst_len)
303 {
304 	TEE_Result res = TEE_SUCCESS;
305 	mbedtls_rsa_context rsa;
306 	int lmd_res = 0;
307 	uint8_t *buf = NULL;
308 	unsigned long blen = 0;
309 	unsigned long offset = 0;
310 
311 	memset(&rsa, 0, sizeof(rsa));
312 	rsa_init_from_key_pair(&rsa, key);
313 
314 	blen = CFG_CORE_BIGNUM_MAX_BITS / 8;
315 	buf = malloc(blen);
316 	if (!buf) {
317 		res = TEE_ERROR_OUT_OF_MEMORY;
318 		goto out;
319 	}
320 
321 	memset(buf, 0, blen);
322 	memcpy(buf + rsa.len - src_len, src, src_len);
323 
324 	lmd_res = mbedtls_rsa_private(&rsa, NULL, NULL, buf, buf);
325 	if (lmd_res != 0) {
326 		FMSG("mbedtls_rsa_private() returned 0x%x", -lmd_res);
327 		res = get_tee_result(lmd_res);
328 		goto out;
329 	}
330 
331 	/* Remove the zero-padding (leave one zero if buff is all zeroes) */
332 	offset = 0;
333 	while ((offset < rsa.len - 1) && (buf[offset] == 0))
334 		offset++;
335 
336 	if (*dst_len < rsa.len - offset) {
337 		*dst_len = rsa.len - offset;
338 		res = TEE_ERROR_SHORT_BUFFER;
339 		goto out;
340 	}
341 	*dst_len = rsa.len - offset;
342 	memcpy(dst, (char *)buf + offset, *dst_len);
343 out:
344 	if (buf)
345 		free(buf);
346 	mbd_rsa_free(&rsa);
347 	return res;
348 }
349 
crypto_acipher_rsaes_decrypt(uint32_t algo,struct rsa_keypair * key,const uint8_t * label __unused,size_t label_len __unused,const uint8_t * src,size_t src_len,uint8_t * dst,size_t * dst_len)350 TEE_Result crypto_acipher_rsaes_decrypt(uint32_t algo, struct rsa_keypair *key,
351 					const uint8_t *label __unused,
352 					size_t label_len __unused,
353 					const uint8_t *src, size_t src_len,
354 					uint8_t *dst, size_t *dst_len)
355 {
356 	TEE_Result res = TEE_SUCCESS;
357 	int lmd_res = 0;
358 	int lmd_padding = 0;
359 	size_t blen = 0;
360 	size_t mod_size = 0;
361 	void *buf = NULL;
362 	mbedtls_rsa_context rsa;
363 	const mbedtls_pk_info_t *pk_info = NULL;
364 	uint32_t md_algo = MBEDTLS_MD_NONE;
365 
366 	memset(&rsa, 0, sizeof(rsa));
367 	rsa_init_from_key_pair(&rsa, key);
368 
369 	/*
370 	 * Use a temporary buffer since we don't know exactly how large
371 	 * the required size of the out buffer without doing a partial
372 	 * decrypt. We know the upper bound though.
373 	 */
374 	if (algo == TEE_ALG_RSAES_PKCS1_V1_5) {
375 		mod_size = crypto_bignum_num_bytes(key->n);
376 		blen = mod_size - 11;
377 		lmd_padding = MBEDTLS_RSA_PKCS_V15;
378 	} else {
379 		/* Decoded message is always shorter than encrypted message */
380 		blen = src_len;
381 		lmd_padding = MBEDTLS_RSA_PKCS_V21;
382 	}
383 
384 	buf = malloc(blen);
385 	if (!buf) {
386 		res = TEE_ERROR_OUT_OF_MEMORY;
387 		goto out;
388 	}
389 
390 	pk_info = mbedtls_pk_info_from_type(MBEDTLS_PK_RSA);
391 	if (!pk_info) {
392 		res = TEE_ERROR_NOT_SUPPORTED;
393 		goto out;
394 	}
395 
396 	/*
397 	 * TEE_ALG_RSAES_PKCS1_V1_5 is invalid in hash. But its hash algo will
398 	 * not be used in rsa, so skip it here.
399 	 */
400 	if (algo != TEE_ALG_RSAES_PKCS1_V1_5) {
401 		md_algo = tee_algo_to_mbedtls_hash_algo(algo);
402 		if (md_algo == MBEDTLS_MD_NONE) {
403 			res = TEE_ERROR_NOT_SUPPORTED;
404 			goto out;
405 		}
406 	}
407 
408 	mbedtls_rsa_set_padding(&rsa, lmd_padding, md_algo);
409 
410 	if (lmd_padding == MBEDTLS_RSA_PKCS_V15)
411 		lmd_res = pk_info->decrypt_func(&rsa, src, src_len, buf, &blen,
412 						blen, NULL, NULL);
413 	else
414 		lmd_res = pk_info->decrypt_func(&rsa, src, src_len, buf, &blen,
415 						blen, mbd_rand, NULL);
416 	if (lmd_res != 0) {
417 		FMSG("decrypt_func() returned 0x%x", -lmd_res);
418 		res = get_tee_result(lmd_res);
419 		goto out;
420 	}
421 
422 	if (*dst_len < blen) {
423 		*dst_len = blen;
424 		res = TEE_ERROR_SHORT_BUFFER;
425 		goto out;
426 	}
427 
428 	res = TEE_SUCCESS;
429 	*dst_len = blen;
430 	memcpy(dst, buf, blen);
431 out:
432 	if (buf)
433 		free(buf);
434 	mbd_rsa_free(&rsa);
435 	return res;
436 }
437 
crypto_acipher_rsaes_encrypt(uint32_t algo,struct rsa_public_key * key,const uint8_t * label __unused,size_t label_len __unused,const uint8_t * src,size_t src_len,uint8_t * dst,size_t * dst_len)438 TEE_Result crypto_acipher_rsaes_encrypt(uint32_t algo,
439 					struct rsa_public_key *key,
440 					const uint8_t *label __unused,
441 					size_t label_len __unused,
442 					const uint8_t *src, size_t src_len,
443 					uint8_t *dst, size_t *dst_len)
444 {
445 	TEE_Result res = TEE_SUCCESS;
446 	int lmd_res = 0;
447 	int lmd_padding = 0;
448 	size_t mod_size = 0;
449 	mbedtls_rsa_context rsa;
450 	const mbedtls_pk_info_t *pk_info = NULL;
451 	uint32_t md_algo = MBEDTLS_MD_NONE;
452 
453 	memset(&rsa, 0, sizeof(rsa));
454 	mbedtls_rsa_init(&rsa, 0, 0);
455 
456 	rsa.E = *(mbedtls_mpi *)key->e;
457 	rsa.N = *(mbedtls_mpi *)key->n;
458 
459 	mod_size = crypto_bignum_num_bytes(key->n);
460 	if (*dst_len < mod_size) {
461 		*dst_len = mod_size;
462 		res = TEE_ERROR_SHORT_BUFFER;
463 		goto out;
464 	}
465 	*dst_len = mod_size;
466 	rsa.len = mod_size;
467 
468 	if (algo == TEE_ALG_RSAES_PKCS1_V1_5)
469 		lmd_padding = MBEDTLS_RSA_PKCS_V15;
470 	else
471 		lmd_padding = MBEDTLS_RSA_PKCS_V21;
472 
473 	pk_info = mbedtls_pk_info_from_type(MBEDTLS_PK_RSA);
474 	if (!pk_info) {
475 		res = TEE_ERROR_NOT_SUPPORTED;
476 		goto out;
477 	}
478 
479 	/*
480 	 * TEE_ALG_RSAES_PKCS1_V1_5 is invalid in hash. But its hash algo will
481 	 * not be used in rsa, so skip it here.
482 	 */
483 	if (algo != TEE_ALG_RSAES_PKCS1_V1_5) {
484 		md_algo = tee_algo_to_mbedtls_hash_algo(algo);
485 		if (md_algo == MBEDTLS_MD_NONE) {
486 			res = TEE_ERROR_NOT_SUPPORTED;
487 			goto out;
488 		}
489 	}
490 
491 	mbedtls_rsa_set_padding(&rsa, lmd_padding, md_algo);
492 
493 	lmd_res = pk_info->encrypt_func(&rsa, src, src_len, dst, dst_len,
494 					*dst_len, mbd_rand, NULL);
495 	if (lmd_res != 0) {
496 		FMSG("encrypt_func() returned 0x%x", -lmd_res);
497 		res = get_tee_result(lmd_res);
498 		goto out;
499 	}
500 	res = TEE_SUCCESS;
501 out:
502 	/* Reset mpi to skip freeing here, those mpis will be freed with key */
503 	mbedtls_mpi_init(&rsa.E);
504 	mbedtls_mpi_init(&rsa.N);
505 	mbedtls_rsa_free(&rsa);
506 	return res;
507 }
508 
crypto_acipher_rsassa_sign(uint32_t algo,struct rsa_keypair * key,int salt_len __unused,const uint8_t * msg,size_t msg_len,uint8_t * sig,size_t * sig_len)509 TEE_Result crypto_acipher_rsassa_sign(uint32_t algo, struct rsa_keypair *key,
510 				      int salt_len __unused, const uint8_t *msg,
511 				      size_t msg_len, uint8_t *sig,
512 				      size_t *sig_len)
513 {
514 	TEE_Result res = TEE_SUCCESS;
515 	int lmd_res = 0;
516 	int lmd_padding = 0;
517 	size_t mod_size = 0;
518 	size_t hash_size = 0;
519 	mbedtls_rsa_context rsa;
520 	const mbedtls_pk_info_t *pk_info = NULL;
521 	uint32_t md_algo = 0;
522 
523 	memset(&rsa, 0, sizeof(rsa));
524 	rsa_init_from_key_pair(&rsa, key);
525 
526 	switch (algo) {
527 	case TEE_ALG_RSASSA_PKCS1_V1_5_MD5:
528 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA1:
529 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA224:
530 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA256:
531 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA384:
532 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA512:
533 		lmd_padding = MBEDTLS_RSA_PKCS_V15;
534 		break;
535 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA1:
536 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA224:
537 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA256:
538 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA384:
539 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA512:
540 		lmd_padding = MBEDTLS_RSA_PKCS_V21;
541 		break;
542 	default:
543 		res = TEE_ERROR_BAD_PARAMETERS;
544 		goto err;
545 	}
546 
547 	res = tee_alg_get_digest_size(TEE_DIGEST_HASH_TO_ALGO(algo),
548 				      &hash_size);
549 	if (res != TEE_SUCCESS)
550 		goto err;
551 
552 	if (msg_len != hash_size) {
553 		res = TEE_ERROR_BAD_PARAMETERS;
554 		goto err;
555 	}
556 
557 	mod_size = crypto_bignum_num_bytes(key->n);
558 	if (*sig_len < mod_size) {
559 		*sig_len = mod_size;
560 		res = TEE_ERROR_SHORT_BUFFER;
561 		goto err;
562 	}
563 	rsa.len = mod_size;
564 
565 	md_algo = tee_algo_to_mbedtls_hash_algo(algo);
566 	if (md_algo == MBEDTLS_MD_NONE) {
567 		res = TEE_ERROR_NOT_SUPPORTED;
568 		goto err;
569 	}
570 
571 	pk_info = mbedtls_pk_info_from_type(MBEDTLS_PK_RSA);
572 	if (!pk_info) {
573 		res = TEE_ERROR_NOT_SUPPORTED;
574 		goto err;
575 	}
576 
577 	mbedtls_rsa_set_padding(&rsa, lmd_padding, md_algo);
578 
579 	if (lmd_padding == MBEDTLS_RSA_PKCS_V15)
580 		lmd_res = pk_info->sign_func(&rsa, md_algo, msg, msg_len, sig,
581 					     sig_len, NULL, NULL);
582 	else
583 		lmd_res = pk_info->sign_func(&rsa, md_algo, msg, msg_len, sig,
584 					     sig_len, mbd_rand, NULL);
585 	if (lmd_res != 0) {
586 		FMSG("sign_func failed, returned 0x%x", -lmd_res);
587 		res = get_tee_result(lmd_res);
588 		goto err;
589 	}
590 	res = TEE_SUCCESS;
591 err:
592 	mbd_rsa_free(&rsa);
593 	return res;
594 }
595 
crypto_acipher_rsassa_verify(uint32_t algo,struct rsa_public_key * key,int salt_len __unused,const uint8_t * msg,size_t msg_len,const uint8_t * sig,size_t sig_len)596 TEE_Result crypto_acipher_rsassa_verify(uint32_t algo,
597 					struct rsa_public_key *key,
598 					int salt_len __unused,
599 					const uint8_t *msg,
600 					size_t msg_len, const uint8_t *sig,
601 					size_t sig_len)
602 {
603 	TEE_Result res = TEE_SUCCESS;
604 	int lmd_res = 0;
605 	int lmd_padding = 0;
606 	size_t hash_size = 0;
607 	size_t bigint_size = 0;
608 	mbedtls_rsa_context rsa;
609 	const mbedtls_pk_info_t *pk_info = NULL;
610 	uint32_t md_algo = 0;
611 
612 	memset(&rsa, 0, sizeof(rsa));
613 	mbedtls_rsa_init(&rsa, 0, 0);
614 
615 	rsa.E = *(mbedtls_mpi *)key->e;
616 	rsa.N = *(mbedtls_mpi *)key->n;
617 
618 	res = tee_alg_get_digest_size(TEE_DIGEST_HASH_TO_ALGO(algo),
619 				      &hash_size);
620 	if (res != TEE_SUCCESS)
621 		goto err;
622 
623 	if (msg_len != hash_size) {
624 		res = TEE_ERROR_BAD_PARAMETERS;
625 		goto err;
626 	}
627 
628 	bigint_size = crypto_bignum_num_bytes(key->n);
629 	if (sig_len < bigint_size) {
630 		res = TEE_ERROR_SIGNATURE_INVALID;
631 		goto err;
632 	}
633 
634 	rsa.len = bigint_size;
635 
636 	switch (algo) {
637 	case TEE_ALG_RSASSA_PKCS1_V1_5_MD5:
638 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA1:
639 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA224:
640 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA256:
641 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA384:
642 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA512:
643 		lmd_padding = MBEDTLS_RSA_PKCS_V15;
644 		break;
645 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA1:
646 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA224:
647 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA256:
648 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA384:
649 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA512:
650 		lmd_padding = MBEDTLS_RSA_PKCS_V21;
651 		break;
652 	default:
653 		res = TEE_ERROR_BAD_PARAMETERS;
654 		goto err;
655 	}
656 
657 	md_algo = tee_algo_to_mbedtls_hash_algo(algo);
658 	if (md_algo == MBEDTLS_MD_NONE) {
659 		res = TEE_ERROR_NOT_SUPPORTED;
660 		goto err;
661 	}
662 
663 	pk_info = mbedtls_pk_info_from_type(MBEDTLS_PK_RSA);
664 	if (!pk_info) {
665 		res = TEE_ERROR_NOT_SUPPORTED;
666 		goto err;
667 	}
668 
669 	mbedtls_rsa_set_padding(&rsa, lmd_padding, md_algo);
670 
671 	lmd_res = pk_info->verify_func(&rsa, md_algo, msg, msg_len,
672 				       sig, sig_len);
673 	if (lmd_res != 0) {
674 		FMSG("verify_func failed, returned 0x%x", -lmd_res);
675 		res = TEE_ERROR_SIGNATURE_INVALID;
676 		goto err;
677 	}
678 	res = TEE_SUCCESS;
679 err:
680 	/* Reset mpi to skip freeing here, those mpis will be freed with key */
681 	mbedtls_mpi_init(&rsa.E);
682 	mbedtls_mpi_init(&rsa.N);
683 	mbedtls_rsa_free(&rsa);
684 	return res;
685 }
686