Skip to content

Commit 00b88ed

Browse files
committed
perf: transmute_to_base to avoid needless clones
1 parent 4ce80c0 commit 00b88ed

File tree

2 files changed

+54
-10
lines changed

2 files changed

+54
-10
lines changed

crates/stark-backend/src/prover/cpu/mod.rs

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
use std::{iter::zip, marker::PhantomData, ops::Deref, sync::Arc};
1+
use std::{iter::zip, marker::PhantomData, mem::ManuallyDrop, ops::Deref, sync::Arc};
22

33
use derivative::Derivative;
44
use itertools::{izip, zip_eq, Itertools};
55
use opener::OpeningProver;
66
use p3_challenger::FieldChallenger;
77
use p3_commit::{Pcs, PolynomialSpace};
8-
use p3_field::FieldExtensionAlgebra;
8+
use p3_field::{ExtensionField, Field, FieldExtensionAlgebra};
99
use p3_matrix::{dense::RowMajorMatrix, Matrix};
1010
use p3_util::log2_strict_usize;
1111
use quotient::QuotientCommitter;
@@ -36,12 +36,23 @@ pub mod opener;
3636
pub mod quotient;
3737

3838
/// CPU backend using Plonky3 traits.
39+
///
40+
/// # Safety
41+
/// For performance optimization of extension field operations, we assumes that `SC::Challenge` is
42+
/// an extension field of `F = Val<SC>` that is `repr(C)` or `repr(transparent)` with
43+
/// internal memory layout `[F; SC::Challenge::D]`.
44+
/// This ensures `SC::Challenge` and `F` have the same alignment and
45+
/// `size_of::<SC::Challenge>() == size_of::<F>() * SC::Challenge::D`.
46+
/// We assume that `<SC::Challenge as ExtensionField<F>::as_base_slice` is the same as
47+
/// transmuting `SC::Challenge` to `[F; SC::Challenge::D]`.
3948
#[derive(Derivative)]
4049
#[derivative(Clone(bound = ""), Copy(bound = ""), Default(bound = ""))]
4150
pub struct CpuBackend<SC> {
4251
phantom: PhantomData<SC>,
4352
}
4453

54+
/// # Safety
55+
/// See [`CpuBackend`].
4556
#[derive(Derivative, derive_new::new)]
4657
#[derivative(Clone(bound = ""), Copy(bound = ""))]
4758
pub struct CpuDevice<'a, SC> {
@@ -65,7 +76,7 @@ impl<SC: StarkGenericConfig> ProverBackend for CpuBackend<SC> {
6576
type RapPartialProvingKey = RapPartialProvingKey<SC>;
6677
}
6778

68-
#[derive(Derivative)]
79+
#[derive(Derivative, derive_new::new)]
6980
#[derivative(Clone(bound = ""))]
7081
pub struct PcsData<SC: StarkGenericConfig> {
7182
/// The preimage of a single commitment.
@@ -232,16 +243,23 @@ impl<SC: StarkGenericConfig> hal::RapPartialProver<CpuBackend<SC>> for CpuDevice
232243
// One shared commit for all permutation traces
233244
let committed_pcs_data_per_phase: Vec<(Com<SC>, PcsData<SC>)> =
234245
metrics_span("perm_trace_commit_time_ms", || {
235-
let flattened_traces: Vec<_> = perm_trace_per_air
246+
let (log_trace_heights, flattened_traces): (Vec<_>, Vec<_>) = perm_trace_per_air
236247
.into_iter()
237-
.flat_map(|perm_trace| {
238-
perm_trace.map(|trace| Arc::new(trace.flatten_to_base()))
248+
.flatten()
249+
.map(|perm_trace| {
250+
// SAFETY: `Challenge` is assumed to be extension field of `F`
251+
// with memory layout `[F; Challenge::D]`
252+
let trace = unsafe { transmute_to_base(perm_trace) };
253+
let height = trace.height();
254+
let log_height: u8 = log2_strict_usize(height).try_into().unwrap();
255+
let domain = self.pcs().natural_domain_for_degree(height);
256+
(log_height, (domain, trace))
239257
})
240258
.collect();
241259
// Only commit if there are permutation traces
242260
if !flattened_traces.is_empty() {
243-
let (commit, data) = self.commit(&flattened_traces);
244-
Some((commit, data))
261+
let (commit, data) = self.pcs().commit(flattened_traces);
262+
Some((commit, PcsData::new(Arc::new(data), log_trace_heights)))
245263
} else {
246264
None
247265
}
@@ -485,3 +503,26 @@ where
485503
data.clone()
486504
}
487505
}
506+
507+
/// # Safety
508+
/// Assumes that `EF` is `repr(C)` or `repr(transparent)` with internal memory layout `[F; EF::D]`.
509+
/// This ensures `EF` and `F` have the same alignment and `size_of::<EF>() == size_of::<F>() * EF::D`.
510+
/// We assume that `EF::as_base_slice` is the same as transmuting `EF` to `[F; EF::D]`.
511+
unsafe fn transmute_to_base<F: Field, EF: ExtensionField<F>>(
512+
ext_matrix: RowMajorMatrix<EF>,
513+
) -> RowMajorMatrix<F> {
514+
let width = ext_matrix.width * EF::D;
515+
// Prevent ptr from deallocating
516+
let mut values = ManuallyDrop::new(ext_matrix.values);
517+
let mut len = values.len();
518+
let mut cap = values.capacity();
519+
let ptr = values.as_mut_ptr();
520+
len *= EF::D;
521+
cap *= EF::D;
522+
// SAFETY:
523+
// - We know that `ptr` is from `Vec` so it is allocated by global allocator,
524+
// - Based on assumptions, `T` and `F` have the same alignment
525+
// - Based on memory layout assumptions, length and capacity is correct
526+
let base_values = Vec::from_raw_parts(ptr as *mut F, len, cap);
527+
RowMajorMatrix::new(base_values, width)
528+
}

crates/stark-backend/src/prover/cpu/quotient/mod.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use p3_util::log2_strict_usize;
88
use tracing::instrument;
99

1010
use self::single::compute_single_rap_quotient_values;
11-
use super::PcsData;
11+
use super::{transmute_to_base, PcsData};
1212
use crate::{
1313
air_builders::symbolic::SymbolicExpressionDag,
1414
config::{Com, Domain, PackedChallenge, StarkGenericConfig, Val},
@@ -156,7 +156,10 @@ impl<SC: StarkGenericConfig> SingleQuotientData<SC> {
156156
let quotient_degree = self.quotient_degree;
157157
let quotient_domain = self.quotient_domain;
158158
// Flatten from extension field elements to base field elements
159-
let quotient_flat = RowMajorMatrix::new_col(self.quotient_values).flatten_to_base();
159+
// SAFETY: `Challenge` is assumed to be extension field of `F`
160+
// with memory layout `[F; Challenge::D]`
161+
let quotient_flat =
162+
unsafe { transmute_to_base(RowMajorMatrix::new_col(self.quotient_values)) };
160163
let quotient_chunks = quotient_domain.split_evals(quotient_degree, quotient_flat);
161164
let qc_domains = quotient_domain.split_domains(quotient_degree);
162165
qc_domains

0 commit comments

Comments
 (0)