1- use std:: { iter:: zip, marker:: PhantomData , ops:: Deref , sync:: Arc } ;
1+ use std:: { iter:: zip, marker:: PhantomData , mem :: ManuallyDrop , ops:: Deref , sync:: Arc } ;
22
33use derivative:: Derivative ;
44use itertools:: { izip, zip_eq, Itertools } ;
55use opener:: OpeningProver ;
66use p3_challenger:: FieldChallenger ;
77use p3_commit:: { Pcs , PolynomialSpace } ;
8- use p3_field:: FieldExtensionAlgebra ;
8+ use p3_field:: { ExtensionField , Field , FieldExtensionAlgebra } ;
99use p3_matrix:: { dense:: RowMajorMatrix , Matrix } ;
1010use p3_util:: log2_strict_usize;
1111use quotient:: QuotientCommitter ;
@@ -36,12 +36,23 @@ pub mod opener;
3636pub mod quotient;
3737
3838/// CPU backend using Plonky3 traits.
39+ ///
40+ /// # Safety
41+ /// For performance optimization of extension field operations, we assumes that `SC::Challenge` is
42+ /// an extension field of `F = Val<SC>` that is `repr(C)` or `repr(transparent)` with
43+ /// internal memory layout `[F; SC::Challenge::D]`.
44+ /// This ensures `SC::Challenge` and `F` have the same alignment and
45+ /// `size_of::<SC::Challenge>() == size_of::<F>() * SC::Challenge::D`.
46+ /// We assume that `<SC::Challenge as ExtensionField<F>::as_base_slice` is the same as
47+ /// transmuting `SC::Challenge` to `[F; SC::Challenge::D]`.
3948#[ derive( Derivative ) ]
4049#[ derivative( Clone ( bound = "" ) , Copy ( bound = "" ) , Default ( bound = "" ) ) ]
4150pub struct CpuBackend < SC > {
4251 phantom : PhantomData < SC > ,
4352}
4453
54+ /// # Safety
55+ /// See [`CpuBackend`].
4556#[ derive( Derivative , derive_new:: new) ]
4657#[ derivative( Clone ( bound = "" ) , Copy ( bound = "" ) ) ]
4758pub struct CpuDevice < ' a , SC > {
@@ -65,7 +76,7 @@ impl<SC: StarkGenericConfig> ProverBackend for CpuBackend<SC> {
6576 type RapPartialProvingKey = RapPartialProvingKey < SC > ;
6677}
6778
68- #[ derive( Derivative ) ]
79+ #[ derive( Derivative , derive_new :: new ) ]
6980#[ derivative( Clone ( bound = "" ) ) ]
7081pub struct PcsData < SC : StarkGenericConfig > {
7182 /// The preimage of a single commitment.
@@ -232,16 +243,23 @@ impl<SC: StarkGenericConfig> hal::RapPartialProver<CpuBackend<SC>> for CpuDevice
232243 // One shared commit for all permutation traces
233244 let committed_pcs_data_per_phase: Vec < ( Com < SC > , PcsData < SC > ) > =
234245 metrics_span ( "perm_trace_commit_time_ms" , || {
235- let flattened_traces: Vec < _ > = perm_trace_per_air
246+ let ( log_trace_heights , flattened_traces) : ( Vec < _ > , Vec < _ > ) = perm_trace_per_air
236247 . into_iter ( )
237- . flat_map ( |perm_trace| {
238- perm_trace. map ( |trace| Arc :: new ( trace. flatten_to_base ( ) ) )
248+ . flatten ( )
249+ . map ( |perm_trace| {
250+ // SAFETY: `Challenge` is assumed to be extension field of `F`
251+ // with memory layout `[F; Challenge::D]`
252+ let trace = unsafe { transmute_to_base ( perm_trace) } ;
253+ let height = trace. height ( ) ;
254+ let log_height: u8 = log2_strict_usize ( height) . try_into ( ) . unwrap ( ) ;
255+ let domain = self . pcs ( ) . natural_domain_for_degree ( height) ;
256+ ( log_height, ( domain, trace) )
239257 } )
240258 . collect ( ) ;
241259 // Only commit if there are permutation traces
242260 if !flattened_traces. is_empty ( ) {
243- let ( commit, data) = self . commit ( & flattened_traces) ;
244- Some ( ( commit, data) )
261+ let ( commit, data) = self . pcs ( ) . commit ( flattened_traces) ;
262+ Some ( ( commit, PcsData :: new ( Arc :: new ( data) , log_trace_heights ) ) )
245263 } else {
246264 None
247265 }
@@ -485,3 +503,26 @@ where
485503 data. clone ( )
486504 }
487505}
506+
507+ /// # Safety
508+ /// Assumes that `EF` is `repr(C)` or `repr(transparent)` with internal memory layout `[F; EF::D]`.
509+ /// This ensures `EF` and `F` have the same alignment and `size_of::<EF>() == size_of::<F>() * EF::D`.
510+ /// We assume that `EF::as_base_slice` is the same as transmuting `EF` to `[F; EF::D]`.
511+ unsafe fn transmute_to_base < F : Field , EF : ExtensionField < F > > (
512+ ext_matrix : RowMajorMatrix < EF > ,
513+ ) -> RowMajorMatrix < F > {
514+ let width = ext_matrix. width * EF :: D ;
515+ // Prevent ptr from deallocating
516+ let mut values = ManuallyDrop :: new ( ext_matrix. values ) ;
517+ let mut len = values. len ( ) ;
518+ let mut cap = values. capacity ( ) ;
519+ let ptr = values. as_mut_ptr ( ) ;
520+ len *= EF :: D ;
521+ cap *= EF :: D ;
522+ // SAFETY:
523+ // - We know that `ptr` is from `Vec` so it is allocated by global allocator,
524+ // - Based on assumptions, `T` and `F` have the same alignment
525+ // - Based on memory layout assumptions, length and capacity is correct
526+ let base_values = Vec :: from_raw_parts ( ptr as * mut F , len, cap) ;
527+ RowMajorMatrix :: new ( base_values, width)
528+ }
0 commit comments