@@ -5,7 +5,7 @@ use crate::*;
55use air:: prove_air;
66use itertools:: Itertools ;
77use lean_vm:: * ;
8- use lookup:: { compute_pushforward, prove_gkr_quotient , prove_logup_star} ;
8+ use lookup:: { compute_pushforward, prove_logup_star} ;
99use multilinear_toolkit:: prelude:: * ;
1010
1111use p3_util:: log2_ceil_usize;
@@ -132,39 +132,41 @@ pub fn prove_execution(
132132 let bus_challenge = prover_state. sample ( ) ;
133133 let fingerprint_challenge = prover_state. sample ( ) ;
134134
135- let mut bus_quotients: BTreeMap < Table , EF > = Default :: default ( ) ;
136- let mut air_points: BTreeMap < Table , MultilinearPoint < EF > > = Default :: default ( ) ;
137- let mut evals_f: BTreeMap < Table , Vec < EF > > = Default :: default ( ) ;
138- let mut evals_ef: BTreeMap < Table , Vec < EF > > = Default :: default ( ) ;
139-
135+ let mut bus_numerators = vec ! [ ] ;
136+ let mut bus_denominators = vec ! [ ] ;
140137 for ( table, trace) in & traces {
141- let ( this_bus_quotient, this_air_point, this_evals_f, this_evals_ef) =
142- prove_bus_and_air ( & mut prover_state, table, trace, bus_challenge, fingerprint_challenge) ;
143- bus_quotients. insert ( * table, this_bus_quotient) ;
144- air_points. insert ( * table, this_air_point) ;
145- evals_f. insert ( * table, this_evals_f) ;
146- evals_ef. insert ( * table, this_evals_ef) ;
138+ for bus in table. buses ( ) {
139+ let numerator = trace. base [ bus. selector ]
140+ . par_iter ( )
141+ . map ( |& selector| match bus. direction {
142+ BusDirection :: Pull => -selector,
143+ BusDirection :: Push => selector,
144+ } )
145+ . collect :: < Vec < _ > > ( ) ;
146+ let denominator = ( 0 ..trace. n_rows_padded ( ) )
147+ . into_par_iter ( )
148+ . map ( |i| {
149+ bus_challenge
150+ + finger_print (
151+ match & bus. table {
152+ BusTable :: Constant ( table) => table. embed ( ) ,
153+ BusTable :: Variable ( col) => trace. base [ * col] [ i] ,
154+ } ,
155+ bus. data
156+ . iter ( )
157+ . map ( |col| trace. base [ * col] [ i] )
158+ . collect :: < Vec < _ > > ( )
159+ . as_slice ( ) ,
160+ fingerprint_challenge,
161+ )
162+ } )
163+ . collect :: < Vec < _ > > ( ) ;
164+
165+ bus_numerators. push ( numerator) ;
166+ bus_denominators. push ( denominator) ;
167+ }
147168 }
148169
149- assert_eq ! ( bus_quotients. values( ) . copied( ) . sum:: <EF >( ) , EF :: ZERO ) ;
150-
151- let bytecode_compression_challenges =
152- MultilinearPoint ( prover_state. sample_vec ( log2_ceil_usize ( N_INSTRUCTION_COLUMNS ) ) ) ;
153-
154- let folded_bytecode = fold_bytecode ( bytecode, & bytecode_compression_challenges) ;
155-
156- let bytecode_lookup_claim = Evaluation :: new (
157- air_points[ & Table :: execution ( ) ] . clone ( ) ,
158- padd_with_zero_to_next_power_of_two ( & evals_f[ & Table :: execution ( ) ] [ ..N_INSTRUCTION_COLUMNS ] )
159- . evaluate ( & bytecode_compression_challenges) ,
160- ) ;
161- let bytecode_poly_eq_point = eval_eq ( & air_points[ & Table :: execution ( ) ] ) ;
162- let bytecode_pushforward = MleOwned :: Extension ( compute_pushforward (
163- & traces[ & Table :: execution ( ) ] . base [ COL_INDEX_PC ] ,
164- folded_bytecode. len ( ) ,
165- & bytecode_poly_eq_point,
166- ) ) ;
167-
168170 let mut lookup_into_memory = CustomLookupProver :: run :: < EF , DIMENSION , VECTOR_LEN > (
169171 & mut prover_state,
170172 & memory,
@@ -193,7 +195,53 @@ pub fn prove_execution(
193195 . iter ( )
194196 . flat_map ( |( table, trace) | table. vector_lookup_values_columns ( trace) )
195197 . collect ( ) ,
198+ collect_refs ( & bus_numerators) ,
199+ collect_refs ( & bus_denominators) ,
200+ UNIVARIATE_SKIPS ,
201+ ) ;
202+
203+ let mut air_points: BTreeMap < Table , MultilinearPoint < EF > > = Default :: default ( ) ;
204+ let mut evals_f: BTreeMap < Table , Vec < EF > > = Default :: default ( ) ;
205+ let mut evals_ef: BTreeMap < Table , Vec < EF > > = Default :: default ( ) ;
206+
207+ let mut bus_offset = 0 ;
208+ for ( table, trace) in & traces {
209+ let ( this_air_point, this_evals_f, this_evals_ef) = prove_bus_and_air (
210+ & mut prover_state,
211+ table,
212+ trace,
213+ bus_challenge,
214+ fingerprint_challenge,
215+ & lookup_into_memory. on_bus_numerators [ bus_offset..] [ ..table. buses ( ) . len ( ) ] ,
216+ & lookup_into_memory. on_bus_denominators [ bus_offset..] [ ..table. buses ( ) . len ( ) ] ,
217+ ) ;
218+ air_points. insert ( * table, this_air_point) ;
219+ evals_f. insert ( * table, this_evals_f) ;
220+ evals_ef. insert ( * table, this_evals_ef) ;
221+ bus_offset += table. buses ( ) . len ( ) ;
222+ }
223+ assert_eq_many ! (
224+ bus_offset,
225+ lookup_into_memory. on_bus_numerators. len( ) ,
226+ lookup_into_memory. on_bus_denominators. len( )
227+ ) ;
228+
229+ let bytecode_compression_challenges =
230+ MultilinearPoint ( prover_state. sample_vec ( log2_ceil_usize ( N_INSTRUCTION_COLUMNS ) ) ) ;
231+
232+ let folded_bytecode = fold_bytecode ( bytecode, & bytecode_compression_challenges) ;
233+
234+ let bytecode_lookup_claim = Evaluation :: new (
235+ air_points[ & Table :: execution ( ) ] . clone ( ) ,
236+ padd_with_zero_to_next_power_of_two ( & evals_f[ & Table :: execution ( ) ] [ ..N_INSTRUCTION_COLUMNS ] )
237+ . evaluate ( & bytecode_compression_challenges) ,
196238 ) ;
239+ let bytecode_poly_eq_point = eval_eq ( & air_points[ & Table :: execution ( ) ] ) ;
240+ let bytecode_pushforward = MleOwned :: Extension ( compute_pushforward (
241+ & traces[ & Table :: execution ( ) ] . base [ COL_INDEX_PC ] ,
242+ folded_bytecode. len ( ) ,
243+ & bytecode_poly_eq_point,
244+ ) ) ;
197245
198246 let bytecode_pushforward_commitment =
199247 WhirConfig :: new ( whir_config_builder_b ( ) , log2_ceil_usize ( bytecode. instructions . len ( ) ) )
@@ -294,135 +342,31 @@ fn prove_bus_and_air(
294342 trace : & TableTrace ,
295343 bus_challenge : EF ,
296344 fingerprint_challenge : EF ,
297- ) -> ( EF , MultilinearPoint < EF > , Vec < EF > , Vec < EF > ) {
298- let n_buses = t. buses ( ) . len ( ) ;
299- let n_buses_padded = n_buses. next_power_of_two ( ) ;
300- let log_n_buses = log2_ceil_usize ( n_buses) ;
301- let n_rows = trace. n_rows_padded ( ) ;
302- let log_n_rows = trace. log_padded ( ) ;
303-
304- assert ! ( n_buses > 0 , "Table {} has no buses" , t. name( ) ) ;
305-
306- let mut numerators = F :: zero_vec ( n_buses_padded * n_rows) ;
307- for ( bus, numerators_chunk) in t. buses ( ) . iter ( ) . zip ( numerators. chunks_mut ( n_rows) ) {
308- assert ! ( bus. selector < trace. base. len( ) ) ;
309- trace. base [ bus. selector ]
310- . par_iter ( )
311- . zip ( numerators_chunk)
312- . for_each ( |( & selector, v) | {
313- * v = match bus. direction {
314- BusDirection :: Pull => -selector,
315- BusDirection :: Push => selector,
316- }
317- } ) ;
318- }
319-
320- let mut denominators = unsafe { uninitialized_vec ( n_buses_padded * n_rows) } ;
321- for ( bus, denomniators_chunk) in t. buses ( ) . iter ( ) . zip ( denominators. chunks_exact_mut ( n_rows) ) {
322- denomniators_chunk. par_iter_mut ( ) . enumerate ( ) . for_each ( |( i, v) | {
323- * v = bus_challenge
324- + finger_print (
325- match & bus. table {
326- BusTable :: Constant ( table) => table. embed ( ) ,
327- BusTable :: Variable ( col) => trace. base [ * col] [ i] ,
328- } ,
329- bus. data
330- . iter ( )
331- . map ( |col| trace. base [ * col] [ i] )
332- . collect :: < Vec < _ > > ( )
333- . as_slice ( ) ,
334- fingerprint_challenge,
335- ) ;
336- } ) ;
337- }
338- denominators[ n_rows * n_buses..]
339- . par_iter_mut ( )
340- . for_each ( |v| * v = EF :: ONE ) ;
341-
342- // TODO avoid embedding !!
343- let numerators_embedded = numerators. par_iter ( ) . copied ( ) . map ( EF :: from) . collect :: < Vec < _ > > ( ) ;
344-
345- // TODO avoid reallocation due to packing (pack directly when constructing)
346- let numerators_packed = pack_extension ( & numerators_embedded) ;
347- let denominators_packed = pack_extension ( & denominators) ;
348- let ( quotient, bus_point_global, numerator_value_global, denominator_value_global) =
349- prove_gkr_quotient :: < _ , TWO_POW_UNIVARIATE_SKIPS > (
350- prover_state,
351- & MleGroupRef :: ExtensionPacked ( vec ! [ & numerators_packed, & denominators_packed] ) ,
352- ) ;
353-
354- let ( bus_point, bus_selector_values, bus_data_values) = if n_buses == 1 {
355- // easy case
356- (
357- bus_point_global,
358- vec ! [ numerator_value_global] ,
359- vec ! [ denominator_value_global] ,
360- )
361- } else {
362- let uni_selectors = univariate_selectors :: < F > ( UNIVARIATE_SKIPS ) ;
363-
364- let sub_numerators_evals = numerators
365- . par_chunks_exact ( 1 << ( log_n_rows - UNIVARIATE_SKIPS ) )
366- . take ( n_buses << UNIVARIATE_SKIPS )
367- . map ( |chunk| chunk. evaluate ( & MultilinearPoint ( bus_point_global[ 1 + log_n_buses..] . to_vec ( ) ) ) )
368- . collect :: < Vec < _ > > ( ) ;
369- prover_state. add_extension_scalars ( & sub_numerators_evals) ;
370- // sanity check:
371- assert_eq ! (
372- numerator_value_global,
373- evaluate_univariate_multilinear:: <_, _, _, false >(
374- & padd_with_zero_to_next_power_of_two( & sub_numerators_evals) ,
375- & bus_point_global[ ..1 + log_n_buses] ,
376- & uni_selectors,
377- None
378- ) ,
379- ) ;
380-
381- let sub_denominators_evals = denominators
382- . par_chunks_exact ( 1 << ( log_n_rows - UNIVARIATE_SKIPS ) )
383- . take ( n_buses << UNIVARIATE_SKIPS )
384- . map ( |chunk| chunk. evaluate ( & MultilinearPoint ( bus_point_global[ 1 + log_n_buses..] . to_vec ( ) ) ) )
385- . collect :: < Vec < _ > > ( ) ;
386- prover_state. add_extension_scalars ( & sub_denominators_evals) ;
387- // sanity check:
388- assert_eq ! (
389- denominator_value_global,
390- evaluate_univariate_multilinear:: <_, _, _, false >(
391- & padd_to_next_power_of_two( & sub_denominators_evals, EF :: ONE ) ,
392- & bus_point_global[ ..1 + log_n_buses] ,
393- & uni_selectors,
394- None
395- ) ,
396- ) ;
397-
398- let epsilon = prover_state. sample ( ) ;
399- let bus_point = MultilinearPoint ( [ vec ! [ epsilon] , bus_point_global[ 1 + log_n_buses..] . to_vec ( ) ] . concat ( ) ) ;
400-
401- let bus_selector_values = sub_numerators_evals
402- . chunks_exact ( 1 << UNIVARIATE_SKIPS )
403- . map ( |chunk| evaluate_univariate_multilinear :: < _ , _ , _ , false > ( chunk, & [ epsilon] , & uni_selectors, None ) )
404- . collect ( ) ;
405- let bus_data_values = sub_denominators_evals
406- . chunks_exact ( 1 << UNIVARIATE_SKIPS )
407- . map ( |chunk| evaluate_univariate_multilinear :: < _ , _ , _ , false > ( chunk, & [ epsilon] , & uni_selectors, None ) )
408- . collect ( ) ;
409-
410- ( bus_point, bus_selector_values, bus_data_values)
411- } ;
345+ bus_numerator_statements : & [ Evaluation < EF > ] ,
346+ bus_denominator_statements : & [ Evaluation < EF > ] ,
347+ ) -> ( MultilinearPoint < EF > , Vec < EF > , Vec < EF > ) {
348+ assert_eq ! ( t. buses( ) . len( ) , bus_numerator_statements. len( ) ) ;
349+ let bus_point = bus_numerator_statements[ 0 ] . point . clone ( ) ;
350+ assert ! ( t. buses( ) . iter( ) . all( |_| bus_numerator_statements[ 0 ] . point == bus_point) ) ;
351+ assert ! (
352+ t. buses( )
353+ . iter( )
354+ . all( |_| bus_denominator_statements[ 0 ] . point == bus_point)
355+ ) ;
412356
413357 let bus_beta = prover_state. sample ( ) ;
414358
415- let bus_final_values = bus_selector_values
359+ let bus_final_values = bus_numerator_statements
416360 . iter ( )
417- . zip_eq ( & bus_data_values )
418- . zip_eq ( & t. buses ( ) )
419- . map ( |( ( & bus_selector_value , & bus_data_value ) , bus) | {
420- bus_selector_value
361+ . zip_eq ( bus_denominator_statements )
362+ . zip_eq ( t. buses ( ) )
363+ . map ( |( ( bus_selector_statement , bus_data_statement ) , bus) | {
364+ bus_selector_statement . value
421365 * match bus. direction {
422366 BusDirection :: Pull => EF :: NEG_ONE ,
423367 BusDirection :: Push => EF :: ONE ,
424368 }
425- + bus_beta * ( bus_data_value - bus_challenge)
369+ + bus_beta * ( bus_data_statement . value - bus_challenge)
426370 } )
427371 . collect :: < Vec < _ > > ( ) ;
428372
@@ -438,7 +382,7 @@ fn prove_bus_and_air(
438382 alpha_powers : vec ! [ ] , // filled later
439383 } ;
440384
441- let ( air_point, evals_f, evals_ef) = info_span ! ( "Table AIR proof" , table = t. name( ) ) . in_scope ( || {
385+ let ( air_point, evals_f, evals_ef) = info_span ! ( "AIR proof" , table = t. name( ) ) . in_scope ( || {
442386 macro_rules! prove_air_for_table {
443387 ( $t: expr) => {
444388 prove_air(
@@ -458,5 +402,5 @@ fn prove_bus_and_air(
458402 delegate_to_inner ! ( t => prove_air_for_table)
459403 } ) ;
460404
461- ( quotient , air_point, evals_f, evals_ef)
405+ ( air_point, evals_f, evals_ef)
462406}
0 commit comments