@@ -15,6 +15,8 @@ use utils::left_ref;
1515use utils:: right_ref;
1616use utils:: { FSProver , FSVerifier } ;
1717
18+ use crate :: MIN_VARS_FOR_PACKING ;
19+
1820/*
1921Custom GKR to compute a product.
2022
@@ -27,47 +29,64 @@ A': [a0*a4, a1*a5, a2*a6, a3*a7]
2729#[ instrument( skip_all) ]
2830pub fn prove_gkr_product < EF > (
2931 prover_state : & mut FSProver < EF , impl FSChallenger < EF > > ,
30- final_layer : Vec < EFPacking < EF > > ,
32+ final_layer : & [ EF ] ,
3133) -> ( EF , Evaluation < EF > )
3234where
3335 EF : ExtensionField < PF < EF > > ,
3436 PF < EF > : PrimeField64 ,
3537{
36- let n = ( final_layer. len ( ) * packing_width :: < EF > ( ) ) . ilog2 ( ) as usize ;
37- let mut layers_packed = Vec :: new ( ) ;
38- let mut layers_not_packed = Vec :: new ( ) ;
39- let last_packed = n
40- . checked_sub ( 6 + packing_log_width :: < EF > ( ) )
41- . expect ( "TODO small GKR, no packing" ) ;
42- layers_packed. push ( final_layer) ;
43- for i in 0 ..last_packed {
44- layers_packed. push ( product_2_by_2 ( & layers_packed[ i] ) ) ;
38+ assert ! ( log2_strict_usize( final_layer. len( ) ) >= 1 ) ;
39+ if final_layer. len ( ) == 2 {
40+ prover_state. add_extension_scalars ( & final_layer) ;
41+ let product = final_layer[ 0 ] * final_layer[ 1 ] ;
42+ let point = MultilinearPoint ( vec ! [ prover_state. sample( ) ] ) ;
43+ let claim = Evaluation {
44+ point : point. clone ( ) ,
45+ value : final_layer. evaluate ( & point) ,
46+ } ;
47+ return ( product, claim) ;
4548 }
46- layers_not_packed. push ( product_2_by_2 ( & unpack_extension (
47- & layers_packed[ last_packed] ,
48- ) ) ) ;
49- for i in 0 ..n - last_packed - 2 {
50- layers_not_packed. push ( product_2_by_2 ( & layers_not_packed[ i] ) ) ;
49+
50+ let final_layer: Mle < ' _ , EF > = if final_layer. len ( ) >= 1 << MIN_VARS_FOR_PACKING {
51+ // TODO packing beforehand
52+ MleOwned :: ExtensionPacked ( pack_extension ( final_layer) ) . into ( )
53+ } else {
54+ MleRef :: Extension ( final_layer) . into ( )
55+ } ;
56+ if final_layer. n_vars ( ) > MIN_VARS_FOR_PACKING && !final_layer. is_packed ( ) {
57+ tracing:: warn!( "GKR product not packed despite being large enough for packing" ) ;
5158 }
5259
53- assert_eq ! ( layers_not_packed[ n - last_packed - 2 ] . len( ) , 2 ) ;
54- let product = layers_not_packed[ n - last_packed - 2 ]
55- . iter ( )
56- . copied ( )
57- . product :: < EF > ( ) ;
58- prover_state. add_extension_scalars ( & layers_not_packed[ n - last_packed - 2 ] ) ;
60+ let mut layers = vec ! [ final_layer] ;
61+ loop {
62+ if layers. last ( ) . unwrap ( ) . n_vars ( ) == 1 {
63+ break ;
64+ }
65+ layers. push ( product_2_by_2 ( & layers. last ( ) . unwrap ( ) . by_ref ( ) ) . into ( ) ) ;
66+ }
67+
68+ let last_layer = match layers. last ( ) . unwrap ( ) . by_ref ( ) {
69+ MleRef :: Extension ( slice) => slice,
70+ _ => unreachable ! ( ) ,
71+ } ;
72+ assert_eq ! ( last_layer. len( ) , 2 ) ;
73+ let product = last_layer[ 0 ] * last_layer[ 1 ] ;
74+ prover_state. add_extension_scalars ( & last_layer) ;
5975
6076 let point = MultilinearPoint ( vec ! [ prover_state. sample( ) ] ) ;
6177 let mut claim = Evaluation {
6278 point : point. clone ( ) ,
63- value : layers_not_packed [ n - last_packed - 2 ] . evaluate ( & point) ,
79+ value : last_layer . evaluate ( & point) ,
6480 } ;
6581
66- for layer in layers_not_packed. iter ( ) . rev ( ) . skip ( 1 ) {
67- claim = prove_gkr_product_step ( prover_state, layer, & claim) ;
68- }
69- for layer in layers_packed. iter ( ) . rev ( ) {
70- claim = prove_gkr_product_step_packed ( prover_state, layer, & claim) ;
82+ for layer in layers. iter ( ) . rev ( ) . skip ( 1 ) {
83+ claim = match layer. by_ref ( ) {
84+ MleRef :: Extension ( slice) => prove_gkr_product_step ( prover_state, slice, & claim) ,
85+ MleRef :: ExtensionPacked ( slice) => {
86+ prove_gkr_product_step_packed ( prover_state, slice, & claim)
87+ }
88+ _ => unreachable ! ( ) ,
89+ }
7190 }
7291
7392 ( product, claim)
@@ -201,7 +220,23 @@ where
201220 Ok ( Evaluation :: new ( next_point, next_claim) )
202221}
203222
204- fn product_2_by_2 < EF : PrimeCharacteristicRing + Sync + Send + Copy > ( layer : & [ EF ] ) -> Vec < EF > {
223+ fn product_2_by_2 < EF : ExtensionField < PF < EF > > > ( layer : & MleRef < ' _ , EF > ) -> MleOwned < EF > {
224+ match layer {
225+ MleRef :: Extension ( slice) => MleOwned :: Extension ( product_2_by_2_helper ( slice) ) ,
226+ MleRef :: ExtensionPacked ( slice) => {
227+ if slice. len ( ) >= 1 << MIN_VARS_FOR_PACKING {
228+ MleOwned :: ExtensionPacked ( product_2_by_2_helper ( slice) )
229+ } else {
230+ MleOwned :: Extension ( product_2_by_2_helper ( & unpack_extension ( slice) ) )
231+ }
232+ }
233+ _ => unreachable ! ( ) ,
234+ }
235+ }
236+
237+ fn product_2_by_2_helper < EF : PrimeCharacteristicRing + Sync + Send + Copy > (
238+ layer : & [ EF ] ,
239+ ) -> Vec < EF > {
205240 let n = layer. len ( ) ;
206241 ( 0 ..n / 2 )
207242 . into_par_iter ( )
@@ -221,34 +256,13 @@ mod tests {
221256 type EF = QuinticExtensionFieldKB ;
222257
223258 #[ test]
224- fn test_gkr_product_step ( ) {
225- let log_n = 12 ;
226- let n = 1 << log_n;
227-
228- let mut rng = StdRng :: seed_from_u64 ( 0 ) ;
229-
230- let big = ( 0 ..n) . map ( |_| rng. random ( ) ) . collect :: < Vec < EF > > ( ) ;
231- let small = product_2_by_2 ( & big) ;
232-
233- let point = MultilinearPoint ( ( 0 ..log_n - 1 ) . map ( |_| rng. random ( ) ) . collect :: < Vec < EF > > ( ) ) ;
234- let eval = small. evaluate ( & point) ;
235-
236- let mut prover_state = build_prover_state ( ) ;
237-
238- let time = Instant :: now ( ) ;
239- let claim = Evaluation { point, value : eval } ;
240- prove_gkr_product_step_packed ( & mut prover_state, & pack_extension ( & big) , & claim) ;
241- dbg ! ( time. elapsed( ) ) ;
242-
243- let mut verifier_state = build_verifier_state ( & prover_state) ;
244-
245- let postponed = verify_gkr_product_step ( & mut verifier_state, log_n - 1 , & claim) . unwrap ( ) ;
246- assert_eq ! ( big. evaluate( & postponed. point) , postponed. value) ;
259+ fn test_gkr_product ( ) {
260+ for log_n in 1 ..10 {
261+ test_gkr_product_helper ( log_n) ;
262+ }
247263 }
248264
249- #[ test]
250- fn test_gkr_product ( ) {
251- let log_n = 13 ;
265+ fn test_gkr_product_helper ( log_n : usize ) {
252266 let n = 1 << log_n;
253267
254268 let mut rng = StdRng :: seed_from_u64 ( 0 ) ;
@@ -259,8 +273,8 @@ mod tests {
259273 let mut prover_state = build_prover_state ( ) ;
260274
261275 let time = Instant :: now ( ) ;
262- let ( product_prover , claim_prover ) =
263- prove_gkr_product ( & mut prover_state, pack_extension ( & layer) ) ;
276+
277+ let ( product_prover , claim_prover ) = prove_gkr_product ( & mut prover_state, & layer) ;
264278 println ! ( "GKR product took {:?}" , time. elapsed( ) ) ;
265279
266280 let mut verifier_state = build_verifier_state ( & prover_state) ;
0 commit comments