Skip to content

Commit 4f5a2b9

Browse files
authored
Feat: store codeword with max height before FRI fold (#24)
* add one more ext_mmcs to store codeword with max height * commit before fri fold in each basefold_fri_round * fix test * minimize diff * remove unnecessary clone * require all reduced_openings are consumed
1 parent 6448be5 commit 4f5a2b9

File tree

5 files changed

+89
-144
lines changed

5 files changed

+89
-144
lines changed

crates/mpcs/src/basefold.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,7 @@ where
403403
// prepare folding challenges via sumcheck round msg + FRI commitment
404404
let mut fold_challenges: Vec<E> = Vec::with_capacity(max_num_var);
405405
let commits = &proof.commits;
406+
assert_eq!(commits.len(), num_rounds);
406407
let sumcheck_messages = proof.sumcheck_proof.as_ref().unwrap();
407408
for i in 0..num_rounds {
408409
transcript.append_field_element_exts(sumcheck_messages[i].evaluations.as_slice());
@@ -411,9 +412,7 @@ where
411412
.sample_and_append_challenge(b"commit round")
412413
.elements,
413414
);
414-
if i < num_rounds - 1 {
415-
write_digest_to_transcript(&commits[i], transcript);
416-
}
415+
write_digest_to_transcript(&commits[i], transcript);
417416
}
418417
#[cfg(debug_assertions)]
419418
{

crates/mpcs/src/basefold/commit_phase.rs

Lines changed: 43 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,8 @@ where
9696
for (i, mat) in mats.into_iter().enumerate() {
9797
let (point, _) = &point_and_evals[i];
9898
let polys = &pcs_data.polys[i];
99-
// the actual ith row and (i+n/2)th row are packed in same row
100-
let num_rows = mat.height() * 2;
101-
let num_polys = mat.width() / 2;
99+
let num_rows = mat.height();
100+
let num_polys = mat.width();
102101
let coeffs = batch_coeffs_iter
103102
.by_ref()
104103
.take(num_polys)
@@ -118,6 +117,7 @@ where
118117
2,
119118
);
120119
let num_vars = polys[0].num_vars();
120+
assert_eq!(1 << (num_vars + Spec::get_rate_log()), num_rows);
121121
let mle_base_vec = polys
122122
.iter()
123123
.map(|mle| mle.get_base_field_vec())
@@ -149,7 +149,6 @@ where
149149
.collect_vec(),
150150
);
151151
exit_span!(batch_codeword_span);
152-
153152
exit_span!(prepare_span);
154153

155154
// eq is the evaluation representation of the eq(X,r) polynomial over the hypercube
@@ -193,9 +192,10 @@ where
193192
})
194193
.collect::<Vec<_>>();
195194
let mut sumcheck_messages = Vec::with_capacity(num_rounds);
196-
let mut commits = Vec::with_capacity(num_rounds - 1);
195+
let mut commits = Vec::with_capacity(max_num_vars);
197196

198197
let mut challenge = None;
198+
let mut running_codeword_opt: Option<RowMajorMatrix<E>> = None;
199199
let sumcheck_phase1 = entered_span!("sumcheck_phase1");
200200
let phase1_rounds = num_rounds.min(max_num_vars - log2_num_threads);
201201

@@ -205,6 +205,7 @@ where
205205
&mut prover_states,
206206
challenge,
207207
&mut sumcheck_messages,
208+
&mut running_codeword_opt,
208209
&mut batched_codewords,
209210
transcript,
210211
&mut trees,
@@ -241,6 +242,7 @@ where
241242
&mut prover_states,
242243
challenge,
243244
&mut sumcheck_messages,
245+
&mut running_codeword_opt,
244246
&mut batched_codewords,
245247
transcript,
246248
&mut trees,
@@ -303,6 +305,7 @@ where
303305
#[allow(clippy::too_many_arguments)]
304306
pub(crate) fn basefold_fri_round<E: ExtensionField, Spec: BasefoldSpec<E>>(
305307
pp: &<Spec::EncodingScheme as EncodingScheme<E>>::ProverParameters,
308+
running_codeword_opt: &mut Option<RowMajorMatrix<E>>,
306309
codewords: &mut VecDeque<RowMajorMatrix<E>>,
307310
trees: &mut Vec<MerkleTreeExt<E>>,
308311
commits: &mut Vec<<Poseidon2ExtMerkleMmcs<E> as Mmcs<E>>::Commitment>,
@@ -319,11 +322,8 @@ pub(crate) fn basefold_fri_round<E: ExtensionField, Spec: BasefoldSpec<E>>(
319322
<Poseidon2ExtMerkleMmcs<E> as Mmcs<E>>::Commitment:
320323
IntoIterator<Item = E::BaseField> + PartialEq,
321324
{
322-
let running_codeword_opt = trees
323-
.last()
324-
.and_then(|mktree| mmcs_ext.get_matrices(mktree).pop())
325-
.map(|m| m.as_view());
326325
let target_len = running_codeword_opt
326+
.as_ref()
327327
.map(|running_codeword| running_codeword.values.len())
328328
.unwrap_or_else(|| {
329329
codewords
@@ -332,7 +332,6 @@ pub(crate) fn basefold_fri_round<E: ExtensionField, Spec: BasefoldSpec<E>>(
332332
.max()
333333
.expect("empty codeword")
334334
});
335-
let next_level_target_len = target_len >> 1;
336335
let level = log2_strict_usize(target_len) - 1;
337336
let folding_coeffs =
338337
<Spec::EncodingScheme as EncodingScheme<E>>::prover_folding_coeffs_level(pp, level);
@@ -342,76 +341,42 @@ pub(crate) fn basefold_fri_round<E: ExtensionField, Spec: BasefoldSpec<E>>(
342341
// take codewords match with target length then fold
343342
let codewords_matched =
344343
pop_front_while(codewords, |codeword| codeword.values.len() == target_len);
345-
// take codewords match next target length in preparation of being committed together
346-
let codewords_next_level_matched = pop_front_while(codewords, |codeword| {
347-
codeword.values.len() == next_level_target_len
348-
});
349344

350-
// optimize for single codeword match
351-
let folded_codeword = if (usize::from(running_codeword_opt.is_some()) + codewords_matched.len())
352-
== 1
353-
&& codewords_next_level_matched.is_empty()
354-
{
355-
RowMajorMatrix::new(
345+
// aggregate codeword with same length
346+
let codeword_to_fold = (0..target_len)
347+
.into_par_iter()
348+
.map(|index| {
356349
running_codeword_opt
357-
.or_else(|| codewords_matched.first().map(|m| m.as_view()))
358-
.unwrap()
359-
.values
360-
.par_chunks_exact(2)
361-
.zip(folding_coeffs)
362-
.map(|(ys, coeff)| codeword_fold_with_challenge(ys, challenge, *coeff, inv_2))
363-
.collect::<Vec<_>>(),
364-
2,
365-
)
366-
} else {
367-
// aggregate codeword with same length
368-
let codeword_to_fold = (0..target_len)
369-
.into_par_iter()
370-
.map(|index| {
371-
running_codeword_opt
372-
.into_iter()
373-
.chain(codewords_matched.iter().map(|m| m.as_view()))
374-
.map(|codeword| codeword.values[index])
375-
.sum::<E>()
376-
})
377-
.collect::<Vec<E>>();
378-
379-
RowMajorMatrix::new(
380-
(0..target_len)
381-
.into_par_iter()
382-
.step_by(2)
383-
.map(|index| {
384-
let coeff = &folding_coeffs[index >> 1];
385-
386-
// 1st part folded with challenge then sum
387-
let cur_same_pos_sum = codeword_fold_with_challenge(
388-
&codeword_to_fold[index..index + 2],
389-
challenge,
390-
*coeff,
391-
inv_2,
392-
);
393-
// 2nd part: retrieve respective index then sum
394-
let next_same_pos_sum = codewords_next_level_matched
395-
.iter()
396-
.map(|codeword| codeword.values[index >> 1])
397-
.sum::<E>();
398-
cur_same_pos_sum + next_same_pos_sum
399-
})
400-
.collect::<Vec<_>>(),
401-
2,
402-
)
403-
};
350+
.iter()
351+
.chain(codewords_matched.iter())
352+
.map(|codeword| codeword.values[index])
353+
.sum::<E>()
354+
})
355+
.collect::<Vec<E>>();
356+
357+
// commit
358+
let codeword_as_matrix = RowMajorMatrix::new(codeword_to_fold.clone(), 2);
359+
let (commitment, merkle_tree) = mmcs_ext.commit_matrix(codeword_as_matrix);
360+
write_digest_to_transcript(&commitment, transcript);
361+
commits.push(commitment);
362+
363+
// codeword_to_fold is owned by merkle_tree in previous step and we can get a reference to it here
364+
let codeword_to_fold = mmcs_ext.get_matrices(&merkle_tree).pop().unwrap();
365+
*running_codeword_opt = Some(RowMajorMatrix::new(
366+
// fri fold
367+
codeword_to_fold
368+
.values
369+
.par_chunks_exact(2)
370+
.zip(folding_coeffs)
371+
.map(|(ys, coeff)| codeword_fold_with_challenge(ys, challenge, *coeff, inv_2))
372+
.collect::<Vec<_>>(),
373+
2,
374+
));
375+
trees.push(merkle_tree);
404376

405377
if cfg!(feature = "sanity-check") && is_last_round {
406-
let (commitment, merkle_tree) = mmcs_ext.commit_matrix(folded_codeword.clone());
407-
commits.push(commitment);
408-
trees.push(merkle_tree);
409-
}
410-
411-
// skip last round commitment as verifer need to derive encode(final_message) = final_codeword itself
412-
if !is_last_round {
413-
let (commitment, merkle_tree) = mmcs_ext.commit_matrix(folded_codeword);
414-
write_digest_to_transcript(&commitment, transcript);
378+
let (commitment, merkle_tree) =
379+
mmcs_ext.commit_matrix(running_codeword_opt.as_ref().map(|c| c.clone()).unwrap());
415380
commits.push(commitment);
416381
trees.push(merkle_tree);
417382
}
@@ -424,6 +389,7 @@ fn basefold_one_round<E: ExtensionField, Spec: BasefoldSpec<E>>(
424389
prover_states: &mut Vec<IOPProverState<'_, E>>,
425390
challenge: Option<Challenge<E>>,
426391
sumcheck_messages: &mut Vec<IOPProverMessage<E>>,
392+
running_codeword_opt: &mut Option<RowMajorMatrix<E>>,
427393
codewords: &mut VecDeque<RowMajorMatrix<E>>,
428394
transcript: &mut impl Transcript<E>,
429395
trees: &mut Vec<MerkleTreeExt<E>>,
@@ -467,6 +433,7 @@ where
467433
let fri_round_span = entered_span!("basefold::fri_one_round");
468434
basefold_fri_round::<E, Spec>(
469435
pp,
436+
running_codeword_opt,
470437
codewords,
471438
trees,
472439
commits,

crates/mpcs/src/basefold/encoding/rs.rs

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -212,9 +212,7 @@ where
212212
.bit_reverse_rows()
213213
.to_row_major_matrix()
214214
.values;
215-
// to make 2 consecutive position to be open together, we trickily "concat" 2 consecutive leafs
216-
// so both can be open under same row index
217-
let codeword = DenseMatrix::new(codeword, num_polys * 2);
215+
let codeword = DenseMatrix::new(codeword, num_polys);
218216
Ok(codeword)
219217
}
220218

@@ -294,10 +292,7 @@ mod tests {
294292

295293
use ff_ext::GoldilocksExt2;
296294
use itertools::izip;
297-
use p3::{
298-
commit::{ExtensionMmcs, Mmcs},
299-
goldilocks::Goldilocks,
300-
};
295+
use p3::{commit::ExtensionMmcs, goldilocks::Goldilocks};
301296

302297
use rand::rngs::OsRng;
303298
use transcript::BasicTranscript;
@@ -336,6 +331,7 @@ mod tests {
336331
izip!(&codeword.values, &codeword_ext.values).all(|(base, ext)| E::from(*base) == *ext)
337332
);
338333

334+
let mut running_codeword_opt = None;
339335
let mut codeword_ext = VecDeque::from(vec![codeword_ext]);
340336
let mut transcript = BasicTranscript::new(b"test");
341337

@@ -344,6 +340,7 @@ mod tests {
344340
let r = E::from_canonical_u64(97);
345341
basefold_fri_round::<E, BasefoldRSParams>(
346342
&pp,
343+
&mut running_codeword_opt,
347344
&mut codeword_ext,
348345
&mut prove_data,
349346
&mut vec![],
@@ -365,7 +362,7 @@ mod tests {
365362
),
366363
);
367364
assert_eq!(
368-
&mmcs_ext.get_matrices(&prove_data[0])[0].values,
365+
&running_codeword_opt.as_ref().unwrap().values,
369366
&codeword_from_folded_rmm.values
370367
);
371368
}

0 commit comments

Comments
 (0)