55// modified for our needs.
66//
77
8- use crate :: types;
98use crate :: types:: { to_scalar, IndexedValue , ShareIndex } ;
109use fastcrypto:: error:: { FastCryptoError , FastCryptoResult } ;
1110use fastcrypto:: groups:: { GroupElement , MultiScalarMul , Scalar } ;
1211use fastcrypto:: traits:: AllowedRng ;
13- use itertools:: { Either , Itertools } ;
12+ use itertools:: Itertools ;
1413use serde:: { Deserialize , Serialize } ;
1514use std:: borrow:: Borrow ;
16- use std:: collections:: HashSet ;
1715use std:: mem:: swap;
1816use std:: num:: NonZeroU16 ;
19- use std:: ops:: { Add , AddAssign , Index , Mul , MulAssign , SubAssign } ;
17+ use std:: ops:: { Add , AddAssign , Div , Index , Mul , MulAssign , SubAssign } ;
2018
2119/// Types
2220
@@ -171,73 +169,80 @@ impl<C: GroupElement> Poly<C> {
171169 ) )
172170 }
173171
174- // Multiply using u128 if possible, otherwise just convert one element to the group element and return the other.
175- pub ( crate ) fn fast_mult ( x : u128 , y : u128 ) -> Either < ( C :: ScalarType , u128 ) , u128 > {
176- if x. leading_zeros ( ) >= ( 128 - y. leading_zeros ( ) ) {
177- Either :: Right ( x * y)
172+ /// Multiply x.1 with y using u128s if possible, otherwise convert x.1 to the group element and multiply.
173+ /// Invariant: If res = fast_mult(x1, x2, y) then x.0 * x.1 * y = res.0 * res.1.
174+ pub ( crate ) fn fast_mult ( x : ( C :: ScalarType , u128 ) , y : u128 ) -> ( C :: ScalarType , u128 ) {
175+ if x. 1 . leading_zeros ( ) >= ( 128 - y. leading_zeros ( ) ) {
176+ ( x. 0 , x. 1 * y)
178177 } else {
179- Either :: Left ( ( C :: ScalarType :: from ( x) , y) )
178+ ( x . 0 * C :: ScalarType :: from ( x. 1 ) , y)
180179 }
181180 }
182181
183- // Expects exactly t unique shares.
182+ /// Compute initial * \prod factors.
183+ pub ( crate ) fn fast_product (
184+ initial : C :: ScalarType ,
185+ factors : impl Iterator < Item = u128 > ,
186+ ) -> C :: ScalarType {
187+ let ( result, remaining) = factors. fold ( ( initial, 1 ) , |acc, factor| {
188+ debug_assert_ne ! ( factor, 0 ) ;
189+ Self :: fast_mult ( acc, factor)
190+ } ) ;
191+ debug_assert_ne ! ( remaining, 0 ) ;
192+ result * C :: ScalarType :: from ( remaining)
193+ }
194+
184195 fn get_lagrange_coefficients_for_c0 (
185196 t : u16 ,
186- mut shares : impl Iterator < Item = impl Borrow < Eval < C > > > ,
187- ) -> FastCryptoResult < Vec < C :: ScalarType > > {
188- let mut ids_set = HashSet :: new ( ) ;
189- let ( shares_size_lower, shares_size_upper) = shares. size_hint ( ) ;
190- let indices = shares. try_fold (
191- Vec :: with_capacity ( shares_size_upper. unwrap_or ( shares_size_lower) ) ,
192- |mut vec, s| {
193- // Check for duplicates.
194- if !ids_set. insert ( s. borrow ( ) . index ) {
195- return Err ( FastCryptoError :: InvalidInput ) ; // expected unique ids
196- }
197- vec. push ( s. borrow ( ) . index . get ( ) as u128 ) ;
198- Ok ( vec)
199- } ,
200- ) ?;
201- if indices. len ( ) != t as usize {
197+ shares : impl Iterator < Item = impl Borrow < Eval < C > > > ,
198+ ) -> FastCryptoResult < ( C :: ScalarType , Vec < C :: ScalarType > ) > {
199+ Self :: get_lagrange_coefficients_for ( 0 , t, shares)
200+ }
201+
202+ /// Expects exactly t unique shares.
203+ /// Returns an error if x is one of the indices.
204+ fn get_lagrange_coefficients_for (
205+ x : u128 ,
206+ t : u16 ,
207+ shares : impl Iterator < Item = impl Borrow < Eval < C > > > ,
208+ ) -> FastCryptoResult < ( C :: ScalarType , Vec < C :: ScalarType > ) > {
209+ let indices = shares. map ( |s| s. borrow ( ) . index . get ( ) as u128 ) . collect_vec ( ) ;
210+ if !indices. iter ( ) . all_unique ( ) || indices. len ( ) != t as usize || indices. contains ( & x) {
202211 return Err ( FastCryptoError :: InvalidInput ) ;
203212 }
204213
205- let full_numerator =
206- C :: ScalarType :: product ( indices. iter ( ) . map ( |i| C :: ScalarType :: from ( * i) ) ) ;
207-
208- let mut coeffs = Vec :: new ( ) ;
209- for i in & indices {
210- let mut negative = false ;
211- let ( mut denominator, remaining) = indices. iter ( ) . filter ( |j| * j != i) . fold (
212- ( C :: ScalarType :: from ( * i) , 1u128 ) ,
213- |( prev_acc, remaining) , j| {
214- let diff = if i > j {
215- negative = !negative;
216- i - j
217- } else {
218- // i < j (but not equal)
219- j - i
220- } ;
221- debug_assert_ne ! ( diff, 0 ) ;
222- let either = Self :: fast_mult ( remaining, diff) ;
223- match either {
224- Either :: Left ( ( remaining_as_scalar, diff) ) => {
225- ( prev_acc * remaining_as_scalar, diff)
226- }
227- Either :: Right ( new_remaining) => ( prev_acc, new_remaining) ,
214+ let x_as_scalar = C :: ScalarType :: from ( x) ;
215+ let full_numerator = C :: ScalarType :: product (
216+ indices
217+ . iter ( )
218+ . map ( |i| C :: ScalarType :: from ( * i) - x_as_scalar) ,
219+ ) ;
220+
221+ Ok ( (
222+ full_numerator,
223+ indices
224+ . iter ( )
225+ . map ( |i| {
226+ let mut negative = false ;
227+ let mut denominator = Self :: fast_product (
228+ C :: ScalarType :: from ( * i) - x_as_scalar,
229+ indices. iter ( ) . filter ( |j| * j != i) . map ( |j| {
230+ if i > j {
231+ negative = !negative;
232+ i - j
233+ } else {
234+ // i < j (but not equal)
235+ j - i
236+ }
237+ } ) ,
238+ ) ;
239+ if negative {
240+ denominator = -denominator;
228241 }
229- } ,
230- ) ;
231- debug_assert_ne ! ( remaining, 0 ) ;
232- denominator = denominator * C :: ScalarType :: from ( remaining) ;
233- if negative {
234- denominator = -denominator;
235- }
236- // TODO: Consider returning full_numerator and dividing once outside instead of here per iteration.
237- let coeff = full_numerator / denominator;
238- coeffs. push ( coeff. expect ( "safe since i != j" ) ) ;
239- }
240- Ok ( coeffs)
242+ denominator. inverse ( ) . expect ( "safe since i != j" )
243+ } )
244+ . collect ( ) ,
245+ ) )
241246 }
242247
243248 /// Given exactly `t` polynomial evaluations, it will recover the polynomial's constant term.
@@ -247,9 +252,12 @@ impl<C: GroupElement> Poly<C> {
247252 shares : impl Iterator < Item = impl Borrow < Eval < C > > > + Clone ,
248253 ) -> FastCryptoResult < C > {
249254 let coeffs = Self :: get_lagrange_coefficients_for_c0 ( t, shares. clone ( ) ) ?;
250- let plain_shares = shares. map ( |s| s. borrow ( ) . value ) ;
251- let res = C :: sum ( coeffs. iter ( ) . zip ( plain_shares) . map ( |( c, s) | s * c) ) ;
252- Ok ( res)
255+ Ok ( C :: sum (
256+ shares
257+ . map ( |s| s. borrow ( ) . value )
258+ . zip ( coeffs. 1 )
259+ . map ( |( c, s) | c * s) ,
260+ ) * coeffs. 0 )
253261 }
254262
255263 /// Checks if a given share is valid.
@@ -331,35 +339,23 @@ impl<C: Scalar> Poly<C> {
331339 /// Returns an error if the input is invalid (e.g., empty or duplicate indices).
332340 ///
333341 /// This is faster than first recovering the polynomial and then evaluating it at the given index.
334- pub fn interpolate_at_index (
335- index : ShareIndex ,
336- points : & [ Eval < C > ] ,
337- ) -> FastCryptoResult < Eval < C > > {
338- if points. is_empty ( ) {
339- return Err ( FastCryptoError :: InvalidInput ) ;
340- }
341- if !points. iter ( ) . map ( |p| p. index ) . all_unique ( ) {
342- return Err ( FastCryptoError :: InvalidInput ) ;
342+ pub fn recover_at ( index : ShareIndex , points : & [ Eval < C > ] ) -> FastCryptoResult < Eval < C > > {
343+ // If the index we're looking for is already given, we can return that
344+ if let Some ( point) = points. iter ( ) . find ( |p| p. index == index) {
345+ return Ok ( point. clone ( ) ) ;
343346 }
344- let x: C = to_scalar ( index) ;
345-
346- // Convert indices to scalars for interpolation.
347- let indices = points
348- . iter ( )
349- . map ( |p| to_scalar ( p. index ) )
350- . collect :: < Vec < _ > > ( ) ;
351-
352- let value = C :: sum ( indices. iter ( ) . enumerate ( ) . map ( |( j, x_j) | {
353- let numerator = C :: product ( indices. iter ( ) . filter ( |x_i| * x_i != x_j) . map ( |x_i| x - x_i) ) ;
354- let denominator = C :: product (
355- indices
356- . iter ( )
357- . filter ( |x_i| * x_i != x_j)
358- . map ( |x_i| * x_j - x_i) ,
359- ) ;
360- points[ j] . value * ( numerator / denominator) . unwrap ( )
361- } ) ) ;
362-
347+ let lagrange_coefficients = Self :: get_lagrange_coefficients_for (
348+ index. get ( ) as u128 ,
349+ points. len ( ) as u16 ,
350+ points. iter ( ) ,
351+ ) ?;
352+ let value = C :: sum (
353+ lagrange_coefficients
354+ . 1
355+ . iter ( )
356+ . zip ( points. iter ( ) . map ( |p| p. value ) )
357+ . map ( |( c, s) | s * c) ,
358+ ) * lagrange_coefficients. 0 ;
363359 Ok ( Eval { index, value } )
364360 }
365361
@@ -371,37 +367,45 @@ impl<C: Scalar> Poly<C> {
371367 if points. is_empty ( ) || !points. iter ( ) . map ( |p| p. index ) . all_unique ( ) {
372368 return Err ( FastCryptoError :: InvalidInput ) ;
373369 }
374- let x: Vec < C > = points
375- . iter ( )
376- . map ( |e| types:: to_scalar ( e. index ) )
377- . collect_vec ( ) ;
378370
379371 // Compute the full numerator polynomial: (x - x_1)(x - x_2)...(x - x_t)
380372 let mut full_numerator = Poly :: one ( ) ;
381- for x_i in & x {
382- full_numerator *= MonicLinear ( -* x_i ) ;
373+ for point in points {
374+ full_numerator *= MonicLinear ( -to_scalar :: < C > ( point . index ) ) ;
383375 }
384376
385- Ok ( Poly :: sum ( points. iter ( ) . enumerate ( ) . map ( |( j, p_j) | {
386- let denominator = C :: product (
387- x. iter ( )
377+ Ok ( Poly :: sum ( points. iter ( ) . enumerate ( ) . map ( |( i, p_i) | {
378+ let x_i = p_i. index . get ( ) as u128 ;
379+ let mut negative = false ;
380+ let mut denominator = Self :: fast_product (
381+ C :: ScalarType :: generator ( ) ,
382+ points
383+ . iter ( )
388384 . enumerate ( )
389- . filter ( |( i, _) | * i != j)
390- . map ( |( _, x_i) | x[ j] - x_i) ,
385+ . filter ( |( j, _) | * j != i)
386+ . map ( |( _, p_j) | {
387+ let x_j = p_j. index . get ( ) as u128 ;
388+ if x_i > x_j {
389+ negative = !negative;
390+ x_i - x_j
391+ } else {
392+ x_j - x_i
393+ }
394+ } ) ,
391395 ) ;
392- // Safe since (x - x[j]) divides full_numerator per definition
393- div_exact ( & full_numerator, & MonicLinear ( -x[ j] ) ) * & ( p_j. value / denominator) . unwrap ( )
396+ if negative {
397+ denominator = -denominator;
398+ }
399+ ( & full_numerator / MonicLinear ( -to_scalar :: < C > ( p_i. index ) ) )
400+ * & ( p_i. value / denominator) . unwrap ( )
394401 } ) ) )
395402 }
396403
397404 /// Returns the leading term of the polynomial.
398405 /// If the polynomial is zero, returns a monomial with coefficient zero and degree zero.
399406 fn lead ( & self ) -> Monomial < C > {
400407 if self . is_zero ( ) {
401- return Monomial {
402- coefficient : C :: zero ( ) ,
403- degree : 0 ,
404- } ;
408+ return Monomial :: zero ( ) ;
405409 }
406410 let degree = self . degree ( ) ;
407411 Monomial {
@@ -419,17 +423,11 @@ impl<C: Scalar> Poly<C> {
419423 let mut remainder = self . clone ( ) ;
420424 let mut quotient = Self :: zero ( ) ;
421425
422- let lead_inverse = divisor. lead ( ) . coefficient . inverse ( ) ?;
423-
424426 // Function to divide a term by the leading term of the divisor.
425- // This panics if the degree of the given term is less than that of the divisor.
426- let divider = |p : Monomial < C > | Monomial {
427- coefficient : p. coefficient * lead_inverse,
428- degree : p. degree - divisor. degree ( ) ,
429- } ;
427+ let divider = divisor. lead ( ) . divider ( ) ;
430428
431429 while !remainder. is_zero ( ) && remainder. degree ( ) >= divisor. degree ( ) {
432- let tmp = divider ( remainder. lead ( ) ) ;
430+ let tmp = divider ( & remainder. lead ( ) ) ;
433431 quotient += & tmp;
434432 remainder -= divisor * & tmp;
435433 remainder. reduce ( ) ;
@@ -477,12 +475,12 @@ impl<C: GroupElement + MultiScalarMul> Poly<C> {
477475 ) -> Result < C , FastCryptoError > {
478476 let coeffs = Self :: get_lagrange_coefficients_for_c0 ( t, shares. clone ( ) ) ?;
479477 let plain_shares = shares. map ( |s| s. borrow ( ) . value ) . collect :: < Vec < _ > > ( ) ;
480- let res = C :: multi_scalar_mul ( & coeffs, & plain_shares) . expect ( "sizes match" ) ;
478+ let res = C :: multi_scalar_mul ( & coeffs. 1 , & plain_shares) . expect ( "sizes match" ) * coeffs . 0 ;
481479 Ok ( res)
482480 }
483481}
484482
485- /// This represents a monomial, e.g., 3 * x ^2, where 3 is the coefficient and 2 is the degree.
483+ /// This represents a monomial, e.g., 3x ^2, where 3 is the coefficient and 2 is the degree.
486484struct Monomial < C > {
487485 coefficient : C ,
488486 degree : usize ,
@@ -512,6 +510,25 @@ impl<C: Scalar> Mul<&Monomial<C>> for &Poly<C> {
512510 }
513511}
514512
513+ impl < C : Scalar > Monomial < C > {
514+ /// Returns a closure which on input `x` computes the division `x / self`.
515+ /// Panics if the degree of `x` is smaller than `self` or if `self` is zero.
516+ fn divider ( self ) -> impl Fn ( & Monomial < C > ) -> Monomial < C > {
517+ let inverse = self . coefficient . inverse ( ) . unwrap ( ) ;
518+ move |p : & Monomial < C > | Monomial {
519+ coefficient : p. coefficient * inverse,
520+ degree : p. degree - self . degree ,
521+ }
522+ }
523+
524+ fn zero ( ) -> Self {
525+ Monomial {
526+ coefficient : C :: zero ( ) ,
527+ degree : 0 ,
528+ }
529+ }
530+ }
531+
515532/// Represents a monic linear polynomial of the form x + c.
516533pub ( crate ) struct MonicLinear < C > ( pub C ) ;
517534
@@ -529,16 +546,16 @@ impl<C: Scalar> MulAssign<MonicLinear<C>> for Poly<C> {
529546 }
530547}
531548
532- /// Assuming that `d` divides `n` exactly (or, that `d.0` is a root in `n`), return the quotient `n / d`.
533- fn div_exact < C : Scalar > ( n : & Poly < C > , d : & MonicLinear < C > ) -> Poly < C > {
534- if n. is_zero ( ) {
535- return Poly :: zero ( ) ;
536- }
537- let mut result = n. 0 [ 1 ..] . to_vec ( ) ;
538- for i in ( 0 ..result. len ( ) - 1 ) . rev ( ) {
539- result[ i] = result[ i] - result[ i + 1 ] * d. 0 ;
549+ impl < C : Scalar > Div < MonicLinear < C > > for & Poly < C > {
550+ type Output = Poly < C > ;
551+
552+ fn div ( self , rhs : MonicLinear < C > ) -> Self :: Output {
553+ let mut result = self . 0 [ 1 ..] . to_vec ( ) ;
554+ for i in ( 0 ..result. len ( ) - 1 ) . rev ( ) {
555+ result[ i] = result[ i] - result[ i + 1 ] * rhs. 0 ;
556+ }
557+ Poly :: from ( result)
540558 }
541- Poly :: from ( result)
542559}
543560
544561#[ cfg( test) ]
0 commit comments