Skip to content

Commit 1c0b146

Browse files
tholopcopybara-github
authored andcommitted
Use interior mutability for Decryptor PRNG too; add constructor; clean up imports.
PiperOrigin-RevId: 860138561
1 parent 7ec4944 commit 1c0b146

File tree

10 files changed

+40
-67
lines changed

10 files changed

+40
-67
lines changed

willow/benches/shell_benchmarks.rs

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,7 @@ use messages::{
2828
PartialDecryptionRequest,
2929
};
3030
use parameters_shell::{create_shell_ahe_config, create_shell_kahe_config};
31-
use prng_traits::SecurePrng;
3231
use server_traits::SecureAggregationServer;
33-
use single_thread_hkdf::SingleThreadHkdfPrng;
3432
use testing_utils::{generate_random_nonce, generate_random_unsigned_vector};
3533
use vahe_shell::ShellVahe;
3634
use verifier_traits::SecureAggregationVerifier;
@@ -135,10 +133,8 @@ fn setup_base(args: &Args) -> BaseInputs {
135133

136134
// Create decryptor.
137135
let vahe = ShellVahe::new(ahe_config.clone(), CONTEXT_STRING).unwrap();
138-
let seed = SingleThreadHkdfPrng::generate_seed().unwrap();
139-
let prng = SingleThreadHkdfPrng::create(&seed).unwrap();
140136
let mut decryptor_state = DecryptorState::default();
141-
let mut decryptor = WillowV1Decryptor { vahe, prng };
137+
let decryptor = WillowV1Decryptor::new_with_randomly_generated_seed(vahe).unwrap();
142138

143139
// Create server.
144140
let kahe = ShellKahe::new(kahe_config.clone(), CONTEXT_STRING).unwrap();

willow/src/api/client.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,9 @@ use client_traits::SecureAggregationClient;
2222
use kahe_shell::ShellKahe;
2323
use kahe_traits::KaheBase;
2424
use parameters_shell::create_shell_configs;
25-
use prng_traits::SecurePrng;
2625
use proto_serialization_traits::{FromProto, ToProto};
2726
use protobuf::prelude::*;
2827
use shell_ciphertexts_rust_proto::ShellAhePublicKey;
29-
use single_thread_hkdf::SingleThreadHkdfPrng;
3028
use status::ffi::FfiStatus;
3129
use status::StatusError;
3230
use std::collections::HashMap;

willow/src/testing_utils/shell_testing_decryptor.rs

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ use protobuf::prelude::*;
3131
use single_thread_hkdf::SingleThreadHkdfPrng;
3232
use status::ffi::FfiStatus;
3333
use status::{StatusError, StatusErrorCode};
34+
use std::cell::RefCell;
3435
use vahe_shell::ShellVahe;
3536
use vahe_traits::Recover;
3637
use vahe_traits::{HasVahe, VaheBase};
@@ -41,7 +42,7 @@ use vahe_traits::{HasVahe, VaheBase};
4142
pub struct ShellTestingDecryptor {
4243
kahe: ShellKahe,
4344
vahe: ShellVahe,
44-
prng: SingleThreadHkdfPrng,
45+
prng: RefCell<SingleThreadHkdfPrng>,
4546
secret_key: Option<<ShellVahe as AheBase>::SecretKeyShare>,
4647
}
4748

@@ -64,14 +65,14 @@ impl ShellTestingDecryptor {
6465
let vahe = ShellVahe::new(ahe_config, context_bytes)?;
6566
let seed = SingleThreadHkdfPrng::generate_seed()?;
6667
let prng = SingleThreadHkdfPrng::create(&seed)?;
67-
Ok(ShellTestingDecryptor { kahe, vahe, prng, secret_key: None })
68+
Ok(ShellTestingDecryptor { kahe, vahe, prng: RefCell::new(prng), secret_key: None })
6869
}
6970

7071
/// Generates a new AHE public key, and stores the corresponding secret key.
7172
pub fn generate_public_key(
7273
&mut self,
7374
) -> Result<<ShellVahe as AheBase>::PublicKey, StatusError> {
74-
let (sk_share, pk_share, _) = self.vahe.key_gen(&mut self.prng)?;
75+
let (sk_share, pk_share, _) = self.vahe.key_gen(&mut self.prng.borrow_mut())?;
7576
self.secret_key = Some(sk_share);
7677
let public_key = self.vahe.aggregate_public_key_shares(&[pk_share])?;
7778
Ok(public_key)
@@ -81,7 +82,7 @@ impl ShellTestingDecryptor {
8182
/// the AHE ciphertext and then decrypting the KAHE ciphertext. Does not verify the client proof
8283
/// contained in the message.
8384
pub fn decrypt(
84-
&mut self,
85+
&self,
8586
client_message: &ClientMessage<ShellKahe, ShellVahe>,
8687
) -> Result<<ShellKahe as KaheBase>::Plaintext, StatusError> {
8788
let partial_dec_ciphertext =
@@ -94,8 +95,11 @@ impl ShellTestingDecryptor {
9495
"No secret key available",
9596
)),
9697
Some(sk_share) => {
97-
let partial_decryption =
98-
self.vahe.partial_decrypt(&partial_dec_ciphertext, sk_share, &mut self.prng)?;
98+
let partial_decryption = self.vahe.partial_decrypt(
99+
&partial_dec_ciphertext,
100+
sk_share,
101+
&mut self.prng.borrow_mut(),
102+
)?;
99103
let decrypted_kahe_key =
100104
self.vahe.recover(&partial_decryption, &rest_of_ciphertext, None)?;
101105
let decrypted_kahe_key = self.kahe.try_secret_key_from(decrypted_kahe_key)?;
@@ -134,7 +138,7 @@ impl ShellTestingDecryptor {
134138
}
135139

136140
fn decrypt_serialized(
137-
&mut self,
141+
&self,
138142
contribution: &[u8],
139143
) -> Result<Vec<ffi::EncodedDataEntry>, StatusError> {
140144
let client_message_proto = ClientMessageProto::parse(contribution)
@@ -192,7 +196,7 @@ impl ShellTestingDecryptor {
192196
let partial_decryption = self.vahe.partial_decrypt(
193197
&request.partial_dec_ciphertext,
194198
sk_share,
195-
&mut self.prng,
199+
&mut self.prng.borrow_mut(),
196200
)?;
197201
Ok(PartialDecryptionResponse { partial_decryption })
198202
}

willow/src/traits/decryptor.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,14 @@ pub trait SecureAggregationDecryptor: HasVahe {
2424
/// Creates a public key share to be sent to the Server, updating the
2525
/// decryptor state.
2626
fn create_public_key_share(
27-
&mut self,
27+
&self,
2828
decryptor_state: &mut Self::DecryptorState,
2929
) -> Result<DecryptorPublicKeyShare<<Self as HasVahe>::Vahe>, StatusError>;
3030

3131
/// Handles a partial decryption request received from the Server. Returns a
3232
/// partial decryption to the Server.
3333
fn handle_partial_decryption_request(
34-
&mut self,
34+
&self,
3535
partial_decryption_request: PartialDecryptionRequest<<Self as HasVahe>::Vahe>,
3636
decryptor_state: &Self::DecryptorState,
3737
) -> Result<PartialDecryptionResponse<<Self as HasVahe>::Vahe>, StatusError>;

willow/src/willow_v1/BUILD

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,10 @@ rust_test(
4444
"//willow/src/api:aggregation_config",
4545
"//willow/src/shell:kahe_shell",
4646
"//willow/src/shell:parameters_shell",
47-
"//willow/src/shell:single_thread_hkdf",
4847
"//willow/src/shell:vahe_shell",
4948
"//willow/src/testing_utils",
5049
"//willow/src/testing_utils:shell_testing_decryptor",
5150
"//willow/src/testing_utils:shell_testing_parameters",
52-
"//willow/src/traits:prng_traits",
5351
],
5452
)
5553

@@ -59,11 +57,9 @@ rust_test(
5957
deps = [
6058
"@crate_index//:googletest",
6159
"//willow/src/shell:parameters_shell",
62-
"//willow/src/shell:single_thread_hkdf",
6360
"//willow/src/shell:vahe_shell",
6461
"//willow/src/traits:ahe_traits",
6562
"//willow/src/traits:decryptor_traits",
66-
"//willow/src/traits:prng_traits",
6763
"//willow/src/traits:proto_serialization_traits",
6864
],
6965
)
@@ -81,6 +77,7 @@ rust_library(
8177
"//willow/src/traits:ahe_traits",
8278
"//willow/src/traits:decryptor_traits",
8379
"//willow/src/traits:messages",
80+
"//willow/src/traits:prng_traits",
8481
"//willow/src/traits:proto_serialization_traits",
8582
"//willow/src/traits:vahe_traits",
8683
],
@@ -96,13 +93,11 @@ rust_test(
9693
"@crate_index//:googletest",
9794
"//willow/src/shell:kahe_shell",
9895
"//willow/src/shell:parameters_shell",
99-
"//willow/src/shell:single_thread_hkdf",
10096
"//willow/src/shell:vahe_shell",
10197
"//willow/src/testing_utils",
10298
"//willow/src/traits:ahe_traits",
10399
"//willow/src/traits:client_traits",
104100
"//willow/src/traits:decryptor_traits",
105-
"//willow/src/traits:prng_traits",
106101
"//willow/src/traits:proto_serialization_traits",
107102
"//willow/src/traits:server_traits",
108103
"//willow/src/traits:verifier_traits",
@@ -158,15 +153,13 @@ rust_test(
158153
"//shell_wrapper:status_matchers_rs",
159154
"//willow/src/shell:kahe_shell",
160155
"//willow/src/shell:parameters_shell",
161-
"//willow/src/shell:single_thread_hkdf",
162156
"//willow/src/shell:vahe_shell",
163157
"//willow/src/testing_utils",
164158
"//willow/src/testing_utils:shell_testing_parameters",
165159
"//willow/src/traits:ahe_traits",
166160
"//willow/src/traits:client_traits",
167161
"//willow/src/traits:decryptor_traits",
168162
"//willow/src/traits:kahe_traits",
169-
"//willow/src/traits:prng_traits",
170163
"//willow/src/traits:proto_serialization_traits",
171164
"//willow/src/traits:server_traits",
172165
"//willow/src/traits:vahe_traits",

willow/src/willow_v1/client.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,7 @@ mod test {
102102
use googletest::{gtest, verify_eq, verify_that};
103103
use kahe_shell::ShellKahe;
104104
use parameters_shell::create_shell_configs;
105-
use prng_traits::SecurePrng;
106105
use shell_testing_decryptor::ShellTestingDecryptor;
107-
use single_thread_hkdf::SingleThreadHkdfPrng;
108106
use std::collections::HashMap;
109107
use testing_utils::generate_random_nonce;
110108
use vahe_shell::ShellVahe;

willow/src/willow_v1/decryptor.rs

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,19 @@ use ahe_traits::{AheKeygen, PartialDec};
1616
use decryptor_traits::SecureAggregationDecryptor;
1717
use messages::{DecryptorPublicKeyShare, PartialDecryptionRequest, PartialDecryptionResponse};
1818
use messages_rust_proto::DecryptorStateProto;
19+
use prng_traits::SecurePrng;
1920
use proto_serialization_traits::{FromProto, ToProto};
2021
use protobuf::AsView;
2122
use shell_ciphertexts_rust_proto::ShellAheSecretKeyShare;
2223
use status::StatusError;
24+
use std::cell::RefCell;
2325
use vahe_traits::{EncryptVerify, HasVahe, VaheBase};
2426

2527
/// Lightweight decryptor directly exposing KAHE/VAHE types. It verifies only the client proofs,
2628
/// does not provide verifiable partial decryptions.
2729
pub struct WillowV1Decryptor<Vahe: VaheBase> {
2830
pub vahe: Vahe,
29-
pub prng: Vahe::Rng,
31+
pub prng: RefCell<Vahe::Rng>,
3032
}
3133

3234
impl<Vahe: VaheBase> HasVahe for WillowV1Decryptor<Vahe> {
@@ -36,6 +38,14 @@ impl<Vahe: VaheBase> HasVahe for WillowV1Decryptor<Vahe> {
3638
}
3739
}
3840

41+
impl<Vahe: VaheBase> WillowV1Decryptor<Vahe> {
42+
pub fn new_with_randomly_generated_seed(vahe: Vahe) -> Result<Self, status::StatusError> {
43+
let seed = Vahe::Rng::generate_seed()?;
44+
let prng = RefCell::new(Vahe::Rng::create(&seed)?);
45+
Ok(Self { vahe, prng })
46+
}
47+
}
48+
3949
pub struct DecryptorState<Vahe: VaheBase> {
4050
sk_share: Option<Vahe::SecretKeyShare>,
4151
}
@@ -97,18 +107,18 @@ where
97107
/// Creates a public key share to be sent to the Server, updating the
98108
/// decryptor state.
99109
fn create_public_key_share(
100-
&mut self,
110+
&self,
101111
decryptor_state: &mut Self::DecryptorState,
102112
) -> Result<DecryptorPublicKeyShare<Vahe>, status::StatusError> {
103-
let (sk_share, pk_share, _) = self.vahe.key_gen(&mut self.prng)?;
113+
let (sk_share, pk_share, _) = self.vahe.key_gen(&mut self.prng.borrow_mut())?;
104114
decryptor_state.sk_share = Some(sk_share);
105115
Ok(pk_share)
106116
}
107117

108118
/// Handles a partial decryption request received from the Server. Returns a
109119
/// partial decryption to the Server.
110120
fn handle_partial_decryption_request(
111-
&mut self,
121+
&self,
112122
partial_decryption_request: PartialDecryptionRequest<Vahe>,
113123
decryptor_state: &Self::DecryptorState,
114124
) -> Result<PartialDecryptionResponse<Vahe>, status::StatusError> {
@@ -121,7 +131,7 @@ where
121131
let pd = self.vahe.partial_decrypt(
122132
&partial_decryption_request.partial_dec_ciphertext,
123133
sk_share,
124-
&mut self.prng,
134+
&mut self.prng.borrow_mut(),
125135
)?;
126136
Ok(PartialDecryptionResponse { partial_decryption: pd })
127137
}
@@ -134,19 +144,15 @@ mod tests {
134144
use decryptor_traits::SecureAggregationDecryptor;
135145
use googletest::{gtest, verify_true};
136146
use parameters_shell::create_shell_ahe_config;
137-
use prng_traits::SecurePrng;
138147
use proto_serialization_traits::{FromProto, ToProto};
139-
use single_thread_hkdf::SingleThreadHkdfPrng;
140148
use vahe_shell::ShellVahe;
141149

142150
const CONTEXT_STRING: &[u8] = b"testing_context_string";
143151

144152
#[gtest]
145153
fn decryptor_state_serialization_roundtrip() -> googletest::Result<()> {
146154
let vahe = ShellVahe::new(create_shell_ahe_config(1).unwrap(), CONTEXT_STRING).unwrap();
147-
let seed = SingleThreadHkdfPrng::generate_seed()?;
148-
let prng = SingleThreadHkdfPrng::create(&seed)?;
149-
let mut decryptor = WillowV1Decryptor { vahe, prng };
155+
let decryptor = WillowV1Decryptor::new_with_randomly_generated_seed(vahe)?;
150156
let mut decryptor_state = DecryptorState::default();
151157

152158
// Check empty state serialization.

willow/src/willow_v1/server.rs

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -362,10 +362,8 @@ mod tests {
362362
use googletest::{gtest, verify_true};
363363
use kahe_shell::ShellKahe;
364364
use parameters_shell::{create_shell_ahe_config, create_shell_kahe_config};
365-
use prng_traits::SecurePrng;
366365
use proto_serialization_traits::{FromProto, ToProto};
367366
use server_traits::SecureAggregationServer;
368-
use single_thread_hkdf::SingleThreadHkdfPrng;
369367
use std::collections::HashMap;
370368
use testing_utils::{generate_aggregation_config, generate_random_nonce};
371369
use vahe_shell::ShellVahe;
@@ -400,10 +398,8 @@ mod tests {
400398
CONTEXT_STRING,
401399
)
402400
.unwrap();
403-
let seed = SingleThreadHkdfPrng::generate_seed()?;
404-
let prng = SingleThreadHkdfPrng::create(&seed)?;
405401
let mut decryptor_state = DecryptorState::default();
406-
let mut decryptor = WillowV1Decryptor { vahe, prng };
402+
let decryptor = WillowV1Decryptor::new_with_randomly_generated_seed(vahe)?;
407403

408404
// Create server.
409405
let kahe =

willow/src/willow_v1/verifier.rs

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -271,10 +271,8 @@ mod tests {
271271
use kahe_shell::ShellKahe;
272272
use kahe_traits::KaheBase;
273273
use parameters_shell::{create_shell_ahe_config, create_shell_kahe_config};
274-
use prng_traits::SecurePrng;
275274
use proto_serialization_traits::{FromProto, ToProto};
276275
use server_traits::SecureAggregationServer;
277-
use single_thread_hkdf::SingleThreadHkdfPrng;
278276
use status_matchers_rs::status_is;
279277
use std::collections::HashMap;
280278
use testing_utils::{generate_aggregation_config, generate_random_nonce};
@@ -314,10 +312,8 @@ mod tests {
314312
CONTEXT_STRING,
315313
)
316314
.unwrap();
317-
let seed = SingleThreadHkdfPrng::generate_seed()?;
318-
let prng = SingleThreadHkdfPrng::create(&seed)?;
319315
let mut decryptor_state = DecryptorState::default();
320-
let mut decryptor = WillowV1Decryptor { vahe, prng };
316+
let decryptor = WillowV1Decryptor::new_with_randomly_generated_seed(vahe)?;
321317

322318
// Create server.
323319
let kahe =

0 commit comments

Comments
 (0)