Skip to content

Commit e5a20f7

Browse files
authored
Refactor + use existing functions (#886)
* Misc simplifications * copyright * more + docs * refactor * Align variable name * xlippy * fmt * Faster interpolate * refactor * refactor * simplify * refactor * review * Use all_unique * simplify
1 parent f2deafa commit e5a20f7

File tree

6 files changed

+173
-152
lines changed

6 files changed

+173
-152
lines changed

fastcrypto-tbls/benches/polynomial.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ mod polynomial_benches {
121121

122122
c.bench_function(format!("interpolate at index t={t}").as_str(), |b| {
123123
b.iter(|| {
124-
let _ = Poly::interpolate_at_index(NonZeroU16::new(7).unwrap(), &points).unwrap();
124+
let _ = Poly::recover_at(NonZeroU16::new(307).unwrap(), &points).unwrap();
125125
})
126126
});
127127
}

fastcrypto-tbls/src/polynomial.rs

Lines changed: 145 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,16 @@
55
// modified for our needs.
66
//
77

8-
use crate::types;
98
use crate::types::{to_scalar, IndexedValue, ShareIndex};
109
use fastcrypto::error::{FastCryptoError, FastCryptoResult};
1110
use fastcrypto::groups::{GroupElement, MultiScalarMul, Scalar};
1211
use fastcrypto::traits::AllowedRng;
13-
use itertools::{Either, Itertools};
12+
use itertools::Itertools;
1413
use serde::{Deserialize, Serialize};
1514
use std::borrow::Borrow;
16-
use std::collections::HashSet;
1715
use std::mem::swap;
1816
use std::num::NonZeroU16;
19-
use std::ops::{Add, AddAssign, Index, Mul, MulAssign, SubAssign};
17+
use std::ops::{Add, AddAssign, Div, Index, Mul, MulAssign, SubAssign};
2018

2119
/// Types
2220
@@ -171,73 +169,80 @@ impl<C: GroupElement> Poly<C> {
171169
))
172170
}
173171

174-
// Multiply using u128 if possible, otherwise just convert one element to the group element and return the other.
175-
pub(crate) fn fast_mult(x: u128, y: u128) -> Either<(C::ScalarType, u128), u128> {
176-
if x.leading_zeros() >= (128 - y.leading_zeros()) {
177-
Either::Right(x * y)
172+
/// Multiply x.1 with y using u128s if possible, otherwise convert x.1 to the group element and multiply.
173+
/// Invariant: If res = fast_mult(x1, x2, y) then x.0 * x.1 * y = res.0 * res.1.
174+
pub(crate) fn fast_mult(x: (C::ScalarType, u128), y: u128) -> (C::ScalarType, u128) {
175+
if x.1.leading_zeros() >= (128 - y.leading_zeros()) {
176+
(x.0, x.1 * y)
178177
} else {
179-
Either::Left((C::ScalarType::from(x), y))
178+
(x.0 * C::ScalarType::from(x.1), y)
180179
}
181180
}
182181

183-
// Expects exactly t unique shares.
182+
/// Compute initial * \prod factors.
183+
pub(crate) fn fast_product(
184+
initial: C::ScalarType,
185+
factors: impl Iterator<Item = u128>,
186+
) -> C::ScalarType {
187+
let (result, remaining) = factors.fold((initial, 1), |acc, factor| {
188+
debug_assert_ne!(factor, 0);
189+
Self::fast_mult(acc, factor)
190+
});
191+
debug_assert_ne!(remaining, 0);
192+
result * C::ScalarType::from(remaining)
193+
}
194+
184195
fn get_lagrange_coefficients_for_c0(
185196
t: u16,
186-
mut shares: impl Iterator<Item = impl Borrow<Eval<C>>>,
187-
) -> FastCryptoResult<Vec<C::ScalarType>> {
188-
let mut ids_set = HashSet::new();
189-
let (shares_size_lower, shares_size_upper) = shares.size_hint();
190-
let indices = shares.try_fold(
191-
Vec::with_capacity(shares_size_upper.unwrap_or(shares_size_lower)),
192-
|mut vec, s| {
193-
// Check for duplicates.
194-
if !ids_set.insert(s.borrow().index) {
195-
return Err(FastCryptoError::InvalidInput); // expected unique ids
196-
}
197-
vec.push(s.borrow().index.get() as u128);
198-
Ok(vec)
199-
},
200-
)?;
201-
if indices.len() != t as usize {
197+
shares: impl Iterator<Item = impl Borrow<Eval<C>>>,
198+
) -> FastCryptoResult<(C::ScalarType, Vec<C::ScalarType>)> {
199+
Self::get_lagrange_coefficients_for(0, t, shares)
200+
}
201+
202+
/// Expects exactly t unique shares.
203+
/// Returns an error if x is one of the indices.
204+
fn get_lagrange_coefficients_for(
205+
x: u128,
206+
t: u16,
207+
shares: impl Iterator<Item = impl Borrow<Eval<C>>>,
208+
) -> FastCryptoResult<(C::ScalarType, Vec<C::ScalarType>)> {
209+
let indices = shares.map(|s| s.borrow().index.get() as u128).collect_vec();
210+
if !indices.iter().all_unique() || indices.len() != t as usize || indices.contains(&x) {
202211
return Err(FastCryptoError::InvalidInput);
203212
}
204213

205-
let full_numerator =
206-
C::ScalarType::product(indices.iter().map(|i| C::ScalarType::from(*i)));
207-
208-
let mut coeffs = Vec::new();
209-
for i in &indices {
210-
let mut negative = false;
211-
let (mut denominator, remaining) = indices.iter().filter(|j| *j != i).fold(
212-
(C::ScalarType::from(*i), 1u128),
213-
|(prev_acc, remaining), j| {
214-
let diff = if i > j {
215-
negative = !negative;
216-
i - j
217-
} else {
218-
// i < j (but not equal)
219-
j - i
220-
};
221-
debug_assert_ne!(diff, 0);
222-
let either = Self::fast_mult(remaining, diff);
223-
match either {
224-
Either::Left((remaining_as_scalar, diff)) => {
225-
(prev_acc * remaining_as_scalar, diff)
226-
}
227-
Either::Right(new_remaining) => (prev_acc, new_remaining),
214+
let x_as_scalar = C::ScalarType::from(x);
215+
let full_numerator = C::ScalarType::product(
216+
indices
217+
.iter()
218+
.map(|i| C::ScalarType::from(*i) - x_as_scalar),
219+
);
220+
221+
Ok((
222+
full_numerator,
223+
indices
224+
.iter()
225+
.map(|i| {
226+
let mut negative = false;
227+
let mut denominator = Self::fast_product(
228+
C::ScalarType::from(*i) - x_as_scalar,
229+
indices.iter().filter(|j| *j != i).map(|j| {
230+
if i > j {
231+
negative = !negative;
232+
i - j
233+
} else {
234+
// i < j (but not equal)
235+
j - i
236+
}
237+
}),
238+
);
239+
if negative {
240+
denominator = -denominator;
228241
}
229-
},
230-
);
231-
debug_assert_ne!(remaining, 0);
232-
denominator = denominator * C::ScalarType::from(remaining);
233-
if negative {
234-
denominator = -denominator;
235-
}
236-
// TODO: Consider returning full_numerator and dividing once outside instead of here per iteration.
237-
let coeff = full_numerator / denominator;
238-
coeffs.push(coeff.expect("safe since i != j"));
239-
}
240-
Ok(coeffs)
242+
denominator.inverse().expect("safe since i != j")
243+
})
244+
.collect(),
245+
))
241246
}
242247

243248
/// Given exactly `t` polynomial evaluations, it will recover the polynomial's constant term.
@@ -247,9 +252,12 @@ impl<C: GroupElement> Poly<C> {
247252
shares: impl Iterator<Item = impl Borrow<Eval<C>>> + Clone,
248253
) -> FastCryptoResult<C> {
249254
let coeffs = Self::get_lagrange_coefficients_for_c0(t, shares.clone())?;
250-
let plain_shares = shares.map(|s| s.borrow().value);
251-
let res = C::sum(coeffs.iter().zip(plain_shares).map(|(c, s)| s * c));
252-
Ok(res)
255+
Ok(C::sum(
256+
shares
257+
.map(|s| s.borrow().value)
258+
.zip(coeffs.1)
259+
.map(|(c, s)| c * s),
260+
) * coeffs.0)
253261
}
254262

255263
/// Checks if a given share is valid.
@@ -331,35 +339,23 @@ impl<C: Scalar> Poly<C> {
331339
/// Returns an error if the input is invalid (e.g., empty or duplicate indices).
332340
///
333341
/// This is faster than first recovering the polynomial and then evaluating it at the given index.
334-
pub fn interpolate_at_index(
335-
index: ShareIndex,
336-
points: &[Eval<C>],
337-
) -> FastCryptoResult<Eval<C>> {
338-
if points.is_empty() {
339-
return Err(FastCryptoError::InvalidInput);
340-
}
341-
if !points.iter().map(|p| p.index).all_unique() {
342-
return Err(FastCryptoError::InvalidInput);
342+
pub fn recover_at(index: ShareIndex, points: &[Eval<C>]) -> FastCryptoResult<Eval<C>> {
343+
// If the index we're looking for is already given, we can return that
344+
if let Some(point) = points.iter().find(|p| p.index == index) {
345+
return Ok(point.clone());
343346
}
344-
let x: C = to_scalar(index);
345-
346-
// Convert indices to scalars for interpolation.
347-
let indices = points
348-
.iter()
349-
.map(|p| to_scalar(p.index))
350-
.collect::<Vec<_>>();
351-
352-
let value = C::sum(indices.iter().enumerate().map(|(j, x_j)| {
353-
let numerator = C::product(indices.iter().filter(|x_i| *x_i != x_j).map(|x_i| x - x_i));
354-
let denominator = C::product(
355-
indices
356-
.iter()
357-
.filter(|x_i| *x_i != x_j)
358-
.map(|x_i| *x_j - x_i),
359-
);
360-
points[j].value * (numerator / denominator).unwrap()
361-
}));
362-
347+
let lagrange_coefficients = Self::get_lagrange_coefficients_for(
348+
index.get() as u128,
349+
points.len() as u16,
350+
points.iter(),
351+
)?;
352+
let value = C::sum(
353+
lagrange_coefficients
354+
.1
355+
.iter()
356+
.zip(points.iter().map(|p| p.value))
357+
.map(|(c, s)| s * c),
358+
) * lagrange_coefficients.0;
363359
Ok(Eval { index, value })
364360
}
365361

@@ -371,37 +367,45 @@ impl<C: Scalar> Poly<C> {
371367
if points.is_empty() || !points.iter().map(|p| p.index).all_unique() {
372368
return Err(FastCryptoError::InvalidInput);
373369
}
374-
let x: Vec<C> = points
375-
.iter()
376-
.map(|e| types::to_scalar(e.index))
377-
.collect_vec();
378370

379371
// Compute the full numerator polynomial: (x - x_1)(x - x_2)...(x - x_t)
380372
let mut full_numerator = Poly::one();
381-
for x_i in &x {
382-
full_numerator *= MonicLinear(-*x_i);
373+
for point in points {
374+
full_numerator *= MonicLinear(-to_scalar::<C>(point.index));
383375
}
384376

385-
Ok(Poly::sum(points.iter().enumerate().map(|(j, p_j)| {
386-
let denominator = C::product(
387-
x.iter()
377+
Ok(Poly::sum(points.iter().enumerate().map(|(i, p_i)| {
378+
let x_i = p_i.index.get() as u128;
379+
let mut negative = false;
380+
let mut denominator = Self::fast_product(
381+
C::ScalarType::generator(),
382+
points
383+
.iter()
388384
.enumerate()
389-
.filter(|(i, _)| *i != j)
390-
.map(|(_, x_i)| x[j] - x_i),
385+
.filter(|(j, _)| *j != i)
386+
.map(|(_, p_j)| {
387+
let x_j = p_j.index.get() as u128;
388+
if x_i > x_j {
389+
negative = !negative;
390+
x_i - x_j
391+
} else {
392+
x_j - x_i
393+
}
394+
}),
391395
);
392-
// Safe since (x - x[j]) divides full_numerator per definition
393-
div_exact(&full_numerator, &MonicLinear(-x[j])) * &(p_j.value / denominator).unwrap()
396+
if negative {
397+
denominator = -denominator;
398+
}
399+
(&full_numerator / MonicLinear(-to_scalar::<C>(p_i.index)))
400+
* &(p_i.value / denominator).unwrap()
394401
})))
395402
}
396403

397404
/// Returns the leading term of the polynomial.
398405
/// If the polynomial is zero, returns a monomial with coefficient zero and degree zero.
399406
fn lead(&self) -> Monomial<C> {
400407
if self.is_zero() {
401-
return Monomial {
402-
coefficient: C::zero(),
403-
degree: 0,
404-
};
408+
return Monomial::zero();
405409
}
406410
let degree = self.degree();
407411
Monomial {
@@ -419,17 +423,11 @@ impl<C: Scalar> Poly<C> {
419423
let mut remainder = self.clone();
420424
let mut quotient = Self::zero();
421425

422-
let lead_inverse = divisor.lead().coefficient.inverse()?;
423-
424426
// Function to divide a term by the leading term of the divisor.
425-
// This panics if the degree of the given term is less than that of the divisor.
426-
let divider = |p: Monomial<C>| Monomial {
427-
coefficient: p.coefficient * lead_inverse,
428-
degree: p.degree - divisor.degree(),
429-
};
427+
let divider = divisor.lead().divider();
430428

431429
while !remainder.is_zero() && remainder.degree() >= divisor.degree() {
432-
let tmp = divider(remainder.lead());
430+
let tmp = divider(&remainder.lead());
433431
quotient += &tmp;
434432
remainder -= divisor * &tmp;
435433
remainder.reduce();
@@ -477,12 +475,12 @@ impl<C: GroupElement + MultiScalarMul> Poly<C> {
477475
) -> Result<C, FastCryptoError> {
478476
let coeffs = Self::get_lagrange_coefficients_for_c0(t, shares.clone())?;
479477
let plain_shares = shares.map(|s| s.borrow().value).collect::<Vec<_>>();
480-
let res = C::multi_scalar_mul(&coeffs, &plain_shares).expect("sizes match");
478+
let res = C::multi_scalar_mul(&coeffs.1, &plain_shares).expect("sizes match") * coeffs.0;
481479
Ok(res)
482480
}
483481
}
484482

485-
/// This represents a monomial, e.g., 3 * x^2, where 3 is the coefficient and 2 is the degree.
483+
/// This represents a monomial, e.g., 3x^2, where 3 is the coefficient and 2 is the degree.
486484
struct Monomial<C> {
487485
coefficient: C,
488486
degree: usize,
@@ -512,6 +510,25 @@ impl<C: Scalar> Mul<&Monomial<C>> for &Poly<C> {
512510
}
513511
}
514512

513+
impl<C: Scalar> Monomial<C> {
514+
/// Returns a closure which on input `x` computes the division `x / self`.
515+
/// Panics if the degree of `x` is smaller than `self` or if `self` is zero.
516+
fn divider(self) -> impl Fn(&Monomial<C>) -> Monomial<C> {
517+
let inverse = self.coefficient.inverse().unwrap();
518+
move |p: &Monomial<C>| Monomial {
519+
coefficient: p.coefficient * inverse,
520+
degree: p.degree - self.degree,
521+
}
522+
}
523+
524+
fn zero() -> Self {
525+
Monomial {
526+
coefficient: C::zero(),
527+
degree: 0,
528+
}
529+
}
530+
}
531+
515532
/// Represents a monic linear polynomial of the form x + c.
516533
pub(crate) struct MonicLinear<C>(pub C);
517534

@@ -529,16 +546,16 @@ impl<C: Scalar> MulAssign<MonicLinear<C>> for Poly<C> {
529546
}
530547
}
531548

532-
/// Assuming that `d` divides `n` exactly (or, that `d.0` is a root in `n`), return the quotient `n / d`.
533-
fn div_exact<C: Scalar>(n: &Poly<C>, d: &MonicLinear<C>) -> Poly<C> {
534-
if n.is_zero() {
535-
return Poly::zero();
536-
}
537-
let mut result = n.0[1..].to_vec();
538-
for i in (0..result.len() - 1).rev() {
539-
result[i] = result[i] - result[i + 1] * d.0;
549+
impl<C: Scalar> Div<MonicLinear<C>> for &Poly<C> {
550+
type Output = Poly<C>;
551+
552+
fn div(self, rhs: MonicLinear<C>) -> Self::Output {
553+
let mut result = self.0[1..].to_vec();
554+
for i in (0..result.len() - 1).rev() {
555+
result[i] = result[i] - result[i + 1] * rhs.0;
556+
}
557+
Poly::from(result)
540558
}
541-
Poly::from(result)
542559
}
543560

544561
#[cfg(test)]

0 commit comments

Comments
 (0)