Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions crates/mpcs/src/basefold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -403,17 +403,16 @@ where
// prepare folding challenges via sumcheck round msg + FRI commitment
let mut fold_challenges: Vec<E> = Vec::with_capacity(max_num_var);
let commits = &proof.commits;
assert_eq!(commits.len(), num_rounds);
let sumcheck_messages = proof.sumcheck_proof.as_ref().unwrap();
for i in 0..num_rounds {
write_digest_to_transcript(&commits[i], transcript);
transcript.append_field_element_exts(sumcheck_messages[i].evaluations.as_slice());
fold_challenges.push(
transcript
.sample_and_append_challenge(b"commit round")
.elements,
);
if i < num_rounds - 1 {
write_digest_to_transcript(&commits[i], transcript);
}
}
#[cfg(debug_assertions)]
{
Expand Down
31 changes: 27 additions & 4 deletions crates/mpcs/src/basefold/commit_phase.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ where
let mmcs_ext = ExtensionMmcs::<E::BaseField, E, _>::new(poseidon2_merkle_tree::<E>());
let mmcs = poseidon2_merkle_tree::<E>();
let mut trees: Vec<MerkleTreeExt<E>> = Vec::with_capacity(max_num_vars);
let mut commits = Vec::with_capacity(max_num_vars);

let total_num_polys = rounds
.iter()
Expand Down Expand Up @@ -96,9 +97,8 @@ where
for (i, mat) in mats.into_iter().enumerate() {
let (point, _) = &point_and_evals[i];
let polys = &pcs_data.polys[i];
// the actual ith row and (i+n/2)th row are packed in same row
let num_rows = mat.height() * 2;
let num_polys = mat.width() / 2;
let num_rows = mat.height();
let num_polys = mat.width();
let coeffs = batch_coeffs_iter
.by_ref()
.take(num_polys)
Expand All @@ -118,6 +118,7 @@ where
2,
);
let num_vars = polys[0].num_vars();
assert_eq!(1 << (num_vars + Spec::get_rate_log()), num_rows);
let mle_base_vec = polys
.iter()
.map(|mle| mle.get_base_field_vec())
Expand Down Expand Up @@ -150,6 +151,29 @@ where
);
exit_span!(batch_codeword_span);

// commit to codewords with max height using mmcs_ext
let max_height = batched_codewords
.front()
.expect("empty batched_codewords")
.height();
let mut highest_codeword = batched_codewords.pop_front().unwrap();
while let Some(new_codeword) = batched_codewords.front() {
if new_codeword.height() == max_height {
let new_codeword = batched_codewords.pop_front().unwrap();
// sum up the rows in each codeword
highest_codeword
.par_rows_mut()
.zip(new_codeword.par_rows())
.for_each(|(row_acc, row)| {
row_acc.iter_mut().zip(row).for_each(|(acc, v)| *acc += v);
});
}
}
let (commit, mmcs) = mmcs_ext.commit_matrix(highest_codeword);
write_digest_to_transcript(&commit, transcript);
trees.push(mmcs);
commits.push(commit);

exit_span!(prepare_span);

// eq is the evaluation representation of the eq(X,r) polynomial over the hypercube
Expand Down Expand Up @@ -193,7 +217,6 @@ where
})
.collect::<Vec<_>>();
let mut sumcheck_messages = Vec::with_capacity(num_rounds);
let mut commits = Vec::with_capacity(num_rounds - 1);

let mut challenge = None;
let sumcheck_phase1 = entered_span!("sumcheck_phase1");
Expand Down
4 changes: 1 addition & 3 deletions crates/mpcs/src/basefold/encoding/rs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,7 @@ where
.bit_reverse_rows()
.to_row_major_matrix()
.values;
// to make 2 consecutive position to be open together, we trickily "concat" 2 consecutive leafs
// so both can be open under same row index
let codeword = DenseMatrix::new(codeword, num_polys * 2);
let codeword = DenseMatrix::new(codeword, num_polys);
Ok(codeword)
}

Expand Down
88 changes: 33 additions & 55 deletions crates/mpcs/src/basefold/query_phase.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ where
log2_max_codeword_size,
);

println!("log2_max_codeword_size: {}", log2_max_codeword_size);
queries
.iter()
.map(|idx| {
Expand All @@ -60,11 +61,8 @@ where
// the oracle values are committed in a row-bit-reversed format.
// rounding `idx` to an even value is equivalent to retrieving the "left-hand" side `j` index
// in the original (non-row-bit-reversed) format.
//
// however, since `p_d[j]` and `p_d[j + n_{d-1}]` are already concatenated in the same merkle leaf,
// we can simply mask out the least significant bit (lsb) by performing a right shift by 1.
let idx_shift = log2_max_codeword_size - pcs_data.log2_max_codeword_size;
let idx = idx >> (idx_shift + 1);
let idx = idx >> idx_shift;
let (opened_values, opening_proof) = mmcs.open_batch(idx, &pcs_data.codeword);
BatchOpening {
opened_values,
Expand All @@ -73,20 +71,19 @@ where
})
.collect_vec();

// this is equivalent with "idx = idx % n_{d-1}" operation in non row bit reverse format
let idx = idx >> 1;
let (_, commit_phase_openings) =
trees
.iter()
.fold((idx, vec![]), |(idx, mut commit_phase_openings), tree| {
.fold((*idx, vec![]), |(idx, mut commit_phase_openings), tree| {
let leaf_idx = idx >> 1;
// differentiate interpolate to left or right position at next layer
let is_interpolate_to_right_index = (idx & 1) == 1;
// mask the least significant bit (LSB) for the same reason as above:
// 1. we only need the even part of the index.
// 2. since even and odd parts are concatenated in the same leaf,
// the overall merkle tree height is effectively halved,
// so we divide by 2.
let (mut values, opening_proof) = mmcs_ext.open_batch(idx >> 1, tree);
let (mut values, opening_proof) = mmcs_ext.open_batch(leaf_idx, tree);
let leafs = values.pop().unwrap();
debug_assert_eq!(leafs.len(), 2);
let sibling_value = leafs[(!is_interpolate_to_right_index) as usize];
Expand Down Expand Up @@ -149,23 +146,20 @@ pub fn batch_verifier_query_phase<E: ExtensionField, S: EncodingScheme<E>>(
},
)| {
// verify base oracle query proof
// refer to prover documentation for the reason of right shift by 1
let mut idx = idx >> 1;
let mut idx = *idx;

let mut reduced_openings_by_height: Vec<Option<(E, E)>> =
vec![None; log2_max_codeword_size];
let mut reduced_openings_by_height: Vec<Option<E>> =
vec![None; log2_max_codeword_size + 1];
let mut batch_coeffs_iter = batch_coeffs.iter();

for ((commit, batch_opening), input_proof) in
rounds.iter().zip_eq(input_proofs.iter())
{
let dimensions = batch_opening
.iter()
.map(|(num_var, (_, evals))| {
Dimensions {
width: evals.len() * 2, // we pack two rows into one in the mmcs
height: 1 << (num_var + log2_blowup - 1),
}
.map(|(num_var, (_, evals))| Dimensions {
width: evals.len(),
height: 1 << (num_var + log2_blowup),
})
.collect_vec();
let bits_reduced = log2_max_codeword_size - commit.log2_max_codeword_size;
Expand All @@ -184,53 +178,37 @@ pub fn batch_verifier_query_phase<E: ExtensionField, S: EncodingScheme<E>>(
for (mat, dimension) in
input_proof.opened_values.iter().zip_eq(dimensions.iter())
{
let width = mat.len() / 2;
let width = mat.len();
assert_eq!(dimension.width, mat.len());
assert_eq!(width * 2, mat.len());
let batch_coeffs = batch_coeffs_iter
.by_ref()
.take(width)
.copied()
.collect_vec();
let (lo, hi): (&[E::BaseField], &[E::BaseField]) = mat.split_at(width);
let low = dot_product::<E, _, _>(
let eval = dot_product::<E, _, _>(
batch_coeffs.iter().copied(),
lo.iter().copied(),
);
let high = dot_product::<E, _, _>(
batch_coeffs.iter().copied(),
hi.iter().copied(),
mat.iter().copied(),
);
let log2_height = log2_strict_usize(dimension.height);

if let Some((low_acc, high_acc)) =
reduced_openings_by_height[log2_height].as_mut()
{
if let Some(eval_acc) = reduced_openings_by_height[log2_height].as_mut() {
// accumulate low and high values for the same log2_height
*low_acc += low;
*high_acc += high;
*eval_acc += eval;
} else {
reduced_openings_by_height[log2_height] = Some((low, high));
reduced_openings_by_height[log2_height] = Some(eval);
}
}
}

// fold and query
let mut cur_num_var = max_num_var;
let mut log2_height = cur_num_var + log2_blowup - 1;
// -1 because for there are only #max_num_var-1 openings proof
let rounds = cur_num_var - S::get_basecode_msg_size_log() - 1;
let mut log2_height = max_num_var + log2_blowup;
let rounds = max_num_var - S::get_basecode_msg_size_log();

assert_eq!(rounds, fold_challenges.len() - 1);
assert_eq!(rounds, fold_challenges.len());
assert_eq!(rounds, proof.commits.len(),);
assert_eq!(rounds, opening_ext.len(),);

// first folding challenge
let r = fold_challenges.first().unwrap();
let coeff = S::verifier_folding_coeffs(vp, log2_height, idx);
let (lo, hi) = reduced_openings_by_height[log2_height].unwrap();
let mut folded = codeword_fold_with_challenge(&[lo, hi], *r, coeff, inv_2);

let mut folded = E::ZERO;
for (
(pi_comm, r),
CommitPhaseProofStep {
Expand All @@ -240,35 +218,35 @@ pub fn batch_verifier_query_phase<E: ExtensionField, S: EncodingScheme<E>>(
) in proof
.commits
.iter()
.zip_eq(fold_challenges.iter().skip(1))
.zip_eq(fold_challenges.iter())
.zip_eq(opening_ext)
{
cur_num_var -= 1;
log2_height -= 1;

let idx_sibling = idx & 0x01;
let folded_idx = idx & 0x01;
let mut leafs = vec![*sibling_value; 2];
leafs[idx_sibling] = folded;
if let Some((lo, hi)) = reduced_openings_by_height[log2_height].as_mut() {
leafs[idx_sibling] += if idx_sibling == 1 { *hi } else { *lo };
leafs[folded_idx] = folded;
if let Some(eval) = reduced_openings_by_height[log2_height] {
leafs[folded_idx] += eval;
}

idx >>= 1;
let leaf_idx = idx >> 1;
mmcs_ext
.verify_batch(
pi_comm,
&[Dimensions {
width: 2,
// width is 2, thus height divide by 2 via right shift
height: 1 << log2_height,
height: 1 << (log2_height - 1),
}],
idx,
leaf_idx,
slice::from_ref(&leafs),
proof,
)
.expect("verify failed");
let coeff = S::verifier_folding_coeffs(vp, log2_height, idx);

let coeff = S::verifier_folding_coeffs(vp, log2_height, leaf_idx);
folded = codeword_fold_with_challenge(&[leafs[0], leafs[1]], *r, coeff, inv_2);
log2_height -= 1;
idx >>= 1;
}
assert!(
final_codeword.values[idx] == folded,
Expand Down
3 changes: 1 addition & 2 deletions crates/mpcs/src/basefold/structure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,10 @@ where
polys: Vec<Vec<ArcMultilinearExtension<'static, E>>>,
) -> Self {
let mmcs = poseidon2_merkle_tree::<E>();
// size = height * 2 because we split codeword leafs into left/right, concat and commit under same row index
let log2_max_codeword_size = log2_strict_usize(
mmcs.get_matrices(&codeword)
.iter()
.map(|m| m.height() * 2)
.map(|m| m.height())
.max()
.unwrap(),
);
Expand Down