Skip to content

Commit 04a0761

Browse files
committed
fix logup* edge cases on small lengths
1 parent b32601d commit 04a0761

File tree

4 files changed

+123
-75
lines changed

4 files changed

+123
-75
lines changed

Cargo.lock

Lines changed: 12 additions & 12 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/lookup/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,5 @@ pub use logup_star::*;
1212

1313
mod product_gkr;
1414
pub use product_gkr::*;
15+
16+
pub(crate) const MIN_VARS_FOR_PACKING: usize = 8;

crates/lookup/src/logup_star.rs

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@ use p3_field::PrimeCharacteristicRing;
1313
use tracing::{info_span, instrument};
1414
use utils::{FSProver, FSVerifier};
1515

16-
use crate::quotient_gkr::{prove_gkr_quotient, verify_gkr_quotient};
16+
use crate::{
17+
MIN_VARS_FOR_PACKING,
18+
quotient_gkr::{prove_gkr_quotient, verify_gkr_quotient},
19+
};
1720

1821
#[derive(Debug)]
1922
pub struct LogupStarStatements<EF> {
@@ -38,25 +41,27 @@ where
3841
{
3942
let table_length = table.unpacked_len();
4043
let indexes_length = indexes.len();
41-
let max_index = max_index
42-
.unwrap_or(table_length)
43-
.next_multiple_of(packing_width::<EF>());
44-
let max_index_packed = max_index.div_ceil(packing_width::<EF>());
44+
let packing = log2_strict_usize(table_length) >= MIN_VARS_FOR_PACKING
45+
&& log2_strict_usize(indexes_length) >= MIN_VARS_FOR_PACKING;
46+
let mut max_index = max_index.unwrap_or(table_length);
47+
if packing {
48+
max_index = max_index.div_ceil(packing_width::<EF>());
49+
}
4550

4651
let (poly_eq_point_packed, pushforward_packed, table_packed) =
4752
info_span!("packing").in_scope(|| {
4853
(
49-
pack_extension(poly_eq_point),
50-
pack_extension(pushforward),
51-
table.pack(),
54+
MleRef::Extension(poly_eq_point).pack_if(packing),
55+
MleRef::Extension(pushforward).pack_if(packing),
56+
table.pack_if(packing),
5257
)
5358
});
5459

5560
let (sc_point, inner_evals, prod) =
5661
info_span!("logup_star sumcheck", table_length, indexes_length).in_scope(|| {
5762
let (sc_point, prod, table_folded, pushforward_folded) = run_product_sumcheck(
5863
&table_packed.by_ref(),
59-
&MleRef::ExtensionPacked(&pushforward_packed),
64+
&pushforward_packed.by_ref(),
6065
prover_state,
6166
claimed_value,
6267
table.n_vars(),
@@ -83,18 +88,22 @@ where
8388

8489
let c = prover_state.sample();
8590

86-
let (claim_left, _, eval_c_minux_indexes) =
87-
prove_gkr_quotient(prover_state, &poly_eq_point_packed, (c, indexes), None);
91+
let (claim_left, _, eval_c_minux_indexes) = prove_gkr_quotient(
92+
prover_state,
93+
&poly_eq_point_packed.by_ref(),
94+
(c, indexes),
95+
None,
96+
);
8897

8998
let increments = (0..table.unpacked_len())
9099
.into_par_iter()
91100
.map(PF::<EF>::from_usize)
92101
.collect::<Vec<_>>();
93102
let (claim_right, pushforward_final_eval, _) = prove_gkr_quotient(
94103
prover_state,
95-
&pushforward_packed,
104+
&pushforward_packed.by_ref(),
96105
(c, &increments),
97-
Some(max_index_packed),
106+
Some(max_index),
98107
);
99108

100109
let final_point_left = claim_left.point[1..].to_vec();
@@ -236,7 +245,7 @@ mod tests {
236245
fn test_logup_star() {
237246
init_tracing();
238247

239-
let log_table_len = 21;
248+
let log_table_len = 14;
240249
let table_length = 1 << log_table_len;
241250

242251
let log_indexes_len = log_table_len + 1;
@@ -319,7 +328,7 @@ mod tests {
319328
.par_iter()
320329
.map(|x| (0..n_muls).map(|_| *x).product::<EFPacking<EF>>())
321330
.sum::<EFPacking<EF>>();
322-
assert!(sum != EFPacking::<EF>::ZERO);
331+
assert!(sum != EFPacking::<EF>::ONE);
323332
println!(
324333
"Optimal time we can hope for: {} ms",
325334
time.elapsed().as_millis()

crates/lookup/src/quotient_gkr.rs

Lines changed: 85 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ use p3_field::{ExtensionField, PrimeField64, dot_product};
1313
use tracing::instrument;
1414
use utils::{FSProver, FSVerifier};
1515

16+
use crate::MIN_VARS_FOR_PACKING;
17+
1618
/*
1719
Custom GKR to compute sum of fractions.
1820
@@ -38,64 +40,100 @@ with: U0 = AB(0 0 --- )
3840
#[instrument(skip_all)]
3941
pub fn prove_gkr_quotient<EF>(
4042
prover_state: &mut FSProver<EF, impl FSChallenger<EF>>,
41-
numerators: &[EFPacking<EF>],
43+
numerators: &MleRef<'_, EF>,
4244
(c, denominator_indexes): (EF, &[PF<EF>]),
4345
n_non_zeros_numerator: Option<usize>, // final_layer[n_non_zeros_numerator..n / 2] are zeros
4446
) -> (Evaluation<EF>, EF, EF)
4547
where
4648
EF: ExtensionField<PF<EF>>,
4749
PF<EF>: PrimeField64,
4850
{
49-
let n = log2_strict_usize(numerators.len()) + packing_log_width::<EF>() + 1;
50-
let n_non_zeros_numerator = n_non_zeros_numerator.unwrap_or(numerators.len());
51-
let mut layers_packed = Vec::new();
52-
assert!(
53-
n >= 5 + packing_log_width::<EF>(),
54-
"TODO small GKR, no packing"
55-
);
56-
let mut layers_not_packed = Vec::new();
57-
let last_packed = n - (4 + packing_log_width::<EF>());
58-
let denominator_indexes_packed = PFPacking::<EF>::pack_slice(denominator_indexes);
59-
let c_packed = EFPacking::<EF>::from(c);
60-
layers_packed.push(sum_quotients_2_by_2_num_and_den(
61-
numerators,
62-
|i| c_packed - denominator_indexes_packed[i],
63-
Some(n_non_zeros_numerator),
64-
));
65-
for i in 0..last_packed - 1 {
66-
layers_packed.push(sum_quotients_2_by_2(&layers_packed[i], None));
51+
let n = numerators.n_vars() + 1;
52+
assert!(n >= 2);
53+
let n_non_zeros_numerator = n_non_zeros_numerator.unwrap_or(numerators.packed_len());
54+
let mut layers = Vec::new();
55+
match numerators {
56+
MleRef::ExtensionPacked(numerators) => {
57+
let denominator_indexes_packed = PFPacking::<EF>::pack_slice(denominator_indexes);
58+
layers.push(MleOwned::ExtensionPacked(sum_quotients_2_by_2_num_and_den(
59+
numerators,
60+
|i| EFPacking::<EF>::from(c) - denominator_indexes_packed[i],
61+
Some(n_non_zeros_numerator),
62+
)));
63+
}
64+
MleRef::Extension(numerators) => {
65+
layers.push(MleOwned::Extension(sum_quotients_2_by_2_num_and_den(
66+
numerators,
67+
|i| c - denominator_indexes[i],
68+
Some(n_non_zeros_numerator),
69+
)));
70+
}
71+
_ => unreachable!(),
6772
}
68-
layers_not_packed.push(sum_quotients_2_by_2(
69-
&unpack_extension(&layers_packed[last_packed - 1]),
70-
None,
71-
));
72-
for i in 0..n - last_packed - 2 {
73-
layers_not_packed.push(sum_quotients_2_by_2(&layers_not_packed[i], None));
73+
74+
loop {
75+
let prev_layer: Mle<'_, EF> = layers.last().unwrap().by_ref().into();
76+
let prev_layer = if prev_layer.is_packed() && prev_layer.n_vars() < MIN_VARS_FOR_PACKING {
77+
prev_layer.unpack()
78+
} else {
79+
prev_layer
80+
};
81+
if prev_layer.n_vars() == 1 {
82+
break;
83+
}
84+
layers.push(match prev_layer.by_ref() {
85+
MleRef::ExtensionPacked(prev_layer) => {
86+
MleOwned::ExtensionPacked(sum_quotients_2_by_2(prev_layer, None))
87+
}
88+
MleRef::Extension(numerators) => {
89+
MleOwned::Extension(sum_quotients_2_by_2(numerators, None))
90+
}
91+
_ => unreachable!(),
92+
})
7493
}
7594

76-
assert_eq!(layers_not_packed[n - last_packed - 2].len(), 2);
77-
prover_state.add_extension_scalars(&layers_not_packed[n - last_packed - 2]);
95+
assert_eq!(layers.last().unwrap().n_vars(), 1);
96+
prover_state.add_extension_scalars(layers.last().unwrap().by_ref().as_extension().unwrap());
7897

7998
let point = MultilinearPoint(vec![prover_state.sample()]);
80-
let mut claim = Evaluation::new(
81-
point.clone(),
82-
layers_not_packed[n - last_packed - 2].evaluate(&point),
83-
);
84-
85-
for layer in layers_not_packed.iter().rev().skip(1) {
86-
(claim, _, _) = prove_gkr_quotient_step(prover_state, layer, &claim);
87-
}
88-
for layer in layers_packed.iter().rev() {
89-
(claim, _, _) = prove_gkr_quotient_step_packed(prover_state, layer, &claim);
99+
let mut claim = Evaluation::new(point.clone(), layers.last().unwrap().evaluate(&point));
100+
101+
for layer in layers.iter().rev().skip(1) {
102+
match layer {
103+
MleOwned::Extension(layer) => {
104+
(claim, _, _) = prove_gkr_quotient_step(prover_state, layer, &claim);
105+
}
106+
MleOwned::ExtensionPacked(layer) => {
107+
(claim, _, _) = prove_gkr_quotient_step_packed(prover_state, layer, &claim);
108+
}
109+
_ => unreachable!(),
110+
}
90111
}
91112
let (up_layer_eval_left, up_layer_eval_right);
92-
(claim, up_layer_eval_left, up_layer_eval_right) = prove_gkr_quotient_step_packed_first_round(
93-
prover_state,
94-
numerators,
95-
(c_packed, denominator_indexes_packed),
96-
&claim,
97-
Some(n_non_zeros_numerator),
98-
);
113+
114+
match numerators {
115+
MleRef::ExtensionPacked(numerators) => {
116+
let denominator_indexes_packed = PFPacking::<EF>::pack_slice(denominator_indexes);
117+
(claim, up_layer_eval_left, up_layer_eval_right) =
118+
prove_gkr_quotient_step_packed_first_round(
119+
prover_state,
120+
numerators,
121+
(EFPacking::<EF>::from(c), denominator_indexes_packed),
122+
&claim,
123+
Some(n_non_zeros_numerator),
124+
);
125+
}
126+
MleRef::Extension(numerators) => {
127+
let mut layer = EF::zero_vec(numerators.len() * 2);
128+
layer[..numerators.len()].copy_from_slice(numerators);
129+
for i in 0..denominator_indexes.len() {
130+
layer[numerators.len() + i] = c - denominator_indexes[i];
131+
}
132+
(claim, up_layer_eval_left, up_layer_eval_right) =
133+
prove_gkr_quotient_step(prover_state, &layer, &claim);
134+
}
135+
_ => unreachable!(),
136+
}
99137

100138
(claim, up_layer_eval_left, up_layer_eval_right)
101139
}
@@ -474,7 +512,7 @@ where
474512
let mid_len_packed = len_packed / 2;
475513
let quarter_len_packed = mid_len_packed / 2;
476514

477-
let mut eq_poly_packed = eval_eq_packed(&claim.point.0[1..]);
515+
let eq_poly_packed = eval_eq_packed(&claim.point.0[1..]);
478516

479517
let up_layer_octics = split_at_many(
480518
up_layer_packed,
@@ -613,7 +651,6 @@ where
613651
let sumcheck_challenge_2 = prover_state.sample();
614652
let sum_2 = sumcheck_polynomial_2.evaluate(sumcheck_challenge_2);
615653

616-
eq_poly_packed.resize(eq_poly_packed.len() / 4, Default::default());
617654
missing_mul_factor *= ((EF::ONE - claim.point[1]) * (EF::ONE - sumcheck_challenge_2)
618655
+ claim.point[1] * sumcheck_challenge_2)
619656
/ (EF::ONE - claim.point.get(2).copied().unwrap_or_default());
@@ -631,7 +668,7 @@ where
631668
&[],
632669
Some((
633670
claim.point.0[2..].to_vec(),
634-
Some(MleOwned::ExtensionPacked(eq_poly_packed)),
671+
Some(MleOwned::ExtensionPacked(eq_poly_packed).halve().halve()),
635672
)),
636673
false,
637674
prover_state,
@@ -854,7 +891,7 @@ mod tests {
854891

855892
let _ = prove_gkr_quotient(
856893
&mut prover_state,
857-
&pack_extension(&numerators),
894+
&MleRef::ExtensionPacked(&pack_extension(&numerators)),
858895
(c, &denominators_indexes),
859896
None,
860897
);

0 commit comments

Comments
 (0)