Skip to content

Commit 84b2c49

Browse files
authored
Merge branch 'main' into center_vt_constant_time-3415539913779892970
2 parents 62fd78c + 89d4411 commit 84b2c49

1 file changed

Lines changed: 35 additions & 59 deletions

File tree

crates/fhe-math/src/zq/mod.rs

Lines changed: 35 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -795,7 +795,7 @@ impl Modulus {
795795

796796
#[cfg(test)]
797797
mod tests {
798-
use super::{Modulus, primes};
798+
use super::Modulus;
799799
use itertools::{Itertools, izip};
800800
use proptest::collection::vec as prop_vec;
801801
use proptest::prelude::{BoxedStrategy, Just, Strategy, any};
@@ -807,6 +807,10 @@ mod tests {
807807
any::<u64>().prop_filter_map("filter invalid moduli", |p| Modulus::new(p).ok())
808808
}
809809

810+
fn valid_moduli_opt() -> impl Strategy<Value = Modulus> {
811+
valid_moduli().prop_filter("filter moduli not supporting opt", |p| p.supports_opt)
812+
}
813+
810814
fn vecs() -> BoxedStrategy<(Vec<u64>, Vec<u64>)> {
811815
prop_vec(any::<u64>(), 1..100)
812816
.prop_flat_map(|vec| {
@@ -1108,77 +1112,49 @@ mod tests {
11081112
for (ai, bi) in izip!(a.iter(), b.iter()) {
11091113
prop_assert_eq!(p.center(*ai), *bi);
11101114
}
1111-
}
1112-
}
1113-
1114-
// TODO: Make a proptest.
1115-
#[test]
1116-
fn mul_opt() {
1117-
let ntests = 100;
1118-
let mut rng = rand::rng();
1115+
}
11191116

1120-
for p in [4611686018326724609] {
1121-
let q = Modulus::new(p).unwrap();
1122-
assert!(primes::supports_opt(p));
1117+
#[test]
1118+
fn mul_opt(p in valid_moduli_opt(), mut a: u64, mut b: u64) {
1119+
a = p.reduce(a);
1120+
b = p.reduce(b);
11231121

1124-
assert_eq!(q.mul_opt(0, 1), 0);
1125-
assert_eq!(q.mul_opt(1, 1), 1);
1126-
assert_eq!(q.mul_opt(2 % p, 3 % p), 6 % p);
1127-
assert_eq!(q.mul_opt(p - 1, 1), p - 1);
1128-
assert_eq!(q.mul_opt(p - 1, 2 % p), p - 2);
1122+
prop_assert_eq!(p.mul_opt(a, b) as u128, ((a as u128) * (b as u128)) % (*p as u128));
1123+
unsafe { prop_assert_eq!(p.mul_opt_vt(a, b) as u128, ((a as u128) * (b as u128)) % (*p as u128)) }
11291124

11301125
#[cfg(debug_assertions)]
11311126
{
1132-
assert!(std::panic::catch_unwind(|| q.mul_opt(p, 1)).is_err());
1133-
assert!(std::panic::catch_unwind(|| q.mul_opt(p << 1, 1)).is_err());
1134-
assert!(std::panic::catch_unwind(|| q.mul_opt(0, p)).is_err());
1135-
assert!(std::panic::catch_unwind(|| q.mul_opt(0, p << 1)).is_err());
1136-
}
1137-
1138-
for _ in 0..ntests {
1139-
let a = rng.next_u64() % p;
1140-
let b = rng.next_u64() % p;
1141-
assert_eq!(
1142-
q.mul_opt(a, b),
1143-
(((a as u128) * (b as u128)) % (p as u128)) as u64
1144-
);
1127+
prop_assert!(std::panic::catch_unwind(|| p.mul_opt(*p, a)).is_err());
1128+
prop_assert!(std::panic::catch_unwind(|| p.mul_opt(a, *p)).is_err());
1129+
prop_assert!(std::panic::catch_unwind(|| p.mul_opt(*p + 1, a)).is_err());
1130+
prop_assert!(std::panic::catch_unwind(|| p.mul_opt(a, *p + 1)).is_err());
11451131
}
11461132
}
1147-
}
11481133

1149-
// TODO: Make a proptest.
1150-
#[test]
1151-
fn pow() {
1152-
let ntests = 10;
1153-
let mut rng = rand::rng();
1134+
#[test]
1135+
fn pow(p in valid_moduli(), mut a: u64, mut b: u64) {
1136+
a = p.reduce(a);
1137+
b = p.reduce(b);
11541138

1155-
for p in [2u64, 3, 17, 1987, 4611686018326724609] {
1156-
let q = Modulus::new(p).unwrap();
1139+
prop_assert_eq!(p.pow(a, 0), 1);
1140+
prop_assert_eq!(p.pow(a, 1), a);
1141+
if *p > 2 {
1142+
prop_assert_eq!(p.pow(a, 2), p.mul(a, a));
1143+
}
11571144

1158-
assert_eq!(q.pow(p - 1, 0), 1);
1159-
assert_eq!(q.pow(p - 1, 1), p - 1);
1160-
assert_eq!(q.pow(p - 1, 2 % p), 1);
1161-
assert_eq!(q.pow(1, p - 2), 1);
1162-
assert_eq!(q.pow(1, p - 1), 1);
1145+
let b_small = b % 1000;
1146+
let mut r = 1;
1147+
for _ in 0..b_small {
1148+
r = p.mul(r, a);
1149+
}
1150+
prop_assert_eq!(p.pow(a, b_small), r);
11631151

11641152
#[cfg(debug_assertions)]
11651153
{
1166-
assert!(std::panic::catch_unwind(|| q.pow(p, 1)).is_err());
1167-
assert!(std::panic::catch_unwind(|| q.pow(p << 1, 1)).is_err());
1168-
assert!(std::panic::catch_unwind(|| q.pow(0, p)).is_err());
1169-
assert!(std::panic::catch_unwind(|| q.pow(0, p << 1)).is_err());
1170-
}
1171-
1172-
for _ in 0..ntests {
1173-
let a = rng.next_u64() % p;
1174-
let b = (rng.next_u64() % p) % 1000;
1175-
let mut c = b;
1176-
let mut r = 1;
1177-
while c > 0 {
1178-
r = q.mul(r, a);
1179-
c -= 1;
1180-
}
1181-
assert_eq!(q.pow(a, b), r);
1154+
prop_assert!(std::panic::catch_unwind(|| p.pow(*p, 1)).is_err());
1155+
prop_assert!(std::panic::catch_unwind(|| p.pow(*p << 1, 1)).is_err());
1156+
prop_assert!(std::panic::catch_unwind(|| p.pow(0, *p)).is_err());
1157+
prop_assert!(std::panic::catch_unwind(|| p.pow(0, *p << 1)).is_err());
11821158
}
11831159
}
11841160
}

0 commit comments

Comments
 (0)