Skip to content

Commit 2e8a2f7

Browse files
schoppmpcopybara-github
authored andcommitted
Wrap Kahe and Vahe struct members with Rc
PiperOrigin-RevId: 859733028
1 parent 745d549 commit 2e8a2f7

File tree

8 files changed

+190
-273
lines changed

8 files changed

+190
-273
lines changed

willow/benches/shell_benchmarks.rs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
use clap::Parser;
1616
use std::collections::HashMap;
1717
use std::hint::black_box;
18+
use std::rc::Rc;
1819
use std::time::Duration;
1920

2021
use aggregation_config::AggregationConfig;
@@ -128,27 +129,26 @@ fn setup_base(args: &Args) -> BaseInputs {
128129
let ahe_config = create_shell_ahe_config(aggregation_config.max_number_of_decryptors).unwrap();
129130
let kahe_config = create_shell_kahe_config(&aggregation_config).unwrap();
130131

132+
// Create common KAHE/VAHE instances.
133+
let kahe = Rc::new(ShellKahe::new(kahe_config.clone(), CONTEXT_STRING).unwrap());
134+
let vahe = Rc::new(ShellVahe::new(ahe_config.clone(), CONTEXT_STRING).unwrap());
135+
131136
// Create client.
132-
let kahe = ShellKahe::new(kahe_config.clone(), CONTEXT_STRING).unwrap();
133-
let vahe = ShellVahe::new(ahe_config.clone(), CONTEXT_STRING).unwrap();
134-
let client = WillowV1Client::new_with_randomly_generated_seed(kahe, vahe).unwrap();
137+
let client =
138+
WillowV1Client::new_with_randomly_generated_seed(kahe.clone(), vahe.clone()).unwrap();
135139

136140
// Create decryptor.
137-
let vahe = ShellVahe::new(ahe_config.clone(), CONTEXT_STRING).unwrap();
138141
let seed = SingleThreadHkdfPrng::generate_seed().unwrap();
139142
let prng = SingleThreadHkdfPrng::create(&seed).unwrap();
140143
let mut decryptor_state = DecryptorState::default();
141-
let mut decryptor = WillowV1Decryptor { vahe, prng };
144+
let mut decryptor = WillowV1Decryptor { vahe: vahe.clone(), prng };
142145

143146
// Create server.
144-
let kahe = ShellKahe::new(kahe_config.clone(), CONTEXT_STRING).unwrap();
145-
let vahe = ShellVahe::new(ahe_config.clone(), CONTEXT_STRING).unwrap();
146-
let server = WillowV1Server { kahe, vahe };
147+
let server = WillowV1Server { kahe: kahe.clone(), vahe: vahe.clone() };
147148
let mut server_state = ServerState::default();
148149

149150
// Create verifier.
150-
let vahe = ShellVahe::new(ahe_config.clone(), CONTEXT_STRING).unwrap();
151-
let verifier = WillowV1Verifier { vahe };
151+
let verifier = WillowV1Verifier { vahe: vahe.clone() };
152152
let verifier_state = VerifierState::default();
153153

154154
// Decryptor generates public key share.

willow/src/api/client.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ use single_thread_hkdf::SingleThreadHkdfPrng;
3030
use status::ffi::FfiStatus;
3131
use status::StatusError;
3232
use std::collections::HashMap;
33+
use std::rc::Rc;
3334
use vahe_shell::ShellVahe;
3435
use willow_v1_client::WillowV1Client;
3536

@@ -84,8 +85,8 @@ impl WillowShellClient {
8485
let aggregation_config = AggregationConfig::from_proto(aggregation_config_proto, ())?;
8586
let (kahe_config, ahe_config) = create_shell_configs(&aggregation_config)?;
8687
let context_bytes = &aggregation_config.key_id;
87-
let kahe = ShellKahe::new(kahe_config, &context_bytes)?;
88-
let vahe = ShellVahe::new(ahe_config, &context_bytes)?;
88+
let kahe = Rc::new(ShellKahe::new(kahe_config, &context_bytes)?);
89+
let vahe = Rc::new(ShellVahe::new(ahe_config, &context_bytes)?);
8990
let client = WillowV1Client::new_with_randomly_generated_seed(kahe, vahe)?;
9091
Ok(WillowShellClient(client))
9192
}
@@ -104,7 +105,7 @@ impl WillowShellClient {
104105
}
105106
let public_key_proto = ShellAhePublicKey::parse(public_key.as_bytes())
106107
.map_err(|e| status::internal(format!("Failed to parse ShellAhePublicKey: {}", e)))?;
107-
let public_key_rust = PublicKey::from_proto(public_key_proto, &self.0.vahe)?;
108+
let public_key_rust = PublicKey::from_proto(public_key_proto, self.0.vahe.as_ref())?;
108109
let message = self.0.create_client_message(&plaintext_slice, &public_key_rust, nonce)?;
109110
Ok(message
110111
.to_proto(&self.0)?

willow/src/api/server_accumulator.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ use server_traits::SecureAggregationServer;
3131
use status::StatusError;
3232
use std::collections::BTreeMap;
3333
use std::ops::Range;
34+
use std::rc::Rc;
3435
use vahe_shell::ShellVahe;
3536
use verifier_traits::SecureAggregationVerifier;
3637
use willow_v1_server::{ServerState, WillowV1Server};
@@ -147,11 +148,10 @@ impl ServerAccumulator {
147148
fn new(aggregation_config: AggregationConfig) -> Result<Self, StatusError> {
148149
let (kahe_config, vahe_config) = create_shell_configs(&aggregation_config)?;
149150
let context_bytes = &aggregation_config.key_id;
150-
let server_kahe = ShellKahe::new(kahe_config, context_bytes)?;
151-
let server_vahe = ShellVahe::new(vahe_config.clone(), context_bytes)?;
152-
let verifier_vahe = ShellVahe::new(vahe_config, context_bytes)?;
153-
let server = WillowV1Server { kahe: server_kahe, vahe: server_vahe };
154-
let verifier = WillowV1Verifier { vahe: verifier_vahe };
151+
let kahe = Rc::new(ShellKahe::new(kahe_config, &context_bytes)?);
152+
let vahe = Rc::new(ShellVahe::new(vahe_config, &context_bytes)?);
153+
let server = WillowV1Server { kahe: kahe.clone(), vahe: vahe.clone() };
154+
let verifier = WillowV1Verifier { vahe };
155155
Ok(Self {
156156
server: server,
157157
server_state: Default::default(),
@@ -661,8 +661,8 @@ impl FinalResultDecryptor {
661661
let aggregation_config = AggregationConfig::from_proto(aggregation_config_proto, ())?;
662662
let (kahe_config, vahe_config) = create_shell_configs(&aggregation_config)?;
663663
let context_bytes = &aggregation_config.key_id;
664-
let kahe = ShellKahe::new(kahe_config, context_bytes)?;
665-
let vahe = ShellVahe::new(vahe_config, context_bytes)?;
664+
let kahe = Rc::new(ShellKahe::new(kahe_config, context_bytes)?);
665+
let vahe = Rc::new(ShellVahe::new(vahe_config, context_bytes)?);
666666
let server = WillowV1Server { kahe, vahe };
667667
let server_state = ServerState::from_proto(server_state_proto, &server)?;
668668

willow/src/willow_v1/client.rs

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,33 +17,34 @@ use kahe_traits::{HasKahe, KaheBase, KaheEncrypt, KaheKeygen, TrySecretKeyInto};
1717
use messages::{ClientMessage, DecryptorPublicKey};
1818
use prng_traits::SecurePrng;
1919
use std::cell::RefCell;
20+
use std::rc::Rc;
2021
use vahe_traits::{HasVahe, VaheBase, VerifiableEncrypt};
2122

2223
/// Lightweight client directly exposing KAHE/VAHE types.
2324
pub struct WillowV1Client<Kahe: KaheBase, Vahe: VaheBase> {
24-
pub kahe: Kahe,
25-
pub vahe: Vahe,
25+
pub kahe: Rc<Kahe>,
26+
pub vahe: Rc<Vahe>,
2627
pub prng: RefCell<Kahe::Rng>, // Using a single PRNG for both VAHE and KAHE.
2728
}
2829

2930
impl<Kahe: KaheBase, Vahe: VaheBase> HasKahe for WillowV1Client<Kahe, Vahe> {
3031
type Kahe = Kahe;
3132
fn kahe(&self) -> &Self::Kahe {
32-
&self.kahe
33+
self.kahe.as_ref()
3334
}
3435
}
3536

3637
impl<Kahe: KaheBase, Vahe: VaheBase> HasVahe for WillowV1Client<Kahe, Vahe> {
3738
type Vahe = Vahe;
3839
fn vahe(&self) -> &Self::Vahe {
39-
&self.vahe
40+
self.vahe.as_ref()
4041
}
4142
}
4243

4344
impl<Kahe: KaheBase, Vahe: VaheBase> WillowV1Client<Kahe, Vahe> {
4445
pub fn new_with_randomly_generated_seed(
45-
kahe: Kahe,
46-
vahe: Vahe,
46+
kahe: Rc<Kahe>,
47+
vahe: Rc<Vahe>,
4748
) -> Result<Self, status::StatusError> {
4849
let seed = Kahe::Rng::generate_seed()?;
4950
let prng = RefCell::new(Kahe::Rng::create(&seed)?);
@@ -124,8 +125,8 @@ mod test {
124125

125126
// Create a client.
126127
let (kahe_config, ahe_config) = create_shell_configs(&aggregation_config)?;
127-
let kahe = ShellKahe::new(kahe_config, CONTEXT_STRING)?;
128-
let vahe = ShellVahe::new(ahe_config, CONTEXT_STRING)?;
128+
let kahe = Rc::new(ShellKahe::new(kahe_config, CONTEXT_STRING)?);
129+
let vahe = Rc::new(ShellVahe::new(ahe_config, CONTEXT_STRING)?);
129130
let client = WillowV1Client::new_with_randomly_generated_seed(kahe, vahe)?;
130131

131132
// Generate AHE keys.
@@ -161,16 +162,13 @@ mod test {
161162
key_id: b"test".to_vec(),
162163
};
163164

164-
// Create a client.
165+
// Create common KAHE/VAHE instances.
165166
let (kahe_config, ahe_config) = create_shell_configs(&aggregation_config)?;
166-
let kahe = ShellKahe::new(kahe_config, CONTEXT_STRING)?;
167-
let vahe = ShellVahe::new(ahe_config, CONTEXT_STRING)?;
168-
let client1 = WillowV1Client::new_with_randomly_generated_seed(kahe, vahe)?;
167+
let kahe = Rc::new(ShellKahe::new(kahe_config, CONTEXT_STRING)?);
168+
let vahe = Rc::new(ShellVahe::new(ahe_config, CONTEXT_STRING)?);
169169

170-
// Create a second client.
171-
let (kahe_config, ahe_config) = create_shell_configs(&aggregation_config)?;
172-
let kahe = ShellKahe::new(kahe_config, CONTEXT_STRING)?;
173-
let vahe = ShellVahe::new(ahe_config, CONTEXT_STRING)?;
170+
// Create clients.
171+
let client1 = WillowV1Client::new_with_randomly_generated_seed(kahe.clone(), vahe.clone())?;
174172
let client2 = WillowV1Client::new_with_randomly_generated_seed(kahe, vahe)?;
175173

176174
// Generate AHE keys.

willow/src/willow_v1/decryptor.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,13 @@ use proto_serialization_traits::{FromProto, ToProto};
2020
use protobuf::AsView;
2121
use shell_ciphertexts_rust_proto::ShellAheSecretKeyShare;
2222
use status::StatusError;
23+
use std::rc::Rc;
2324
use vahe_traits::{EncryptVerify, HasVahe, VaheBase};
2425

2526
/// Lightweight decryptor directly exposing KAHE/VAHE types. It verifies only the client proofs,
2627
/// does not provide verifiable partial decryptions.
2728
pub struct WillowV1Decryptor<Vahe: VaheBase> {
28-
pub vahe: Vahe,
29+
pub vahe: Rc<Vahe>,
2930
pub prng: Vahe::Rng,
3031
}
3132

@@ -137,13 +138,15 @@ mod tests {
137138
use prng_traits::SecurePrng;
138139
use proto_serialization_traits::{FromProto, ToProto};
139140
use single_thread_hkdf::SingleThreadHkdfPrng;
141+
use std::rc::Rc;
140142
use vahe_shell::ShellVahe;
141143

142144
const CONTEXT_STRING: &[u8] = b"testing_context_string";
143145

144146
#[gtest]
145147
fn decryptor_state_serialization_roundtrip() -> googletest::Result<()> {
146-
let vahe = ShellVahe::new(create_shell_ahe_config(1).unwrap(), CONTEXT_STRING).unwrap();
148+
let vahe =
149+
Rc::new(ShellVahe::new(create_shell_ahe_config(1).unwrap(), CONTEXT_STRING).unwrap());
147150
let seed = SingleThreadHkdfPrng::generate_seed()?;
148151
let prng = SingleThreadHkdfPrng::create(&seed)?;
149152
let mut decryptor = WillowV1Decryptor { vahe, prng };

willow/src/willow_v1/server.rs

Lines changed: 20 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,15 @@ use shell_ciphertexts_rust_proto::{
2828
};
2929
use status::StatusError;
3030
use std::collections::HashMap;
31+
use std::rc::Rc;
3132
use vahe_traits::{EncryptVerify, HasVahe, Recover, VaheBase};
3233

3334
/// Implements the `server` role in the Willow protocol. This includes aggregating public key shares
3435
/// from the decryptors, aggregating client ciphertexts, and recovering the aggregation result after
3536
/// receiving partial decryption responses from the decryptors.
3637
pub struct WillowV1Server<Kahe: KaheBase, Vahe: VaheBase> {
37-
pub kahe: Kahe,
38-
pub vahe: Vahe,
38+
pub kahe: Rc<Kahe>,
39+
pub vahe: Rc<Vahe>,
3940
}
4041

4142
impl<Kahe: KaheBase, Vahe: VaheBase> HasKahe for WillowV1Server<Kahe, Vahe> {
@@ -383,47 +384,34 @@ mod tests {
383384
generate_aggregation_config(DEFAULT_VECTOR_ID.to_string(), 16, 10, 1, 1);
384385
let max_number_of_decryptors = aggregation_config.max_number_of_decryptors;
385386

386-
// Create client.
387-
let kahe =
387+
// Create common KAHE/VAHE instances.
388+
let kahe = Rc::new(
388389
ShellKahe::new(create_shell_kahe_config(&aggregation_config).unwrap(), CONTEXT_STRING)
389-
.unwrap();
390-
let vahe = ShellVahe::new(
391-
create_shell_ahe_config(max_number_of_decryptors).unwrap(),
392-
CONTEXT_STRING,
393-
)
394-
.unwrap();
395-
let client = WillowV1Client::new_with_randomly_generated_seed(kahe, vahe)?;
390+
.unwrap(),
391+
);
392+
let vahe = Rc::new(
393+
ShellVahe::new(
394+
create_shell_ahe_config(max_number_of_decryptors).unwrap(),
395+
CONTEXT_STRING,
396+
)
397+
.unwrap(),
398+
);
399+
400+
// Create client.
401+
let client = WillowV1Client::new_with_randomly_generated_seed(kahe.clone(), vahe.clone())?;
396402

397403
// Create decryptor.
398-
let vahe = ShellVahe::new(
399-
create_shell_ahe_config(max_number_of_decryptors).unwrap(),
400-
CONTEXT_STRING,
401-
)
402-
.unwrap();
403404
let seed = SingleThreadHkdfPrng::generate_seed()?;
404405
let prng = SingleThreadHkdfPrng::create(&seed)?;
405406
let mut decryptor_state = DecryptorState::default();
406-
let mut decryptor = WillowV1Decryptor { vahe, prng };
407+
let mut decryptor = WillowV1Decryptor { vahe: vahe.clone(), prng };
407408

408409
// Create server.
409-
let kahe =
410-
ShellKahe::new(create_shell_kahe_config(&aggregation_config).unwrap(), CONTEXT_STRING)
411-
.unwrap();
412-
let vahe = ShellVahe::new(
413-
create_shell_ahe_config(max_number_of_decryptors).unwrap(),
414-
CONTEXT_STRING,
415-
)
416-
.unwrap();
417-
let server = WillowV1Server { kahe, vahe };
410+
let server = WillowV1Server { kahe: kahe.clone(), vahe: vahe.clone() };
418411
let mut server_state = ServerState::default();
419412

420413
// Create verifier.
421-
let vahe = ShellVahe::new(
422-
create_shell_ahe_config(max_number_of_decryptors).unwrap(),
423-
CONTEXT_STRING,
424-
)
425-
.unwrap();
426-
let verifier = WillowV1Verifier { vahe };
414+
let verifier = WillowV1Verifier { vahe: vahe.clone() };
427415
let mut verifier_state = VerifierState::default();
428416

429417
// Check empty state serialization

willow/src/willow_v1/verifier.rs

Lines changed: 24 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,13 @@ use protobuf::{proto, AsView};
1919
use shell_ciphertexts_rust_proto::ShellAhePartialDecCiphertext;
2020
use status::StatusError;
2121
use std::fmt::Debug;
22+
use std::rc::Rc;
2223
use vahe_traits::{EncryptVerify, HasVahe, VaheBase};
2324
use verifier_traits::SecureAggregationVerifier;
2425

2526
/// The verifier struct, containing a WillowCommon instance.
2627
pub struct WillowV1Verifier<Vahe: VaheBase> {
27-
pub vahe: Vahe,
28+
pub vahe: Rc<Vahe>,
2829
}
2930

3031
impl<Vahe: VaheBase> HasVahe for WillowV1Verifier<Vahe> {
@@ -296,48 +297,35 @@ mod tests {
296297
generate_aggregation_config(DEFAULT_VECTOR_ID.to_string(), 16, 10, 1, 1);
297298
let max_number_of_decryptors = aggregation_config.max_number_of_decryptors;
298299

299-
// Create client.
300-
let kahe =
300+
// Create common KAHE/VAHE instances.
301+
let kahe = Rc::new(
301302
ShellKahe::new(create_shell_kahe_config(&aggregation_config).unwrap(), CONTEXT_STRING)
302-
.unwrap();
303-
let vahe = ShellVahe::new(
304-
create_shell_ahe_config(max_number_of_decryptors).unwrap(),
305-
CONTEXT_STRING,
306-
)
307-
.unwrap();
308-
let client = WillowV1Client::new_with_randomly_generated_seed(kahe, vahe)?;
309-
310-
// Create decryptor, which needs its own `vahe` (with same public polynomials
311-
// generated from the seeds) and `prng`.
312-
let vahe = ShellVahe::new(
313-
create_shell_ahe_config(max_number_of_decryptors).unwrap(),
314-
CONTEXT_STRING,
315-
)
316-
.unwrap();
317-
let seed = SingleThreadHkdfPrng::generate_seed()?;
318-
let prng = SingleThreadHkdfPrng::create(&seed)?;
303+
.unwrap(),
304+
);
305+
let vahe = Rc::new(
306+
ShellVahe::new(
307+
create_shell_ahe_config(max_number_of_decryptors).unwrap(),
308+
CONTEXT_STRING,
309+
)
310+
.unwrap(),
311+
);
312+
313+
// Create client.
314+
let client =
315+
WillowV1Client::new_with_randomly_generated_seed(kahe.clone(), vahe.clone()).unwrap();
316+
317+
// Create decryptor.
318+
let seed = SingleThreadHkdfPrng::generate_seed().unwrap();
319+
let prng = SingleThreadHkdfPrng::create(&seed).unwrap();
319320
let mut decryptor_state = DecryptorState::default();
320-
let mut decryptor = WillowV1Decryptor { vahe, prng };
321+
let mut decryptor = WillowV1Decryptor { vahe: vahe.clone(), prng };
321322

322323
// Create server.
323-
let kahe =
324-
ShellKahe::new(create_shell_kahe_config(&aggregation_config).unwrap(), CONTEXT_STRING)
325-
.unwrap();
326-
let vahe = ShellVahe::new(
327-
create_shell_ahe_config(max_number_of_decryptors).unwrap(),
328-
CONTEXT_STRING,
329-
)
330-
.unwrap();
331-
let server = WillowV1Server { kahe, vahe };
324+
let server = WillowV1Server { kahe: kahe.clone(), vahe: vahe.clone() };
332325
let mut server_state = ServerState::default();
333326

334327
// Create verifier.
335-
let vahe = ShellVahe::new(
336-
create_shell_ahe_config(max_number_of_decryptors).unwrap(),
337-
CONTEXT_STRING,
338-
)
339-
.unwrap();
340-
let verifier = WillowV1Verifier { vahe };
328+
let verifier = WillowV1Verifier { vahe: vahe.clone() };
341329

342330
// Decryptor generates public key share.
343331
let public_key_share = decryptor.create_public_key_share(&mut decryptor_state)?;

0 commit comments

Comments
 (0)