Skip to content

Commit 434455e

Browse files
VitaliiHVitaliiH
authored andcommitted
par map
1 parent 0825b87 commit 434455e

File tree

5 files changed

+50
-6
lines changed

5 files changed

+50
-6
lines changed

crates/prover/src/constraint_framework/component.rs

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -286,11 +286,22 @@ impl<E: FrameworkEval + Sync> ComponentProver<SimdBackend> for FrameworkComponen
286286
// Extend trace if necessary.
287287
// TODO: Don't extend when eval_size < committed_size. Instead, pick a good
288288
// subdomain. (For larger blowup factors).
289-
nvtx_timed!("extend trace");
289+
nvtx_timed!("need_to_extend");
290+
#[cfg(not(feature = "parallel"))]
290291
let need_to_extend = component_evals
291292
.iter()
292293
.flatten()
293294
.any(|c| c.domain != eval_domain);
295+
nvtx_timed_pop!();
296+
297+
#[cfg(feature = "parallel")]
298+
let need_to_extend = component_evals
299+
.par_iter()
300+
.flat_map_iter(|v| v.iter())
301+
.any(|c| c.domain != eval_domain);
302+
nvtx_timed_pop!();
303+
304+
nvtx_timed!("SIMD extend trace");
294305
let trace: TreeVec<
295306
Vec<Cow<'_, CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>>>,
296307
> = if need_to_extend {
@@ -578,13 +589,20 @@ impl<E: FrameworkEval + Sync> ComponentProver<IcicleBackend> for FrameworkCompon
578589
// Extend trace if necessary.
579590
// TODO: Don't extend when eval_size < committed_size. Instead, pick a good
580591
// subdomain. (For larger blowup factors).
581-
nvtx_timed!("extend trace");
592+
nvtx_timed!("need_to_extend");
593+
#[cfg(not(feature = "parallel"))]
582594
let need_to_extend = component_evals
583595
.iter()
584596
.flatten()
585597
.any(|c| c.domain != eval_domain);
586598
nvtx_timed_pop!();
587599

600+
#[cfg(feature = "parallel")]
601+
let need_to_extend = component_evals
602+
.par_iter()
603+
.flat_map_iter(|v| v.iter())
604+
.any(|c| c.domain != eval_domain);
605+
588606
// Denom inverses.
589607
nvtx_timed!("denom inverses");
590608
let log_expand = eval_domain.log_size() - trace_domain.log_size();

crates/prover/src/core/backend/cpu/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ pub fn bit_reverse<T>(v: &mut [T]) {
4848
}
4949
}
5050

51-
impl<T: Debug + Clone + Default> ColumnOps<T> for CpuBackend {
51+
impl<T: Debug + Clone + Default + Sync> ColumnOps<T> for CpuBackend {
5252
type Column = Vec<T>;
5353

5454
fn bit_reverse_column(column: &mut Self::Column) {
@@ -64,7 +64,7 @@ impl<F: Field> FieldOps<F> for CpuBackend {
6464
}
6565
}
6666

67-
impl<T: Debug + Clone + Default> Column<T> for Vec<T> {
67+
impl<T: Debug + Clone + Default + Sync> Column<T> for Vec<T> {
6868
fn zeros(len: usize) -> Self {
6969
vec![T::default(); len]
7070
}

crates/prover/src/core/backend/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ pub trait ColumnOps<T> {
4848
pub type Col<B, T> = <B as ColumnOps<T>>::Column;
4949

5050
// TODO(alont): Consider removing the generic parameter and only support BaseField.
51-
pub trait Column<T>: Clone + Debug + FromIterator<T> {
51+
pub trait Column<T>: Clone + Debug + FromIterator<T> + Sync {
5252
/// Creates a new column of zeros with the given length.
5353
fn zeros(len: usize) -> Self;
5454
/// Creates a new column of uninitialized values with the given length.

crates/prover/src/core/pcs/prover.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ pub struct CommitmentSchemeProver<'a, B: BackendForChannel<MC>, MC: MerkleChanne
3030
pub config: PcsConfig,
3131
twiddles: &'a TwiddleTree<B>,
3232
}
33+
use crate::core::backend::ColumnOps;
3334

3435
impl<'a, B: BackendForChannel<MC>, MC: MerkleChannel> CommitmentSchemeProver<'a, B, MC> {
3536
pub fn new(config: PcsConfig, twiddles: &'a TwiddleTree<B>) -> Self {
@@ -82,7 +83,6 @@ impl<'a, B: BackendForChannel<MC>, MC: MerkleChannel> CommitmentSchemeProver<'a,
8283
let evals = self.evaluations();
8384
Trace { polys, evals }
8485
}
85-
8686
pub fn prove_values(
8787
self,
8888
sampled_points: TreeVec<ColumnVec<Vec<CirclePoint<SecureField>>>>,

crates/prover/src/core/pcs/utils.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,11 @@ impl<T> Default for TreeVec<T> {
6767
}
6868
}
6969

70+
#[cfg(feature = "parallel")]
71+
use rayon::prelude::*;
72+
7073
impl<T> TreeVec<ColumnVec<T>> {
74+
#[cfg(not(feature = "parallel"))]
7175
pub fn map_cols<U, F: FnMut(T) -> U>(self, mut f: F) -> TreeVec<ColumnVec<U>> {
7276
TreeVec(
7377
self.0
@@ -77,6 +81,28 @@ impl<T> TreeVec<ColumnVec<T>> {
7781
)
7882
}
7983

84+
85+
#[cfg(feature = "parallel")]
86+
pub fn map_cols<U, F>(self, f: F) -> TreeVec<ColumnVec<U>>
87+
where
88+
F: Fn(T) -> U + Send + Sync + Clone,
89+
T: Send,
90+
U: Send,
91+
{
92+
TreeVec(
93+
self.0
94+
.into_par_iter()
95+
.map(|column| {
96+
let f = f.clone();
97+
column
98+
.into_par_iter()
99+
.map(f)
100+
.collect()
101+
})
102+
.collect(),
103+
)
104+
}
105+
80106
/// Zips two [`TreeVec<ColumVec<T>>`] with the same structure (number of columns in each tree).
81107
/// The resulting [`TreeVec<ColumVec<T>>`] has the same structure, with each value being a tuple
82108
/// of the corresponding values from the input [`TreeVec<ColumVec<T>>`].

0 commit comments

Comments
 (0)