@@ -337,7 +337,8 @@ TEST(KaheTest, UnpackMessagesRawRemovesConsumedPackedValues) {
337337 rust::Vec<Integer> unpacked_messages;
338338 SECAGG_EXPECT_OK (UnwrapFfiStatus (
339339 UnpackMessagesRaw (packing_base, packing_dimension, num_packed_values,
340- packed_values, unpacked_messages)));
340+ num_packed_values * packing_dimension, packed_values,
341+ unpacked_messages)));
341342 EXPECT_EQ (packed_values.ptr ->size (), num_packed_values);
342343 EXPECT_EQ (unpacked_messages.size (), num_packed_values);
343344 // Unpacked values should match the first half of the original packed values.
@@ -349,6 +350,94 @@ TEST(KaheTest, UnpackMessagesRawRemovesConsumedPackedValues) {
349350 absl::MakeSpan (packed).subspan (num_packed_values));
350351}
351352
353+ TEST (KaheTest, RawEncryptDecryptPadding) {
354+ constexpr int num_packing = 2 ;
355+ constexpr int num_public_polynomials = 2 ;
356+ constexpr int num_messages = 9 ;
357+ constexpr Integer packing_base = 10 ;
358+ // 2 messages per coefficient, last coefficient has only 1 message.
359+ constexpr int num_packed_messages = 5 ;
360+
361+ std::unique_ptr<std::string> public_seed;
362+ SECAGG_ASSERT_OK (UnwrapFfiStatus (GenerateSingleThreadHkdfSeed (public_seed)));
363+ KahePublicParametersWrapper params;
364+ SECAGG_ASSERT_OK (UnwrapFfiStatus (CreateKahePublicParametersWrapper (
365+ kLogN , kLogT , ToRustSlice (kQs ), num_public_polynomials,
366+ ToRustSlice (*public_seed), ¶ms)));
367+ std::unique_ptr<std::string> private_seed;
368+ SECAGG_ASSERT_OK (UnwrapFfiStatus (GenerateSingleThreadHkdfSeed (private_seed)));
369+ SingleThreadHkdfWrapper prng;
370+ SECAGG_ASSERT_OK (UnwrapFfiStatus (
371+ CreateSingleThreadHkdf (ToRustSlice (*private_seed), prng)));
372+ RnsPolynomialWrapper key;
373+ SECAGG_ASSERT_OK (
374+ UnwrapFfiStatus (GenerateSecretKeyWrapper (params, &prng, &key)));
375+
376+ // Pack messages that don't fully occupy the packed coefficients.
377+ std::vector<Integer> input_messages =
378+ rlwe::testing::SampleMessages (num_messages, packing_base);
379+ BigIntVectorWrapper packed_messages{
380+ .ptr = std::make_unique<std::vector<BigInteger>>()};
381+ SECAGG_ASSERT_OK (UnwrapFfiStatus (
382+ PackMessagesRaw (ToRustSlice (input_messages), packing_base, num_packing,
383+ num_packed_messages, &packed_messages)));
384+
385+ // Encrypt the packed messages.
386+ RnsPolynomialVecWrapper ciphertexts;
387+ SECAGG_ASSERT_OK (UnwrapFfiStatus (
388+ Encrypt (packed_messages, key, params, &prng, &ciphertexts)));
389+
390+ // Decrypt to get a packed plaintext.
391+ BigIntVectorWrapper packed_messages_1{
392+ .ptr = std::make_unique<std::vector<BigInteger>>()};
393+ SECAGG_ASSERT_OK (
394+ UnwrapFfiStatus (Decrypt (ciphertexts, key, params, &packed_messages_1)));
395+
396+ // Unpack and retrieve the original messages plus padding.
397+ rust::Vec<Integer> unpacked_messages_1;
398+ SECAGG_ASSERT_OK (UnwrapFfiStatus (UnpackMessagesRaw (
399+ packing_base, num_packing, packed_messages_1.ptr ->size (),
400+ num_packed_messages * num_packing, packed_messages_1,
401+ unpacked_messages_1)));
402+
403+ // Decrypted messages are padded to zero up to the end of the polynomial.
404+ EXPECT_THAT (
405+ absl::MakeSpan (unpacked_messages_1.data (), unpacked_messages_1.size ())
406+ .subspan (num_messages,
407+ num_packed_messages * num_packing - num_messages),
408+ ::testing::Each (::testing::Eq(0 )));
409+
410+ // Decrypt to obtain a fresh packed plaintext.
411+ BigIntVectorWrapper packed_messages_2{
412+ .ptr = std::make_unique<std::vector<BigInteger>>()};
413+ SECAGG_ASSERT_OK (
414+ UnwrapFfiStatus (Decrypt (ciphertexts, key, params, &packed_messages_2)));
415+
416+ // Now unpack and directly pass the right length to remove padding.
417+ rust::Vec<Integer> unpacked_messages_2;
418+ SECAGG_ASSERT_OK (UnwrapFfiStatus (UnpackMessagesRaw (
419+ packing_base, num_packing, packed_messages_2.ptr ->size (), num_messages,
420+ packed_messages_2, unpacked_messages_2)));
421+ EXPECT_EQ (
422+ absl::MakeSpan (unpacked_messages_2.data (), unpacked_messages_2.size ()),
423+ absl::MakeSpan (input_messages));
424+
425+ // Finally, check that we fail if we request too many unpacked messages
426+ BigIntVectorWrapper packed_messages_3{
427+ .ptr = std::make_unique<std::vector<BigInteger>>()};
428+ SECAGG_ASSERT_OK (
429+ UnwrapFfiStatus (Decrypt (ciphertexts, key, params, &packed_messages_2)));
430+ SECAGG_ASSERT_OK (UnwrapFfiStatus (
431+ Encrypt (packed_messages_3, key, params, &prng, &ciphertexts)));
432+ rust::Vec<Integer> unpacked_messages_3;
433+ int num_unpacked_messages_3 = packed_messages_3.ptr ->size () * num_packing + 1 ;
434+ EXPECT_THAT (
435+ UnwrapFfiStatus (UnpackMessagesRaw (
436+ packing_base, num_packing, packed_messages_3.ptr ->size (),
437+ num_unpacked_messages_3, packed_messages_3, unpacked_messages_3)),
438+ StatusIs (absl::StatusCode::kInvalidArgument ));
439+ }
440+
352441TEST (KaheTest, PackAndEncrypt) {
353442 constexpr int num_packing = 8 ;
354443 constexpr int num_public_polynomials = 2 ;
@@ -399,19 +488,14 @@ TEST(KaheTest, PackAndEncrypt) {
399488 .ptr = std::make_unique<std::vector<BigInteger>>(std::move (decrypted))};
400489 rust::Vec<Integer> unpacked_messages;
401490 SECAGG_ASSERT_OK (UnwrapFfiStatus (
402- UnpackMessagesRaw (packing_base, num_packing, packed_messages. size () ,
403- decrypted_wrapper, unpacked_messages)));
491+ UnpackMessagesRaw (packing_base, num_packing, num_packed_messages ,
492+ num_messages, decrypted_wrapper, unpacked_messages)));
404493 EXPECT_EQ (absl::MakeSpan (unpacked_messages.data (), num_messages),
405494 absl::MakeSpan (expected_unpacked_messages.data (), num_messages));
495+
406496 // Check against the original input messages.
407497 EXPECT_EQ (absl::MakeSpan (unpacked_messages.data (), num_messages),
408- absl::MakeSpan (input_messages).subspan (0 , num_messages));
409- // Check unpacked messages are padded with zeros.
410- ASSERT_GE (expected_unpacked_messages.size (), num_messages);
411- EXPECT_THAT (
412- absl::MakeSpan (unpacked_messages.data (), unpacked_messages.size ())
413- .subspan (num_messages, unpacked_messages.size () - num_messages),
414- ::testing::Each (::testing::Eq(0 )));
498+ absl::MakeSpan (input_messages));
415499}
416500
417501TEST (KaheTest, RawVectorEncryptOnePolynomial) {
@@ -461,7 +545,8 @@ TEST(KaheTest, RawVectorEncryptOnePolynomial) {
461545 rust::Vec<Integer> unpacked_decrypted_messages;
462546 SECAGG_ASSERT_OK (UnwrapFfiStatus (UnpackMessagesRaw (
463547 packing_base, num_packing, decrypted_wrapper.ptr ->size (),
464- decrypted_wrapper, unpacked_decrypted_messages)));
548+ decrypted_wrapper.ptr ->size () * num_packing, decrypted_wrapper,
549+ unpacked_decrypted_messages)));
465550
466551 // Filled the whole buffer with right messages.
467552 EXPECT_EQ (absl::MakeSpan (unpacked_decrypted_messages.data (), num_messages),
@@ -481,6 +566,7 @@ TEST(KaheTest, RawVectorEncryptOnePolynomial) {
481566 unpacked_decrypted_long_messages.reserve (buffer_length);
482567 SECAGG_ASSERT_OK (UnwrapFfiStatus (UnpackMessagesRaw (
483568 packing_base, num_packing, decrypted_long_messages_wrapper.ptr ->size (),
569+ decrypted_long_messages_wrapper.ptr ->size () * num_packing,
484570 decrypted_long_messages_wrapper, unpacked_decrypted_long_messages)));
485571
486572 // The non-zero messages are identical.
@@ -538,10 +624,10 @@ TEST(KaheTest, RawVectorEncryptTwoPolynomials) {
538624 UnwrapFfiStatus (Decrypt (ciphertexts, key, params, &decrypted_wrapper)));
539625 rust::Vec<Integer> unpacked_decrypted_messages;
540626 SECAGG_ASSERT_OK (UnwrapFfiStatus (UnpackMessagesRaw (
541- packing_base, num_packing, decrypted_wrapper.ptr ->size (),
627+ packing_base, num_packing, decrypted_wrapper.ptr ->size (), num_messages,
542628 decrypted_wrapper, unpacked_decrypted_messages)));
543629
544- EXPECT_GE (unpacked_decrypted_messages.size (), num_messages);
630+ EXPECT_EQ (unpacked_decrypted_messages.size (), num_messages);
545631 EXPECT_EQ (absl::MakeSpan (input_messages),
546632 absl::MakeSpan (unpacked_decrypted_messages.data (), num_messages));
547633}
@@ -643,7 +729,8 @@ TEST(KaheTest, UnpackMessagesRawFailsIfUnallocatedPackedValues) {
643729 rust::Vec<Integer> unpacked_messages;
644730 EXPECT_THAT (UnwrapFfiStatus (UnpackMessagesRaw (
645731 packing_base, packing_dimension, num_packed_messages,
646- bad_packed_values, unpacked_messages)),
732+ num_packed_messages * packing_dimension, bad_packed_values,
733+ unpacked_messages)),
647734 StatusIs (absl::StatusCode::kInvalidArgument ));
648735}
649736
@@ -658,7 +745,8 @@ TEST(KaheTest, UnpackMessagesRawFailsIfPackedValuesTooShort) {
658745 rust::Vec<Integer> unpacked_messages;
659746 EXPECT_THAT (UnwrapFfiStatus (UnpackMessagesRaw (
660747 packing_base, packing_dimension, num_packed_messages,
661- bad_packed_values, unpacked_messages)),
748+ num_packed_messages * packing_dimension, bad_packed_values,
749+ unpacked_messages)),
662750 StatusIs (absl::StatusCode::kInvalidArgument ));
663751}
664752
@@ -738,7 +826,7 @@ TEST(KaheTest, AddInPlacePolynomial) {
738826 rust::Vec<Integer> unpacked_decrypted_messages;
739827 unpacked_decrypted_messages.reserve (num_messages);
740828 SECAGG_ASSERT_OK (UnwrapFfiStatus (UnpackMessagesRaw (
741- packing_base, num_packing, decrypted_wrapper.ptr ->size (),
829+ packing_base, num_packing, decrypted_wrapper.ptr ->size (), num_messages,
742830 decrypted_wrapper, unpacked_decrypted_messages)));
743831 for (int i = 0 ; i < num_messages; ++i) {
744832 EXPECT_EQ (input_values1[i] + input_values2[i],
0 commit comments