Skip to content

Commit 5ff316b

Browse files
committed
commit before fri fold in each basefold_fri_round
1 parent 5e13a5c commit 5ff316b

File tree

4 files changed

+39
-97
lines changed

4 files changed

+39
-97
lines changed

crates/mpcs/src/basefold.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,13 +406,13 @@ where
406406
assert_eq!(commits.len(), num_rounds);
407407
let sumcheck_messages = proof.sumcheck_proof.as_ref().unwrap();
408408
for i in 0..num_rounds {
409-
write_digest_to_transcript(&commits[i], transcript);
410409
transcript.append_field_element_exts(sumcheck_messages[i].evaluations.as_slice());
411410
fold_challenges.push(
412411
transcript
413412
.sample_and_append_challenge(b"commit round")
414413
.elements,
415414
);
415+
write_digest_to_transcript(&commits[i], transcript);
416416
}
417417
#[cfg(debug_assertions)]
418418
{

crates/mpcs/src/basefold/commit_phase.rs

Lines changed: 36 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -150,30 +150,6 @@ where
150150
.collect_vec(),
151151
);
152152
exit_span!(batch_codeword_span);
153-
154-
// commit to codewords with max height using mmcs_ext
155-
let max_height = batched_codewords
156-
.front()
157-
.expect("empty batched_codewords")
158-
.height();
159-
let mut highest_codeword = batched_codewords.pop_front().unwrap();
160-
while let Some(new_codeword) = batched_codewords.front() {
161-
if new_codeword.height() == max_height {
162-
let new_codeword = batched_codewords.pop_front().unwrap();
163-
// sum up the rows in each codeword
164-
highest_codeword
165-
.par_rows_mut()
166-
.zip(new_codeword.par_rows())
167-
.for_each(|(row_acc, row)| {
168-
row_acc.iter_mut().zip(row).for_each(|(acc, v)| *acc += v);
169-
});
170-
}
171-
}
172-
let (commit, mmcs) = mmcs_ext.commit_matrix(highest_codeword);
173-
write_digest_to_transcript(&commit, transcript);
174-
trees.push(mmcs);
175-
commits.push(commit);
176-
177153
exit_span!(prepare_span);
178154

179155
// eq is the evaluation representation of the eq(X,r) polynomial over the hypercube
@@ -219,6 +195,7 @@ where
219195
let mut sumcheck_messages = Vec::with_capacity(num_rounds);
220196

221197
let mut challenge = None;
198+
let mut running_codeword_opt: Option<RowMajorMatrix<E>> = None;
222199
let sumcheck_phase1 = entered_span!("sumcheck_phase1");
223200
let phase1_rounds = num_rounds.min(max_num_vars - log2_num_threads);
224201

@@ -228,6 +205,7 @@ where
228205
&mut prover_states,
229206
challenge,
230207
&mut sumcheck_messages,
208+
&mut running_codeword_opt,
231209
&mut batched_codewords,
232210
transcript,
233211
&mut trees,
@@ -264,6 +242,7 @@ where
264242
&mut prover_states,
265243
challenge,
266244
&mut sumcheck_messages,
245+
&mut running_codeword_opt,
267246
&mut batched_codewords,
268247
transcript,
269248
&mut trees,
@@ -326,6 +305,7 @@ where
326305
#[allow(clippy::too_many_arguments)]
327306
pub(crate) fn basefold_fri_round<E: ExtensionField, Spec: BasefoldSpec<E>>(
328307
pp: &<Spec::EncodingScheme as EncodingScheme<E>>::ProverParameters,
308+
running_codeword_opt: &mut Option<RowMajorMatrix<E>>,
329309
codewords: &mut VecDeque<RowMajorMatrix<E>>,
330310
trees: &mut Vec<MerkleTreeExt<E>>,
331311
commits: &mut Vec<<Poseidon2ExtMerkleMmcs<E> as Mmcs<E>>::Commitment>,
@@ -342,11 +322,8 @@ pub(crate) fn basefold_fri_round<E: ExtensionField, Spec: BasefoldSpec<E>>(
342322
<Poseidon2ExtMerkleMmcs<E> as Mmcs<E>>::Commitment:
343323
IntoIterator<Item = E::BaseField> + PartialEq,
344324
{
345-
let running_codeword_opt = trees
346-
.last()
347-
.and_then(|mktree| mmcs_ext.get_matrices(mktree).pop())
348-
.map(|m| m.as_view());
349325
let target_len = running_codeword_opt
326+
.as_ref()
350327
.map(|running_codeword| running_codeword.values.len())
351328
.unwrap_or_else(|| {
352329
codewords
@@ -355,7 +332,6 @@ pub(crate) fn basefold_fri_round<E: ExtensionField, Spec: BasefoldSpec<E>>(
355332
.max()
356333
.expect("empty codeword")
357334
});
358-
let next_level_target_len = target_len >> 1;
359335
let level = log2_strict_usize(target_len) - 1;
360336
let folding_coeffs =
361337
<Spec::EncodingScheme as EncodingScheme<E>>::prover_folding_coeffs_level(pp, level);
@@ -365,76 +341,39 @@ pub(crate) fn basefold_fri_round<E: ExtensionField, Spec: BasefoldSpec<E>>(
365341
// take codewords match with target length then fold
366342
let codewords_matched =
367343
pop_front_while(codewords, |codeword| codeword.values.len() == target_len);
368-
// take codewords match next target length in preparation of being committed together
369-
let codewords_next_level_matched = pop_front_while(codewords, |codeword| {
370-
codeword.values.len() == next_level_target_len
371-
});
372344

373-
// optimize for single codeword match
374-
let folded_codeword = if (usize::from(running_codeword_opt.is_some()) + codewords_matched.len())
375-
== 1
376-
&& codewords_next_level_matched.is_empty()
377-
{
378-
RowMajorMatrix::new(
345+
// aggregate codeword with same length
346+
let codeword_to_fold = (0..target_len)
347+
.into_par_iter()
348+
.map(|index| {
379349
running_codeword_opt
380-
.or_else(|| codewords_matched.first().map(|m| m.as_view()))
381-
.unwrap()
382-
.values
383-
.par_chunks_exact(2)
384-
.zip(folding_coeffs)
385-
.map(|(ys, coeff)| codeword_fold_with_challenge(ys, challenge, *coeff, inv_2))
386-
.collect::<Vec<_>>(),
387-
2,
388-
)
389-
} else {
390-
// aggregate codeword with same length
391-
let codeword_to_fold = (0..target_len)
392-
.into_par_iter()
393-
.map(|index| {
394-
running_codeword_opt
395-
.into_iter()
396-
.chain(codewords_matched.iter().map(|m| m.as_view()))
397-
.map(|codeword| codeword.values[index])
398-
.sum::<E>()
399-
})
400-
.collect::<Vec<E>>();
401-
402-
RowMajorMatrix::new(
403-
(0..target_len)
404-
.into_par_iter()
405-
.step_by(2)
406-
.map(|index| {
407-
let coeff = &folding_coeffs[index >> 1];
408-
409-
// 1st part folded with challenge then sum
410-
let cur_same_pos_sum = codeword_fold_with_challenge(
411-
&codeword_to_fold[index..index + 2],
412-
challenge,
413-
*coeff,
414-
inv_2,
415-
);
416-
// 2nd part: retrieve respective index then sum
417-
let next_same_pos_sum = codewords_next_level_matched
418-
.iter()
419-
.map(|codeword| codeword.values[index >> 1])
420-
.sum::<E>();
421-
cur_same_pos_sum + next_same_pos_sum
422-
})
423-
.collect::<Vec<_>>(),
424-
2,
425-
)
426-
};
350+
.iter()
351+
.chain(codewords_matched.iter())
352+
.map(|codeword| codeword.values[index])
353+
.sum::<E>()
354+
})
355+
.collect::<Vec<E>>();
356+
357+
// commit
358+
let codeword_as_matrix = RowMajorMatrix::new(codeword_to_fold.clone(), 2);
359+
let (commitment, merkle_tree) = mmcs_ext.commit_matrix(codeword_as_matrix);
360+
write_digest_to_transcript(&commitment, transcript);
361+
commits.push(commitment);
362+
trees.push(merkle_tree);
363+
364+
// fri fold and add codewords with next target length
365+
*running_codeword_opt = Some(RowMajorMatrix::new(
366+
codeword_to_fold
367+
.par_chunks_exact(2)
368+
.zip(folding_coeffs)
369+
.map(|(ys, coeff)| codeword_fold_with_challenge(ys, challenge, *coeff, inv_2))
370+
.collect::<Vec<_>>(),
371+
2,
372+
));
427373

428374
if cfg!(feature = "sanity-check") && is_last_round {
429-
let (commitment, merkle_tree) = mmcs_ext.commit_matrix(folded_codeword.clone());
430-
commits.push(commitment);
431-
trees.push(merkle_tree);
432-
}
433-
434-
// skip last round commitment as verifer need to derive encode(final_message) = final_codeword itself
435-
if !is_last_round {
436-
let (commitment, merkle_tree) = mmcs_ext.commit_matrix(folded_codeword);
437-
write_digest_to_transcript(&commitment, transcript);
375+
let (commitment, merkle_tree) =
376+
mmcs_ext.commit_matrix(running_codeword_opt.as_ref().map(|c| c.clone()).unwrap());
438377
commits.push(commitment);
439378
trees.push(merkle_tree);
440379
}
@@ -447,6 +386,7 @@ fn basefold_one_round<E: ExtensionField, Spec: BasefoldSpec<E>>(
447386
prover_states: &mut Vec<IOPProverState<'_, E>>,
448387
challenge: Option<Challenge<E>>,
449388
sumcheck_messages: &mut Vec<IOPProverMessage<E>>,
389+
running_codeword_opt: &mut Option<RowMajorMatrix<E>>,
450390
codewords: &mut VecDeque<RowMajorMatrix<E>>,
451391
transcript: &mut impl Transcript<E>,
452392
trees: &mut Vec<MerkleTreeExt<E>>,
@@ -490,6 +430,7 @@ where
490430
let fri_round_span = entered_span!("basefold::fri_one_round");
491431
basefold_fri_round::<E, Spec>(
492432
pp,
433+
running_codeword_opt,
493434
codewords,
494435
trees,
495436
commits,

crates/mpcs/src/basefold/encoding/rs.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,7 @@ mod tests {
334334
izip!(&codeword.values, &codeword_ext.values).all(|(base, ext)| E::from(*base) == *ext)
335335
);
336336

337+
let mut running_codeword_opt = None;
337338
let mut codeword_ext = VecDeque::from(vec![codeword_ext]);
338339
let mut transcript = BasicTranscript::new(b"test");
339340

@@ -342,6 +343,7 @@ mod tests {
342343
let r = E::from_canonical_u64(97);
343344
basefold_fri_round::<E, BasefoldRSParams>(
344345
&pp,
346+
&mut running_codeword_opt,
345347
&mut codeword_ext,
346348
&mut prove_data,
347349
&mut vec![],

crates/mpcs/src/basefold/query_phase.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ where
4949
log2_max_codeword_size,
5050
);
5151

52-
println!("log2_max_codeword_size: {}", log2_max_codeword_size);
5352
queries
5453
.iter()
5554
.map(|idx| {

0 commit comments

Comments
 (0)