Skip to content

Commit 2d80d63

Browse files
tholopcopybara-github
authored andcommitted
Remove zero padding when unpacking messages in KAHE decryption.
This requires storing the unpacked vector length in KAHE config. PiperOrigin-RevId: 850474871
1 parent 98d9af8 commit 2d80d63

File tree

10 files changed

+155
-48
lines changed

10 files changed

+155
-48
lines changed

shell_wrapper/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ cc_library(
242242
":status_macros",
243243
"@abseil-cpp//absl/status",
244244
"@abseil-cpp//absl/status:statusor",
245+
"@abseil-cpp//absl/strings:str_format",
245246
"@abseil-cpp//absl/strings:string_view",
246247
"@abseil-cpp//absl/types:span",
247248
"@shell-encryption//shell_encryption/rns:coefficient_encoder",

shell_wrapper/kahe.cc

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
#include "absl/status/status.h"
2727
#include "absl/status/statusor.h"
28+
#include "absl/strings/str_format.h"
2829
#include "absl/strings/string_view.h"
2930
#include "absl/types/span.h"
3031
#include "include/cxx.h"
@@ -347,6 +348,7 @@ FfiStatus PackMessagesRaw(rust::Slice<const uint64_t> messages,
347348

348349
FfiStatus UnpackMessagesRaw(uint64_t packing_base, uint64_t packing_dimension,
349350
uint64_t num_packed_values,
351+
uint64_t num_unpacked_values,
350352
BigIntVectorWrapper& packed_values,
351353
rust::Vec<uint64_t>& out) {
352354
// Validate the wrappers.
@@ -358,15 +360,25 @@ FfiStatus UnpackMessagesRaw(uint64_t packing_base, uint64_t packing_dimension,
358360
return MakeFfiStatus(
359361
absl::InvalidArgumentError("insufficient number of packed values."));
360362
}
363+
364+
// `unpacked_messages` is padded with zeros if needed.
361365
std::vector<uint64_t> unpacked_messages =
362366
rlwe::UnpackMessagesFlat<secure_aggregation::Integer,
363367
secure_aggregation::BigInteger>(
364368
absl::MakeSpan(*packed_values.ptr).subspan(0, num_packed_values),
365369
packing_base, packing_dimension);
366370
packed_values.ptr->erase(packed_values.ptr->begin(),
367371
packed_values.ptr->begin() + num_packed_values);
368-
for (auto& val : unpacked_messages) {
369-
out.push_back(val);
372+
373+
// Remove padding and copy values to Rust output vector.
374+
if (unpacked_messages.size() < num_unpacked_values) {
375+
return MakeFfiStatus(absl::InvalidArgumentError(
376+
absl::StrFormat("unpacked messages is too short (%d) for the requested "
377+
"number of unpacked values (%d)",
378+
unpacked_messages.size(), num_unpacked_values)));
379+
}
380+
for (size_t i = 0; i < num_unpacked_values; ++i) {
381+
out.push_back(unpacked_messages[i]);
370382
}
371383
return MakeFfiStatus();
372384
}

shell_wrapper/kahe.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,14 +154,16 @@ FfiStatus PackMessagesRaw(rust::Slice<const uint64_t> messages,
154154
uint64_t num_packed_values,
155155
BigIntVectorWrapper* packed_values);
156156

157-
// Unpacks messages stored at `packed_values[0..num_packed_values]` and appends
158-
// them to `out`, and removes these packed values from `packed_values`.
157+
// Unpacks messages stored at `packed_values[0..num_packed_values]`, removes
158+
// these packed values from `packed_values` and appends the first
159+
// `num_unpacked_values` messages to `out`.
159160
// Expects `packed_values.ptr` to be a valid pointer to the vector of packed
160161
// values, and expects packing_base > 1, packing_dimension > 0,
161162
// num_packed_values > 0, packing_base^packing_dimension <
162163
// std::numeric_limits<BigInteger>::max().
163164
FfiStatus UnpackMessagesRaw(uint64_t packing_base, uint64_t packing_dimension,
164165
uint64_t num_packed_values,
166+
uint64_t num_unpacked_values,
165167
BigIntVectorWrapper& packed_values,
166168
rust::Vec<uint64_t>& out);
167169

shell_wrapper/kahe.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,22 @@ use std::collections::HashMap;
2222
use std::marker::PhantomData;
2323
use std::mem::MaybeUninit;
2424

25+
/// Configuration for packing and unpacking. Used to convert a long vector of `length` small
26+
/// integers in [0, `base`) into a short vector of `num_packed_coeffs` large integers in
27+
/// [0, `base`^`dimension`), and vice versa.
2528
#[derive(Debug, PartialEq, Clone)]
2629
pub struct PackedVectorConfig {
30+
/// Base for packing.
2731
pub base: u64,
32+
33+
/// Number of elements packed into each coefficient.
2834
pub dimension: u64,
35+
36+
/// Number of coefficients in the packed vector.
2937
pub num_packed_coeffs: u64,
38+
39+
/// Number of elements in the plaintext vector before packing.
40+
pub length: u64,
3041
}
3142

3243
#[cxx::bridge]
@@ -93,6 +104,7 @@ mod ffi {
93104
packing_base: u64,
94105
packing_dimension: u64,
95106
num_packed_values: u64,
107+
num_unpacked_values: u64,
96108
packed_values: &mut BigIntVectorWrapper,
97109
out: &mut Vec<u64>,
98110
) -> FfiStatus;
@@ -260,6 +272,7 @@ pub fn decrypt(
260272
packed_vector_config.base,
261273
packed_vector_config.dimension,
262274
packed_vector_config.num_packed_coeffs,
275+
packed_vector_config.length,
263276
&mut packed_values,
264277
&mut unpacked_values,
265278
)

shell_wrapper/kahe_test.cc

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,8 @@ TEST(KaheTest, UnpackMessagesRawRemovesConsumedPackedValues) {
337337
rust::Vec<Integer> unpacked_messages;
338338
SECAGG_EXPECT_OK(UnwrapFfiStatus(
339339
UnpackMessagesRaw(packing_base, packing_dimension, num_packed_values,
340-
packed_values, unpacked_messages)));
340+
num_packed_values * packing_dimension, packed_values,
341+
unpacked_messages)));
341342
EXPECT_EQ(packed_values.ptr->size(), num_packed_values);
342343
EXPECT_EQ(unpacked_messages.size(), num_packed_values);
343344
// Unpacked values should match the first half of the original packed values.
@@ -399,8 +400,9 @@ TEST(KaheTest, PackAndEncrypt) {
399400
.ptr = std::make_unique<std::vector<BigInteger>>(std::move(decrypted))};
400401
rust::Vec<Integer> unpacked_messages;
401402
SECAGG_ASSERT_OK(UnwrapFfiStatus(
402-
UnpackMessagesRaw(packing_base, num_packing, packed_messages.size(),
403-
decrypted_wrapper, unpacked_messages)));
403+
UnpackMessagesRaw(packing_base, num_packing, num_packed_messages,
404+
num_packed_messages * num_packing, decrypted_wrapper,
405+
unpacked_messages)));
404406
EXPECT_EQ(absl::MakeSpan(unpacked_messages.data(), num_messages),
405407
absl::MakeSpan(expected_unpacked_messages.data(), num_messages));
406408
// Check against the original input messages.
@@ -461,7 +463,8 @@ TEST(KaheTest, RawVectorEncryptOnePolynomial) {
461463
rust::Vec<Integer> unpacked_decrypted_messages;
462464
SECAGG_ASSERT_OK(UnwrapFfiStatus(UnpackMessagesRaw(
463465
packing_base, num_packing, decrypted_wrapper.ptr->size(),
464-
decrypted_wrapper, unpacked_decrypted_messages)));
466+
decrypted_wrapper.ptr->size() * num_packing, decrypted_wrapper,
467+
unpacked_decrypted_messages)));
465468

466469
// Filled the whole buffer with right messages.
467470
EXPECT_EQ(absl::MakeSpan(unpacked_decrypted_messages.data(), num_messages),
@@ -481,6 +484,7 @@ TEST(KaheTest, RawVectorEncryptOnePolynomial) {
481484
unpacked_decrypted_long_messages.reserve(buffer_length);
482485
SECAGG_ASSERT_OK(UnwrapFfiStatus(UnpackMessagesRaw(
483486
packing_base, num_packing, decrypted_long_messages_wrapper.ptr->size(),
487+
decrypted_long_messages_wrapper.ptr->size() * num_packing,
484488
decrypted_long_messages_wrapper, unpacked_decrypted_long_messages)));
485489

486490
// The non-zero messages are identical.
@@ -493,6 +497,38 @@ TEST(KaheTest, RawVectorEncryptOnePolynomial) {
493497
unpacked_decrypted_long_messages.size())
494498
.subspan(num_messages, padded_length - num_messages),
495499
::testing::Each(::testing::Eq(0)));
500+
501+
// Prepare a fresh packed plaintext.
502+
BigIntVectorWrapper packed_messages_2{
503+
.ptr = std::make_unique<std::vector<BigInteger>>()};
504+
SECAGG_ASSERT_OK(
505+
UnwrapFfiStatus(Decrypt(ciphertexts, key, params, &packed_messages_2)));
506+
SECAGG_ASSERT_OK(UnwrapFfiStatus(
507+
Encrypt(packed_messages_2, key, params, &prng, &ciphertexts)));
508+
509+
// Now unpack and directly pass the right length to remove padding.
510+
rust::Vec<Integer> unpacked_messages_2;
511+
SECAGG_ASSERT_OK(UnwrapFfiStatus(UnpackMessagesRaw(
512+
packing_base, num_packing, packed_messages_2.ptr->size(), num_messages,
513+
packed_messages_2, unpacked_messages_2)));
514+
EXPECT_EQ(
515+
absl::MakeSpan(unpacked_messages_2.data(), unpacked_messages_2.size()),
516+
absl::MakeSpan(input_messages));
517+
518+
// Finally, check that we fail if we request too many unpacked messages
519+
BigIntVectorWrapper packed_messages_3{
520+
.ptr = std::make_unique<std::vector<BigInteger>>()};
521+
SECAGG_ASSERT_OK(
522+
UnwrapFfiStatus(Decrypt(ciphertexts, key, params, &packed_messages_2)));
523+
SECAGG_ASSERT_OK(UnwrapFfiStatus(
524+
Encrypt(packed_messages_3, key, params, &prng, &ciphertexts)));
525+
rust::Vec<Integer> unpacked_messages_3;
526+
int num_unpacked_messages_3 = packed_messages_3.ptr->size() * num_packing + 1;
527+
EXPECT_THAT(
528+
UnwrapFfiStatus(UnpackMessagesRaw(
529+
packing_base, num_packing, packed_messages_3.ptr->size(),
530+
num_unpacked_messages_3, packed_messages_3, unpacked_messages_3)),
531+
StatusIs(absl::StatusCode::kInvalidArgument));
496532
}
497533

498534
TEST(KaheTest, RawVectorEncryptTwoPolynomials) {
@@ -538,7 +574,7 @@ TEST(KaheTest, RawVectorEncryptTwoPolynomials) {
538574
UnwrapFfiStatus(Decrypt(ciphertexts, key, params, &decrypted_wrapper)));
539575
rust::Vec<Integer> unpacked_decrypted_messages;
540576
SECAGG_ASSERT_OK(UnwrapFfiStatus(UnpackMessagesRaw(
541-
packing_base, num_packing, decrypted_wrapper.ptr->size(),
577+
packing_base, num_packing, decrypted_wrapper.ptr->size(), num_messages,
542578
decrypted_wrapper, unpacked_decrypted_messages)));
543579

544580
EXPECT_GE(unpacked_decrypted_messages.size(), num_messages);
@@ -643,7 +679,8 @@ TEST(KaheTest, UnpackMessagesRawFailsIfUnallocatedPackedValues) {
643679
rust::Vec<Integer> unpacked_messages;
644680
EXPECT_THAT(UnwrapFfiStatus(UnpackMessagesRaw(
645681
packing_base, packing_dimension, num_packed_messages,
646-
bad_packed_values, unpacked_messages)),
682+
num_packed_messages * packing_dimension, bad_packed_values,
683+
unpacked_messages)),
647684
StatusIs(absl::StatusCode::kInvalidArgument));
648685
}
649686

@@ -658,7 +695,8 @@ TEST(KaheTest, UnpackMessagesRawFailsIfPackedValuesTooShort) {
658695
rust::Vec<Integer> unpacked_messages;
659696
EXPECT_THAT(UnwrapFfiStatus(UnpackMessagesRaw(
660697
packing_base, packing_dimension, num_packed_messages,
661-
bad_packed_values, unpacked_messages)),
698+
num_packed_messages * packing_dimension, bad_packed_values,
699+
unpacked_messages)),
662700
StatusIs(absl::StatusCode::kInvalidArgument));
663701
}
664702

@@ -738,7 +776,7 @@ TEST(KaheTest, AddInPlacePolynomial) {
738776
rust::Vec<Integer> unpacked_decrypted_messages;
739777
unpacked_decrypted_messages.reserve(num_messages);
740778
SECAGG_ASSERT_OK(UnwrapFfiStatus(UnpackMessagesRaw(
741-
packing_base, num_packing, decrypted_wrapper.ptr->size(),
779+
packing_base, num_packing, decrypted_wrapper.ptr->size(), num_messages,
742780
decrypted_wrapper, unpacked_decrypted_messages)));
743781
for (int i = 0; i < num_messages; ++i) {
744782
EXPECT_EQ(input_values1[i] + input_values2[i],

shell_wrapper/kahe_test.rs

Lines changed: 16 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,13 @@ fn encrypt_decrypt() -> Result<()> {
4747
let plaintext = HashMap::from([(DEFAULT_ID, input_values.as_slice())]);
4848
let packed_vector_configs = HashMap::from([(
4949
DEFAULT_ID.to_string(),
50-
PackedVectorConfig { base: 10, dimension: 2, num_packed_coeffs: 2 },
50+
PackedVectorConfig { base: 10, dimension: 2, num_packed_coeffs: 2, length: 3 },
5151
)]);
5252
let ciphertext = encrypt(&plaintext, &packed_vector_configs, &secret_key, &params, &mut prng)?;
5353

5454
let output_values = decrypt(&ciphertext, &secret_key, &params, &packed_vector_configs)?;
5555
expect_that!(output_values.contains_key(DEFAULT_ID), eq(true));
56-
expect_that!(output_values[DEFAULT_ID][..3], container_eq(input_values));
56+
expect_that!(output_values[DEFAULT_ID], container_eq(input_values));
5757
Ok(())
5858
}
5959

@@ -80,14 +80,16 @@ fn encrypt_decrypt_padding() -> Result<()> {
8080
let input_values: Vec<u64> =
8181
(0..num_input_values).map(|_| rand::thread_rng().gen_range(0..input_domain)).collect();
8282

83-
// Encrypt the vector.
83+
// Encrypt the vector. Pass a longer length than what we need.
84+
let padded_length = (num_packed_coeffs * packing_dimension) as usize;
8485
let plaintext = HashMap::from([(DEFAULT_ID, input_values.as_slice())]);
8586
let packed_vector_configs = HashMap::from([(
8687
DEFAULT_ID.to_string(),
8788
PackedVectorConfig {
8889
base: input_domain as u64,
8990
dimension: packing_dimension as u64,
9091
num_packed_coeffs: num_packed_coeffs as u64,
92+
length: padded_length as u64,
9193
},
9294
)]);
9395
let ciphertext = encrypt(&plaintext, &packed_vector_configs, &secret_key, &params, &mut prng)?;
@@ -97,7 +99,6 @@ fn encrypt_decrypt_padding() -> Result<()> {
9799
let output_values = &decrypted[DEFAULT_ID];
98100

99101
// Check that message is correctly decrypted with right padding.
100-
let padded_length = (num_packed_coeffs * packing_dimension) as usize;
101102
expect_that!(output_values.len(), eq(padded_length));
102103
expect_that!(output_values.len(), gt(num_input_values));
103104
expect_that!(output_values[..num_input_values], container_eq(input_values));
@@ -138,22 +139,17 @@ fn encrypt_decrypt_long() -> Result<()> {
138139
base: input_domain as u64,
139140
dimension: packing_dimension as u64,
140141
num_packed_coeffs: num_packed_coeffs as u64,
142+
length: num_input_values as u64,
141143
},
142144
)]);
143145
let ciphertext = encrypt(&plaintext, &packed_vector_configs, &secret_key, &params, &mut prng)?;
144146

145147
let decrypted = decrypt(&ciphertext, &secret_key, &params, &packed_vector_configs)?;
146148
let output_values = &decrypted[DEFAULT_ID];
147149

148-
// Check that message is correctly decrypted with right padding.
149-
let padded_length = num_packed_coeffs * packing_dimension;
150-
expect_that!(output_values.len(), eq(padded_length));
151-
expect_that!(output_values.len(), gt(num_input_values));
152-
expect_that!(output_values[..num_input_values], container_eq(input_values));
153-
expect_that!(
154-
output_values[num_input_values..],
155-
container_eq(vec![0; padded_length - num_input_values])
156-
);
150+
// Check that message is correctly decrypted (no padding).
151+
expect_that!(output_values.len(), eq(num_input_values));
152+
expect_that!(output_values, container_eq(input_values));
157153

158154
// If the input is too long, we should fail.
159155
let num_values_too_long = num_public_polynomials * poly_capacity + 1;
@@ -195,6 +191,7 @@ fn encrypt_decrypt_two_vectors() -> Result<()> {
195191
base: input_domains[0] as u64,
196192
dimension: packing_dimensions[0] as u64,
197193
num_packed_coeffs: num_packed_coeffs[0] as u64,
194+
length: num_input_values[0] as u64,
198195
},
199196
),
200197
(
@@ -203,6 +200,7 @@ fn encrypt_decrypt_two_vectors() -> Result<()> {
203200
base: input_domains[1] as u64,
204201
dimension: packing_dimensions[1] as u64,
205202
num_packed_coeffs: num_packed_coeffs[1] as u64,
203+
length: num_input_values[1] as u64,
206204
},
207205
),
208206
]);
@@ -218,26 +216,16 @@ fn encrypt_decrypt_two_vectors() -> Result<()> {
218216
HashMap::from([(ID0, input_values0.as_slice()), (ID1, input_values1.as_slice())]);
219217
let ciphertext = encrypt(&plaintext, &packed_vector_configs, &secret_key, &params, &mut prng)?;
220218

221-
// Decrypt and check the output contains the two vectors that are padded correctly.
219+
// Decrypt and check the output contains the two vectors.
222220
let decrypted = decrypt(&ciphertext, &secret_key, &params, &packed_vector_configs)?;
223221
verify_that!(decrypted.contains_key(ID0), eq(true))?;
224222
verify_that!(decrypted.contains_key(ID1), eq(true))?;
225223

226224
let output_values0 = &decrypted[ID0];
227225
let output_values1 = &decrypted[ID1];
228-
expect_that!(output_values0.len(), eq(num_packed_coeffs[0] * packing_dimensions[0]));
229-
expect_that!(output_values0.len(), gt(num_input_values[0]));
230-
expect_that!(output_values0[..num_input_values[0]], container_eq(input_values0));
231-
expect_that!(
232-
output_values0[num_input_values[0]..],
233-
container_eq(vec![0; num_packed_coeffs[0] * packing_dimensions[0] - num_input_values[0]])
234-
);
235-
expect_that!(output_values1.len(), eq(num_packed_coeffs[1] * packing_dimensions[1]));
236-
expect_that!(output_values1.len(), gt(num_input_values[1]));
237-
expect_that!(output_values1[..num_input_values[1]], container_eq(input_values1));
238-
expect_that!(
239-
output_values1[num_input_values[1]..],
240-
container_eq(vec![0; num_packed_coeffs[1] * packing_dimensions[1] - num_input_values[1]])
241-
);
226+
expect_that!(output_values0.len(), eq(num_input_values[0]));
227+
expect_that!(output_values0, container_eq(input_values0));
228+
expect_that!(output_values1.len(), eq(num_input_values[1]));
229+
expect_that!(output_values1, container_eq(input_values1));
242230
Ok(())
243231
}

willow/proto/shell/parameters.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ message PackedVectorConfigProto {
2828
int64 base = 1;
2929
int64 dimension = 2;
3030
int64 num_packed_coeffs = 3;
31+
int64 length = 4;
3132
}
3233

3334
// This proto defines the parameters for instantiating the KAHE scheme

0 commit comments

Comments
 (0)