Skip to content

Commit e03a5a1

Browse files
Make center_vt constant time and rename to center
- Implement `center` as a constant-time method in `Modulus`. - Implement `center_vec` as a constant-time vector method. - Update `center_vec_vt` to call `center_vec`. - Remove private `center_vt`. - Add tests. Co-authored-by: tlepoint <1345502+tlepoint@users.noreply.github.com>
1 parent bd67af0 commit e03a5a1

1 file changed

Lines changed: 46 additions & 16 deletions

File tree

crates/fhe-math/src/zq/mod.rs

Lines changed: 46 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -440,20 +440,30 @@ impl Modulus {
440440
.dispatch(|| a.iter_mut().for_each(|ai| *ai = self.reduce(*ai)))
441441
}
442442

443-
/// Center a value modulo p as i64 in variable time.
444-
/// TODO: To test and to make constant time?
443+
/// Center a value modulo p as i64 in constant time.
445444
///
446-
/// # Safety
447-
/// This function is not constant time and its timing may reveal information
448-
/// about the value being centered.
449-
const unsafe fn center_vt(&self, a: u64) -> i64 {
445+
/// The output is in the interval `[-p/2, p/2)`.
446+
/// Aborts if `a >= p` in debug mode.
447+
#[must_use]
448+
pub const fn center(&self, a: u64) -> i64 {
450449
debug_assert!(a < self.p);
451450

452-
if a >= self.p >> 1 {
453-
(a as i64) - (self.p as i64)
454-
} else {
455-
a as i64
456-
}
451+
let threshold = self.p >> 1;
452+
let cond = a >= threshold;
453+
let on_true = (a as i64).wrapping_sub(self.p as i64) as u64;
454+
let on_false = a as u64;
455+
456+
const_time_cond_select(on_true, on_false, cond) as i64
457+
}
458+
459+
/// Center a vector in constant time.
460+
#[must_use]
461+
pub fn center_vec(&self, a: &[u64]) -> Vec<i64> {
462+
self.arch.dispatch(|| {
463+
a.iter()
464+
.map(|ai| self.center(*ai))
465+
.collect_vec()
466+
})
457467
}
458468

459469
/// Center a vector in variable time.
@@ -463,11 +473,7 @@ impl Modulus {
463473
/// about the values being centered.
464474
#[must_use]
465475
pub unsafe fn center_vec_vt(&self, a: &[u64]) -> Vec<i64> {
466-
self.arch.dispatch(|| {
467-
a.iter()
468-
.map(|ai| unsafe { self.center_vt(*ai) })
469-
.collect_vec()
470-
})
476+
self.center_vec(a)
471477
}
472478

473479
/// Reduce a vector in place in variable time.
@@ -1093,6 +1099,30 @@ mod tests {
10931099
let c = p.deserialize_vec(&b);
10941100
prop_assert_eq!(a, c);
10951101
}
1102+
1103+
#[test]
1104+
fn center(p in valid_moduli(), a: u64) {
1105+
let a = p.reduce(a);
1106+
let b = p.center(a);
1107+
if a >= *p >> 1 {
1108+
prop_assert_eq!(b, (a as i64) - (*p as i64));
1109+
} else {
1110+
prop_assert_eq!(b, a as i64);
1111+
}
1112+
prop_assert_eq!(p.reduce_i64(b), a);
1113+
}
1114+
1115+
#[test]
1116+
fn center_vec(p in valid_moduli(), a: Vec<u64>) {
1117+
let mut a = a.clone();
1118+
p.reduce_vec(&mut a);
1119+
let b = p.center_vec(&a);
1120+
prop_assert_eq!(b.len(), a.len());
1121+
for (ai, bi) in izip!(a.iter(), b.iter()) {
1122+
prop_assert_eq!(p.center(*ai), *bi);
1123+
}
1124+
unsafe { prop_assert_eq!(p.center_vec_vt(&a), b); }
1125+
}
10961126
}
10971127

10981128
// TODO: Make a proptest.

0 commit comments

Comments
 (0)