Skip to content

Commit 8f3b645

Browse files
committed
Simplify and cleanup
1 parent 336f23b commit 8f3b645

2 files changed

Lines changed: 90 additions & 103 deletions

File tree

crates/fhe-math/src/rq/ops.rs

Lines changed: 84 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,26 @@ use zeroize::Zeroize;
1414
impl AddAssign<&Poly> for Poly {
1515
fn add_assign(&mut self, p: &Poly) {
1616
assert!(!self.has_lazy_coefficients && !p.has_lazy_coefficients);
17+
18+
// p and self must have the same context.
19+
debug_assert_eq!(self.ctx, p.ctx, "Incompatible contexts");
20+
21+
// p and q must have comptatible representations.
22+
match self.representation {
23+
Representation::PowerBasis => assert_eq!(
24+
p.representation,
25+
Representation::PowerBasis,
26+
"Incompatible representations"
27+
),
28+
Representation::Ntt | Representation::NttShoup => assert!(
29+
p.representation == Representation::Ntt
30+
|| p.representation == Representation::NttShoup,
31+
"Incompatible representations"
32+
),
33+
}
34+
35+
// If the representation is NttShoup, drop the Shoup coefficients
36+
// and switch to Ntt representation.
1737
if self.representation == Representation::NttShoup {
1838
self.coefficients_shoup
1939
.as_mut()
@@ -24,19 +44,6 @@ impl AddAssign<&Poly> for Poly {
2444
unsafe { self.override_representation(Representation::Ntt) }
2545
}
2646

27-
if self.representation == Representation::Ntt {
28-
assert!(
29-
p.representation == Representation::Ntt
30-
|| p.representation == Representation::NttShoup,
31-
"Incompatible representations"
32-
)
33-
} else {
34-
assert_eq!(
35-
self.representation, p.representation,
36-
"Incompatible representations"
37-
);
38-
}
39-
debug_assert_eq!(self.ctx, p.ctx, "Incompatible contexts");
4047
self.allow_variable_time_computations |= p.allow_variable_time_computations;
4148
if self.allow_variable_time_computations {
4249
izip!(
@@ -63,26 +70,15 @@ impl AddAssign<&Poly> for Poly {
6370
impl Add<&Poly> for &Poly {
6471
type Output = Poly;
6572
fn add(self, p: &Poly) -> Poly {
66-
match self.representation {
67-
Representation::NttShoup => {
68-
let mut q = self.clone();
69-
if q.representation == Representation::NttShoup {
70-
q.coefficients_shoup
71-
.as_mut()
72-
.unwrap()
73-
.as_slice_mut()
74-
.unwrap()
75-
.zeroize();
76-
unsafe { q.override_representation(Representation::Ntt) }
77-
}
78-
q += p;
79-
q
80-
}
81-
Representation::PowerBasis | Representation::Ntt => {
82-
let mut q = self.clone();
83-
q += p;
84-
q
85-
}
73+
// if self is in NttShoup representation, let's copy `p` instead
74+
if self.representation == Representation::NttShoup {
75+
let mut q = p.clone();
76+
q += self;
77+
q
78+
} else {
79+
let mut q = self.clone();
80+
q += p;
81+
q
8682
}
8783
}
8884
}
@@ -98,24 +94,36 @@ impl Add for Poly {
9894
impl SubAssign<&Poly> for Poly {
9995
fn sub_assign(&mut self, p: &Poly) {
10096
assert!(!self.has_lazy_coefficients && !p.has_lazy_coefficients);
101-
assert_ne!(
102-
self.representation,
103-
Representation::NttShoup,
104-
"Cannot subtract from a polynomial in NttShoup representation"
105-
);
106-
if self.representation == Representation::Ntt {
107-
assert!(
97+
98+
// p and self must have the same context.
99+
debug_assert_eq!(self.ctx, p.ctx, "Incompatible contexts");
100+
101+
// p and q must have comptatible representations.
102+
match self.representation {
103+
Representation::PowerBasis => assert_eq!(
104+
p.representation,
105+
Representation::PowerBasis,
106+
"Incompatible representations"
107+
),
108+
Representation::Ntt | Representation::NttShoup => assert!(
108109
p.representation == Representation::Ntt
109110
|| p.representation == Representation::NttShoup,
110111
"Incompatible representations"
111-
)
112-
} else {
113-
assert_eq!(
114-
self.representation, p.representation,
115-
"Incompatible representations"
116-
);
112+
),
117113
}
118-
debug_assert_eq!(self.ctx, p.ctx, "Incompatible contexts");
114+
115+
// If the representation is NttShoup, drop the Shoup coefficients
116+
// and switch to Ntt representation.
117+
if self.representation == Representation::NttShoup {
118+
self.coefficients_shoup
119+
.as_mut()
120+
.unwrap()
121+
.as_slice_mut()
122+
.unwrap()
123+
.zeroize();
124+
unsafe { self.override_representation(Representation::Ntt) }
125+
}
126+
119127
self.allow_variable_time_computations |= p.allow_variable_time_computations;
120128
if self.allow_variable_time_computations {
121129
izip!(
@@ -142,27 +150,9 @@ impl SubAssign<&Poly> for Poly {
142150
impl Sub<&Poly> for &Poly {
143151
type Output = Poly;
144152
fn sub(self, p: &Poly) -> Poly {
145-
match self.representation {
146-
Representation::NttShoup => {
147-
let mut q = self.clone();
148-
if q.representation == Representation::NttShoup {
149-
q.coefficients_shoup
150-
.as_mut()
151-
.unwrap()
152-
.as_slice_mut()
153-
.unwrap()
154-
.zeroize();
155-
unsafe { q.override_representation(Representation::Ntt) }
156-
}
157-
q -= p;
158-
q
159-
}
160-
Representation::PowerBasis | Representation::Ntt => {
161-
let mut q = self.clone();
162-
q -= p;
163-
q
164-
}
165-
}
153+
let mut q = self.clone();
154+
q -= p;
155+
q
166156
}
167157
}
168158

@@ -255,11 +245,17 @@ impl MulAssign<&Poly> for Poly {
255245

256246
impl MulAssign<&BigUint> for Poly {
257247
fn mul_assign(&mut self, p: &BigUint) {
258-
assert_ne!(
259-
self.representation,
260-
Representation::NttShoup,
261-
"Cannot multiply a polynomial in NttShoup representation by a scalar"
262-
);
248+
// If the representation is NttShoup, drop the Shoup coefficients
249+
// and switch to Ntt representation.
250+
if self.representation == Representation::NttShoup {
251+
self.coefficients_shoup
252+
.as_mut()
253+
.unwrap()
254+
.as_slice_mut()
255+
.unwrap()
256+
.zeroize();
257+
unsafe { self.override_representation(Representation::Ntt) }
258+
}
263259

264260
// Project the scalar into its CRT representation (reduced modulo each prime)
265261
let scalar_crt = self.ctx.rns.project(p);
@@ -291,27 +287,15 @@ impl MulAssign<&BigUint> for Poly {
291287
impl Mul<&Poly> for &Poly {
292288
type Output = Poly;
293289
fn mul(self, p: &Poly) -> Poly {
294-
match self.representation {
295-
Representation::NttShoup => {
296-
// TODO: To test, and do the same thing for add, sub, and neg
297-
let mut q = p.clone();
298-
if q.representation == Representation::NttShoup {
299-
q.coefficients_shoup
300-
.as_mut()
301-
.unwrap()
302-
.as_slice_mut()
303-
.unwrap()
304-
.zeroize();
305-
unsafe { q.override_representation(Representation::Ntt) }
306-
}
307-
q *= self;
308-
q
309-
}
310-
Representation::PowerBasis | Representation::Ntt => {
311-
let mut q = self.clone();
312-
q *= p;
313-
q
314-
}
290+
// if self is in NttShoup representation, let's copy `p` instead
291+
if self.representation == Representation::NttShoup {
292+
let mut q = p.clone();
293+
q *= self;
294+
q
295+
} else {
296+
let mut q = self.clone();
297+
q *= p;
298+
q
315299
}
316300
}
317301
}
@@ -852,17 +836,18 @@ mod tests {
852836
}
853837

854838
#[test]
855-
#[should_panic(
856-
expected = "Cannot multiply a polynomial in NttShoup representation by a scalar"
857-
)]
858-
fn mul_scalar_ntt_shoup_panic() {
839+
fn mul_scalar_ntt_shoup() {
859840
use num_bigint::BigUint;
860841

861842
let ctx = Arc::new(Context::new(MODULI, 16).unwrap());
862843
let mut p = Poly::random(&ctx, Representation::NttShoup, &mut rng());
844+
let mut p_ntt = p.clone();
845+
p_ntt.change_representation(Representation::Ntt);
863846
let scalar = BigUint::from(42u64);
864847

865-
// This should panic with the assertion message
866848
p *= &scalar;
849+
850+
assert_eq!(p.representation, Representation::Ntt);
851+
assert_eq!(&p_ntt * &scalar, p);
867852
}
868853
}

crates/fhe-math/tests/ntt_shoup_ops.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
//! Unit test for polynomial Shoup operations.
2+
13
use fhe_math::rq::{Context, Poly, Representation};
24
use rand::rng;
35
use std::sync::Arc;
@@ -16,8 +18,8 @@ fn test_ntt_shoup_add_sub_neg() {
1618
let mut q = p.clone();
1719
if *q.representation() == Representation::NttShoup {
1820
unsafe { q.override_representation(Representation::Ntt) };
19-
// Note: override_representation handles shoup cleanup if needed or just switch enum
20-
// But strict conversion:
21+
// Note: override_representation handles shoup cleanup if needed or just switch
22+
// enum But strict conversion:
2123
q.change_representation(Representation::Ntt);
2224
}
2325
q
@@ -37,8 +39,8 @@ fn test_ntt_shoup_add_sub_neg() {
3739

3840
// Case 3: NttShoup + NttShoup (should work if we relaxed AddAssign correctly)
3941
// Wait, AddAssign on LHS=NttShoup is forbidden.
40-
// But Add(&NttShoup, &NttShoup) -> converts LHS to Ntt, then adds RHS (NttShoup).
41-
// So LHS becomes Ntt. Ntt += NttShoup. This should work now.
42+
// But Add(&NttShoup, &NttShoup) -> converts LHS to Ntt, then adds RHS
43+
// (NttShoup). So LHS becomes Ntt. Ntt += NttShoup. This should work now.
4244
let p_shoup2 = Poly::random(&ctx, Representation::NttShoup, &mut rng);
4345
let p_shoup2_as_ntt = to_ntt(&p_shoup2);
4446

0 commit comments

Comments
 (0)