11// === Standard Library ===
2- use std:: sync:: Arc ;
2+ use std:: { marker :: PhantomData , sync:: Arc } ;
33
44// === External Crates ===
55use kms_grpc:: {
@@ -12,7 +12,7 @@ use threshold_fhe::{
1212 algebra:: galois_rings:: degree_4:: { ResiduePolyF4Z128 , ResiduePolyF4Z64 } ,
1313 execution:: {
1414 runtime:: session:: ParameterHandles ,
15- small_execution:: prss:: { PRSSInit , PRSSSetup , RobustSecurePrssInit } ,
15+ small_execution:: prss:: { PRSSInit , PRSSSetup } ,
1616 } ,
1717 networking:: NetworkMode ,
1818} ;
@@ -33,16 +33,26 @@ use crate::{
3333// === Current Module Imports ===
3434use super :: { session:: SessionPreparer , RealThresholdKms } ;
3535
36- pub struct RealInitiator < PrivS : Storage + Send + Sync + ' static > {
36+ pub struct RealInitiator <
37+ PrivS : Storage + Send + Sync + ' static ,
38+ Init : PRSSInit < ResiduePolyF4Z64 > + PRSSInit < ResiduePolyF4Z128 > ,
39+ > {
3740 // TODO eventually add mode to allow for nlarge as well.
3841 pub prss_setup_z128 : Arc < RwLock < Option < PRSSSetup < ResiduePolyF4Z128 > > > > ,
3942 pub prss_setup_z64 : Arc < RwLock < Option < PRSSSetup < ResiduePolyF4Z64 > > > > ,
4043 pub private_storage : Arc < Mutex < PrivS > > ,
4144 pub session_preparer : SessionPreparer ,
42- pub health_reporter : Arc < RwLock < HealthReporter > > ,
45+ pub health_reporter : HealthReporter ,
46+ pub ( crate ) _init : PhantomData < Init > ,
4347}
4448
45- impl < PrivS : Storage + Send + Sync + ' static > RealInitiator < PrivS > {
49+ impl <
50+ PrivS : Storage + Send + Sync + ' static ,
51+ Init : PRSSInit < ResiduePolyF4Z64 , OutputType = PRSSSetup < ResiduePolyF4Z64 > >
52+ + PRSSInit < ResiduePolyF4Z128 , OutputType = PRSSSetup < ResiduePolyF4Z128 > >
53+ + Default ,
54+ > RealInitiator < PrivS , Init >
55+ {
4656 pub async fn init_prss_from_disk ( & self , req_id : & RequestId ) -> anyhow:: Result < ( ) > {
4757 let prss_setup_z128_from_file = {
4858 let guarded_private_storage = self . private_storage . lock ( ) . await ;
@@ -109,8 +119,6 @@ impl<PrivS: Storage + Send + Sync + 'static> RealInitiator<PrivS> {
109119 {
110120 // Notice that this is a hack to get the health reporter to report serving. The type `PrivS` has no influence on the service name.
111121 self . health_reporter
112- . write ( )
113- . await
114122 . set_serving :: < CoreServiceEndpointServer < RealThresholdKms < PrivS , PrivS > > > ( )
115123 . await ;
116124 }
@@ -141,13 +149,14 @@ impl<PrivS: Storage + Send + Sync + 'static> RealInitiator<PrivS> {
141149 "Role assignments: {:?}" ,
142150 base_session. parameters. role_assignments( )
143151 ) ;
144- let prss_setup_obj_z128: PRSSSetup < ResiduePolyF4Z128 > = RobustSecurePrssInit :: default ( )
145- . init ( & mut base_session)
146- . await ?;
147152
148- let prss_setup_obj_z64: PRSSSetup < ResiduePolyF4Z64 > = RobustSecurePrssInit :: default ( )
149- . init ( & mut base_session)
150- . await ?;
153+ // It seems we cannot do something like
154+ // `Init::default().init(&mut base_session).await?;`
155+ // as the type inference gets confused even when using the correct return type.
156+ let prss_setup_obj_z128: PRSSSetup < ResiduePolyF4Z128 > =
157+ PRSSInit :: < ResiduePolyF4Z128 > :: init ( & Init :: default ( ) , & mut base_session) . await ?;
158+ let prss_setup_obj_z64: PRSSSetup < ResiduePolyF4Z64 > =
159+ PRSSInit :: < ResiduePolyF4Z64 > :: init ( & Init :: default ( ) , & mut base_session) . await ?;
151160
152161 let mut guarded_prss_setup = self . prss_setup_z128 . write ( ) . await ;
153162 * guarded_prss_setup = Some ( prss_setup_obj_z128. clone ( ) ) ;
@@ -186,8 +195,6 @@ impl<PrivS: Storage + Send + Sync + 'static> RealInitiator<PrivS> {
186195 {
187196 // Notice that this is a hack to get the health reporter to report serving. The type `PrivS` has no influence on the service name.
188197 self . health_reporter
189- . write ( )
190- . await
191198 . set_serving :: < CoreServiceEndpointServer < RealThresholdKms < PrivS , PrivS > > > ( )
192199 . await ;
193200 }
@@ -197,7 +204,13 @@ impl<PrivS: Storage + Send + Sync + 'static> RealInitiator<PrivS> {
197204}
198205
199206#[ tonic:: async_trait]
200- impl < PrivS : Storage + Send + Sync + ' static > Initiator for RealInitiator < PrivS > {
207+ impl <
208+ PrivS : Storage + Send + Sync + ' static ,
209+ Init : PRSSInit < ResiduePolyF4Z64 , OutputType = PRSSSetup < ResiduePolyF4Z64 > >
210+ + PRSSInit < ResiduePolyF4Z128 , OutputType = PRSSSetup < ResiduePolyF4Z128 > >
211+ + Default ,
212+ > Initiator for RealInitiator < PrivS , Init >
213+ {
201214 async fn init ( & self , request : Request < v1:: InitRequest > ) -> Result < Response < Empty > , Status > {
202215 let inner = request. into_inner ( ) ;
203216 let request_id = tonic_some_or_err (
@@ -219,14 +232,38 @@ impl<PrivS: Storage + Send + Sync + 'static> Initiator for RealInitiator<PrivS>
219232
220233#[ cfg( test) ]
221234mod tests {
235+ use super :: * ;
236+
222237 use crate :: {
223238 client:: test_tools:: { self } ,
224239 consts:: PRSS_INIT_REQ_ID ,
225- engine:: base:: derive_request_id,
226240 util:: key_setup:: test_tools:: purge,
227- vault:: storage:: { file:: FileStorage , StorageType } ,
241+ vault:: storage:: { file:: FileStorage , ram, StorageType } ,
242+ } ;
243+ use aes_prng:: AesRng ;
244+ use kms_grpc:: kms:: v1:: InitRequest ;
245+ use rand:: SeedableRng ;
246+ use threshold_fhe:: {
247+ execution:: runtime:: party:: Role ,
248+ malicious_execution:: small_execution:: malicious_prss:: { EmptyPrss , FailingPrss } ,
228249 } ;
229- use threshold_fhe:: execution:: runtime:: party:: Role ;
250+
251+ impl <
252+ Init : PRSSInit < ResiduePolyF4Z64 , OutputType = PRSSSetup < ResiduePolyF4Z64 > >
253+ + PRSSInit < ResiduePolyF4Z128 , OutputType = PRSSSetup < ResiduePolyF4Z128 > > ,
254+ > RealInitiator < ram:: RamStorage , Init >
255+ {
256+ fn init_test ( session_preparer : SessionPreparer ) -> Self {
257+ Self {
258+ prss_setup_z128 : Arc :: new ( RwLock :: new ( None ) ) ,
259+ prss_setup_z64 : Arc :: new ( RwLock :: new ( None ) ) ,
260+ private_storage : Arc :: new ( Mutex :: new ( ram:: RamStorage :: new ( ) ) ) ,
261+ session_preparer,
262+ health_reporter : HealthReporter :: new ( ) ,
263+ _init : PhantomData ,
264+ }
265+ }
266+ }
230267
231268 #[ tokio:: test]
232269 #[ serial_test:: serial]
@@ -316,4 +353,75 @@ mod tests {
316353 "Initializing threshold KMS server with PRSS Setup Z64 from disk"
317354 ) ) ;
318355 }
356+
357+ #[ tokio:: test]
358+ async fn sunshine ( ) {
359+ let session_preparer = SessionPreparer :: new_test_session ( false ) ;
360+ let initiator = RealInitiator :: < ram:: RamStorage , EmptyPrss > :: init_test ( session_preparer) ;
361+
362+ let mut rng = AesRng :: seed_from_u64 ( 42 ) ;
363+ let req_id = RequestId :: new_random ( & mut rng) ;
364+ initiator
365+ . init ( tonic:: Request :: new ( InitRequest {
366+ request_id : Some ( req_id. into ( ) ) ,
367+ } ) )
368+ . await
369+ . unwrap ( ) ;
370+ }
371+ #[ tokio:: test]
372+ async fn invalid_argument ( ) {
373+ let session_preparer = SessionPreparer :: new_test_session ( false ) ;
374+ let initiator = RealInitiator :: < ram:: RamStorage , EmptyPrss > :: init_test ( session_preparer) ;
375+
376+ let bad_req_id = kms_grpc:: kms:: v1:: RequestId {
377+ request_id : "bad req id" . to_string ( ) ,
378+ } ;
379+ assert_eq ! (
380+ initiator
381+ . init( tonic:: Request :: new( InitRequest {
382+ request_id: Some ( bad_req_id)
383+ } ) )
384+ . await
385+ . unwrap_err( )
386+ . code( ) ,
387+ tonic:: Code :: InvalidArgument
388+ ) ;
389+ }
390+
391+ #[ tokio:: test]
392+ async fn aborted ( ) {
393+ let session_preparer = SessionPreparer :: new_test_session ( false ) ;
394+ let initiator = RealInitiator :: < ram:: RamStorage , EmptyPrss > :: init_test ( session_preparer) ;
395+
396+ assert_eq ! (
397+ initiator
398+ . init( tonic:: Request :: new( InitRequest {
399+ // this is set to none
400+ request_id: None
401+ } ) )
402+ . await
403+ . unwrap_err( )
404+ . code( ) ,
405+ tonic:: Code :: Aborted
406+ ) ;
407+ }
408+
409+ #[ tokio:: test]
410+ async fn internal ( ) {
411+ let session_preparer = SessionPreparer :: new_test_session ( false ) ;
412+ let initiator = RealInitiator :: < ram:: RamStorage , FailingPrss > :: init_test ( session_preparer) ;
413+
414+ let mut rng = AesRng :: seed_from_u64 ( 42 ) ;
415+ let req_id = RequestId :: new_random ( & mut rng) ;
416+ assert_eq ! (
417+ initiator
418+ . init( tonic:: Request :: new( InitRequest {
419+ request_id: Some ( req_id. into( ) )
420+ } ) )
421+ . await
422+ . unwrap_err( )
423+ . code( ) ,
424+ tonic:: Code :: Internal
425+ ) ;
426+ }
319427}
0 commit comments