Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/mpcs/src/basefold/commit_phase.rs
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ where
let evaluations: AdditiveVec<E> =
prover_msgs
.into_iter()
.fold(AdditiveVec::new(3), |mut acc, prover_msg| {
.fold(AdditiveVec::new(2), |mut acc, prover_msg| {
acc += AdditiveVec(prover_msg.evaluations);
acc
});
Expand Down
34 changes: 20 additions & 14 deletions crates/mpcs/src/basefold/query_phase.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,22 +282,28 @@ pub fn batch_verifier_query_phase<E: ExtensionField, S: EncodingScheme<E>>(
.sum::<E>();
}
}
assert_eq!(expected_sum, {
sumcheck_messages[0].evaluations[0] + sumcheck_messages[0].evaluations[1]
});
// 2. check every round of sumcheck match with prev claims
for i in 0..fold_challenges.len() - 1 {
assert_eq!(
extrapolate_uni_poly(&sumcheck_messages[i].evaluations, fold_challenges[i]),
{ sumcheck_messages[i + 1].evaluations[0] + sumcheck_messages[i + 1].evaluations[1] }
);

assert_eq!(
sumcheck_messages.len(),
fold_challenges.len(),
"sumcheck messages and fold challenges mismatch"
);

// Reconstruct the implicit P(0) evaluation for each round and update the claim in place.
let mut current_claim = expected_sum;
for (msg, challenge) in sumcheck_messages.iter().zip(fold_challenges.iter()) {
let eval_1 = msg
.evaluations
.first()
.copied()
.expect("sumcheck prover message missing evaluations");
let eval_0 = current_claim - eval_1;
current_claim = extrapolate_uni_poly(eval_0, &msg.evaluations, *challenge);
}
// 3. check final evaluation are correct

// check final evaluation are correct
assert_eq!(
extrapolate_uni_poly(
&sumcheck_messages[fold_challenges.len() - 1].evaluations,
fold_challenges[fold_challenges.len() - 1]
),
current_claim,
// \sum_i eq(p,[r,i]) * f(r,i)
izip!(
final_message,
Expand Down
17 changes: 13 additions & 4 deletions crates/sumcheck/src/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ impl<'a, E: ExtensionField> Phase1Workers<'a, E> {
.workers_states
.par_iter_mut()
.map(|state| state.run_round())
.reduce(|| AdditiveVec::new(max_degree + 1), |a, b| a + b);
.reduce(|| AdditiveVec::new(max_degree), |a, b| a + b);

transcript.append_field_element_exts(&evaluations.0);

Expand Down Expand Up @@ -353,7 +353,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
// Step 2: generate sum for the partial evaluated polynomial:
// f(r_1, ... r_m,, x_{m+1}... x_n)
let span = entered_span!("build_uni_poly");
let AdditiveVec(uni_polys) = self.poly.products.iter().fold(
let AdditiveVec(mut uni_polys) = self.poly.products.iter().fold(
AdditiveVec::new(self.poly.aux_info.max_degree + 1),
|mut uni_polys, MonomialTerms { terms }| {
for Term {
Expand Down Expand Up @@ -405,6 +405,11 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {

exit_span!(start);

assert!(uni_polys.len() > 1);
// NOTE remove uni_polys.eval(0) from lagrange domain
// as verifier can derive via claim - uni_polys.eval(1)
uni_polys.remove(0);

IOPProverMessage {
evaluations: uni_polys,
}
Expand Down Expand Up @@ -603,7 +608,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
// Step 2: generate sum for the partial evaluated polynomial:
// f(r_1, ... r_m,, x_{m+1}... x_n)
let span = entered_span!("build_uni_poly");
let AdditiveVec(uni_polys) = self
let AdditiveVec(mut uni_polys) = self
.poly
.products
.par_iter()
Expand Down Expand Up @@ -654,9 +659,13 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
.reduce_with(|acc, item| acc + item)
.unwrap();
exit_span!(span);

exit_span!(start);

assert!(uni_polys.len() > 1);
// NOTE remove uni_polys.eval(0) from lagrange domain
// as verifier can derive via claim - uni_polys.eval(1)
uni_polys.remove(0);

IOPProverMessage {
evaluations: uni_polys,
}
Expand Down
5 changes: 0 additions & 5 deletions crates/sumcheck/src/structs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,6 @@ use transcript::Challenge;
pub struct IOPProof<E: ExtensionField> {
pub proofs: Vec<IOPProverMessage<E>>,
}
impl<E: ExtensionField> IOPProof<E> {
pub fn extract_sum(&self) -> E {
self.proofs[0].evaluations[0] + self.proofs[0].evaluations[1]
}
}

/// A message from the prover to the verifier at a given round
/// is a list of evaluations.
Expand Down
20 changes: 4 additions & 16 deletions crates/sumcheck/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,21 +190,6 @@ fn test_normal_polynomial_helper<E: ExtensionField>() {
test_sumcheck_internal::<E>(nv, num_multiplicands_range, num_products);
}

#[test]
fn test_extract_sum() {
test_extract_sum_helper::<GoldilocksExt2>();
test_extract_sum_helper::<BabyBearExt4>();
}

fn test_extract_sum_helper<E: ExtensionField>() {
let mut rng = thread_rng();
let mut transcript = BasicTranscript::new(b"test");
let (poly, asserted_sum) = VirtualPolynomial::<E>::random(&[8], (2, 3), 3, &mut rng);
#[allow(deprecated)]
let (proof, _) = IOPProverState::<E>::prove_parallel(poly, &mut transcript);
assert_eq!(proof.extract_sum(), asserted_sum);
}

struct DensePolynomial(Vec<GoldilocksExt2>);

impl DensePolynomial {
Expand Down Expand Up @@ -236,7 +221,10 @@ fn test_extrapolation() {
.map(|i| poly.evaluate(&GoldilocksExt2::from_canonical_u64(i as u64)))
.collect::<Vec<_>>();
let query = GoldilocksExt2::random(&mut prng);
assert_eq!(poly.evaluate(&query), extrapolate_uni_poly(&evals, query));
assert_eq!(
poly.evaluate(&query),
extrapolate_uni_poly(evals[0], &evals[1..], query)
);
}

run_extrapolation_test(1);
Expand Down
55 changes: 30 additions & 25 deletions crates/sumcheck/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ pub fn extrapolate_from_table<E: ExtensionField>(uni_variate: &mut [E], start: u
}
}

fn extrapolate_uni_poly_deg_1<F: Field>(p_i: &[F; 2], eval_at: F) -> F {
fn extrapolate_uni_poly_deg_1<F: Field>(p0: F, p1: F, eval_at: F) -> F {
let x0 = F::ZERO;
let x1 = F::ONE;

Expand All @@ -69,13 +69,13 @@ fn extrapolate_uni_poly_deg_1<F: Field>(p_i: &[F; 2], eval_at: F) -> F {
let inv_d0 = d0.inverse();
let inv_d1 = d1.inverse();

let t0 = w0 * p_i[0] * inv_d0;
let t1 = w1 * p_i[1] * inv_d1;
let t0 = w0 * p0 * inv_d0;
let t1 = w1 * p1 * inv_d1;

l * (t0 + t1)
}

fn extrapolate_uni_poly_deg_2<F: Field>(p_i: &[F; 3], eval_at: F) -> F {
fn extrapolate_uni_poly_deg_2<F: Field>(p0: F, p1: F, p2: F, eval_at: F) -> F {
let x0 = F::from_canonical_u64(0);
let x1 = F::from_canonical_u64(1);
let x2 = F::from_canonical_u64(2);
Expand All @@ -97,14 +97,14 @@ fn extrapolate_uni_poly_deg_2<F: Field>(p_i: &[F; 3], eval_at: F) -> F {
let inv_d1 = d1.inverse();
let inv_d2 = d2.inverse();

let t0 = w0 * p_i[0] * inv_d0;
let t1 = w1 * p_i[1] * inv_d1;
let t2 = w2 * p_i[2] * inv_d2;
let t0 = w0 * p0 * inv_d0;
let t1 = w1 * p1 * inv_d1;
let t2 = w2 * p2 * inv_d2;

l * (t0 + t1 + t2)
}

fn extrapolate_uni_poly_deg_3<F: Field>(p_i: &[F; 4], eval_at: F) -> F {
fn extrapolate_uni_poly_deg_3<F: Field>(p0: F, p1: F, p2: F, p3: F, eval_at: F) -> F {
let x0 = F::from_canonical_u64(0);
let x1 = F::from_canonical_u64(1);
let x2 = F::from_canonical_u64(2);
Expand All @@ -131,15 +131,15 @@ fn extrapolate_uni_poly_deg_3<F: Field>(p_i: &[F; 4], eval_at: F) -> F {
let inv_d2 = d2.inverse();
let inv_d3 = d3.inverse();

let t0 = w0 * p_i[0] * inv_d0;
let t1 = w1 * p_i[1] * inv_d1;
let t2 = w2 * p_i[2] * inv_d2;
let t3 = w3 * p_i[3] * inv_d3;
let t0 = w0 * p0 * inv_d0;
let t1 = w1 * p1 * inv_d1;
let t2 = w2 * p2 * inv_d2;
let t3 = w3 * p3 * inv_d3;

l * (t0 + t1 + t2 + t3)
}

fn extrapolate_uni_poly_deg_4<F: Field>(p_i: &[F; 5], eval_at: F) -> F {
fn extrapolate_uni_poly_deg_4<F: Field>(p0: F, p1: F, p2: F, p3: F, p4: F, eval_at: F) -> F {
let x0 = F::from_canonical_u64(0);
let x1 = F::from_canonical_u64(1);
let x2 = F::from_canonical_u64(2);
Expand Down Expand Up @@ -171,11 +171,11 @@ fn extrapolate_uni_poly_deg_4<F: Field>(p_i: &[F; 5], eval_at: F) -> F {
let inv_d3 = d3.inverse();
let inv_d4 = d4.inverse();

let t0 = w0 * p_i[0] * inv_d0;
let t1 = w1 * p_i[1] * inv_d1;
let t2 = w2 * p_i[2] * inv_d2;
let t3 = w3 * p_i[3] * inv_d3;
let t4 = w4 * p_i[4] * inv_d4;
let t0 = w0 * p0 * inv_d0;
let t1 = w1 * p1 * inv_d1;
let t2 = w2 * p2 * inv_d2;
let t3 = w3 * p3 * inv_d3;
let t4 = w4 * p4 * inv_d4;

l * (t0 + t1 + t2 + t3 + t4)
}
Expand All @@ -195,18 +195,23 @@ fn extrapolate_uni_poly_deg_4<F: Field>(p_i: &[F; 5], eval_at: F) -> F {
/// with unrolled loops for performance
///
/// # Arguments
/// * `p_i` - Values of the polynomial at consecutive integer points.
/// * `p0` - Polynomial evaluation at point 0.
/// * `p_i` - Values of the polynomial at consecutive integer points starting from 1.
/// * `eval_at` - The point at which to evaluate the interpolated polynomial.
///
/// # Returns
/// The value of the polynomial `eval_at`.
pub fn extrapolate_uni_poly<F: Field>(p: &[F], eval_at: F) -> F {
pub fn extrapolate_uni_poly<F: Field>(p0: F, p: &[F], eval_at: F) -> F {
assert!(
!p.is_empty(),
"at least one evaluation beyond p(0) is required"
);
match p.len() {
2 => extrapolate_uni_poly_deg_1(p.try_into().unwrap(), eval_at),
3 => extrapolate_uni_poly_deg_2(p.try_into().unwrap(), eval_at),
4 => extrapolate_uni_poly_deg_3(p.try_into().unwrap(), eval_at),
5 => extrapolate_uni_poly_deg_4(p.try_into().unwrap(), eval_at),
_ => unimplemented!("Extrapolation for degree {} not implemented", p.len() - 1),
1 => extrapolate_uni_poly_deg_1(p0, p[0], eval_at),
2 => extrapolate_uni_poly_deg_2(p0, p[0], p[1], eval_at),
3 => extrapolate_uni_poly_deg_3(p0, p[0], p[1], p[2], eval_at),
4 => extrapolate_uni_poly_deg_4(p0, p[0], p[1], p[2], p[3], eval_at),
_ => unimplemented!("Extrapolation for degree {} not implemented", p.len()),
}
}

Expand Down
49 changes: 29 additions & 20 deletions crates/sumcheck/src/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,39 +122,48 @@ impl<E: ExtensionField> IOPVerifierState<E> {

// the deferred check during the interactive phase:
// 2. set `expected` to P(r)`
let mut expected_vec = self
let (expected_vec, evals_0) = self
.polynomials_received
.iter()
.zip(self.challenges.iter())
.map(|(evaluations, challenge)| {
if evaluations.len() != self.max_degree + 1 {
panic!(
"incorrect number of evaluations: {} vs {}",
evaluations.len(),
self.max_degree + 1
);
}
extrapolate_uni_poly::<E>(evaluations, challenge.elements)
})
.collect::<Vec<_>>();

// l-append asserted_sum to the first position of the expected vector
expected_vec.insert(0, *asserted_sum);

for (i, (evaluations, &expected)) in self
.fold(
(vec![*asserted_sum], vec![]),
|(mut claims, mut evals_0), (evaluations, challenge)| {
let last_claim = claims.last().copied().unwrap();
if evaluations.len() != self.max_degree {
panic!(
"incorrect number of evaluations: {} vs {}",
evaluations.len(),
self.max_degree
);
}
// https://eprint.iacr.org/2024/108.pdf sec 3.1 derive eval_0 = claim - eval_1
let eval_0 = last_claim - evaluations.first().copied().unwrap();
evals_0.push(eval_0);
claims.push(extrapolate_uni_poly::<E>(
eval_0,
evaluations,
challenge.elements,
));
(claims, evals_0)
},
);

for (i, ((evaluations, &expected), eval_0)) in self
.polynomials_received
.iter()
.zip(expected_vec.iter())
.zip(&expected_vec)
.zip(&evals_0)
.enumerate()
.take(self.num_vars)
{
// the deferred check during the interactive phase:
// 1. check if the received 'P(0) + P(1) = expected`.
if evaluations[0] + evaluations[1] != expected {
if *eval_0 + evaluations[0] != expected {
panic!(
"{}th round's prover message is not consistent with the claim. {:?} {:?}",
i,
evaluations[0] + evaluations[1],
*eval_0 + evaluations[0],
expected
);
}
Expand Down