Skip to content

Commit d03c94c

Browse files
feat: simplify QuotientCommitter trait for performance (#30)
* wip: simplify QuotientCommitter trait * updated coordinator * feat: simplify `QuotientCommitter` trait for performance Just have one function for quotient poly eval and commit. Importantly we update the CPU implementation to avoid additional memory allocation when creating a **view** of the trace LDE matrices necessary for quotient poly evaluation. * chore: remove metric from doc
1 parent 1e53b82 commit d03c94c

File tree

7 files changed

+182
-232
lines changed

7 files changed

+182
-232
lines changed

Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[workspace.package]
2-
version = "0.2.0-alpha"
2+
version = "0.3.0-alpha"
33
edition = "2021"
44
rust-version = "1.82"
55
authors = ["OpenVM contributors"]

crates/stark-backend/src/prover/coordinator.rs

+20-125
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,8 @@ use p3_util::log2_strict_usize;
77
use tracing::instrument;
88

99
use super::{
10-
hal::{ProverBackend, ProverDevice, QuotientCommitter},
11-
types::{
12-
DeviceMultiStarkProvingKey, HalProof, ProvingContext, RapSinglePhaseView, RapView,
13-
SingleCommitPreimage,
14-
},
10+
hal::{ProverBackend, ProverDevice},
11+
types::{DeviceMultiStarkProvingKey, HalProof, ProvingContext, SingleCommitPreimage},
1512
Prover,
1613
};
1714
use crate::{
@@ -176,6 +173,10 @@ where
176173
&mpk.per_air,
177174
pair_trace_view_per_air,
178175
);
176+
// Challenger observes additional commitments if any exist:
177+
for (commit, _) in &prover_data_after.committed_pcs_data_per_phase {
178+
self.challenger.observe(commit.clone());
179+
}
179180

180181
// Collect exposed_values_per_air for the proof:
181182
// - transpose per_phase, per_air -> per_air, per_phase
@@ -205,47 +206,31 @@ where
205206
})
206207
.collect_vec();
207208

208-
let (commitments_after, pcs_data_after): (Vec<_>, Vec<_>) = prover_data_after
209-
.committed_pcs_data_per_phase
210-
.into_iter()
211-
.unzip();
212-
// Challenger observes additional commitments if any exist:
213-
for commit in &commitments_after {
214-
self.challenger.observe(commit.clone());
215-
}
216-
217209
// ==================== Quotient polynomial computation and commitment, if any ====================
218210
// Note[jpw]: Currently we always call this step, we could add a flag to skip it for protocols that
219211
// do not require quotient poly.
220-
let extended_rap_views = metrics_span("quotient_extended_view_time_ms", || {
221-
create_trace_view_per_air(
222-
&self.device,
223-
mpk,
224-
&log_trace_height_per_air,
225-
&cached_views_per_air,
226-
&common_main_pcs_data,
227-
&pvs_per_air,
228-
&pcs_data_after,
229-
prover_data_after.rap_views_per_phase,
230-
)
231-
});
232-
let (constraints, quotient_degrees): (Vec<_>, Vec<_>) = mpk
233-
.vk_view()
234-
.per_air
235-
.iter()
236-
.map(|vk| (&vk.symbolic_constraints.constraints, vk.quotient_degree))
237-
.unzip();
238212
let (quotient_commit, quotient_data) = self.device.eval_and_commit_quotient(
239213
&mut self.challenger,
240-
&constraints,
241-
extended_rap_views,
242-
&quotient_degrees,
214+
&mpk.per_air,
215+
&pvs_per_air,
216+
&cached_views_per_air,
217+
&common_main_pcs_data,
218+
&prover_data_after,
243219
);
244220
// Observe quotient commitment
245221
self.challenger.observe(quotient_commit.clone());
246222

223+
let (commitments_after, pcs_data_after): (Vec<_>, Vec<_>) = prover_data_after
224+
.committed_pcs_data_per_phase
225+
.into_iter()
226+
.unzip();
247227
// ==================== Polynomial Opening Proofs ====================
248228
let opening = metrics_span("pcs_opening_time_ms", || {
229+
let quotient_degrees = mpk
230+
.per_air
231+
.iter()
232+
.map(|pk| pk.vk.quotient_degree)
233+
.collect_vec();
249234
let preprocessed = mpk
250235
.per_air
251236
.iter()
@@ -326,93 +311,3 @@ impl<'a, PB: ProverBackend> DeviceMultiStarkProvingKey<'a, PB> {
326311
MultiStarkVerifyingKeyView::new(self.per_air.iter().map(|pk| pk.vk).collect())
327312
}
328313
}
329-
330-
/// Takes in views of pcs data and returns extended views of all matrices evaluated on quotient domains
331-
/// for quotient poly calculation.
332-
#[allow(clippy::too_many_arguments)]
333-
fn create_trace_view_per_air<PB: ProverBackend>(
334-
device: &impl QuotientCommitter<PB>,
335-
mpk: &DeviceMultiStarkProvingKey<PB>,
336-
log_trace_height_per_air: &[u8],
337-
cached_views_per_air: &[Vec<SingleCommitPreimage<&PB::Matrix, &PB::PcsData>>],
338-
common_main_pcs_data: &PB::PcsData,
339-
pvs_per_air: &[Vec<PB::Val>],
340-
pcs_data_per_phase: &[PB::PcsData],
341-
rap_views_per_phase: Vec<Vec<RapSinglePhaseView<usize, PB::Challenge>>>,
342-
) -> Vec<RapView<PB::Matrix, PB::Val, PB::Challenge>> {
343-
let mut common_main_idx = 0;
344-
izip!(
345-
&mpk.per_air,
346-
log_trace_height_per_air,
347-
cached_views_per_air,
348-
pvs_per_air
349-
)
350-
.enumerate()
351-
.map(|(i, (pk, &log_trace_height, cached_views, pvs))| {
352-
let quotient_degree = pk.vk.quotient_degree;
353-
// The AIR will be treated as the full RAP with virtual columns after this
354-
let preprocessed = pk.preprocessed_data.as_ref().map(|cv| {
355-
device
356-
.get_extended_matrix(&cv.data, cv.matrix_idx as usize, quotient_degree)
357-
.unwrap()
358-
});
359-
let mut partitioned_main: Vec<_> = cached_views
360-
.iter()
361-
.map(|cv| {
362-
device
363-
.get_extended_matrix(cv.data, cv.matrix_idx as usize, quotient_degree)
364-
.unwrap()
365-
})
366-
.collect();
367-
if pk.vk.has_common_main() {
368-
partitioned_main.push(
369-
device
370-
.get_extended_matrix(common_main_pcs_data, common_main_idx, quotient_degree)
371-
.unwrap_or_else(|| {
372-
panic!("common main commitment could not get matrix_idx={common_main_idx}")
373-
}),
374-
);
375-
common_main_idx += 1;
376-
}
377-
let pair = PairView {
378-
log_trace_height,
379-
preprocessed,
380-
partitioned_main,
381-
public_values: pvs.to_vec(),
382-
};
383-
let mut per_phase = pcs_data_per_phase
384-
.iter()
385-
.zip_eq(&rap_views_per_phase)
386-
.map(
387-
|(pcs_data, rap_views)| -> Option<RapSinglePhaseView<PB::Matrix, PB::Challenge>> {
388-
let rap_view = rap_views.get(i)?;
389-
let matrix_idx = rap_view.inner?;
390-
let extended_matrix =
391-
device.get_extended_matrix(pcs_data, matrix_idx, quotient_degree);
392-
let extended_matrix = extended_matrix.unwrap_or_else(|| {
393-
panic!("could not get matrix_idx={matrix_idx} for rap {i}")
394-
});
395-
Some(RapSinglePhaseView {
396-
inner: Some(extended_matrix),
397-
challenges: rap_view.challenges.clone(),
398-
exposed_values: rap_view.exposed_values.clone(),
399-
})
400-
},
401-
)
402-
.collect_vec();
403-
while let Some(last) = per_phase.last() {
404-
if last.is_none() {
405-
per_phase.pop();
406-
} else {
407-
break;
408-
}
409-
}
410-
let per_phase = per_phase
411-
.into_iter()
412-
.map(|v| v.unwrap_or_default())
413-
.collect();
414-
415-
RapView { pair, per_phase }
416-
})
417-
.collect()
418-
}

crates/stark-backend/src/prover/cpu/mod.rs

+98-29
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
use std::{marker::PhantomData, ops::Deref, sync::Arc};
1+
use std::{iter::zip, marker::PhantomData, ops::Deref, sync::Arc};
22

33
use derivative::Derivative;
4-
use itertools::{zip_eq, Itertools};
4+
use itertools::{izip, zip_eq, Itertools};
55
use opener::OpeningProver;
66
use p3_challenger::FieldChallenger;
77
use p3_commit::{Pcs, PolynomialSpace};
@@ -18,7 +18,7 @@ use super::{
1818
},
1919
};
2020
use crate::{
21-
air_builders::symbolic::{SymbolicConstraints, SymbolicExpressionDag},
21+
air_builders::symbolic::SymbolicConstraints,
2222
config::{
2323
Com, PcsProof, PcsProverData, RapPhaseSeqPartialProof, RapPhaseSeqProvingKey,
2424
StarkGenericConfig, Val,
@@ -231,43 +231,112 @@ impl<SC: StarkGenericConfig> hal::RapPartialProver<CpuBackend<SC>> for CpuDevice
231231
}
232232
}
233233

234-
type RapMatrixView<SC> =
235-
RapView<Arc<RowMajorMatrix<Val<SC>>>, Val<SC>, <SC as StarkGenericConfig>::Challenge>;
236-
237234
impl<SC: StarkGenericConfig> hal::QuotientCommitter<CpuBackend<SC>> for CpuDevice<'_, SC> {
238-
fn get_extended_matrix(
239-
&self,
240-
view: &PcsData<SC>,
241-
matrix_idx: usize,
242-
quotient_degree: u8,
243-
) -> Option<Arc<RowMajorMatrix<Val<SC>>>> {
244-
let pcs = self.pcs();
245-
let log_trace_height = *view.log_trace_heights.get(matrix_idx)?;
246-
let trace_domain = pcs.natural_domain_for_degree(1usize << log_trace_height);
247-
let quotient_domain =
248-
trace_domain.create_disjoint_domain(trace_domain.size() * quotient_degree as usize);
249-
// NOTE[jpw]: (perf) this clones under the hood!
250-
let lde_matrix = self
251-
.pcs()
252-
.get_evaluations_on_domain(&view.data, matrix_idx, quotient_domain)
253-
.to_row_major_matrix();
254-
Some(Arc::new(lde_matrix))
255-
}
256-
257235
fn eval_and_commit_quotient(
258236
&self,
259237
challenger: &mut SC::Challenger,
260-
constraints: &[&SymbolicExpressionDag<Val<SC>>],
261-
extended_views: Vec<RapMatrixView<SC>>,
262-
quotient_degrees: &[u8],
238+
pk_views: &[DeviceStarkProvingKey<CpuBackend<SC>>],
239+
public_values: &[Vec<Val<SC>>],
240+
cached_views_per_air: &[Vec<
241+
SingleCommitPreimage<&Arc<RowMajorMatrix<Val<SC>>>, &PcsData<SC>>,
242+
>],
243+
common_main_pcs_data: &PcsData<SC>,
244+
prover_data_after: &ProverDataAfterRapPhases<CpuBackend<SC>>,
263245
) -> (Com<SC>, PcsData<SC>) {
246+
let pcs = self.pcs();
264247
// Generate `alpha` challenge
265248
let alpha: SC::Challenge = challenger.sample_ext_element();
266249
tracing::debug!("alpha: {alpha:?}");
250+
// Prepare extended views:
251+
let mut common_main_idx = 0;
252+
let extended_views = izip!(pk_views, cached_views_per_air, public_values)
253+
.enumerate()
254+
.map(|(i, (pk, cached_views, pvs))| {
255+
let quotient_degree = pk.vk.quotient_degree;
256+
let log_trace_height = if pk.vk.has_common_main() {
257+
common_main_pcs_data.log_trace_heights[common_main_idx]
258+
} else {
259+
log2_strict_usize(cached_views[0].trace.height()) as u8
260+
};
261+
let trace_domain = pcs.natural_domain_for_degree(1usize << log_trace_height);
262+
let quotient_domain = trace_domain
263+
.create_disjoint_domain(trace_domain.size() * quotient_degree as usize);
264+
// **IMPORTANT**: the return type of `get_evaluations_on_domain` is a matrix view. DO NOT call to_row_major_matrix as this will allocate new memory
265+
let preprocessed = pk.preprocessed_data.as_ref().map(|cv| {
266+
pcs.get_evaluations_on_domain(
267+
&cv.data.data,
268+
cv.matrix_idx as usize,
269+
quotient_domain,
270+
)
271+
});
272+
let mut partitioned_main: Vec<_> = cached_views
273+
.iter()
274+
.map(|cv| {
275+
pcs.get_evaluations_on_domain(
276+
&cv.data.data,
277+
cv.matrix_idx as usize,
278+
quotient_domain,
279+
)
280+
})
281+
.collect();
282+
if pk.vk.has_common_main() {
283+
partitioned_main.push(pcs.get_evaluations_on_domain(
284+
&common_main_pcs_data.data,
285+
common_main_idx,
286+
quotient_domain,
287+
));
288+
common_main_idx += 1;
289+
}
290+
let pair = PairView {
291+
log_trace_height,
292+
preprocessed,
293+
partitioned_main,
294+
public_values: pvs.to_vec(),
295+
};
296+
let mut per_phase = zip(
297+
&prover_data_after.committed_pcs_data_per_phase,
298+
&prover_data_after.rap_views_per_phase,
299+
)
300+
.map(|((_, pcs_data), rap_views)| -> Option<_> {
301+
let rap_view = rap_views.get(i)?;
302+
let matrix_idx = rap_view.inner?;
303+
let extended_matrix =
304+
pcs.get_evaluations_on_domain(&pcs_data.data, matrix_idx, quotient_domain);
305+
Some(RapSinglePhaseView {
306+
inner: Some(extended_matrix),
307+
challenges: rap_view.challenges.clone(),
308+
exposed_values: rap_view.exposed_values.clone(),
309+
})
310+
})
311+
.collect_vec();
312+
while let Some(last) = per_phase.last() {
313+
if last.is_none() {
314+
per_phase.pop();
315+
} else {
316+
break;
317+
}
318+
}
319+
let per_phase = per_phase
320+
.into_iter()
321+
.map(|v| v.unwrap_or_default())
322+
.collect();
323+
324+
RapView { pair, per_phase }
325+
})
326+
.collect_vec();
267327

328+
let (constraints, quotient_degrees): (Vec<_>, Vec<_>) = pk_views
329+
.iter()
330+
.map(|pk| {
331+
(
332+
&pk.vk.symbolic_constraints.constraints,
333+
pk.vk.quotient_degree,
334+
)
335+
})
336+
.unzip();
268337
let qc = QuotientCommitter::new(self.pcs(), alpha);
269338
let quotient_values = metrics_span("quotient_poly_compute_time_ms", || {
270-
qc.quotient_values(constraints, extended_views, quotient_degrees)
339+
qc.quotient_values(&constraints, extended_views, &quotient_degrees)
271340
});
272341

273342
// Commit to quotient polynomials. One shared commit for all quotient polynomials

0 commit comments

Comments
 (0)