@@ -150,30 +150,6 @@ where
150150 . collect_vec ( ) ,
151151 ) ;
152152 exit_span ! ( batch_codeword_span) ;
153-
154- // commit to codewords with max height using mmcs_ext
155- let max_height = batched_codewords
156- . front ( )
157- . expect ( "empty batched_codewords" )
158- . height ( ) ;
159- let mut highest_codeword = batched_codewords. pop_front ( ) . unwrap ( ) ;
160- while let Some ( new_codeword) = batched_codewords. front ( ) {
161- if new_codeword. height ( ) == max_height {
162- let new_codeword = batched_codewords. pop_front ( ) . unwrap ( ) ;
163- // sum up the rows in each codeword
164- highest_codeword
165- . par_rows_mut ( )
166- . zip ( new_codeword. par_rows ( ) )
167- . for_each ( |( row_acc, row) | {
168- row_acc. iter_mut ( ) . zip ( row) . for_each ( |( acc, v) | * acc += v) ;
169- } ) ;
170- }
171- }
172- let ( commit, mmcs) = mmcs_ext. commit_matrix ( highest_codeword) ;
173- write_digest_to_transcript ( & commit, transcript) ;
174- trees. push ( mmcs) ;
175- commits. push ( commit) ;
176-
177153 exit_span ! ( prepare_span) ;
178154
179155 // eq is the evaluation representation of the eq(X,r) polynomial over the hypercube
@@ -219,6 +195,7 @@ where
219195 let mut sumcheck_messages = Vec :: with_capacity ( num_rounds) ;
220196
221197 let mut challenge = None ;
198+ let mut running_codeword_opt: Option < RowMajorMatrix < E > > = None ;
222199 let sumcheck_phase1 = entered_span ! ( "sumcheck_phase1" ) ;
223200 let phase1_rounds = num_rounds. min ( max_num_vars - log2_num_threads) ;
224201
@@ -228,6 +205,7 @@ where
228205 & mut prover_states,
229206 challenge,
230207 & mut sumcheck_messages,
208+ & mut running_codeword_opt,
231209 & mut batched_codewords,
232210 transcript,
233211 & mut trees,
@@ -264,6 +242,7 @@ where
264242 & mut prover_states,
265243 challenge,
266244 & mut sumcheck_messages,
245+ & mut running_codeword_opt,
267246 & mut batched_codewords,
268247 transcript,
269248 & mut trees,
@@ -326,6 +305,7 @@ where
326305#[ allow( clippy:: too_many_arguments) ]
327306pub ( crate ) fn basefold_fri_round < E : ExtensionField , Spec : BasefoldSpec < E > > (
328307 pp : & <Spec :: EncodingScheme as EncodingScheme < E > >:: ProverParameters ,
308+ running_codeword_opt : & mut Option < RowMajorMatrix < E > > ,
329309 codewords : & mut VecDeque < RowMajorMatrix < E > > ,
330310 trees : & mut Vec < MerkleTreeExt < E > > ,
331311 commits : & mut Vec < <Poseidon2ExtMerkleMmcs < E > as Mmcs < E > >:: Commitment > ,
@@ -342,11 +322,8 @@ pub(crate) fn basefold_fri_round<E: ExtensionField, Spec: BasefoldSpec<E>>(
342322 <Poseidon2ExtMerkleMmcs < E > as Mmcs < E > >:: Commitment :
343323 IntoIterator < Item = E :: BaseField > + PartialEq ,
344324{
345- let running_codeword_opt = trees
346- . last ( )
347- . and_then ( |mktree| mmcs_ext. get_matrices ( mktree) . pop ( ) )
348- . map ( |m| m. as_view ( ) ) ;
349325 let target_len = running_codeword_opt
326+ . as_ref ( )
350327 . map ( |running_codeword| running_codeword. values . len ( ) )
351328 . unwrap_or_else ( || {
352329 codewords
@@ -355,7 +332,6 @@ pub(crate) fn basefold_fri_round<E: ExtensionField, Spec: BasefoldSpec<E>>(
355332 . max ( )
356333 . expect ( "empty codeword" )
357334 } ) ;
358- let next_level_target_len = target_len >> 1 ;
359335 let level = log2_strict_usize ( target_len) - 1 ;
360336 let folding_coeffs =
361337 <Spec :: EncodingScheme as EncodingScheme < E > >:: prover_folding_coeffs_level ( pp, level) ;
@@ -365,76 +341,39 @@ pub(crate) fn basefold_fri_round<E: ExtensionField, Spec: BasefoldSpec<E>>(
365341 // take codewords match with target length then fold
366342 let codewords_matched =
367343 pop_front_while ( codewords, |codeword| codeword. values . len ( ) == target_len) ;
368- // take codewords match next target length in preparation of being committed together
369- let codewords_next_level_matched = pop_front_while ( codewords, |codeword| {
370- codeword. values . len ( ) == next_level_target_len
371- } ) ;
372344
373- // optimize for single codeword match
374- let folded_codeword = if ( usize:: from ( running_codeword_opt. is_some ( ) ) + codewords_matched. len ( ) )
375- == 1
376- && codewords_next_level_matched. is_empty ( )
377- {
378- RowMajorMatrix :: new (
345+ // aggregate codeword with same length
346+ let codeword_to_fold = ( 0 ..target_len)
347+ . into_par_iter ( )
348+ . map ( |index| {
379349 running_codeword_opt
380- . or_else ( || codewords_matched. first ( ) . map ( |m| m. as_view ( ) ) )
381- . unwrap ( )
382- . values
383- . par_chunks_exact ( 2 )
384- . zip ( folding_coeffs)
385- . map ( |( ys, coeff) | codeword_fold_with_challenge ( ys, challenge, * coeff, inv_2) )
386- . collect :: < Vec < _ > > ( ) ,
387- 2 ,
388- )
389- } else {
390- // aggregate codeword with same length
391- let codeword_to_fold = ( 0 ..target_len)
392- . into_par_iter ( )
393- . map ( |index| {
394- running_codeword_opt
395- . into_iter ( )
396- . chain ( codewords_matched. iter ( ) . map ( |m| m. as_view ( ) ) )
397- . map ( |codeword| codeword. values [ index] )
398- . sum :: < E > ( )
399- } )
400- . collect :: < Vec < E > > ( ) ;
401-
402- RowMajorMatrix :: new (
403- ( 0 ..target_len)
404- . into_par_iter ( )
405- . step_by ( 2 )
406- . map ( |index| {
407- let coeff = & folding_coeffs[ index >> 1 ] ;
408-
409- // 1st part folded with challenge then sum
410- let cur_same_pos_sum = codeword_fold_with_challenge (
411- & codeword_to_fold[ index..index + 2 ] ,
412- challenge,
413- * coeff,
414- inv_2,
415- ) ;
416- // 2nd part: retrieve respective index then sum
417- let next_same_pos_sum = codewords_next_level_matched
418- . iter ( )
419- . map ( |codeword| codeword. values [ index >> 1 ] )
420- . sum :: < E > ( ) ;
421- cur_same_pos_sum + next_same_pos_sum
422- } )
423- . collect :: < Vec < _ > > ( ) ,
424- 2 ,
425- )
426- } ;
350+ . iter ( )
351+ . chain ( codewords_matched. iter ( ) )
352+ . map ( |codeword| codeword. values [ index] )
353+ . sum :: < E > ( )
354+ } )
355+ . collect :: < Vec < E > > ( ) ;
356+
357+ // commit
358+ let codeword_as_matrix = RowMajorMatrix :: new ( codeword_to_fold. clone ( ) , 2 ) ;
359+ let ( commitment, merkle_tree) = mmcs_ext. commit_matrix ( codeword_as_matrix) ;
360+ write_digest_to_transcript ( & commitment, transcript) ;
361+ commits. push ( commitment) ;
362+ trees. push ( merkle_tree) ;
363+
364+ // fri fold and add codewords with next target length
365+ * running_codeword_opt = Some ( RowMajorMatrix :: new (
366+ codeword_to_fold
367+ . par_chunks_exact ( 2 )
368+ . zip ( folding_coeffs)
369+ . map ( |( ys, coeff) | codeword_fold_with_challenge ( ys, challenge, * coeff, inv_2) )
370+ . collect :: < Vec < _ > > ( ) ,
371+ 2 ,
372+ ) ) ;
427373
428374 if cfg ! ( feature = "sanity-check" ) && is_last_round {
429- let ( commitment, merkle_tree) = mmcs_ext. commit_matrix ( folded_codeword. clone ( ) ) ;
430- commits. push ( commitment) ;
431- trees. push ( merkle_tree) ;
432- }
433-
434- // skip last round commitment as verifer need to derive encode(final_message) = final_codeword itself
435- if !is_last_round {
436- let ( commitment, merkle_tree) = mmcs_ext. commit_matrix ( folded_codeword) ;
437- write_digest_to_transcript ( & commitment, transcript) ;
375+ let ( commitment, merkle_tree) =
376+ mmcs_ext. commit_matrix ( running_codeword_opt. as_ref ( ) . map ( |c| c. clone ( ) ) . unwrap ( ) ) ;
438377 commits. push ( commitment) ;
439378 trees. push ( merkle_tree) ;
440379 }
@@ -447,6 +386,7 @@ fn basefold_one_round<E: ExtensionField, Spec: BasefoldSpec<E>>(
447386 prover_states : & mut Vec < IOPProverState < ' _ , E > > ,
448387 challenge : Option < Challenge < E > > ,
449388 sumcheck_messages : & mut Vec < IOPProverMessage < E > > ,
389+ running_codeword_opt : & mut Option < RowMajorMatrix < E > > ,
450390 codewords : & mut VecDeque < RowMajorMatrix < E > > ,
451391 transcript : & mut impl Transcript < E > ,
452392 trees : & mut Vec < MerkleTreeExt < E > > ,
@@ -490,6 +430,7 @@ where
490430 let fri_round_span = entered_span ! ( "basefold::fri_one_round" ) ;
491431 basefold_fri_round :: < E , Spec > (
492432 pp,
433+ running_codeword_opt,
493434 codewords,
494435 trees,
495436 commits,
0 commit comments