diff --git a/library/psa_crypto.c b/library/psa_crypto.c index 8ca787ed645..0fb9b956097 100644 --- a/library/psa_crypto.c +++ b/library/psa_crypto.c @@ -1500,32 +1500,10 @@ psa_status_t psa_export_key_internal( PSA_KEY_TYPE_IS_RSA(type) || PSA_KEY_TYPE_IS_ECC(type) || PSA_KEY_TYPE_IS_DH(type) || - (PSA_KEY_TYPE_IS_ML_KEM(type) && PSA_KEY_TYPE_IS_PUBLIC_KEY(type))) { + PSA_KEY_TYPE_IS_ML_KEM(type)) { return psa_export_key_buffer_internal( key_buffer, key_buffer_size, data, data_size, data_length); - } else if (PSA_KEY_TYPE_IS_ML_KEM(type)) { -#if defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_MLKEM_KEY_PAIR_EXPORT) - psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED; - mbedtls_mlkem_context *mlkem = NULL; - - status = mbedtls_psa_mlkem_load_representation( - type, attributes->bits, - key_buffer, key_buffer_size, &mlkem); - if (status != PSA_SUCCESS) { - goto exit; - } - - status = mbedtls_psa_mlkem_export_key(PSA_KEY_TYPE_ML_KEM_KEY_PAIR, attributes->bits, mlkem, data, data_size, data_length); -exit: - if (status != PSA_SUCCESS) { - mbedtls_free(mlkem); - } - return status; -#else - /* We don't know how to export a MLKEM key. */ - return PSA_ERROR_NOT_SUPPORTED; -#endif /* MBEDTLS_PSA_BUILTIN_KEY_TYPE_MLKEM_KEY_PAIR_EXPORT */ } else { /* This shouldn't happen in the reference implementation, but it is valid for a special-purpose implementation to omit @@ -8038,7 +8016,7 @@ psa_status_t psa_generate_key_internal( #if defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_MLKEM_KEY_PAIR_GENERATE) if (PSA_KEY_TYPE_IS_ML_KEM(type) && PSA_KEY_TYPE_IS_KEY_PAIR(type)) { - return mbedtls_psa_mlkem_generate_key(attributes->bits, + return mbedtls_psa_mlkem_generate_key(attributes, key_buffer, key_buffer_size, key_buffer_length); @@ -8302,7 +8280,7 @@ psa_status_t psa_decapsulate(psa_key_id_t key, } #if defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_MLKEM_KEY_DECAPSULATE) - status = mbedtls_psa_mlkem_decapsulate(slot->attr.bits, + status = mbedtls_psa_mlkem_decapsulate(&slot->attr, slot->key.data, slot->key.bytes, ciphertext, diff --git a/library/psa_crypto_mlkem.c b/library/psa_crypto_mlkem.c index 057ba037aea..2e29ad0d980 100644 --- a/library/psa_crypto_mlkem.c +++ b/library/psa_crypto_mlkem.c @@ -97,18 +97,18 @@ psa_status_t mbedtls_psa_mlkem_load_representation( mbedtls_mlkem_init(*p_mlkem); if (PSA_KEY_TYPE_IS_PUBLIC_KEY(type)) { - return PSA_ERROR_NOT_SUPPORTED; + (*p_mlkem)->encaps_key.key_data = (uint32_t *)data; + (*p_mlkem)->encaps_key.key_len = data_length; } else { - (*p_mlkem)->decaps_key.key_data = (uint32_t *)data; - (*p_mlkem)->decaps_key.key_len = PSA_KEY_EXPORT_ML_KEM_PRIVATE_KEY_SIZE(bits); - (*p_mlkem)->d.key_data = (uint32_t *)(data + (*p_mlkem)->decaps_key.key_len); + (*p_mlkem)->d.key_data = (uint32_t *)data; (*p_mlkem)->d.key_len = PSA_ML_KEM_SEED_SIZE; - (*p_mlkem)->z.key_data = (uint32_t *)(data + (*p_mlkem)->decaps_key.key_len + (*p_mlkem)->d.key_len); + (*p_mlkem)->z.key_data = (uint32_t *)(data + (*p_mlkem)->d.key_len); (*p_mlkem)->z.key_len = PSA_ML_KEM_SEED_SIZE; + (*p_mlkem)->decaps_key.key_data = (uint32_t *)(data + (*p_mlkem)->d.key_len + (*p_mlkem)->z.key_len); + (*p_mlkem)->decaps_key.key_len = PSA_KEY_EXPORT_ML_KEM_PRIVATE_KEY_SIZE(bits); } - return PSA_SUCCESS; } #endif @@ -227,35 +227,42 @@ psa_status_t mbedtls_psa_mlkem_export_public_key( #if defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_MLKEM_KEY_PAIR_GENERATE) psa_status_t mbedtls_psa_mlkem_generate_key( - const psa_key_bits_t bits, + const psa_key_attributes_t *attributes, uint8_t *key_buffer, size_t key_buffer_size, size_t *key_buffer_length) { - int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED; - mbedtls_mlkem_context mlkem; + psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED; + mbedtls_mlkem_context *mlkem = NULL; #if defined(MBEDTLS_MLKEM_TEST_FIXED_TRNG) random_call_count = 0; #endif - - mbedtls_mlkem_init(&mlkem); - mlkem.decaps_key.key_data = (uint32_t *)key_buffer; - mlkem.decaps_key.key_len = PSA_KEY_EXPORT_ML_KEM_PRIVATE_KEY_SIZE(bits); - mlkem.d.key_data = (uint32_t *)(key_buffer + mlkem.decaps_key.key_len); - mlkem.d.key_len = PSA_ML_KEM_SEED_SIZE; - mlkem.z.key_data = (uint32_t *)(key_buffer + mlkem.decaps_key.key_len + mlkem.d.key_len); - mlkem.z.key_len = PSA_ML_KEM_SEED_SIZE; - - if (key_buffer_size < mlkem.decaps_key.key_len + mlkem.d.key_len + mlkem.z.key_len) { - return PSA_ERROR_BUFFER_TOO_SMALL; + + /* Parse input */ + status = mbedtls_psa_mlkem_load_representation(attributes->type, + attributes->bits, + key_buffer, + key_buffer_size, + &mlkem); + if (PSA_SUCCESS != status) { + goto exit; } - ret = mbedtls_mlkem_generate_key(&mlkem, bits, mbedtls_mlkem_get_random); - if (ret != 0) { - return mbedtls_to_psa_error(ret); + if (key_buffer_size < mlkem->decaps_key.key_len + mlkem->d.key_len + mlkem->z.key_len) { + status = PSA_ERROR_BUFFER_TOO_SMALL; + goto exit; } - *key_buffer_length = PSA_ML_KEM_SEED_SIZE + PSA_ML_KEM_SEED_SIZE + mlkem.decaps_key.key_len; + int ret = mbedtls_mlkem_generate_key(mlkem, attributes->bits, mbedtls_mlkem_get_random); + if (ret != 0) { + status = mbedtls_to_psa_error(ret); + goto exit; + } - return mbedtls_to_psa_error(ret); + *key_buffer_length = PSA_ML_KEM_SEED_SIZE + PSA_ML_KEM_SEED_SIZE + mlkem->decaps_key.key_len; +exit: + if (status != PSA_SUCCESS) { + mbedtls_free(mlkem); + } + return status; } #endif /* MBEDTLS_PSA_BUILTIN_KEY_TYPE_MLKEM_KEY_PAIR_GENERATE */ @@ -270,24 +277,26 @@ psa_status_t mbedtls_psa_mlkem_encapsulate( size_t ciphertext_size, size_t *ciphertext_length) { - int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED; - mbedtls_mlkem_context mlkem; + psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED; + mbedtls_mlkem_context *mlkem = NULL; mbedtls_mlkem_data_t cipher; mbedtls_mlkem_data_t shared_key; #if defined(MBEDTLS_MLKEM_TEST_FIXED_TRNG) random_call_count = 3; #endif - - mbedtls_mlkem_init(&mlkem); - if (PSA_KEY_TYPE_IS_PUBLIC_KEY(attributes->type)) { - mlkem.encaps_key.key_data = (uint32_t *)key_buffer; - mlkem.encaps_key.key_len = key_buffer_size; + + /* Parse input */ + status = mbedtls_psa_mlkem_load_representation(attributes->type, + attributes->bits, + key_buffer, + key_buffer_size, + &mlkem); + if (PSA_SUCCESS != status) { + goto exit; } - else - { - mlkem.decaps_key.key_data = (uint32_t *)key_buffer; - mlkem.decaps_key.key_len = PSA_KEY_EXPORT_ML_KEM_PRIVATE_KEY_SIZE(attributes->bits); - mbedtls_mlkem_export_public_key(&mlkem, attributes->bits); + + if (PSA_KEY_TYPE_IS_KEY_PAIR(attributes->type)) { + mbedtls_mlkem_export_public_key(mlkem, attributes->bits); } cipher.key_data = (uint32_t *)ciphertext; @@ -295,29 +304,37 @@ psa_status_t mbedtls_psa_mlkem_encapsulate( shared_key.key_data = (uint32_t *)output_key_buffer; shared_key.key_len = output_key_buffer_size; - if (key_buffer_size < mlkem.decaps_key.key_len) { - return PSA_ERROR_BUFFER_TOO_SMALL; + if (key_buffer_size < mlkem->decaps_key.key_len) { + status = PSA_ERROR_BUFFER_TOO_SMALL; + goto exit; } - ret = mbedtls_mlkem_encapsulate(&mlkem, attributes->bits, &cipher, &shared_key, mbedtls_mlkem_get_random); + int ret = mbedtls_mlkem_encapsulate(mlkem, attributes->bits, &cipher, &shared_key, mbedtls_mlkem_get_random); if (ret != 0) { - return mbedtls_to_psa_error(ret); + status = mbedtls_to_psa_error(ret); + goto exit; } if (shared_key.key_len > output_key_buffer_size) { - return PSA_ERROR_BUFFER_TOO_SMALL; + status = PSA_ERROR_BUFFER_TOO_SMALL; + goto exit; } if (cipher.key_len > ciphertext_size) { - return PSA_ERROR_BUFFER_TOO_SMALL; + status = PSA_ERROR_BUFFER_TOO_SMALL; + goto exit; } *ciphertext_length = cipher.key_len; - return mbedtls_to_psa_error(ret); +exit: + if (status != PSA_SUCCESS) { + mbedtls_free(mlkem); + } + return status; } #endif /* MBEDTLS_PSA_BUILTIN_KEY_TYPE_MLKEM_KEY_ENCAPSULATE */ #if defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_MLKEM_KEY_DECAPSULATE) psa_status_t mbedtls_psa_mlkem_decapsulate( - const psa_key_bits_t bits, + const psa_key_attributes_t *attributes, uint8_t *key_buffer, size_t key_buffer_size, const uint8_t *ciphertext, @@ -325,40 +342,55 @@ psa_status_t mbedtls_psa_mlkem_decapsulate( uint8_t *shared_secret, size_t *shared_secret_len) { - int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED; - mbedtls_mlkem_context mlkem; + psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED; + mbedtls_mlkem_context *mlkem = NULL; mbedtls_mlkem_data_t cipher; mbedtls_mlkem_data_t shared_key; #if defined(MBEDTLS_MLKEM_TEST_FIXED_TRNG) random_call_count = 3; #endif + + /* Parse input */ + status = mbedtls_psa_mlkem_load_representation(attributes->type, + attributes->bits, + key_buffer, + key_buffer_size, + &mlkem); + if (PSA_SUCCESS != status) { + goto exit; + } - mbedtls_mlkem_init(&mlkem); - mlkem.decaps_key.key_data = (uint32_t *)key_buffer; - mlkem.decaps_key.key_len = PSA_KEY_EXPORT_ML_KEM_PRIVATE_KEY_SIZE(bits); cipher.key_data = (uint32_t *)ciphertext; cipher.key_len = ciphertext_len; shared_key.key_data = (uint32_t *)shared_secret; shared_key.key_len = *shared_secret_len; - if (key_buffer_size < mlkem.decaps_key.key_len) { - return PSA_ERROR_BUFFER_TOO_SMALL; + if (key_buffer_size < mlkem->decaps_key.key_len) { + status = PSA_ERROR_BUFFER_TOO_SMALL; + goto exit; } - if (key_buffer_size < mlkem.decaps_key.key_len) { - return PSA_ERROR_BUFFER_TOO_SMALL; + if (key_buffer_size < mlkem->decaps_key.key_len) { + status = PSA_ERROR_BUFFER_TOO_SMALL; + goto exit; } - ret = mbedtls_mlkem_decapsulate(&mlkem, bits, &cipher, &shared_key, mbedtls_mlkem_get_random); + int ret = mbedtls_mlkem_decapsulate(mlkem, attributes->bits, &cipher, &shared_key, mbedtls_mlkem_get_random); if (ret != 0) { - return mbedtls_to_psa_error(ret); + status = mbedtls_to_psa_error(ret); + goto exit; } if (shared_key.key_len > *shared_secret_len) { - return PSA_ERROR_BUFFER_TOO_SMALL; + status = PSA_ERROR_BUFFER_TOO_SMALL; + goto exit; } *shared_secret_len = shared_key.key_len; - return mbedtls_to_psa_error(ret); +exit: + if (status != PSA_SUCCESS) { + mbedtls_free(mlkem); + } + return status; } #endif /* MBEDTLS_PSA_BUILTIN_KEY_TYPE_MLKEM_KEY_DECAPSULATE */ diff --git a/library/psa_crypto_mlkem.h b/library/psa_crypto_mlkem.h index 2f6816b327f..5d1b3f0dc84 100644 --- a/library/psa_crypto_mlkem.h +++ b/library/psa_crypto_mlkem.h @@ -112,7 +112,7 @@ psa_status_t mbedtls_psa_mlkem_export_public_key( * \note The signature of the function is that of a PSA driver generate_key * entry point. * - * \param[in] bits The algorithm strength in bits. + * \param[in] attributes The attributes for the key to generate. * \param[out] key_buffer Buffer where the key data is to be written. * \param[in] key_buffer_size Size of \p key_buffer in bytes. * \param[out] key_buffer_length On success, the number of bytes written in @@ -126,7 +126,7 @@ psa_status_t mbedtls_psa_mlkem_export_public_key( * The size of \p key_buffer is too small. */ psa_status_t mbedtls_psa_mlkem_generate_key( - const psa_key_bits_t bits, + const psa_key_attributes_t *attributes, uint8_t *key_buffer, size_t key_buffer_size, size_t *key_buffer_length); /** @@ -162,7 +162,7 @@ psa_status_t mbedtls_psa_mlkem_encapsulate( /** * \brief Decapsulate MLKEM ciphertext. * - * \param[in] bits The algorithm strength in bits. + * \param[in] attributes The attributes for the key to decapsulate. * \param[in] key_buffer Buffer holding the key data. * \param[in] key_buffer_size Size of \p key_buffer in bytes. * \param[in] ciphertext Buffer holding the ciphertext data. @@ -177,7 +177,7 @@ psa_status_t mbedtls_psa_mlkem_encapsulate( * Key length or type not supported. */ psa_status_t mbedtls_psa_mlkem_decapsulate( - const psa_key_bits_t bits, + const psa_key_attributes_t *attributes, uint8_t *key_buffer, size_t key_buffer_size, const uint8_t *ciphertext,