Skip to content

Commit 7af55b5

Browse files
tholopcopybara-github
authored andcommitted
Remove zero padding when unpacking messages in KAHE decryption.
This requires storing the unpacked vector length in KAHE config. PiperOrigin-RevId: 850560265
1 parent 98d9af8 commit 7af55b5

File tree

10 files changed

+217
-60
lines changed

10 files changed

+217
-60
lines changed

shell_wrapper/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ cc_library(
242242
":status_macros",
243243
"@abseil-cpp//absl/status",
244244
"@abseil-cpp//absl/status:statusor",
245+
"@abseil-cpp//absl/strings:str_format",
245246
"@abseil-cpp//absl/strings:string_view",
246247
"@abseil-cpp//absl/types:span",
247248
"@shell-encryption//shell_encryption/rns:coefficient_encoder",

shell_wrapper/kahe.cc

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
#include "absl/status/status.h"
2727
#include "absl/status/statusor.h"
28+
#include "absl/strings/str_format.h"
2829
#include "absl/strings/string_view.h"
2930
#include "absl/types/span.h"
3031
#include "include/cxx.h"
@@ -347,6 +348,7 @@ FfiStatus PackMessagesRaw(rust::Slice<const uint64_t> messages,
347348

348349
FfiStatus UnpackMessagesRaw(uint64_t packing_base, uint64_t packing_dimension,
349350
uint64_t num_packed_values,
351+
uint64_t num_unpacked_values,
350352
BigIntVectorWrapper& packed_values,
351353
rust::Vec<uint64_t>& out) {
352354
// Validate the wrappers.
@@ -358,15 +360,25 @@ FfiStatus UnpackMessagesRaw(uint64_t packing_base, uint64_t packing_dimension,
358360
return MakeFfiStatus(
359361
absl::InvalidArgumentError("insufficient number of packed values."));
360362
}
363+
364+
// `unpacked_messages` is padded with zeros if needed.
361365
std::vector<uint64_t> unpacked_messages =
362366
rlwe::UnpackMessagesFlat<secure_aggregation::Integer,
363367
secure_aggregation::BigInteger>(
364368
absl::MakeSpan(*packed_values.ptr).subspan(0, num_packed_values),
365369
packing_base, packing_dimension);
366370
packed_values.ptr->erase(packed_values.ptr->begin(),
367371
packed_values.ptr->begin() + num_packed_values);
368-
for (auto& val : unpacked_messages) {
369-
out.push_back(val);
372+
373+
// Remove padding and copy values to Rust output vector.
374+
if (unpacked_messages.size() < num_unpacked_values) {
375+
return MakeFfiStatus(absl::InvalidArgumentError(
376+
absl::StrFormat("unpacked messages is too short (%d) for the requested "
377+
"number of unpacked values (%d)",
378+
unpacked_messages.size(), num_unpacked_values)));
379+
}
380+
for (size_t i = 0; i < num_unpacked_values; ++i) {
381+
out.push_back(unpacked_messages[i]);
370382
}
371383
return MakeFfiStatus();
372384
}

shell_wrapper/kahe.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,14 +154,16 @@ FfiStatus PackMessagesRaw(rust::Slice<const uint64_t> messages,
154154
uint64_t num_packed_values,
155155
BigIntVectorWrapper* packed_values);
156156

157-
// Unpacks messages stored at `packed_values[0..num_packed_values]` and appends
158-
// them to `out`, and removes these packed values from `packed_values`.
157+
// Unpacks messages stored at `packed_values[0..num_packed_values]`, removes
158+
// these packed values from `packed_values` and appends the first
159+
// `num_unpacked_values` messages to `out`.
159160
// Expects `packed_values.ptr` to be a valid pointer to the vector of packed
160161
// values, and expects packing_base > 1, packing_dimension > 0,
161162
// num_packed_values > 0, packing_base^packing_dimension <
162163
// std::numeric_limits<BigInteger>::max().
163164
FfiStatus UnpackMessagesRaw(uint64_t packing_base, uint64_t packing_dimension,
164165
uint64_t num_packed_values,
166+
uint64_t num_unpacked_values,
165167
BigIntVectorWrapper& packed_values,
166168
rust::Vec<uint64_t>& out);
167169

shell_wrapper/kahe.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,22 @@ use std::collections::HashMap;
2222
use std::marker::PhantomData;
2323
use std::mem::MaybeUninit;
2424

25+
/// Configuration for packing and unpacking. Used to convert a long vector of `length` small
26+
/// integers in [0, `base`) into a short vector of `num_packed_coeffs` large integers in
27+
/// [0, `base`^`dimension`), and vice versa.
2528
#[derive(Debug, PartialEq, Clone)]
2629
pub struct PackedVectorConfig {
30+
/// Base for packing.
2731
pub base: u64,
32+
33+
/// Number of elements packed into each coefficient.
2834
pub dimension: u64,
35+
36+
/// Number of coefficients in the packed vector.
2937
pub num_packed_coeffs: u64,
38+
39+
/// Number of elements in the plaintext vector before packing.
40+
pub length: u64,
3041
}
3142

3243
#[cxx::bridge]
@@ -93,6 +104,7 @@ mod ffi {
93104
packing_base: u64,
94105
packing_dimension: u64,
95106
num_packed_values: u64,
107+
num_unpacked_values: u64,
96108
packed_values: &mut BigIntVectorWrapper,
97109
out: &mut Vec<u64>,
98110
) -> FfiStatus;
@@ -260,6 +272,7 @@ pub fn decrypt(
260272
packed_vector_config.base,
261273
packed_vector_config.dimension,
262274
packed_vector_config.num_packed_coeffs,
275+
packed_vector_config.length,
263276
&mut packed_values,
264277
&mut unpacked_values,
265278
)

shell_wrapper/kahe_test.cc

Lines changed: 104 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,8 @@ TEST(KaheTest, UnpackMessagesRawRemovesConsumedPackedValues) {
337337
rust::Vec<Integer> unpacked_messages;
338338
SECAGG_EXPECT_OK(UnwrapFfiStatus(
339339
UnpackMessagesRaw(packing_base, packing_dimension, num_packed_values,
340-
packed_values, unpacked_messages)));
340+
num_packed_values * packing_dimension, packed_values,
341+
unpacked_messages)));
341342
EXPECT_EQ(packed_values.ptr->size(), num_packed_values);
342343
EXPECT_EQ(unpacked_messages.size(), num_packed_values);
343344
// Unpacked values should match the first half of the original packed values.
@@ -349,6 +350,94 @@ TEST(KaheTest, UnpackMessagesRawRemovesConsumedPackedValues) {
349350
absl::MakeSpan(packed).subspan(num_packed_values));
350351
}
351352

353+
TEST(KaheTest, RawEncryptDecryptPadding) {
354+
constexpr int num_packing = 2;
355+
constexpr int num_public_polynomials = 2;
356+
constexpr int num_messages = 9;
357+
constexpr Integer packing_base = 10;
358+
// 2 messages per coefficient, last coefficient has only 1 message.
359+
constexpr int num_packed_messages = 5;
360+
361+
std::unique_ptr<std::string> public_seed;
362+
SECAGG_ASSERT_OK(UnwrapFfiStatus(GenerateSingleThreadHkdfSeed(public_seed)));
363+
KahePublicParametersWrapper params;
364+
SECAGG_ASSERT_OK(UnwrapFfiStatus(CreateKahePublicParametersWrapper(
365+
kLogN, kLogT, ToRustSlice(kQs), num_public_polynomials,
366+
ToRustSlice(*public_seed), &params)));
367+
std::unique_ptr<std::string> private_seed;
368+
SECAGG_ASSERT_OK(UnwrapFfiStatus(GenerateSingleThreadHkdfSeed(private_seed)));
369+
SingleThreadHkdfWrapper prng;
370+
SECAGG_ASSERT_OK(UnwrapFfiStatus(
371+
CreateSingleThreadHkdf(ToRustSlice(*private_seed), prng)));
372+
RnsPolynomialWrapper key;
373+
SECAGG_ASSERT_OK(
374+
UnwrapFfiStatus(GenerateSecretKeyWrapper(params, &prng, &key)));
375+
376+
// Pack messages that don't fully occupy the packed coefficients.
377+
std::vector<Integer> input_messages =
378+
rlwe::testing::SampleMessages(num_messages, packing_base);
379+
BigIntVectorWrapper packed_messages{
380+
.ptr = std::make_unique<std::vector<BigInteger>>()};
381+
SECAGG_ASSERT_OK(UnwrapFfiStatus(
382+
PackMessagesRaw(ToRustSlice(input_messages), packing_base, num_packing,
383+
num_packed_messages, &packed_messages)));
384+
385+
// Encrypt the packed messages.
386+
RnsPolynomialVecWrapper ciphertexts;
387+
SECAGG_ASSERT_OK(UnwrapFfiStatus(
388+
Encrypt(packed_messages, key, params, &prng, &ciphertexts)));
389+
390+
// Decrypt to get a packed plaintext.
391+
BigIntVectorWrapper packed_messages_1{
392+
.ptr = std::make_unique<std::vector<BigInteger>>()};
393+
SECAGG_ASSERT_OK(
394+
UnwrapFfiStatus(Decrypt(ciphertexts, key, params, &packed_messages_1)));
395+
396+
// Unpack and retrieve the original messages plus padding.
397+
rust::Vec<Integer> unpacked_messages_1;
398+
SECAGG_ASSERT_OK(UnwrapFfiStatus(UnpackMessagesRaw(
399+
packing_base, num_packing, packed_messages_1.ptr->size(),
400+
num_packed_messages * num_packing, packed_messages_1,
401+
unpacked_messages_1)));
402+
403+
// Decrypted messages are padded to zero up to the end of the polynomial.
404+
EXPECT_THAT(
405+
absl::MakeSpan(unpacked_messages_1.data(), unpacked_messages_1.size())
406+
.subspan(num_messages,
407+
num_packed_messages * num_packing - num_messages),
408+
::testing::Each(::testing::Eq(0)));
409+
410+
// Decrypt to obtain a fresh packed plaintext.
411+
BigIntVectorWrapper packed_messages_2{
412+
.ptr = std::make_unique<std::vector<BigInteger>>()};
413+
SECAGG_ASSERT_OK(
414+
UnwrapFfiStatus(Decrypt(ciphertexts, key, params, &packed_messages_2)));
415+
416+
// Now unpack and directly pass the right length to remove padding.
417+
rust::Vec<Integer> unpacked_messages_2;
418+
SECAGG_ASSERT_OK(UnwrapFfiStatus(UnpackMessagesRaw(
419+
packing_base, num_packing, packed_messages_2.ptr->size(), num_messages,
420+
packed_messages_2, unpacked_messages_2)));
421+
EXPECT_EQ(
422+
absl::MakeSpan(unpacked_messages_2.data(), unpacked_messages_2.size()),
423+
absl::MakeSpan(input_messages));
424+
425+
// Finally, check that we fail if we request too many unpacked messages
426+
BigIntVectorWrapper packed_messages_3{
427+
.ptr = std::make_unique<std::vector<BigInteger>>()};
428+
SECAGG_ASSERT_OK(
429+
UnwrapFfiStatus(Decrypt(ciphertexts, key, params, &packed_messages_2)));
430+
SECAGG_ASSERT_OK(UnwrapFfiStatus(
431+
Encrypt(packed_messages_3, key, params, &prng, &ciphertexts)));
432+
rust::Vec<Integer> unpacked_messages_3;
433+
int num_unpacked_messages_3 = packed_messages_3.ptr->size() * num_packing + 1;
434+
EXPECT_THAT(
435+
UnwrapFfiStatus(UnpackMessagesRaw(
436+
packing_base, num_packing, packed_messages_3.ptr->size(),
437+
num_unpacked_messages_3, packed_messages_3, unpacked_messages_3)),
438+
StatusIs(absl::StatusCode::kInvalidArgument));
439+
}
440+
352441
TEST(KaheTest, PackAndEncrypt) {
353442
constexpr int num_packing = 8;
354443
constexpr int num_public_polynomials = 2;
@@ -399,19 +488,14 @@ TEST(KaheTest, PackAndEncrypt) {
399488
.ptr = std::make_unique<std::vector<BigInteger>>(std::move(decrypted))};
400489
rust::Vec<Integer> unpacked_messages;
401490
SECAGG_ASSERT_OK(UnwrapFfiStatus(
402-
UnpackMessagesRaw(packing_base, num_packing, packed_messages.size(),
403-
decrypted_wrapper, unpacked_messages)));
491+
UnpackMessagesRaw(packing_base, num_packing, num_packed_messages,
492+
num_messages, decrypted_wrapper, unpacked_messages)));
404493
EXPECT_EQ(absl::MakeSpan(unpacked_messages.data(), num_messages),
405494
absl::MakeSpan(expected_unpacked_messages.data(), num_messages));
495+
406496
// Check against the original input messages.
407497
EXPECT_EQ(absl::MakeSpan(unpacked_messages.data(), num_messages),
408-
absl::MakeSpan(input_messages).subspan(0, num_messages));
409-
// Check unpacked messages are padded with zeros.
410-
ASSERT_GE(expected_unpacked_messages.size(), num_messages);
411-
EXPECT_THAT(
412-
absl::MakeSpan(unpacked_messages.data(), unpacked_messages.size())
413-
.subspan(num_messages, unpacked_messages.size() - num_messages),
414-
::testing::Each(::testing::Eq(0)));
498+
absl::MakeSpan(input_messages));
415499
}
416500

417501
TEST(KaheTest, RawVectorEncryptOnePolynomial) {
@@ -461,7 +545,8 @@ TEST(KaheTest, RawVectorEncryptOnePolynomial) {
461545
rust::Vec<Integer> unpacked_decrypted_messages;
462546
SECAGG_ASSERT_OK(UnwrapFfiStatus(UnpackMessagesRaw(
463547
packing_base, num_packing, decrypted_wrapper.ptr->size(),
464-
decrypted_wrapper, unpacked_decrypted_messages)));
548+
decrypted_wrapper.ptr->size() * num_packing, decrypted_wrapper,
549+
unpacked_decrypted_messages)));
465550

466551
// Filled the whole buffer with right messages.
467552
EXPECT_EQ(absl::MakeSpan(unpacked_decrypted_messages.data(), num_messages),
@@ -481,6 +566,7 @@ TEST(KaheTest, RawVectorEncryptOnePolynomial) {
481566
unpacked_decrypted_long_messages.reserve(buffer_length);
482567
SECAGG_ASSERT_OK(UnwrapFfiStatus(UnpackMessagesRaw(
483568
packing_base, num_packing, decrypted_long_messages_wrapper.ptr->size(),
569+
decrypted_long_messages_wrapper.ptr->size() * num_packing,
484570
decrypted_long_messages_wrapper, unpacked_decrypted_long_messages)));
485571

486572
// The non-zero messages are identical.
@@ -538,10 +624,10 @@ TEST(KaheTest, RawVectorEncryptTwoPolynomials) {
538624
UnwrapFfiStatus(Decrypt(ciphertexts, key, params, &decrypted_wrapper)));
539625
rust::Vec<Integer> unpacked_decrypted_messages;
540626
SECAGG_ASSERT_OK(UnwrapFfiStatus(UnpackMessagesRaw(
541-
packing_base, num_packing, decrypted_wrapper.ptr->size(),
627+
packing_base, num_packing, decrypted_wrapper.ptr->size(), num_messages,
542628
decrypted_wrapper, unpacked_decrypted_messages)));
543629

544-
EXPECT_GE(unpacked_decrypted_messages.size(), num_messages);
630+
EXPECT_EQ(unpacked_decrypted_messages.size(), num_messages);
545631
EXPECT_EQ(absl::MakeSpan(input_messages),
546632
absl::MakeSpan(unpacked_decrypted_messages.data(), num_messages));
547633
}
@@ -643,7 +729,8 @@ TEST(KaheTest, UnpackMessagesRawFailsIfUnallocatedPackedValues) {
643729
rust::Vec<Integer> unpacked_messages;
644730
EXPECT_THAT(UnwrapFfiStatus(UnpackMessagesRaw(
645731
packing_base, packing_dimension, num_packed_messages,
646-
bad_packed_values, unpacked_messages)),
732+
num_packed_messages * packing_dimension, bad_packed_values,
733+
unpacked_messages)),
647734
StatusIs(absl::StatusCode::kInvalidArgument));
648735
}
649736

@@ -658,7 +745,8 @@ TEST(KaheTest, UnpackMessagesRawFailsIfPackedValuesTooShort) {
658745
rust::Vec<Integer> unpacked_messages;
659746
EXPECT_THAT(UnwrapFfiStatus(UnpackMessagesRaw(
660747
packing_base, packing_dimension, num_packed_messages,
661-
bad_packed_values, unpacked_messages)),
748+
num_packed_messages * packing_dimension, bad_packed_values,
749+
unpacked_messages)),
662750
StatusIs(absl::StatusCode::kInvalidArgument));
663751
}
664752

@@ -738,7 +826,7 @@ TEST(KaheTest, AddInPlacePolynomial) {
738826
rust::Vec<Integer> unpacked_decrypted_messages;
739827
unpacked_decrypted_messages.reserve(num_messages);
740828
SECAGG_ASSERT_OK(UnwrapFfiStatus(UnpackMessagesRaw(
741-
packing_base, num_packing, decrypted_wrapper.ptr->size(),
829+
packing_base, num_packing, decrypted_wrapper.ptr->size(), num_messages,
742830
decrypted_wrapper, unpacked_decrypted_messages)));
743831
for (int i = 0; i < num_messages; ++i) {
744832
EXPECT_EQ(input_values1[i] + input_values2[i],

0 commit comments

Comments
 (0)