Skip to content

Commit adb906c

Browse files
feat: hardware abstraction layer (#23)
* chore: move debug out of prover (#21) * chore: debug checks don't use perm trace * chore: move debug out of prove * feat: hardware coordinator (#24) * wip: add main prover backend traits * chore: rename MatrixBuffer to MatrixView to reflect metadata * feat: finished coordinator * chore: delete circuit_api * chore: remove unused commit.rs * chore: mv VerifierConstraintFolder out of air_builders since it's not a builder anymore * wip: cpu * cpu prover done * chore * chore * chore: rename Cpu to CpuBackend * feat: data transport trait * fix engine trait * fix lifetimes * fix all tests * chore: rename trait * fix bad regex * chore: make device memory types owned (#28) Even if they are buffers, make them owned types (no Clone) so that when they are dropped, the device memory is automatically deallocated. * revert Arc in RawAirProofInput * chore: update test utils * chore: add assert
1 parent d2788c7 commit adb906c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+2024
-1629
lines changed

Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ tracing = "0.1.40"
8686
serde_json = "1.0.117"
8787
lazy_static = "1.5.0"
8888
once_cell = "1.19.0"
89-
derive-new = "0.6.0"
89+
derive-new = "0.7.0"
9090
derive_more = "1.0.0"
9191
derivative = "2.2.0"
9292
strum_macros = "0.26.4"

crates/stark-backend/Cargo.toml

+3-2
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,10 @@ serde = { workspace = true, default-features = false, features = [
2323
"alloc",
2424
"rc",
2525
] }
26-
derivative = { workspace = true }
26+
derivative.workspace = true
27+
derive-new.workspace = true
2728
metrics = { workspace = true, optional = true }
28-
cfg-if = { workspace = true }
29+
cfg-if.workspace = true
2930
thiserror.workspace = true
3031
async-trait.workspace = true
3132
rustc-hash.workspace = true

crates/stark-backend/src/air_builders/debug/check_constraints.rs

+5-24
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,7 @@ pub fn check_constraints<R, SC>(
2121
rap_name: &str,
2222
preprocessed: &Option<RowMajorMatrixView<Val<SC>>>,
2323
partitioned_main: &[RowMajorMatrixView<Val<SC>>],
24-
after_challenge: &[RowMajorMatrixView<SC::Challenge>],
25-
challenges: &[Vec<SC::Challenge>],
2624
public_values: &[Val<SC>],
27-
exposed_values_after_challenge: &[Vec<SC::Challenge>],
28-
rap_phase_seq_kind: RapPhaseSeqKind,
2925
) where
3026
R: for<'a> Rap<DebugConstraintBuilder<'a, SC>>
3127
+ BaseAir<Val<SC>>
@@ -35,7 +31,6 @@ pub fn check_constraints<R, SC>(
3531
{
3632
let height = partitioned_main[0].height();
3733
assert!(partitioned_main.iter().all(|mat| mat.height() == height));
38-
assert!(after_challenge.iter().all(|mat| mat.height() == height));
3934

4035
// Check that constraints are satisfied.
4136
(0..height).into_par_iter().for_each(|i| {
@@ -65,20 +60,6 @@ pub fn check_constraints<R, SC>(
6560
})
6661
.collect::<Vec<_>>();
6762

68-
let after_challenge_row_pair = after_challenge
69-
.iter()
70-
.map(|mat| (mat.row_slice(i), mat.row_slice(i_next)))
71-
.collect::<Vec<_>>();
72-
let after_challenge = after_challenge_row_pair
73-
.iter()
74-
.map(|(local, next)| {
75-
VerticalPair::new(
76-
RowMajorMatrixView::new_row(local),
77-
RowMajorMatrixView::new_row(next),
78-
)
79-
})
80-
.collect::<Vec<_>>();
81-
8263
let mut builder = DebugConstraintBuilder {
8364
air_name: rap_name,
8465
row_index: i,
@@ -87,14 +68,14 @@ pub fn check_constraints<R, SC>(
8768
RowMajorMatrixView::new_row(preprocessed_next.as_slice()),
8869
),
8970
partitioned_main,
90-
after_challenge,
91-
challenges,
71+
after_challenge: vec![], // unreachable
72+
challenges: &[], // unreachable
9273
public_values,
93-
exposed_values_after_challenge,
74+
exposed_values_after_challenge: &[], // unreachable
9475
is_first_row: Val::<SC>::ZERO,
9576
is_last_row: Val::<SC>::ZERO,
9677
is_transition: Val::<SC>::ONE,
97-
rap_phase_seq_kind,
78+
rap_phase_seq_kind: RapPhaseSeqKind::StarkLogUp, // unused
9879
has_common_main: rap.common_main_width() > 0,
9980
};
10081
if i == 0 {
@@ -111,7 +92,7 @@ pub fn check_constraints<R, SC>(
11192

11293
pub fn check_logup<F: Field>(
11394
air_names: &[String],
114-
interactions: &[&[SymbolicInteraction<F>]],
95+
interactions: &[Vec<SymbolicInteraction<F>>],
11596
preprocessed: &[Option<RowMajorMatrixView<F>>],
11697
partitioned_main: &[Vec<RowMajorMatrixView<F>>],
11798
public_values: &[Vec<F>],

crates/stark-backend/src/air_builders/debug/mod.rs

+62-3
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,79 @@
1+
use std::sync::{Arc, Mutex};
2+
3+
use itertools::{izip, Itertools};
14
use p3_air::{
25
AirBuilder, AirBuilderWithPublicValues, ExtensionBuilder, PairBuilder, PermutationAirBuilder,
36
};
47
use p3_field::FieldAlgebra;
58
use p3_matrix::{dense::RowMajorMatrixView, stack::VerticalPair};
69

7-
use super::{PartitionedAirBuilder, ViewPair};
10+
use super::{symbolic::SymbolicConstraints, PartitionedAirBuilder, ViewPair};
811
use crate::{
912
config::{StarkGenericConfig, Val},
1013
interaction::{
1114
rap::InteractionPhaseAirBuilder, Interaction, InteractionBuilder, InteractionType,
1215
RapPhaseSeqKind,
1316
},
14-
rap::PermutationAirBuilderWithExposedValues,
17+
keygen::types::StarkProvingKey,
18+
rap::{AnyRap, PermutationAirBuilderWithExposedValues},
1519
};
1620

17-
pub mod check_constraints;
21+
mod check_constraints;
22+
23+
use check_constraints::*;
24+
25+
thread_local! {
26+
pub static USE_DEBUG_BUILDER: Arc<Mutex<bool>> = Arc::new(Mutex::new(true));
27+
}
28+
29+
/// The debugging will check the main AIR constraints and then separately check LogUp constraints by
30+
/// checking the actual multiset equalities. Currently it will not debug check any after challenge phase
31+
/// constraints for implementation simplicity.
32+
#[allow(dead_code)]
33+
#[allow(clippy::too_many_arguments)]
34+
pub fn debug_constraints_and_interactions<SC: StarkGenericConfig>(
35+
airs: &[Arc<dyn AnyRap<SC>>],
36+
pk: &[StarkProvingKey<SC>],
37+
main_views_per_air: &[Vec<RowMajorMatrixView<'_, Val<SC>>>],
38+
public_values_per_air: &[Vec<Val<SC>>],
39+
) {
40+
USE_DEBUG_BUILDER.with(|debug| {
41+
if *debug.lock().unwrap() {
42+
let preprocessed = izip!(airs, pk, main_views_per_air, public_values_per_air,)
43+
.map(|(rap, pk, main, public_values)| {
44+
let preprocessed_trace = pk
45+
.preprocessed_data
46+
.as_ref()
47+
.map(|data| data.trace.as_view());
48+
tracing::debug!("Checking constraints for {}", rap.name());
49+
check_constraints(
50+
rap.as_ref(),
51+
&rap.name(),
52+
&preprocessed_trace,
53+
main,
54+
public_values,
55+
);
56+
preprocessed_trace
57+
})
58+
.collect_vec();
59+
60+
let (air_names, interactions): (Vec<_>, Vec<_>) = pk
61+
.iter()
62+
.map(|pk| {
63+
let sym_constraints = SymbolicConstraints::from(&pk.vk.symbolic_constraints);
64+
(pk.air_name.clone(), sym_constraints.interactions)
65+
})
66+
.unzip();
67+
check_logup(
68+
&air_names,
69+
&interactions,
70+
&preprocessed,
71+
main_views_per_air,
72+
public_values_per_air,
73+
);
74+
}
75+
});
76+
}
1877

1978
/// An `AirBuilder` which asserts that each constraint is zero, allowing any failed constraints to
2079
/// be detected early.

crates/stark-backend/src/air_builders/mod.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@ use p3_matrix::{dense::RowMajorMatrixView, stack::VerticalPair};
33

44
pub mod debug;
55
pub mod sub;
6+
/// AIR builder that collects the constraints expressed via the [Air](p3_air::Air) trait into
7+
/// a directed acyclic graph of symbolic expressions for serialization purposes.
68
pub mod symbolic;
7-
pub mod verifier;
89

910
pub type ViewPair<'a, T> = VerticalPair<RowMajorMatrixView<'a, T>, RowMajorMatrixView<'a, T>>;
1011

crates/stark-backend/src/air_builders/symbolic/dag.rs

+21-10
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ use crate::{
1616
/// Basically replace `Arc`s in `SymbolicExpression` with node IDs.
1717
/// Intended to be serializable and deserializable.
1818
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
19-
#[serde(bound = "F: Field")]
19+
#[serde(bound(serialize = "F: Serialize", deserialize = "F: Deserialize<'de>"))]
2020
#[repr(C)]
2121
pub enum SymbolicExpressionNode<F> {
2222
Variable(SymbolicVariable<F>),
@@ -46,7 +46,8 @@ pub enum SymbolicExpressionNode<F> {
4646
}
4747

4848
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
49-
#[serde(bound = "F: Field")]
49+
#[serde(bound(serialize = "F: Serialize", deserialize = "F: Deserialize<'de>"))]
50+
#[repr(C)]
5051
pub struct SymbolicExpressionDag<F> {
5152
/// Nodes in **topological** order.
5253
pub(crate) nodes: Vec<SymbolicExpressionNode<F>>,
@@ -67,7 +68,8 @@ impl<F> SymbolicExpressionDag<F> {
6768
}
6869

6970
#[derive(Clone, Debug, Serialize, Deserialize)]
70-
#[serde(bound = "F: Field")]
71+
#[serde(bound(serialize = "F: Serialize", deserialize = "F: Deserialize<'de>"))]
72+
#[repr(C)] // TODO[jpw]: device transfer requires usize-independent serialization
7173
pub struct SymbolicConstraintsDag<F> {
7274
/// DAG with all symbolic expressions as nodes.
7375
/// A subset of the nodes represents all constraints that will be
@@ -115,6 +117,9 @@ pub(crate) fn build_symbolic_constraints_dag<F: Field>(
115117
}
116118
})
117119
.collect();
120+
// Note[jpw]: there could be few nodes created after `constraint_idx` is built
121+
// from `interactions` even though constraints already contain all interactions.
122+
// This should be marginal and is not optimized for now.
118123
let constraints = SymbolicExpressionDag {
119124
nodes,
120125
constraint_idx,
@@ -252,23 +257,23 @@ impl<F: Field> SymbolicExpressionDag<F> {
252257
}
253258

254259
// TEMPORARY conversions until we switch main interfaces to use SymbolicConstraintsDag
255-
impl<F: Field> From<SymbolicConstraintsDag<F>> for SymbolicConstraints<F> {
256-
fn from(dag: SymbolicConstraintsDag<F>) -> Self {
260+
impl<'a, F: Field> From<&'a SymbolicConstraintsDag<F>> for SymbolicConstraints<F> {
261+
fn from(dag: &'a SymbolicConstraintsDag<F>) -> Self {
257262
let exprs = dag.constraints.to_symbolic_expressions();
258263
let constraints = dag
259264
.constraints
260265
.constraint_idx
261-
.into_iter()
262-
.map(|idx| exprs[idx].as_ref().clone())
266+
.iter()
267+
.map(|&idx| exprs[idx].as_ref().clone())
263268
.collect::<Vec<_>>();
264269
let interactions = dag
265270
.interactions
266-
.into_iter()
271+
.iter()
267272
.map(|interaction| {
268273
let fields = interaction
269274
.fields
270-
.into_iter()
271-
.map(|idx| exprs[idx].as_ref().clone())
275+
.iter()
276+
.map(|&idx| exprs[idx].as_ref().clone())
272277
.collect();
273278
let count = exprs[interaction.count].as_ref().clone();
274279
Interaction {
@@ -286,6 +291,12 @@ impl<F: Field> From<SymbolicConstraintsDag<F>> for SymbolicConstraints<F> {
286291
}
287292
}
288293

294+
impl<F: Field> From<SymbolicConstraintsDag<F>> for SymbolicConstraints<F> {
295+
fn from(dag: SymbolicConstraintsDag<F>) -> Self {
296+
(&dag).into()
297+
}
298+
}
299+
289300
impl<F: Field> From<SymbolicConstraints<F>> for SymbolicConstraintsDag<F> {
290301
fn from(sc: SymbolicConstraints<F>) -> Self {
291302
build_symbolic_constraints_dag(&sc.constraints, &sc.interactions)

crates/stark-backend/src/air_builders/symbolic/mod.rs

+3-1
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,12 @@ use crate::{
2424
rap::{BaseAirWithPublicValues, PermutationAirBuilderWithExposedValues, Rap},
2525
};
2626

27-
pub mod dag;
27+
mod dag;
2828
pub mod symbolic_expression;
2929
pub mod symbolic_variable;
3030

31+
pub use dag::*;
32+
3133
/// Symbolic constraints for a single AIR with interactions.
3234
/// The constraints contain the constraints on the logup partial sums.
3335
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]

crates/stark-backend/src/circuit_api.rs

-23
This file was deleted.

crates/stark-backend/src/commit.rs

-66
This file was deleted.

crates/stark-backend/src/config.rs

+4
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ where
1515
Com<Self>: Send + Sync,
1616
PcsProof<Self>: Send + Sync,
1717
PcsProverData<Self>: Send + Sync,
18+
RapPhaseSeqPartialProof<Self>: Send + Sync,
19+
RapPhaseSeqProvingKey<Self>: Send + Sync,
1820
{
1921
/// The PCS used to commit to trace polynomials.
2022
type Pcs: Pcs<Self::Challenge, Self::Challenger>;
@@ -115,6 +117,8 @@ where
115117
Pcs::ProverData: Send + Sync,
116118
Pcs::Proof: Send + Sync,
117119
Rps: RapPhaseSeq<<Pcs::Domain as PolynomialSpace>::Val, Challenge, Challenger>,
120+
Rps::PartialProof: Send + Sync,
121+
Rps::ProvingKey: Send + Sync,
118122
Challenger: FieldChallenger<<Pcs::Domain as PolynomialSpace>::Val>
119123
+ CanObserve<<Pcs as p3_commit::Pcs<Challenge, Challenger>>::Commitment>
120124
+ CanSample<Challenge>,

0 commit comments

Comments
 (0)