@@ -14,6 +14,26 @@ use zeroize::Zeroize;
1414impl AddAssign < & Poly > for Poly {
1515 fn add_assign ( & mut self , p : & Poly ) {
1616 assert ! ( !self . has_lazy_coefficients && !p. has_lazy_coefficients) ;
17+
18+ // p and self must have the same context.
19+ debug_assert_eq ! ( self . ctx, p. ctx, "Incompatible contexts" ) ;
20+
21+ // p and q must have comptatible representations.
22+ match self . representation {
23+ Representation :: PowerBasis => assert_eq ! (
24+ p. representation,
25+ Representation :: PowerBasis ,
26+ "Incompatible representations"
27+ ) ,
28+ Representation :: Ntt | Representation :: NttShoup => assert ! (
29+ p. representation == Representation :: Ntt
30+ || p. representation == Representation :: NttShoup ,
31+ "Incompatible representations"
32+ ) ,
33+ }
34+
35+ // If the representation is NttShoup, drop the Shoup coefficients
36+ // and switch to Ntt representation.
1737 if self . representation == Representation :: NttShoup {
1838 self . coefficients_shoup
1939 . as_mut ( )
@@ -24,19 +44,6 @@ impl AddAssign<&Poly> for Poly {
2444 unsafe { self . override_representation ( Representation :: Ntt ) }
2545 }
2646
27- if self . representation == Representation :: Ntt {
28- assert ! (
29- p. representation == Representation :: Ntt
30- || p. representation == Representation :: NttShoup ,
31- "Incompatible representations"
32- )
33- } else {
34- assert_eq ! (
35- self . representation, p. representation,
36- "Incompatible representations"
37- ) ;
38- }
39- debug_assert_eq ! ( self . ctx, p. ctx, "Incompatible contexts" ) ;
4047 self . allow_variable_time_computations |= p. allow_variable_time_computations ;
4148 if self . allow_variable_time_computations {
4249 izip ! (
@@ -63,26 +70,15 @@ impl AddAssign<&Poly> for Poly {
6370impl Add < & Poly > for & Poly {
6471 type Output = Poly ;
6572 fn add ( self , p : & Poly ) -> Poly {
66- match self . representation {
67- Representation :: NttShoup => {
68- let mut q = self . clone ( ) ;
69- if q. representation == Representation :: NttShoup {
70- q. coefficients_shoup
71- . as_mut ( )
72- . unwrap ( )
73- . as_slice_mut ( )
74- . unwrap ( )
75- . zeroize ( ) ;
76- unsafe { q. override_representation ( Representation :: Ntt ) }
77- }
78- q += p;
79- q
80- }
81- Representation :: PowerBasis | Representation :: Ntt => {
82- let mut q = self . clone ( ) ;
83- q += p;
84- q
85- }
73+ // if self is in NttShoup representation, let's copy `p` instead
74+ if self . representation == Representation :: NttShoup {
75+ let mut q = p. clone ( ) ;
76+ q += self ;
77+ q
78+ } else {
79+ let mut q = self . clone ( ) ;
80+ q += p;
81+ q
8682 }
8783 }
8884}
@@ -98,24 +94,36 @@ impl Add for Poly {
9894impl SubAssign < & Poly > for Poly {
9995 fn sub_assign ( & mut self , p : & Poly ) {
10096 assert ! ( !self . has_lazy_coefficients && !p. has_lazy_coefficients) ;
101- assert_ne ! (
102- self . representation,
103- Representation :: NttShoup ,
104- "Cannot subtract from a polynomial in NttShoup representation"
105- ) ;
106- if self . representation == Representation :: Ntt {
107- assert ! (
97+
98+ // p and self must have the same context.
99+ debug_assert_eq ! ( self . ctx, p. ctx, "Incompatible contexts" ) ;
100+
101+ // p and q must have comptatible representations.
102+ match self . representation {
103+ Representation :: PowerBasis => assert_eq ! (
104+ p. representation,
105+ Representation :: PowerBasis ,
106+ "Incompatible representations"
107+ ) ,
108+ Representation :: Ntt | Representation :: NttShoup => assert ! (
108109 p. representation == Representation :: Ntt
109110 || p. representation == Representation :: NttShoup ,
110111 "Incompatible representations"
111- )
112- } else {
113- assert_eq ! (
114- self . representation, p. representation,
115- "Incompatible representations"
116- ) ;
112+ ) ,
117113 }
118- debug_assert_eq ! ( self . ctx, p. ctx, "Incompatible contexts" ) ;
114+
115+ // If the representation is NttShoup, drop the Shoup coefficients
116+ // and switch to Ntt representation.
117+ if self . representation == Representation :: NttShoup {
118+ self . coefficients_shoup
119+ . as_mut ( )
120+ . unwrap ( )
121+ . as_slice_mut ( )
122+ . unwrap ( )
123+ . zeroize ( ) ;
124+ unsafe { self . override_representation ( Representation :: Ntt ) }
125+ }
126+
119127 self . allow_variable_time_computations |= p. allow_variable_time_computations ;
120128 if self . allow_variable_time_computations {
121129 izip ! (
@@ -142,27 +150,9 @@ impl SubAssign<&Poly> for Poly {
142150impl Sub < & Poly > for & Poly {
143151 type Output = Poly ;
144152 fn sub ( self , p : & Poly ) -> Poly {
145- match self . representation {
146- Representation :: NttShoup => {
147- let mut q = self . clone ( ) ;
148- if q. representation == Representation :: NttShoup {
149- q. coefficients_shoup
150- . as_mut ( )
151- . unwrap ( )
152- . as_slice_mut ( )
153- . unwrap ( )
154- . zeroize ( ) ;
155- unsafe { q. override_representation ( Representation :: Ntt ) }
156- }
157- q -= p;
158- q
159- }
160- Representation :: PowerBasis | Representation :: Ntt => {
161- let mut q = self . clone ( ) ;
162- q -= p;
163- q
164- }
165- }
153+ let mut q = self . clone ( ) ;
154+ q -= p;
155+ q
166156 }
167157}
168158
@@ -255,11 +245,17 @@ impl MulAssign<&Poly> for Poly {
255245
256246impl MulAssign < & BigUint > for Poly {
257247 fn mul_assign ( & mut self , p : & BigUint ) {
258- assert_ne ! (
259- self . representation,
260- Representation :: NttShoup ,
261- "Cannot multiply a polynomial in NttShoup representation by a scalar"
262- ) ;
248+ // If the representation is NttShoup, drop the Shoup coefficients
249+ // and switch to Ntt representation.
250+ if self . representation == Representation :: NttShoup {
251+ self . coefficients_shoup
252+ . as_mut ( )
253+ . unwrap ( )
254+ . as_slice_mut ( )
255+ . unwrap ( )
256+ . zeroize ( ) ;
257+ unsafe { self . override_representation ( Representation :: Ntt ) }
258+ }
263259
264260 // Project the scalar into its CRT representation (reduced modulo each prime)
265261 let scalar_crt = self . ctx . rns . project ( p) ;
@@ -291,27 +287,15 @@ impl MulAssign<&BigUint> for Poly {
291287impl Mul < & Poly > for & Poly {
292288 type Output = Poly ;
293289 fn mul ( self , p : & Poly ) -> Poly {
294- match self . representation {
295- Representation :: NttShoup => {
296- // TODO: To test, and do the same thing for add, sub, and neg
297- let mut q = p. clone ( ) ;
298- if q. representation == Representation :: NttShoup {
299- q. coefficients_shoup
300- . as_mut ( )
301- . unwrap ( )
302- . as_slice_mut ( )
303- . unwrap ( )
304- . zeroize ( ) ;
305- unsafe { q. override_representation ( Representation :: Ntt ) }
306- }
307- q *= self ;
308- q
309- }
310- Representation :: PowerBasis | Representation :: Ntt => {
311- let mut q = self . clone ( ) ;
312- q *= p;
313- q
314- }
290+ // if self is in NttShoup representation, let's copy `p` instead
291+ if self . representation == Representation :: NttShoup {
292+ let mut q = p. clone ( ) ;
293+ q *= self ;
294+ q
295+ } else {
296+ let mut q = self . clone ( ) ;
297+ q *= p;
298+ q
315299 }
316300 }
317301}
@@ -852,17 +836,18 @@ mod tests {
852836 }
853837
854838 #[ test]
855- #[ should_panic(
856- expected = "Cannot multiply a polynomial in NttShoup representation by a scalar"
857- ) ]
858- fn mul_scalar_ntt_shoup_panic ( ) {
839+ fn mul_scalar_ntt_shoup ( ) {
859840 use num_bigint:: BigUint ;
860841
861842 let ctx = Arc :: new ( Context :: new ( MODULI , 16 ) . unwrap ( ) ) ;
862843 let mut p = Poly :: random ( & ctx, Representation :: NttShoup , & mut rng ( ) ) ;
844+ let mut p_ntt = p. clone ( ) ;
845+ p_ntt. change_representation ( Representation :: Ntt ) ;
863846 let scalar = BigUint :: from ( 42u64 ) ;
864847
865- // This should panic with the assertion message
866848 p *= & scalar;
849+
850+ assert_eq ! ( p. representation, Representation :: Ntt ) ;
851+ assert_eq ! ( & p_ntt * & scalar, p) ;
867852 }
868853}
0 commit comments