@@ -13,6 +13,8 @@ use p3_field::{ExtensionField, PrimeField64, dot_product};
1313use tracing:: instrument;
1414use utils:: { FSProver , FSVerifier } ;
1515
16+ use crate :: MIN_VARS_FOR_PACKING ;
17+
1618/*
1719Custom GKR to compute sum of fractions.
1820
@@ -38,64 +40,100 @@ with: U0 = AB(0 0 --- )
3840#[ instrument( skip_all) ]
3941pub fn prove_gkr_quotient < EF > (
4042 prover_state : & mut FSProver < EF , impl FSChallenger < EF > > ,
41- numerators : & [ EFPacking < EF > ] ,
43+ numerators : & MleRef < ' _ , EF > ,
4244 ( c, denominator_indexes) : ( EF , & [ PF < EF > ] ) ,
4345 n_non_zeros_numerator : Option < usize > , // final_layer[n_non_zeros_numerator..n / 2] are zeros
4446) -> ( Evaluation < EF > , EF , EF )
4547where
4648 EF : ExtensionField < PF < EF > > ,
4749 PF < EF > : PrimeField64 ,
4850{
49- let n = log2_strict_usize ( numerators. len ( ) ) + packing_log_width :: < EF > ( ) + 1 ;
50- let n_non_zeros_numerator = n_non_zeros_numerator. unwrap_or ( numerators. len ( ) ) ;
51- let mut layers_packed = Vec :: new ( ) ;
52- assert ! (
53- n >= 5 + packing_log_width:: <EF >( ) ,
54- "TODO small GKR, no packing"
55- ) ;
56- let mut layers_not_packed = Vec :: new ( ) ;
57- let last_packed = n - ( 4 + packing_log_width :: < EF > ( ) ) ;
58- let denominator_indexes_packed = PFPacking :: < EF > :: pack_slice ( denominator_indexes) ;
59- let c_packed = EFPacking :: < EF > :: from ( c) ;
60- layers_packed. push ( sum_quotients_2_by_2_num_and_den (
61- numerators,
62- |i| c_packed - denominator_indexes_packed[ i] ,
63- Some ( n_non_zeros_numerator) ,
64- ) ) ;
65- for i in 0 ..last_packed - 1 {
66- layers_packed. push ( sum_quotients_2_by_2 ( & layers_packed[ i] , None ) ) ;
51+ let n = numerators. n_vars ( ) + 1 ;
52+ assert ! ( n >= 2 ) ;
53+ let n_non_zeros_numerator = n_non_zeros_numerator. unwrap_or ( numerators. packed_len ( ) ) ;
54+ let mut layers = Vec :: new ( ) ;
55+ match numerators {
56+ MleRef :: ExtensionPacked ( numerators) => {
57+ let denominator_indexes_packed = PFPacking :: < EF > :: pack_slice ( denominator_indexes) ;
58+ layers. push ( MleOwned :: ExtensionPacked ( sum_quotients_2_by_2_num_and_den (
59+ numerators,
60+ |i| EFPacking :: < EF > :: from ( c) - denominator_indexes_packed[ i] ,
61+ Some ( n_non_zeros_numerator) ,
62+ ) ) ) ;
63+ }
64+ MleRef :: Extension ( numerators) => {
65+ layers. push ( MleOwned :: Extension ( sum_quotients_2_by_2_num_and_den (
66+ numerators,
67+ |i| c - denominator_indexes[ i] ,
68+ Some ( n_non_zeros_numerator) ,
69+ ) ) ) ;
70+ }
71+ _ => unreachable ! ( ) ,
6772 }
68- layers_not_packed. push ( sum_quotients_2_by_2 (
69- & unpack_extension ( & layers_packed[ last_packed - 1 ] ) ,
70- None ,
71- ) ) ;
72- for i in 0 ..n - last_packed - 2 {
73- layers_not_packed. push ( sum_quotients_2_by_2 ( & layers_not_packed[ i] , None ) ) ;
73+
74+ loop {
75+ let prev_layer: Mle < ' _ , EF > = layers. last ( ) . unwrap ( ) . by_ref ( ) . into ( ) ;
76+ let prev_layer = if prev_layer. is_packed ( ) && prev_layer. n_vars ( ) < MIN_VARS_FOR_PACKING {
77+ prev_layer. unpack ( )
78+ } else {
79+ prev_layer
80+ } ;
81+ if prev_layer. n_vars ( ) == 1 {
82+ break ;
83+ }
84+ layers. push ( match prev_layer. by_ref ( ) {
85+ MleRef :: ExtensionPacked ( prev_layer) => {
86+ MleOwned :: ExtensionPacked ( sum_quotients_2_by_2 ( prev_layer, None ) )
87+ }
88+ MleRef :: Extension ( numerators) => {
89+ MleOwned :: Extension ( sum_quotients_2_by_2 ( numerators, None ) )
90+ }
91+ _ => unreachable ! ( ) ,
92+ } )
7493 }
7594
76- assert_eq ! ( layers_not_packed [ n - last_packed - 2 ] . len ( ) , 2 ) ;
77- prover_state. add_extension_scalars ( & layers_not_packed [ n - last_packed - 2 ] ) ;
95+ assert_eq ! ( layers . last ( ) . unwrap ( ) . n_vars ( ) , 1 ) ;
96+ prover_state. add_extension_scalars ( & layers . last ( ) . unwrap ( ) . by_ref ( ) . as_extension ( ) . unwrap ( ) ) ;
7897
7998 let point = MultilinearPoint ( vec ! [ prover_state. sample( ) ] ) ;
80- let mut claim = Evaluation :: new (
81- point. clone ( ) ,
82- layers_not_packed[ n - last_packed - 2 ] . evaluate ( & point) ,
83- ) ;
84-
85- for layer in layers_not_packed. iter ( ) . rev ( ) . skip ( 1 ) {
86- ( claim, _, _) = prove_gkr_quotient_step ( prover_state, layer, & claim) ;
87- }
88- for layer in layers_packed. iter ( ) . rev ( ) {
89- ( claim, _, _) = prove_gkr_quotient_step_packed ( prover_state, layer, & claim) ;
99+ let mut claim = Evaluation :: new ( point. clone ( ) , layers. last ( ) . unwrap ( ) . evaluate ( & point) ) ;
100+
101+ for layer in layers. iter ( ) . rev ( ) . skip ( 1 ) {
102+ match layer {
103+ MleOwned :: Extension ( layer) => {
104+ ( claim, _, _) = prove_gkr_quotient_step ( prover_state, layer, & claim) ;
105+ }
106+ MleOwned :: ExtensionPacked ( layer) => {
107+ ( claim, _, _) = prove_gkr_quotient_step_packed ( prover_state, layer, & claim) ;
108+ }
109+ _ => unreachable ! ( ) ,
110+ }
90111 }
91112 let ( up_layer_eval_left, up_layer_eval_right) ;
92- ( claim, up_layer_eval_left, up_layer_eval_right) = prove_gkr_quotient_step_packed_first_round (
93- prover_state,
94- numerators,
95- ( c_packed, denominator_indexes_packed) ,
96- & claim,
97- Some ( n_non_zeros_numerator) ,
98- ) ;
113+
114+ match numerators {
115+ MleRef :: ExtensionPacked ( numerators) => {
116+ let denominator_indexes_packed = PFPacking :: < EF > :: pack_slice ( denominator_indexes) ;
117+ ( claim, up_layer_eval_left, up_layer_eval_right) =
118+ prove_gkr_quotient_step_packed_first_round (
119+ prover_state,
120+ numerators,
121+ ( EFPacking :: < EF > :: from ( c) , denominator_indexes_packed) ,
122+ & claim,
123+ Some ( n_non_zeros_numerator) ,
124+ ) ;
125+ }
126+ MleRef :: Extension ( numerators) => {
127+ let mut layer = EF :: zero_vec ( numerators. len ( ) * 2 ) ;
128+ layer[ ..numerators. len ( ) ] . copy_from_slice ( numerators) ;
129+ for i in 0 ..denominator_indexes. len ( ) {
130+ layer[ numerators. len ( ) + i] = c - denominator_indexes[ i] ;
131+ }
132+ ( claim, up_layer_eval_left, up_layer_eval_right) =
133+ prove_gkr_quotient_step ( prover_state, & layer, & claim) ;
134+ }
135+ _ => unreachable ! ( ) ,
136+ }
99137
100138 ( claim, up_layer_eval_left, up_layer_eval_right)
101139}
@@ -474,7 +512,7 @@ where
474512 let mid_len_packed = len_packed / 2 ;
475513 let quarter_len_packed = mid_len_packed / 2 ;
476514
477- let mut eq_poly_packed = eval_eq_packed ( & claim. point . 0 [ 1 ..] ) ;
515+ let eq_poly_packed = eval_eq_packed ( & claim. point . 0 [ 1 ..] ) ;
478516
479517 let up_layer_octics = split_at_many (
480518 up_layer_packed,
@@ -613,7 +651,6 @@ where
613651 let sumcheck_challenge_2 = prover_state. sample ( ) ;
614652 let sum_2 = sumcheck_polynomial_2. evaluate ( sumcheck_challenge_2) ;
615653
616- eq_poly_packed. resize ( eq_poly_packed. len ( ) / 4 , Default :: default ( ) ) ;
617654 missing_mul_factor *= ( ( EF :: ONE - claim. point [ 1 ] ) * ( EF :: ONE - sumcheck_challenge_2)
618655 + claim. point [ 1 ] * sumcheck_challenge_2)
619656 / ( EF :: ONE - claim. point . get ( 2 ) . copied ( ) . unwrap_or_default ( ) ) ;
@@ -631,7 +668,7 @@ where
631668 & [ ] ,
632669 Some ( (
633670 claim. point . 0 [ 2 ..] . to_vec ( ) ,
634- Some ( MleOwned :: ExtensionPacked ( eq_poly_packed) ) ,
671+ Some ( MleOwned :: ExtensionPacked ( eq_poly_packed) . halve ( ) . halve ( ) ) ,
635672 ) ) ,
636673 false ,
637674 prover_state,
@@ -854,7 +891,7 @@ mod tests {
854891
855892 let _ = prove_gkr_quotient (
856893 & mut prover_state,
857- & pack_extension ( & numerators) ,
894+ & MleRef :: ExtensionPacked ( & pack_extension ( & numerators) ) ,
858895 ( c, & denominators_indexes) ,
859896 None ,
860897 ) ;
0 commit comments