Skip to content

Commit 6d43a2d

Browse files
committed
Merge branch 'batch-memory-lookup-and-bus-consistency-GKR' into lean-vm-simple
2 parents c0d0bfb + eedbb94 commit 6d43a2d

File tree

9 files changed

+540
-329
lines changed

9 files changed

+540
-329
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ The full recursion program is not finished yet. Instead, we prove validity of a
4747
- n-to-1: Recursive proof of many WHIR openings (≈ 8) (we report prover time per WHIR)
4848

4949
```console
50-
RUSTFLAGS='-C target-cpu=native' cargo run --release -- recursion --count 12
50+
RUSTFLAGS='-C target-cpu=native' cargo run --release -- recursion --count 18
5151
```
5252

5353
![Alt text](docs/benchmark_graphs/graphs/recursive_whir_opening.svg)

crates/lean_prover/src/prove_execution.rs

Lines changed: 98 additions & 154 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use crate::*;
55
use air::prove_air;
66
use itertools::Itertools;
77
use lean_vm::*;
8-
use lookup::{compute_pushforward, prove_gkr_quotient, prove_logup_star};
8+
use lookup::{compute_pushforward, prove_logup_star};
99
use multilinear_toolkit::prelude::*;
1010

1111
use p3_util::log2_ceil_usize;
@@ -132,39 +132,41 @@ pub fn prove_execution(
132132
let bus_challenge = prover_state.sample();
133133
let fingerprint_challenge = prover_state.sample();
134134

135-
let mut bus_quotients: BTreeMap<Table, EF> = Default::default();
136-
let mut air_points: BTreeMap<Table, MultilinearPoint<EF>> = Default::default();
137-
let mut evals_f: BTreeMap<Table, Vec<EF>> = Default::default();
138-
let mut evals_ef: BTreeMap<Table, Vec<EF>> = Default::default();
139-
135+
let mut bus_numerators = vec![];
136+
let mut bus_denominators = vec![];
140137
for (table, trace) in &traces {
141-
let (this_bus_quotient, this_air_point, this_evals_f, this_evals_ef) =
142-
prove_bus_and_air(&mut prover_state, table, trace, bus_challenge, fingerprint_challenge);
143-
bus_quotients.insert(*table, this_bus_quotient);
144-
air_points.insert(*table, this_air_point);
145-
evals_f.insert(*table, this_evals_f);
146-
evals_ef.insert(*table, this_evals_ef);
138+
for bus in table.buses() {
139+
let numerator = trace.base[bus.selector]
140+
.par_iter()
141+
.map(|&selector| match bus.direction {
142+
BusDirection::Pull => -selector,
143+
BusDirection::Push => selector,
144+
})
145+
.collect::<Vec<_>>();
146+
let denominator = (0..trace.n_rows_padded())
147+
.into_par_iter()
148+
.map(|i| {
149+
bus_challenge
150+
+ finger_print(
151+
match &bus.table {
152+
BusTable::Constant(table) => table.embed(),
153+
BusTable::Variable(col) => trace.base[*col][i],
154+
},
155+
bus.data
156+
.iter()
157+
.map(|col| trace.base[*col][i])
158+
.collect::<Vec<_>>()
159+
.as_slice(),
160+
fingerprint_challenge,
161+
)
162+
})
163+
.collect::<Vec<_>>();
164+
165+
bus_numerators.push(numerator);
166+
bus_denominators.push(denominator);
167+
}
147168
}
148169

149-
assert_eq!(bus_quotients.values().copied().sum::<EF>(), EF::ZERO);
150-
151-
let bytecode_compression_challenges =
152-
MultilinearPoint(prover_state.sample_vec(log2_ceil_usize(N_INSTRUCTION_COLUMNS)));
153-
154-
let folded_bytecode = fold_bytecode(bytecode, &bytecode_compression_challenges);
155-
156-
let bytecode_lookup_claim = Evaluation::new(
157-
air_points[&Table::execution()].clone(),
158-
padd_with_zero_to_next_power_of_two(&evals_f[&Table::execution()][..N_INSTRUCTION_COLUMNS])
159-
.evaluate(&bytecode_compression_challenges),
160-
);
161-
let bytecode_poly_eq_point = eval_eq(&air_points[&Table::execution()]);
162-
let bytecode_pushforward = MleOwned::Extension(compute_pushforward(
163-
&traces[&Table::execution()].base[COL_INDEX_PC],
164-
folded_bytecode.len(),
165-
&bytecode_poly_eq_point,
166-
));
167-
168170
let mut lookup_into_memory = CustomLookupProver::run::<EF, DIMENSION, VECTOR_LEN>(
169171
&mut prover_state,
170172
&memory,
@@ -193,7 +195,53 @@ pub fn prove_execution(
193195
.iter()
194196
.flat_map(|(table, trace)| table.vector_lookup_values_columns(trace))
195197
.collect(),
198+
collect_refs(&bus_numerators),
199+
collect_refs(&bus_denominators),
200+
UNIVARIATE_SKIPS,
201+
);
202+
203+
let mut air_points: BTreeMap<Table, MultilinearPoint<EF>> = Default::default();
204+
let mut evals_f: BTreeMap<Table, Vec<EF>> = Default::default();
205+
let mut evals_ef: BTreeMap<Table, Vec<EF>> = Default::default();
206+
207+
let mut bus_offset = 0;
208+
for (table, trace) in &traces {
209+
let (this_air_point, this_evals_f, this_evals_ef) = prove_bus_and_air(
210+
&mut prover_state,
211+
table,
212+
trace,
213+
bus_challenge,
214+
fingerprint_challenge,
215+
&lookup_into_memory.on_bus_numerators[bus_offset..][..table.buses().len()],
216+
&lookup_into_memory.on_bus_denominators[bus_offset..][..table.buses().len()],
217+
);
218+
air_points.insert(*table, this_air_point);
219+
evals_f.insert(*table, this_evals_f);
220+
evals_ef.insert(*table, this_evals_ef);
221+
bus_offset += table.buses().len();
222+
}
223+
assert_eq_many!(
224+
bus_offset,
225+
lookup_into_memory.on_bus_numerators.len(),
226+
lookup_into_memory.on_bus_denominators.len()
227+
);
228+
229+
let bytecode_compression_challenges =
230+
MultilinearPoint(prover_state.sample_vec(log2_ceil_usize(N_INSTRUCTION_COLUMNS)));
231+
232+
let folded_bytecode = fold_bytecode(bytecode, &bytecode_compression_challenges);
233+
234+
let bytecode_lookup_claim = Evaluation::new(
235+
air_points[&Table::execution()].clone(),
236+
padd_with_zero_to_next_power_of_two(&evals_f[&Table::execution()][..N_INSTRUCTION_COLUMNS])
237+
.evaluate(&bytecode_compression_challenges),
196238
);
239+
let bytecode_poly_eq_point = eval_eq(&air_points[&Table::execution()]);
240+
let bytecode_pushforward = MleOwned::Extension(compute_pushforward(
241+
&traces[&Table::execution()].base[COL_INDEX_PC],
242+
folded_bytecode.len(),
243+
&bytecode_poly_eq_point,
244+
));
197245

198246
let bytecode_pushforward_commitment =
199247
WhirConfig::new(whir_config_builder_b(), log2_ceil_usize(bytecode.instructions.len()))
@@ -294,135 +342,31 @@ fn prove_bus_and_air(
294342
trace: &TableTrace,
295343
bus_challenge: EF,
296344
fingerprint_challenge: EF,
297-
) -> (EF, MultilinearPoint<EF>, Vec<EF>, Vec<EF>) {
298-
let n_buses = t.buses().len();
299-
let n_buses_padded = n_buses.next_power_of_two();
300-
let log_n_buses = log2_ceil_usize(n_buses);
301-
let n_rows = trace.n_rows_padded();
302-
let log_n_rows = trace.log_padded();
303-
304-
assert!(n_buses > 0, "Table {} has no buses", t.name());
305-
306-
let mut numerators = F::zero_vec(n_buses_padded * n_rows);
307-
for (bus, numerators_chunk) in t.buses().iter().zip(numerators.chunks_mut(n_rows)) {
308-
assert!(bus.selector < trace.base.len());
309-
trace.base[bus.selector]
310-
.par_iter()
311-
.zip(numerators_chunk)
312-
.for_each(|(&selector, v)| {
313-
*v = match bus.direction {
314-
BusDirection::Pull => -selector,
315-
BusDirection::Push => selector,
316-
}
317-
});
318-
}
319-
320-
let mut denominators = unsafe { uninitialized_vec(n_buses_padded * n_rows) };
321-
for (bus, denomniators_chunk) in t.buses().iter().zip(denominators.chunks_exact_mut(n_rows)) {
322-
denomniators_chunk.par_iter_mut().enumerate().for_each(|(i, v)| {
323-
*v = bus_challenge
324-
+ finger_print(
325-
match &bus.table {
326-
BusTable::Constant(table) => table.embed(),
327-
BusTable::Variable(col) => trace.base[*col][i],
328-
},
329-
bus.data
330-
.iter()
331-
.map(|col| trace.base[*col][i])
332-
.collect::<Vec<_>>()
333-
.as_slice(),
334-
fingerprint_challenge,
335-
);
336-
});
337-
}
338-
denominators[n_rows * n_buses..]
339-
.par_iter_mut()
340-
.for_each(|v| *v = EF::ONE);
341-
342-
// TODO avoid embedding !!
343-
let numerators_embedded = numerators.par_iter().copied().map(EF::from).collect::<Vec<_>>();
344-
345-
// TODO avoid reallocation due to packing (pack directly when constructing)
346-
let numerators_packed = pack_extension(&numerators_embedded);
347-
let denominators_packed = pack_extension(&denominators);
348-
let (quotient, bus_point_global, numerator_value_global, denominator_value_global) =
349-
prove_gkr_quotient::<_, TWO_POW_UNIVARIATE_SKIPS>(
350-
prover_state,
351-
&MleGroupRef::ExtensionPacked(vec![&numerators_packed, &denominators_packed]),
352-
);
353-
354-
let (bus_point, bus_selector_values, bus_data_values) = if n_buses == 1 {
355-
// easy case
356-
(
357-
bus_point_global,
358-
vec![numerator_value_global],
359-
vec![denominator_value_global],
360-
)
361-
} else {
362-
let uni_selectors = univariate_selectors::<F>(UNIVARIATE_SKIPS);
363-
364-
let sub_numerators_evals = numerators
365-
.par_chunks_exact(1 << (log_n_rows - UNIVARIATE_SKIPS))
366-
.take(n_buses << UNIVARIATE_SKIPS)
367-
.map(|chunk| chunk.evaluate(&MultilinearPoint(bus_point_global[1 + log_n_buses..].to_vec())))
368-
.collect::<Vec<_>>();
369-
prover_state.add_extension_scalars(&sub_numerators_evals);
370-
// sanity check:
371-
assert_eq!(
372-
numerator_value_global,
373-
evaluate_univariate_multilinear::<_, _, _, false>(
374-
&padd_with_zero_to_next_power_of_two(&sub_numerators_evals),
375-
&bus_point_global[..1 + log_n_buses],
376-
&uni_selectors,
377-
None
378-
),
379-
);
380-
381-
let sub_denominators_evals = denominators
382-
.par_chunks_exact(1 << (log_n_rows - UNIVARIATE_SKIPS))
383-
.take(n_buses << UNIVARIATE_SKIPS)
384-
.map(|chunk| chunk.evaluate(&MultilinearPoint(bus_point_global[1 + log_n_buses..].to_vec())))
385-
.collect::<Vec<_>>();
386-
prover_state.add_extension_scalars(&sub_denominators_evals);
387-
// sanity check:
388-
assert_eq!(
389-
denominator_value_global,
390-
evaluate_univariate_multilinear::<_, _, _, false>(
391-
&padd_to_next_power_of_two(&sub_denominators_evals, EF::ONE),
392-
&bus_point_global[..1 + log_n_buses],
393-
&uni_selectors,
394-
None
395-
),
396-
);
397-
398-
let epsilon = prover_state.sample();
399-
let bus_point = MultilinearPoint([vec![epsilon], bus_point_global[1 + log_n_buses..].to_vec()].concat());
400-
401-
let bus_selector_values = sub_numerators_evals
402-
.chunks_exact(1 << UNIVARIATE_SKIPS)
403-
.map(|chunk| evaluate_univariate_multilinear::<_, _, _, false>(chunk, &[epsilon], &uni_selectors, None))
404-
.collect();
405-
let bus_data_values = sub_denominators_evals
406-
.chunks_exact(1 << UNIVARIATE_SKIPS)
407-
.map(|chunk| evaluate_univariate_multilinear::<_, _, _, false>(chunk, &[epsilon], &uni_selectors, None))
408-
.collect();
409-
410-
(bus_point, bus_selector_values, bus_data_values)
411-
};
345+
bus_numerator_statements: &[Evaluation<EF>],
346+
bus_denominator_statements: &[Evaluation<EF>],
347+
) -> (MultilinearPoint<EF>, Vec<EF>, Vec<EF>) {
348+
assert_eq!(t.buses().len(), bus_numerator_statements.len());
349+
let bus_point = bus_numerator_statements[0].point.clone();
350+
assert!(t.buses().iter().all(|_| bus_numerator_statements[0].point == bus_point));
351+
assert!(
352+
t.buses()
353+
.iter()
354+
.all(|_| bus_denominator_statements[0].point == bus_point)
355+
);
412356

413357
let bus_beta = prover_state.sample();
414358

415-
let bus_final_values = bus_selector_values
359+
let bus_final_values = bus_numerator_statements
416360
.iter()
417-
.zip_eq(&bus_data_values)
418-
.zip_eq(&t.buses())
419-
.map(|((&bus_selector_value, &bus_data_value), bus)| {
420-
bus_selector_value
361+
.zip_eq(bus_denominator_statements)
362+
.zip_eq(t.buses())
363+
.map(|((bus_selector_statement, bus_data_statement), bus)| {
364+
bus_selector_statement.value
421365
* match bus.direction {
422366
BusDirection::Pull => EF::NEG_ONE,
423367
BusDirection::Push => EF::ONE,
424368
}
425-
+ bus_beta * (bus_data_value - bus_challenge)
369+
+ bus_beta * (bus_data_statement.value - bus_challenge)
426370
})
427371
.collect::<Vec<_>>();
428372

@@ -438,7 +382,7 @@ fn prove_bus_and_air(
438382
alpha_powers: vec![], // filled later
439383
};
440384

441-
let (air_point, evals_f, evals_ef) = info_span!("Table AIR proof", table = t.name()).in_scope(|| {
385+
let (air_point, evals_f, evals_ef) = info_span!("AIR proof", table = t.name()).in_scope(|| {
442386
macro_rules! prove_air_for_table {
443387
($t:expr) => {
444388
prove_air(
@@ -458,5 +402,5 @@ fn prove_bus_and_air(
458402
delegate_to_inner!(t => prove_air_for_table)
459403
});
460404

461-
(quotient, air_point, evals_f, evals_ef)
405+
(air_point, evals_f, evals_ef)
462406
}

0 commit comments

Comments
 (0)