@@ -440,20 +440,30 @@ impl Modulus {
440440 . dispatch ( || a. iter_mut ( ) . for_each ( |ai| * ai = self . reduce ( * ai) ) )
441441 }
442442
443- /// Center a value modulo p as i64 in variable time.
444- /// TODO: To test and to make constant time?
443+ /// Center a value modulo p as i64 in constant time.
445444 ///
446- /// # Safety
447- /// This function is not constant time and its timing may reveal information
448- /// about the value being centered.
449- const unsafe fn center_vt ( & self , a : u64 ) -> i64 {
445+ /// The output is in the interval `[-p/2, p/2)`.
446+ /// Aborts if `a >= p` in debug mode.
447+ # [ must_use ]
448+ pub const fn center ( & self , a : u64 ) -> i64 {
450449 debug_assert ! ( a < self . p) ;
451450
452- if a >= self . p >> 1 {
453- ( a as i64 ) - ( self . p as i64 )
454- } else {
455- a as i64
456- }
451+ let threshold = self . p >> 1 ;
452+ let cond = a >= threshold;
453+ let on_true = ( a as i64 ) . wrapping_sub ( self . p as i64 ) as u64 ;
454+ let on_false = a as u64 ;
455+
456+ const_time_cond_select ( on_true, on_false, cond) as i64
457+ }
458+
459+ /// Center a vector in constant time.
460+ #[ must_use]
461+ pub fn center_vec ( & self , a : & [ u64 ] ) -> Vec < i64 > {
462+ self . arch . dispatch ( || {
463+ a. iter ( )
464+ . map ( |ai| self . center ( * ai) )
465+ . collect_vec ( )
466+ } )
457467 }
458468
459469 /// Center a vector in variable time.
@@ -463,11 +473,7 @@ impl Modulus {
463473 /// about the values being centered.
464474 #[ must_use]
465475 pub unsafe fn center_vec_vt ( & self , a : & [ u64 ] ) -> Vec < i64 > {
466- self . arch . dispatch ( || {
467- a. iter ( )
468- . map ( |ai| unsafe { self . center_vt ( * ai) } )
469- . collect_vec ( )
470- } )
476+ self . center_vec ( a)
471477 }
472478
473479 /// Reduce a vector in place in variable time.
@@ -1093,6 +1099,30 @@ mod tests {
10931099 let c = p. deserialize_vec( & b) ;
10941100 prop_assert_eq!( a, c) ;
10951101 }
1102+
1103+ #[ test]
1104+ fn center( p in valid_moduli( ) , a: u64 ) {
1105+ let a = p. reduce( a) ;
1106+ let b = p. center( a) ;
1107+ if a >= * p >> 1 {
1108+ prop_assert_eq!( b, ( a as i64 ) - ( * p as i64 ) ) ;
1109+ } else {
1110+ prop_assert_eq!( b, a as i64 ) ;
1111+ }
1112+ prop_assert_eq!( p. reduce_i64( b) , a) ;
1113+ }
1114+
1115+ #[ test]
1116+ fn center_vec( p in valid_moduli( ) , a: Vec <u64 >) {
1117+ let mut a = a. clone( ) ;
1118+ p. reduce_vec( & mut a) ;
1119+ let b = p. center_vec( & a) ;
1120+ prop_assert_eq!( b. len( ) , a. len( ) ) ;
1121+ for ( ai, bi) in izip!( a. iter( ) , b. iter( ) ) {
1122+ prop_assert_eq!( p. center( * ai) , * bi) ;
1123+ }
1124+ unsafe { prop_assert_eq!( p. center_vec_vt( & a) , b) ; }
1125+ }
10961126 }
10971127
10981128 // TODO: Make a proptest.
0 commit comments