diff --git a/src/crypto/CHIPCryptoPALPSA.cpp b/src/crypto/CHIPCryptoPALPSA.cpp index 604125e7efd..4e27bffa72f 100644 --- a/src/crypto/CHIPCryptoPALPSA.cpp +++ b/src/crypto/CHIPCryptoPALPSA.cpp @@ -28,6 +28,7 @@ #include #include #include +#include #include #include #include @@ -71,45 +72,28 @@ CHIP_ERROR AES_CCM_encrypt(const uint8_t * plaintext, size_t plaintext_length, c const psa_algorithm_t algorithm = PSA_ALG_AEAD_WITH_SHORTENED_TAG(PSA_ALG_CCM, tag_length); psa_status_t status = PSA_SUCCESS; - psa_aead_operation_t operation = PSA_AEAD_OPERATION_INIT; - size_t out_length; - size_t tag_out_length; + size_t ciphertext_length = 0; - status = psa_aead_encrypt_setup(&operation, key.As(), algorithm); - VerifyOrReturnError(status == PSA_SUCCESS, CHIP_ERROR_INTERNAL); + // The PSA single-part API appends the tag to the ciphertext output buffer. + // So we need to allocate a temporary buffer to hold ciphertext + tag, then split them. + size_t total_output_size = plaintext_length + tag_length; + uint8_t * temp_output = static_cast(chip::Platform::MemoryAlloc(total_output_size)); + VerifyOrReturnError(temp_output != nullptr, CHIP_ERROR_NO_MEMORY); - status = psa_aead_set_lengths(&operation, aad_length, plaintext_length); - VerifyOrReturnError(status == PSA_SUCCESS, CHIP_ERROR_INTERNAL); + status = psa_aead_encrypt(key.As(), algorithm, nonce, nonce_length, aad, aad_length, plaintext, plaintext_length, + temp_output, total_output_size, &ciphertext_length); - status = psa_aead_set_nonce(&operation, nonce, nonce_length); - VerifyOrReturnError(status == PSA_SUCCESS, CHIP_ERROR_INTERNAL); - - if (aad_length != 0) - { - status = psa_aead_update_ad(&operation, aad, aad_length); - VerifyOrReturnError(status == PSA_SUCCESS, CHIP_ERROR_INTERNAL); - } - else + if (status == PSA_SUCCESS && ciphertext_length == total_output_size) { - ChipLogDetail(Crypto, "AES_CCM_encrypt: Using aad == null path"); - } - - if (plaintext_length != 0) - { - status = psa_aead_update(&operation, plaintext, plaintext_length, ciphertext, - PSA_AEAD_UPDATE_OUTPUT_SIZE(PSA_KEY_TYPE_AES, algorithm, plaintext_length), &out_length); - VerifyOrReturnError(status == PSA_SUCCESS, CHIP_ERROR_INTERNAL); - - ciphertext += out_length; - - status = psa_aead_finish(&operation, ciphertext, PSA_AEAD_FINISH_OUTPUT_SIZE(PSA_KEY_TYPE_AES, algorithm), &out_length, tag, - tag_length, &tag_out_length); - } - else - { - status = psa_aead_finish(&operation, nullptr, 0, &out_length, tag, tag_length, &tag_out_length); + // Copy ciphertext and tag to their respective output buffers + if (plaintext_length > 0) + { + memcpy(ciphertext, temp_output, plaintext_length); + } + memcpy(tag, temp_output + plaintext_length, tag_length); } - VerifyOrReturnError(status == PSA_SUCCESS && tag_length == tag_out_length, CHIP_ERROR_INTERNAL); + chip::Platform::MemoryFree(temp_output); + VerifyOrReturnError(status == PSA_SUCCESS && ciphertext_length == total_output_size, CHIP_ERROR_INTERNAL); return CHIP_NO_ERROR; } @@ -125,45 +109,26 @@ CHIP_ERROR AES_CCM_decrypt(const uint8_t * ciphertext, size_t ciphertext_length, const psa_algorithm_t algorithm = PSA_ALG_AEAD_WITH_SHORTENED_TAG(PSA_ALG_CCM, tag_length); psa_status_t status = PSA_SUCCESS; - psa_aead_operation_t operation = PSA_AEAD_OPERATION_INIT; - size_t outLength; - - status = psa_aead_decrypt_setup(&operation, key.As(), algorithm); - VerifyOrReturnError(status == PSA_SUCCESS, CHIP_ERROR_INTERNAL); - - status = psa_aead_set_lengths(&operation, aad_length, ciphertext_length); - VerifyOrReturnError(status == PSA_SUCCESS, CHIP_ERROR_INTERNAL); + size_t plaintext_length = 0; - status = psa_aead_set_nonce(&operation, nonce, nonce_length); - VerifyOrReturnError(status == PSA_SUCCESS, CHIP_ERROR_INTERNAL); + // The PSA single-part API expects the tag to be appended to the ciphertext input buffer. + // So we need to allocate a temporary buffer to hold ciphertext + tag. + size_t total_input_size = ciphertext_length + tag_length; + uint8_t * temp_input = static_cast(chip::Platform::MemoryAlloc(total_input_size)); + VerifyOrReturnError(temp_input != nullptr, CHIP_ERROR_NO_MEMORY); - if (aad_length != 0) - { - status = psa_aead_update_ad(&operation, aad, aad_length); - VerifyOrReturnError(status == PSA_SUCCESS, CHIP_ERROR_INTERNAL); - } - else + if (ciphertext_length > 0) { - ChipLogDetail(Crypto, "AES_CCM_decrypt: Using aad == null path"); + memcpy(temp_input, ciphertext, ciphertext_length); } + memcpy(temp_input + ciphertext_length, tag, tag_length); - if (ciphertext_length != 0) - { - status = psa_aead_update(&operation, ciphertext, ciphertext_length, plaintext, - PSA_AEAD_UPDATE_OUTPUT_SIZE(PSA_KEY_TYPE_AES, algorithm, ciphertext_length), &outLength); - VerifyOrReturnError(status == PSA_SUCCESS, CHIP_ERROR_INTERNAL); - - plaintext += outLength; + status = psa_aead_decrypt(key.As(), algorithm, nonce, nonce_length, aad, aad_length, temp_input, total_input_size, + plaintext, ciphertext_length, &plaintext_length); - status = psa_aead_verify(&operation, plaintext, PSA_AEAD_VERIFY_OUTPUT_SIZE(PSA_KEY_TYPE_AES, algorithm), &outLength, tag, - tag_length); - } - else - { - status = psa_aead_verify(&operation, nullptr, 0, &outLength, tag, tag_length); - } - - VerifyOrReturnError(status == PSA_SUCCESS, CHIP_ERROR_INTERNAL); + chip::Platform::MemoryFree(temp_input); + VerifyOrReturnError(status == PSA_SUCCESS && plaintext_length == ciphertext_length, + (status == PSA_ERROR_INVALID_SIGNATURE) ? CHIP_ERROR_INVALID_SIGNATURE : CHIP_ERROR_INTERNAL); return CHIP_NO_ERROR; }