Skip to content

Commit a92ba9e

Browse files
committed
wip edge cases GKR small length
1 parent 04a0761 commit a92ba9e

File tree

3 files changed

+88
-74
lines changed

3 files changed

+88
-74
lines changed

crates/lean_prover/src/prove_execution.rs

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -250,10 +250,8 @@ pub fn prove_execution(
250250
);
251251
}
252252

253-
let (grand_product_exec_res, grand_product_exec_statement) = prove_gkr_product(
254-
&mut prover_state,
255-
pack_extension(&exec_column_for_grand_product),
256-
);
253+
let (grand_product_exec_res, grand_product_exec_statement) =
254+
prove_gkr_product(&mut prover_state, &exec_column_for_grand_product);
257255

258256
let p16_column_for_grand_product = poseidons_16
259257
.par_iter()
@@ -267,10 +265,8 @@ pub fn prove_execution(
267265
})
268266
.collect::<Vec<_>>();
269267

270-
let (grand_product_p16_res, grand_product_p16_statement) = prove_gkr_product(
271-
&mut prover_state,
272-
pack_extension(&p16_column_for_grand_product),
273-
);
268+
let (grand_product_p16_res, grand_product_p16_statement) =
269+
prove_gkr_product(&mut prover_state, &p16_column_for_grand_product);
274270

275271
let p24_column_for_grand_product = poseidons_24
276272
.par_iter()
@@ -284,10 +280,8 @@ pub fn prove_execution(
284280
})
285281
.collect::<Vec<_>>();
286282

287-
let (grand_product_p24_res, grand_product_p24_statement) = prove_gkr_product(
288-
&mut prover_state,
289-
pack_extension(&p24_column_for_grand_product),
290-
);
283+
let (grand_product_p24_res, grand_product_p24_statement) =
284+
prove_gkr_product(&mut prover_state, &p24_column_for_grand_product);
291285

292286
let dot_product_column_for_grand_product = (0..1 << log_n_rows_dot_product_table)
293287
.into_par_iter()
@@ -322,10 +316,8 @@ pub fn prove_execution(
322316
})
323317
.product::<EF>();
324318

325-
let (grand_product_dot_product_res, grand_product_dot_product_statement) = prove_gkr_product(
326-
&mut prover_state,
327-
pack_extension(&dot_product_column_for_grand_product),
328-
);
319+
let (grand_product_dot_product_res, grand_product_dot_product_statement) =
320+
prove_gkr_product(&mut prover_state, &dot_product_column_for_grand_product);
329321

330322
let corrected_prod_exec = grand_product_exec_res
331323
/ grand_product_challenge_global.exp_u64(

crates/lookup/src/logup_star.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,12 +243,20 @@ mod tests {
243243

244244
#[test]
245245
fn test_logup_star() {
246+
for log_table_len in [1, 10] {
247+
for log_indexes_len in 1..10 {
248+
test_logup_star_helper(log_table_len, log_indexes_len);
249+
}
250+
}
251+
252+
test_logup_star_helper(15, 17);
253+
}
254+
255+
fn test_logup_star_helper(log_table_len: usize, log_indexes_len: usize) {
246256
init_tracing();
247257

248-
let log_table_len = 14;
249258
let table_length = 1 << log_table_len;
250259

251-
let log_indexes_len = log_table_len + 1;
252260
let indexes_len = 1 << log_indexes_len;
253261

254262
let mut rng = StdRng::seed_from_u64(0);

crates/lookup/src/product_gkr.rs

Lines changed: 70 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ use utils::left_ref;
1515
use utils::right_ref;
1616
use utils::{FSProver, FSVerifier};
1717

18+
use crate::MIN_VARS_FOR_PACKING;
19+
1820
/*
1921
Custom GKR to compute a product.
2022
@@ -27,47 +29,64 @@ A': [a0*a4, a1*a5, a2*a6, a3*a7]
2729
#[instrument(skip_all)]
2830
pub fn prove_gkr_product<EF>(
2931
prover_state: &mut FSProver<EF, impl FSChallenger<EF>>,
30-
final_layer: Vec<EFPacking<EF>>,
32+
final_layer: &[EF],
3133
) -> (EF, Evaluation<EF>)
3234
where
3335
EF: ExtensionField<PF<EF>>,
3436
PF<EF>: PrimeField64,
3537
{
36-
let n = (final_layer.len() * packing_width::<EF>()).ilog2() as usize;
37-
let mut layers_packed = Vec::new();
38-
let mut layers_not_packed = Vec::new();
39-
let last_packed = n
40-
.checked_sub(6 + packing_log_width::<EF>())
41-
.expect("TODO small GKR, no packing");
42-
layers_packed.push(final_layer);
43-
for i in 0..last_packed {
44-
layers_packed.push(product_2_by_2(&layers_packed[i]));
38+
assert!(log2_strict_usize(final_layer.len()) >= 1);
39+
if final_layer.len() == 2 {
40+
prover_state.add_extension_scalars(&final_layer);
41+
let product = final_layer[0] * final_layer[1];
42+
let point = MultilinearPoint(vec![prover_state.sample()]);
43+
let claim = Evaluation {
44+
point: point.clone(),
45+
value: final_layer.evaluate(&point),
46+
};
47+
return (product, claim);
4548
}
46-
layers_not_packed.push(product_2_by_2(&unpack_extension(
47-
&layers_packed[last_packed],
48-
)));
49-
for i in 0..n - last_packed - 2 {
50-
layers_not_packed.push(product_2_by_2(&layers_not_packed[i]));
49+
50+
let final_layer: Mle<'_, EF> = if final_layer.len() >= 1 << MIN_VARS_FOR_PACKING {
51+
// TODO packing beforehand
52+
MleOwned::ExtensionPacked(pack_extension(final_layer)).into()
53+
} else {
54+
MleRef::Extension(final_layer).into()
55+
};
56+
if final_layer.n_vars() > MIN_VARS_FOR_PACKING && !final_layer.is_packed() {
57+
tracing::warn!("GKR product not packed despite being large enough for packing");
5158
}
5259

53-
assert_eq!(layers_not_packed[n - last_packed - 2].len(), 2);
54-
let product = layers_not_packed[n - last_packed - 2]
55-
.iter()
56-
.copied()
57-
.product::<EF>();
58-
prover_state.add_extension_scalars(&layers_not_packed[n - last_packed - 2]);
60+
let mut layers = vec![final_layer];
61+
loop {
62+
if layers.last().unwrap().n_vars() == 1 {
63+
break;
64+
}
65+
layers.push(product_2_by_2(&layers.last().unwrap().by_ref()).into());
66+
}
67+
68+
let last_layer = match layers.last().unwrap().by_ref() {
69+
MleRef::Extension(slice) => slice,
70+
_ => unreachable!(),
71+
};
72+
assert_eq!(last_layer.len(), 2);
73+
let product = last_layer[0] * last_layer[1];
74+
prover_state.add_extension_scalars(&last_layer);
5975

6076
let point = MultilinearPoint(vec![prover_state.sample()]);
6177
let mut claim = Evaluation {
6278
point: point.clone(),
63-
value: layers_not_packed[n - last_packed - 2].evaluate(&point),
79+
value: last_layer.evaluate(&point),
6480
};
6581

66-
for layer in layers_not_packed.iter().rev().skip(1) {
67-
claim = prove_gkr_product_step(prover_state, layer, &claim);
68-
}
69-
for layer in layers_packed.iter().rev() {
70-
claim = prove_gkr_product_step_packed(prover_state, layer, &claim);
82+
for layer in layers.iter().rev().skip(1) {
83+
claim = match layer.by_ref() {
84+
MleRef::Extension(slice) => prove_gkr_product_step(prover_state, slice, &claim),
85+
MleRef::ExtensionPacked(slice) => {
86+
prove_gkr_product_step_packed(prover_state, slice, &claim)
87+
}
88+
_ => unreachable!(),
89+
}
7190
}
7291

7392
(product, claim)
@@ -201,7 +220,23 @@ where
201220
Ok(Evaluation::new(next_point, next_claim))
202221
}
203222

204-
fn product_2_by_2<EF: PrimeCharacteristicRing + Sync + Send + Copy>(layer: &[EF]) -> Vec<EF> {
223+
fn product_2_by_2<EF: ExtensionField<PF<EF>>>(layer: &MleRef<'_, EF>) -> MleOwned<EF> {
224+
match layer {
225+
MleRef::Extension(slice) => MleOwned::Extension(product_2_by_2_helper(slice)),
226+
MleRef::ExtensionPacked(slice) => {
227+
if slice.len() >= 1 << MIN_VARS_FOR_PACKING {
228+
MleOwned::ExtensionPacked(product_2_by_2_helper(slice))
229+
} else {
230+
MleOwned::Extension(product_2_by_2_helper(&unpack_extension(slice)))
231+
}
232+
}
233+
_ => unreachable!(),
234+
}
235+
}
236+
237+
fn product_2_by_2_helper<EF: PrimeCharacteristicRing + Sync + Send + Copy>(
238+
layer: &[EF],
239+
) -> Vec<EF> {
205240
let n = layer.len();
206241
(0..n / 2)
207242
.into_par_iter()
@@ -221,34 +256,13 @@ mod tests {
221256
type EF = QuinticExtensionFieldKB;
222257

223258
#[test]
224-
fn test_gkr_product_step() {
225-
let log_n = 12;
226-
let n = 1 << log_n;
227-
228-
let mut rng = StdRng::seed_from_u64(0);
229-
230-
let big = (0..n).map(|_| rng.random()).collect::<Vec<EF>>();
231-
let small = product_2_by_2(&big);
232-
233-
let point = MultilinearPoint((0..log_n - 1).map(|_| rng.random()).collect::<Vec<EF>>());
234-
let eval = small.evaluate(&point);
235-
236-
let mut prover_state = build_prover_state();
237-
238-
let time = Instant::now();
239-
let claim = Evaluation { point, value: eval };
240-
prove_gkr_product_step_packed(&mut prover_state, &pack_extension(&big), &claim);
241-
dbg!(time.elapsed());
242-
243-
let mut verifier_state = build_verifier_state(&prover_state);
244-
245-
let postponed = verify_gkr_product_step(&mut verifier_state, log_n - 1, &claim).unwrap();
246-
assert_eq!(big.evaluate(&postponed.point), postponed.value);
259+
fn test_gkr_product() {
260+
for log_n in 1..10 {
261+
test_gkr_product_helper(log_n);
262+
}
247263
}
248264

249-
#[test]
250-
fn test_gkr_product() {
251-
let log_n = 13;
265+
fn test_gkr_product_helper(log_n: usize) {
252266
let n = 1 << log_n;
253267

254268
let mut rng = StdRng::seed_from_u64(0);
@@ -259,8 +273,8 @@ mod tests {
259273
let mut prover_state = build_prover_state();
260274

261275
let time = Instant::now();
262-
let (product_prover, claim_prover) =
263-
prove_gkr_product(&mut prover_state, pack_extension(&layer));
276+
277+
let (product_prover, claim_prover) = prove_gkr_product(&mut prover_state, &layer);
264278
println!("GKR product took {:?}", time.elapsed());
265279

266280
let mut verifier_state = build_verifier_state(&prover_state);

0 commit comments

Comments
 (0)