Skip to content

Commit 599d077

Browse files
tholopcopybara-github
authored andcommitted
Replace session_id by key_id; use key_id as Kahe/Vahe context directly.
PiperOrigin-RevId: 859265873
1 parent 9400890 commit 599d077

18 files changed

+64
-76
lines changed

willow/benches/shell_benchmarks.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ fn setup_base(args: &Args) -> BaseInputs {
123123
max_number_of_decryptors: 1,
124124
max_number_of_clients: args.max_num_clients as i64,
125125
max_decryptor_dropouts: 0,
126-
session_id: String::from("benchmark"),
126+
key_id: b"benchmark".to_vec(),
127127
};
128128
let ahe_config = create_shell_ahe_config(aggregation_config.max_number_of_decryptors).unwrap();
129129
let kahe_config = create_shell_kahe_config(&aggregation_config).unwrap();

willow/proto/willow/aggregation_config.proto

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,11 @@ option java_outer_classname = "AggregationConfigProto";
2222
// The configuration of the aggregation as a proto.
2323
message AggregationConfigProto {
2424
map<string, VectorConfig> vector_configs = 1;
25-
int64 max_number_of_decryptors = 5;
2625
int64 max_decryptor_dropouts = 2;
2726
int64 max_number_of_clients = 3;
28-
string session_id = 4;
27+
string session_id = 4 [deprecated = true];
28+
int64 max_number_of_decryptors = 5;
29+
bytes key_id = 6;
2930
}
3031

3132
// The configuration for a single vector in an aggregation.

willow/src/api/aggregation_config.rs

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,16 @@ use std::collections::HashMap;
2929
/// aggregation failing.
3030
/// max_number_of_clients: The maximum number of clients that will participate in the
3131
/// aggregation.
32-
/// session_id: The session id of the aggregation.
32+
/// key_id: The key id of the aggregation, used as context_bytes to seed Kahe
33+
/// and Vahe public parameters. Must be unique for each instantiation.
3334
/// willow_version: The version of the willow protocol.
3435
#[derive(Debug, Clone, PartialEq, Eq)]
3536
pub struct AggregationConfig {
3637
pub vector_lengths_and_bounds: HashMap<String, (isize, i64)>,
3738
pub max_number_of_decryptors: i64,
3839
pub max_decryptor_dropouts: i64,
3940
pub max_number_of_clients: i64,
40-
pub session_id: String,
41+
pub key_id: Vec<u8>,
4142
}
4243

4344
impl FromProto for AggregationConfig {
@@ -57,7 +58,7 @@ impl FromProto for AggregationConfig {
5758
max_number_of_decryptors: proto.max_number_of_decryptors(),
5859
max_decryptor_dropouts: proto.max_decryptor_dropouts(),
5960
max_number_of_clients: proto.max_number_of_clients(),
60-
session_id: proto.session_id().to_string(),
61+
key_id: proto.key_id().to_vec(),
6162
})
6263
}
6364
}
@@ -71,7 +72,7 @@ impl ToProto for AggregationConfig {
7172
max_number_of_decryptors: self.max_number_of_decryptors,
7273
max_decryptor_dropouts: self.max_decryptor_dropouts,
7374
max_number_of_clients: self.max_number_of_clients,
74-
session_id: self.session_id.clone(),
75+
key_id: self.key_id.clone(),
7576
});
7677
aggregation_config_proto.vector_configs_mut().copy_from(
7778
self.vector_lengths_and_bounds.iter().map(|(key, (length, bound))| {
@@ -82,19 +83,6 @@ impl ToProto for AggregationConfig {
8283
}
8384
}
8485

85-
impl AggregationConfig {
86-
/// Computes context bytes by hashing the session ID in the config.
87-
pub fn compute_context_bytes(&self) -> Result<Vec<u8>, StatusError> {
88-
let context_seed = single_thread_hkdf::compute_hkdf(
89-
self.session_id.as_bytes(),
90-
b"",
91-
b"AggregationConfig.context_string",
92-
single_thread_hkdf::seed_length(),
93-
)?;
94-
Ok(context_seed.as_bytes().to_vec())
95-
}
96-
}
97-
9886
#[cfg(test)]
9987
mod tests {
10088
use crate::AggregationConfig;
@@ -109,7 +97,7 @@ mod tests {
10997
max_number_of_decryptors: 1,
11098
max_decryptor_dropouts: 0,
11199
max_number_of_clients: 1,
112-
session_id: String::from("test"),
100+
key_id: b"test".to_vec(),
113101
};
114102

115103
verify_that!(

willow/src/api/client.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ impl WillowShellClient {
8383
})?;
8484
let aggregation_config = AggregationConfig::from_proto(aggregation_config_proto, ())?;
8585
let (kahe_config, ahe_config) = create_shell_configs(&aggregation_config)?;
86-
let context_bytes = aggregation_config.compute_context_bytes()?;
86+
let context_bytes = &aggregation_config.key_id;
8787
let kahe = ShellKahe::new(kahe_config, &context_bytes)?;
8888
let vahe = ShellVahe::new(ahe_config, &context_bytes)?;
8989
let client = WillowV1Client::new_with_randomly_generated_seed(kahe, vahe)?;

willow/src/api/client_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ AggregationConfigProto CreateTestConfig() {
4949
(*config.mutable_vector_configs())["metric1"] = vector_config;
5050
config.set_max_number_of_decryptors(1);
5151
config.set_max_number_of_clients(10);
52-
config.set_session_id("test");
52+
config.set_key_id("test");
5353
return config;
5454
}
5555

willow/src/api/server_accumulator.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -145,11 +145,11 @@ pub struct ServerAccumulator {
145145

146146
impl ServerAccumulator {
147147
fn new(aggregation_config: AggregationConfig) -> Result<Self, StatusError> {
148-
let context_string = aggregation_config.compute_context_bytes()?;
149148
let (kahe_config, vahe_config) = create_shell_configs(&aggregation_config)?;
150-
let server_kahe = ShellKahe::new(kahe_config, &context_string)?;
151-
let server_vahe = ShellVahe::new(vahe_config.clone(), &context_string)?;
152-
let verifier_vahe = ShellVahe::new(vahe_config, &context_string)?;
149+
let context_bytes = &aggregation_config.key_id;
150+
let server_kahe = ShellKahe::new(kahe_config, context_bytes)?;
151+
let server_vahe = ShellVahe::new(vahe_config.clone(), context_bytes)?;
152+
let verifier_vahe = ShellVahe::new(vahe_config, context_bytes)?;
153153
let server = WillowV1Server { kahe: server_kahe, vahe: server_vahe };
154154
let verifier = WillowV1Verifier { vahe: verifier_vahe };
155155
Ok(Self {
@@ -659,10 +659,10 @@ impl FinalResultDecryptor {
659659

660660
// Build server that holds the necessary KAHE and AHE contexts, and recover server state.
661661
let aggregation_config = AggregationConfig::from_proto(aggregation_config_proto, ())?;
662-
let context_string = aggregation_config.compute_context_bytes()?;
663662
let (kahe_config, vahe_config) = create_shell_configs(&aggregation_config)?;
664-
let kahe = ShellKahe::new(kahe_config, &context_string)?;
665-
let vahe = ShellVahe::new(vahe_config, &context_string)?;
663+
let context_bytes = &aggregation_config.key_id;
664+
let kahe = ShellKahe::new(kahe_config, context_bytes)?;
665+
let vahe = ShellVahe::new(vahe_config, context_bytes)?;
666666
let server = WillowV1Server { kahe, vahe };
667667
let server_state = ServerState::from_proto(server_state_proto, &server)?;
668668

willow/src/api/server_accumulator_test.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ AggregationConfigProto CreateValidConfig() {
4545
(*config.mutable_vector_configs())["test_vector"] = vector_config;
4646
config.set_max_number_of_decryptors(1);
4747
config.set_max_number_of_clients(10);
48-
config.set_session_id("test_session");
48+
config.set_key_id("test_key");
4949
return config;
5050
}
5151

@@ -67,7 +67,7 @@ TEST(BasicServerAccumulatorTest, ToSerializedStateHasCorrectConfig) {
6767
ASSERT_TRUE(state.ParseFromString(*serialized_state_or));
6868
// Check if the config matches. We serialize and deserialize to compare protos
6969
// easily or check fields.
70-
EXPECT_EQ(state.aggregation_config().session_id(), config.session_id());
70+
EXPECT_EQ(state.aggregation_config().key_id(), config.key_id());
7171
EXPECT_EQ(state.aggregation_config().max_number_of_clients(),
7272
config.max_number_of_clients());
7373
}
@@ -382,7 +382,7 @@ TEST_F(ServerAccumulatorTest, MergeFailsWithOverlappingRanges) {
382382

383383
TEST_F(ServerAccumulatorTest, MergeFailsWithConfigMismatch) {
384384
AggregationConfigProto config2 = config_;
385-
config2.set_session_id("other_session");
385+
config2.set_key_id("other_key");
386386
SECAGG_ASSERT_OK_AND_ASSIGN(auto accumulator2,
387387
ServerAccumulator::Create(config2));
388388

willow/src/shell/ahe.rs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -510,10 +510,10 @@ impl AheBase for ShellAhe {
510510

511511
type Config = ShellAheConfig;
512512

513-
fn new(config: Self::Config, context_string: &[u8]) -> Result<Self, status::StatusError> {
513+
fn new(config: Self::Config, context_bytes: &[u8]) -> Result<Self, status::StatusError> {
514514
let num_coeffs = 1 << config.log_n;
515515
let public_seed = single_thread_hkdf::compute_hkdf(
516-
context_string,
516+
context_bytes,
517517
b"",
518518
b"ShellAhe.public_seed",
519519
single_thread_hkdf::seed_length(),
@@ -783,13 +783,13 @@ mod test {
783783
const NUM_DECRYPTORS: usize = 3;
784784
const NUM_CLIENTS: usize = 1000;
785785
const MAX_ABSOLUTE_VALUE: i64 = 72;
786-
const CONTEXT_STRING: &[u8] = b"test_context_string";
786+
const CONTEXT_BYTES: &[u8] = b"test_context_bytes";
787787

788788
#[gtest]
789789
fn test_encrypt_decrypt_one() -> googletest::Result<()> {
790790
const NUM_VALUES: usize = 100;
791791

792-
let ahe = ShellAhe::new(make_ahe_config(), CONTEXT_STRING)?;
792+
let ahe = ShellAhe::new(make_ahe_config(), CONTEXT_BYTES)?;
793793

794794
let pt = vec![1, 2, 3, 4, 5, 6, 7, 8];
795795
let seed = SingleThreadHkdfPrng::generate_seed()?;
@@ -811,7 +811,7 @@ mod test {
811811
fn test_encrypt_decrypt_serialized() -> googletest::Result<()> {
812812
const NUM_VALUES: usize = 100;
813813

814-
let ahe = ShellAhe::new(make_ahe_config(), CONTEXT_STRING)?;
814+
let ahe = ShellAhe::new(make_ahe_config(), CONTEXT_BYTES)?;
815815

816816
let pt = vec![1, 2, 3, 4, 5, 6, 7, 8];
817817
let seed = SingleThreadHkdfPrng::generate_seed()?;
@@ -853,7 +853,7 @@ mod test {
853853
let config = make_ahe_config();
854854
let t = config.t; // Keep a copy of the plaintext modulus.
855855

856-
let ahe = ShellAhe::new(config, CONTEXT_STRING)?;
856+
let ahe = ShellAhe::new(config, CONTEXT_BYTES)?;
857857
let seed = SingleThreadHkdfPrng::generate_seed()?;
858858
let mut prng = SingleThreadHkdfPrng::create(&seed)?;
859859

@@ -920,7 +920,7 @@ mod test {
920920

921921
#[gtest]
922922
fn test_errors() -> googletest::Result<()> {
923-
let ahe = ShellAhe::new(make_ahe_config(), CONTEXT_STRING)?;
923+
let ahe = ShellAhe::new(make_ahe_config(), CONTEXT_BYTES)?;
924924
let seed = SingleThreadHkdfPrng::generate_seed()?;
925925
let mut prng = SingleThreadHkdfPrng::create(&seed)?;
926926

@@ -998,7 +998,7 @@ mod test {
998998
let config = make_ahe_config();
999999
let q: i128 = config.qs.iter().map(|x| *x as i128).product();
10001000

1001-
let ahe = ShellAhe::new(config, CONTEXT_STRING)?;
1001+
let ahe = ShellAhe::new(config, CONTEXT_BYTES)?;
10021002
let seed = SingleThreadHkdfPrng::generate_seed()?;
10031003
let mut prng = SingleThreadHkdfPrng::create(&seed)?;
10041004
let (_, pk_share, _) = ahe.key_gen(&mut prng)?;
@@ -1040,7 +1040,7 @@ mod test {
10401040
#[gtest]
10411041
fn test_export_ciphertext_has_right_order() -> googletest::Result<()> {
10421042
let config = make_ahe_config();
1043-
let ahe = ShellAhe::new(config, CONTEXT_STRING)?;
1043+
let ahe = ShellAhe::new(config, CONTEXT_BYTES)?;
10441044
let seed = SingleThreadHkdfPrng::generate_seed()?;
10451045
let mut prng = SingleThreadHkdfPrng::create(&seed)?;
10461046
let (_, pk_share, _) = ahe.key_gen(&mut prng)?;

willow/src/shell/kahe.rs

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -180,12 +180,12 @@ impl KaheBase for ShellKahe {
180180

181181
fn new(
182182
shell_kahe_config: Self::Config,
183-
context_string: &[u8],
183+
context_bytes: &[u8],
184184
) -> Result<Self, status::StatusError> {
185185
Self::validate_kahe_config(&shell_kahe_config)?;
186186
let num_coeffs = 1 << shell_kahe_config.log_n;
187187
let public_seed = single_thread_hkdf::compute_hkdf(
188-
context_string,
188+
context_bytes,
189189
b"",
190190
b"ShellKahe.public_seed",
191191
single_thread_hkdf::seed_length(),
@@ -395,7 +395,7 @@ mod test {
395395
/// Default ID used in tests.
396396
const DEFAULT_ID: &str = "default";
397397

398-
const CONTEXT_STRING: &[u8] = b"test_context_string";
398+
const CONTEXT_BYTES: &[u8] = b"test_context_bytes";
399399

400400
#[gtest]
401401
fn test_encrypt_decrypt_short() -> googletest::Result<()> {
@@ -405,7 +405,7 @@ mod test {
405405
PackedVectorConfig { base: 10, dimension: 2, num_packed_coeffs: 5, length: 10 },
406406
)]);
407407
let kahe_config = make_kahe_config_for(plaintext_modulus_bits, packed_vector_configs)?;
408-
let kahe = ShellKahe::new(kahe_config, CONTEXT_STRING)?;
408+
let kahe = ShellKahe::new(kahe_config, CONTEXT_BYTES)?;
409409

410410
let pt = HashMap::from([(DEFAULT_ID.to_string(), vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9])]);
411411
let seed = SingleThreadHkdfPrng::generate_seed()?;
@@ -425,7 +425,7 @@ mod test {
425425
PackedVectorConfig { base: 10, dimension: 2, num_packed_coeffs: 5, length: 8 },
426426
)]);
427427
let kahe_config = make_kahe_config_for(plaintext_modulus_bits, packed_vector_configs)?;
428-
let kahe = ShellKahe::new(kahe_config, CONTEXT_STRING)?;
428+
let kahe = ShellKahe::new(kahe_config, CONTEXT_BYTES)?;
429429

430430
let pt = HashMap::from([(DEFAULT_ID.to_string(), vec![0, 1, 2, 3, 4, 5, 6, 7])]);
431431
let seed = SingleThreadHkdfPrng::generate_seed()?;
@@ -445,7 +445,7 @@ mod test {
445445
PackedVectorConfig { base: 10, dimension: 2, num_packed_coeffs: 5, length: 10 },
446446
)]);
447447
let kahe_config = make_kahe_config_for(plaintext_modulus_bits, packed_vector_configs)?;
448-
let kahe = ShellKahe::new(kahe_config, CONTEXT_STRING)?;
448+
let kahe = ShellKahe::new(kahe_config, CONTEXT_BYTES)?;
449449

450450
let pt = HashMap::from([(DEFAULT_ID.to_string(), vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9])]);
451451
let seed = SingleThreadHkdfPrng::generate_seed()?;
@@ -484,7 +484,7 @@ mod test {
484484
packed_vector_config.length = num_messages;
485485
set_kahe_num_public_polynomials(&mut kahe_config);
486486

487-
let kahe = ShellKahe::new(kahe_config, CONTEXT_STRING)?;
487+
let kahe = ShellKahe::new(kahe_config, CONTEXT_BYTES)?;
488488

489489
let seed = SingleThreadHkdfPrng::generate_seed()?;
490490
let mut prng = SingleThreadHkdfPrng::create(&seed)?;
@@ -518,7 +518,7 @@ mod test {
518518
)]);
519519
let kahe_config = make_kahe_config_for(plaintext_modulus_bits, packed_vector_configs)?;
520520

521-
let kahe = ShellKahe::new(kahe_config, CONTEXT_STRING)?;
521+
let kahe = ShellKahe::new(kahe_config, CONTEXT_BYTES)?;
522522
let seed = SingleThreadHkdfPrng::generate_seed()?;
523523
let mut prng = SingleThreadHkdfPrng::create(&seed)?;
524524

@@ -556,7 +556,7 @@ mod test {
556556
let packed_vector_configs = BTreeMap::from([]);
557557
let kahe_config = make_kahe_config_for(plaintext_modulus_bits, packed_vector_configs)?;
558558

559-
let kahe = ShellKahe::new(kahe_config, CONTEXT_STRING)?;
559+
let kahe = ShellKahe::new(kahe_config, CONTEXT_BYTES)?;
560560
let seed = SingleThreadHkdfPrng::generate_seed()?;
561561
let mut prng = SingleThreadHkdfPrng::create(&seed)?;
562562

@@ -600,7 +600,7 @@ mod test {
600600
PackedVectorConfig { base: 10, dimension: 2, num_packed_coeffs: 5, length: 10 },
601601
)]);
602602
let kahe_config = make_kahe_config_for(plaintext_modulus_bits, packed_vector_configs)?;
603-
let kahe = ShellKahe::new(kahe_config, CONTEXT_STRING)?;
603+
let kahe = ShellKahe::new(kahe_config, CONTEXT_BYTES)?;
604604

605605
let pt = HashMap::from([(String::from(DEFAULT_ID), vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9])]);
606606
let seed = SingleThreadHkdfPrng::generate_seed()?;
@@ -626,7 +626,7 @@ mod test {
626626
let plaintext_modulus_bits = 39;
627627
let packed_vector_configs = BTreeMap::from([]);
628628
let kahe_config = make_kahe_config_for(plaintext_modulus_bits, packed_vector_configs)?;
629-
let kahe = ShellKahe::new(kahe_config, CONTEXT_STRING)?;
629+
let kahe = ShellKahe::new(kahe_config, CONTEXT_BYTES)?;
630630

631631
// The seed used to sample the secret keys.
632632
let seed = SingleThreadHkdfPrng::generate_seed()?;

willow/src/shell/parameters_generation.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ mod test {
109109
max_number_of_decryptors: 1,
110110
max_decryptor_dropouts: 0,
111111
max_number_of_clients: 1,
112-
session_id: String::from("test"),
112+
key_id: b"test".to_vec(),
113113
};
114114
let invalid_plaintext_bits = 0;
115115
let result = generate_packing_config(invalid_plaintext_bits, &agg_config);
@@ -130,7 +130,7 @@ mod test {
130130
max_number_of_decryptors: 1,
131131
max_decryptor_dropouts: 0,
132132
max_number_of_clients: 1,
133-
session_id: String::from("test"),
133+
key_id: b"test".to_vec(),
134134
};
135135
let result = generate_packing_config(plaintext_bits, &bad_agg_config);
136136
expect_true!(result.is_err());
@@ -151,7 +151,7 @@ mod test {
151151
max_number_of_decryptors: 1,
152152
max_decryptor_dropouts: 0,
153153
max_number_of_clients: 0,
154-
session_id: String::from("test"),
154+
key_id: b"test".to_vec(),
155155
};
156156
let result = generate_packing_config(plaintext_bits, &bad_agg_config);
157157
expect_true!(result.is_err());
@@ -168,7 +168,7 @@ mod test {
168168
max_number_of_decryptors: 1,
169169
max_decryptor_dropouts: 0,
170170
max_number_of_clients: 2,
171-
session_id: String::from("test"),
171+
key_id: b"test".to_vec(),
172172
};
173173
let result = generate_packing_config(plaintext_bits, &agg_config);
174174
expect_true!(result.is_err());
@@ -187,7 +187,7 @@ mod test {
187187
max_number_of_decryptors: 1,
188188
max_decryptor_dropouts: 0,
189189
max_number_of_clients: 1 << 8,
190-
session_id: String::from("test"),
190+
key_id: b"test".to_vec(),
191191
};
192192
let plaintext_bits = 24;
193193
let packed_vector_configs = generate_packing_config(plaintext_bits, &agg_config)?;

0 commit comments

Comments
 (0)