@@ -337,7 +337,8 @@ TEST(KaheTest, UnpackMessagesRawRemovesConsumedPackedValues) {
337337 rust::Vec<Integer> unpacked_messages;
338338 SECAGG_EXPECT_OK (UnwrapFfiStatus (
339339 UnpackMessagesRaw (packing_base, packing_dimension, num_packed_values,
340- packed_values, unpacked_messages)));
340+ num_packed_values * packing_dimension, packed_values,
341+ unpacked_messages)));
341342 EXPECT_EQ (packed_values.ptr ->size (), num_packed_values);
342343 EXPECT_EQ (unpacked_messages.size (), num_packed_values);
343344 // Unpacked values should match the first half of the original packed values.
@@ -399,8 +400,9 @@ TEST(KaheTest, PackAndEncrypt) {
399400 .ptr = std::make_unique<std::vector<BigInteger>>(std::move (decrypted))};
400401 rust::Vec<Integer> unpacked_messages;
401402 SECAGG_ASSERT_OK (UnwrapFfiStatus (
402- UnpackMessagesRaw (packing_base, num_packing, packed_messages.size (),
403- decrypted_wrapper, unpacked_messages)));
403+ UnpackMessagesRaw (packing_base, num_packing, num_packed_messages,
404+ num_packed_messages * num_packing, decrypted_wrapper,
405+ unpacked_messages)));
404406 EXPECT_EQ (absl::MakeSpan (unpacked_messages.data (), num_messages),
405407 absl::MakeSpan (expected_unpacked_messages.data (), num_messages));
406408 // Check against the original input messages.
@@ -461,7 +463,8 @@ TEST(KaheTest, RawVectorEncryptOnePolynomial) {
461463 rust::Vec<Integer> unpacked_decrypted_messages;
462464 SECAGG_ASSERT_OK (UnwrapFfiStatus (UnpackMessagesRaw (
463465 packing_base, num_packing, decrypted_wrapper.ptr ->size (),
464- decrypted_wrapper, unpacked_decrypted_messages)));
466+ decrypted_wrapper.ptr ->size () * num_packing, decrypted_wrapper,
467+ unpacked_decrypted_messages)));
465468
466469 // Filled the whole buffer with right messages.
467470 EXPECT_EQ (absl::MakeSpan (unpacked_decrypted_messages.data (), num_messages),
@@ -481,6 +484,7 @@ TEST(KaheTest, RawVectorEncryptOnePolynomial) {
481484 unpacked_decrypted_long_messages.reserve (buffer_length);
482485 SECAGG_ASSERT_OK (UnwrapFfiStatus (UnpackMessagesRaw (
483486 packing_base, num_packing, decrypted_long_messages_wrapper.ptr ->size (),
487+ decrypted_long_messages_wrapper.ptr ->size () * num_packing,
484488 decrypted_long_messages_wrapper, unpacked_decrypted_long_messages)));
485489
486490 // The non-zero messages are identical.
@@ -493,6 +497,38 @@ TEST(KaheTest, RawVectorEncryptOnePolynomial) {
493497 unpacked_decrypted_long_messages.size ())
494498 .subspan (num_messages, padded_length - num_messages),
495499 ::testing::Each (::testing::Eq(0 )));
500+
501+ // Prepare a fresh packed plaintext.
502+ BigIntVectorWrapper packed_messages_2{
503+ .ptr = std::make_unique<std::vector<BigInteger>>()};
504+ SECAGG_ASSERT_OK (
505+ UnwrapFfiStatus (Decrypt (ciphertexts, key, params, &packed_messages_2)));
506+ SECAGG_ASSERT_OK (UnwrapFfiStatus (
507+ Encrypt (packed_messages_2, key, params, &prng, &ciphertexts)));
508+
509+ // Now unpack and directly pass the right length to remove padding.
510+ rust::Vec<Integer> unpacked_messages_2;
511+ SECAGG_ASSERT_OK (UnwrapFfiStatus (UnpackMessagesRaw (
512+ packing_base, num_packing, packed_messages_2.ptr ->size (), num_messages,
513+ packed_messages_2, unpacked_messages_2)));
514+ EXPECT_EQ (
515+ absl::MakeSpan (unpacked_messages_2.data (), unpacked_messages_2.size ()),
516+ absl::MakeSpan (input_messages));
517+
518+ // Finally, check that we fail if we request too many unpacked messages
519+ BigIntVectorWrapper packed_messages_3{
520+ .ptr = std::make_unique<std::vector<BigInteger>>()};
521+ SECAGG_ASSERT_OK (
522+ UnwrapFfiStatus (Decrypt (ciphertexts, key, params, &packed_messages_2)));
523+ SECAGG_ASSERT_OK (UnwrapFfiStatus (
524+ Encrypt (packed_messages_3, key, params, &prng, &ciphertexts)));
525+ rust::Vec<Integer> unpacked_messages_3;
526+ int num_unpacked_messages_3 = packed_messages_3.ptr ->size () * num_packing + 1 ;
527+ EXPECT_THAT (
528+ UnwrapFfiStatus (UnpackMessagesRaw (
529+ packing_base, num_packing, packed_messages_3.ptr ->size (),
530+ num_unpacked_messages_3, packed_messages_3, unpacked_messages_3)),
531+ StatusIs (absl::StatusCode::kInvalidArgument ));
496532}
497533
498534TEST (KaheTest, RawVectorEncryptTwoPolynomials) {
@@ -538,7 +574,7 @@ TEST(KaheTest, RawVectorEncryptTwoPolynomials) {
538574 UnwrapFfiStatus (Decrypt (ciphertexts, key, params, &decrypted_wrapper)));
539575 rust::Vec<Integer> unpacked_decrypted_messages;
540576 SECAGG_ASSERT_OK (UnwrapFfiStatus (UnpackMessagesRaw (
541- packing_base, num_packing, decrypted_wrapper.ptr ->size (),
577+ packing_base, num_packing, decrypted_wrapper.ptr ->size (), num_messages,
542578 decrypted_wrapper, unpacked_decrypted_messages)));
543579
544580 EXPECT_GE (unpacked_decrypted_messages.size (), num_messages);
@@ -643,7 +679,8 @@ TEST(KaheTest, UnpackMessagesRawFailsIfUnallocatedPackedValues) {
643679 rust::Vec<Integer> unpacked_messages;
644680 EXPECT_THAT (UnwrapFfiStatus (UnpackMessagesRaw (
645681 packing_base, packing_dimension, num_packed_messages,
646- bad_packed_values, unpacked_messages)),
682+ num_packed_messages * packing_dimension, bad_packed_values,
683+ unpacked_messages)),
647684 StatusIs (absl::StatusCode::kInvalidArgument ));
648685}
649686
@@ -658,7 +695,8 @@ TEST(KaheTest, UnpackMessagesRawFailsIfPackedValuesTooShort) {
658695 rust::Vec<Integer> unpacked_messages;
659696 EXPECT_THAT (UnwrapFfiStatus (UnpackMessagesRaw (
660697 packing_base, packing_dimension, num_packed_messages,
661- bad_packed_values, unpacked_messages)),
698+ num_packed_messages * packing_dimension, bad_packed_values,
699+ unpacked_messages)),
662700 StatusIs (absl::StatusCode::kInvalidArgument ));
663701}
664702
@@ -738,7 +776,7 @@ TEST(KaheTest, AddInPlacePolynomial) {
738776 rust::Vec<Integer> unpacked_decrypted_messages;
739777 unpacked_decrypted_messages.reserve (num_messages);
740778 SECAGG_ASSERT_OK (UnwrapFfiStatus (UnpackMessagesRaw (
741- packing_base, num_packing, decrypted_wrapper.ptr ->size (),
779+ packing_base, num_packing, decrypted_wrapper.ptr ->size (), num_messages,
742780 decrypted_wrapper, unpacked_decrypted_messages)));
743781 for (int i = 0 ; i < num_messages; ++i) {
744782 EXPECT_EQ (input_values1[i] + input_values2[i],
0 commit comments