diff --git a/backend/Cargo.toml b/backend/Cargo.toml index 0b9d891d54..eff547acb9 100644 --- a/backend/Cargo.toml +++ b/backend/Cargo.toml @@ -63,7 +63,7 @@ p3-commit = { git = "https://github.com/plonky3/Plonky3.git", rev = "2192432ddf2 ], optional = true } p3-matrix = { git = "https://github.com/plonky3/Plonky3.git", rev = "2192432ddf28e7359dd2c577447886463e6124f0", optional = true } p3-uni-stark = { git = "https://github.com/plonky3/Plonky3.git", rev = "2192432ddf28e7359dd2c577447886463e6124f0", optional = true } -stwo-prover = { git = "https://github.com/starkware-libs/stwo.git", optional = true, rev = "81d1fe349b490089f65723ad49ef72b9d09495ba", features = [ +stwo-prover = { git = "https://github.com/ShuangWu121/stwo-GKR-serialized-proof.git", optional = true, features = [ "parallel", ] } diff --git a/backend/src/stwo/circuit_builder.rs b/backend/src/stwo/circuit_builder.rs index efbc137907..99e2861132 100644 --- a/backend/src/stwo/circuit_builder.rs +++ b/backend/src/stwo/circuit_builder.rs @@ -48,22 +48,22 @@ where CircleEvaluation::new(domain, column) } - +#[derive(Default, Clone)] pub struct PowdrEval { log_degree: u32, - analyzed: Analyzed, + pub analyzed: Analyzed, // the pre-processed are indexed in the whole proof, instead of in each component. // this offset represents the index of the first pre-processed column in this component preprocess_col_offset: usize, // The name of the public, the poly-id of the witness poly that this public is related to, the public value pub(crate) publics_values: Vec<(String, PolyID, M31)>, - stage0_witness_columns: BTreeMap, + pub stage0_witness_columns: BTreeMap, stage1_witness_columns: BTreeMap, constant_shifted: BTreeMap, constant_columns: BTreeMap, // stwo supports maximum 2 stages, challenges are only created after stage 0 - pub challenges: BTreeMap, - poly_stage_map: BTreeMap, + challenges: BTreeMap, + pub poly_stage_map: BTreeMap, } impl PowdrEval { @@ -134,15 +134,15 @@ impl PowdrEval { } } -struct Data<'a, F> { - stage0_witness_eval: &'a BTreeMap, - stage1_witness_eval: &'a BTreeMap, - constant_shifted_eval: &'a BTreeMap, - constant_eval: &'a BTreeMap, - publics_values: &'a BTreeMap, +pub struct Data<'a, F> { + pub stage0_witness_eval: &'a BTreeMap, + pub stage1_witness_eval: &'a BTreeMap, + pub constant_shifted_eval: &'a BTreeMap, + pub constant_eval: &'a BTreeMap, + pub publics_values: &'a BTreeMap, // challenges for stage 1 - challenges: &'a BTreeMap, - poly_stage_map: &'a BTreeMap, + pub challenges: &'a BTreeMap, + pub poly_stage_map: &'a BTreeMap, } impl TerminalAccess for &Data<'_, F> { diff --git a/backend/src/stwo/logup_gkr.rs b/backend/src/stwo/logup_gkr.rs new file mode 100644 index 0000000000..d0cf78bb6b --- /dev/null +++ b/backend/src/stwo/logup_gkr.rs @@ -0,0 +1,356 @@ +use itertools::Itertools; +use num_traits::{One, Zero}; +use serde::Serialize; + +use powdr_ast::analyzed::AlgebraicExpression; +use powdr_backend_utils::machine_fixed_columns; +use powdr_executor_utils::expression_evaluator::ExpressionEvaluator; +use powdr_number::Mersenne31Field; +use stwo_prover::constraint_framework::EvalAtRow; +use stwo_prover::constraint_framework::PointEvaluator; +use stwo_prover::constraint_framework::{FrameworkComponent, FrameworkEval}; +use stwo_prover::core::air::accumulation::PointEvaluationAccumulator; +use stwo_prover::core::air::ComponentProver; + +use stwo_prover::core::backend::simd::SimdBackend; +use stwo_prover::core::backend::BackendForChannel; +use stwo_prover::core::channel::Channel; +use stwo_prover::core::channel::MerkleChannel; +use stwo_prover::core::circle::CirclePoint; + +use stwo_prover::core::fields::qm31::SecureField; +use stwo_prover::core::fields::qm31::QM31; +use stwo_prover::core::lookups::gkr_verifier::GkrArtifact; +use stwo_prover::core::lookups::gkr_verifier::GkrBatchProof; +use stwo_prover::core::pcs::TreeVec; + +use crate::stwo::circuit_builder::STAGE0_TRACE_IDX; +use stwo_prover::core::ColumnVec; + +use powdr_ast::analyzed::PolyID; + +use stwo_prover::core::lookups::gkr_prover::{prove_batch, Layer}; +use stwo_prover::core::lookups::mle::Mle; +use stwo_prover::examples::xor::gkr_lookups::mle_eval::MleCoeffColumnOracle; + +use crate::stwo::circuit_builder::Data; +use crate::stwo::prover::into_stwo_field; +use powdr_ast::analyzed::Identity; + +use serde::de::DeserializeOwned; +use std::collections::BTreeMap; +use std::ops::Deref; + +use crate::stwo::circuit_builder::PowdrComponent; +use crate::stwo::StwoProver; +use stwo_prover::core::utils::{bit_reverse_index, coset_index_to_circle_domain_index}; + +use super::circuit_builder::PowdrEval; + +// using this flag to enable logup-GKR for now +pub const LOGUP_GKR: bool = true; +// preprocess column has commitment tree index 0 +// stage 0 witness columns has commitment tree index 1 +// index 2 is for GKR auxiliary traces or stage 1 witness +pub const MLE_TRACE_IDX: usize = 2; + +// Wrapper for PowdrComponent to implement MleCoeffColumnOracle +pub struct PowdrComponentWrapper<'a> { + pub powdr_component: &'a FrameworkComponent, + pub logup_challenge: QM31, + pub main_machine_powdr_eval: PowdrEval, +} + +// MleCoeffColumnOracle returns the ood point evaluation of the bus payload +impl MleCoeffColumnOracle for PowdrComponentWrapper<'_> { + fn evaluate_at_point( + &self, + _point: CirclePoint, + mask: &TreeVec>>, + ) -> SecureField { + // Create dummy point evaluator just to extract the ood point evaluation value we need from the mask + let mut accumulator = PointEvaluationAccumulator::new(SecureField::one()); + + // TODO: evaluator cannot get constant columns, need to fix this + let eval_mask = mask.sub_tree(self.powdr_component.trace_locations()); + + let mut eval = PointEvaluator::new( + eval_mask, + &mut accumulator, + SecureField::one(), + self.powdr_component.log_size(), + SecureField::zero(), + ); + + let stage0_witness_eval: BTreeMap::F; 2]> = self + .main_machine_powdr_eval + .stage0_witness_columns + .keys() + .map(|poly_id| { + ( + *poly_id, + eval.next_interaction_mask(STAGE0_TRACE_IDX, [0, 1]), + ) + }) + .collect(); + + let intermediate_definitions = self.analyzed.intermediate_definitions(); + + let data = Data { + stage0_witness_eval: &stage0_witness_eval, + stage1_witness_eval: &BTreeMap::new(), + publics_values: &BTreeMap::new(), + constant_shifted_eval: &BTreeMap::new(), + constant_eval: &BTreeMap::new(), + challenges: &BTreeMap::new(), + poly_stage_map: &self.main_machine_powdr_eval.poly_stage_map, + }; + + let mut evaluator = + ExpressionEvaluator::new_with_custom_expr(&data, &intermediate_definitions, |v| { + ::F::from(into_stwo_field(v)) + }); + + let mut accumulator = SecureField::zero(); + + for id in &self.main_machine_powdr_eval.analyzed.identities { + if let Identity::BusInteraction(id) = id { + let payload: Vec<::F> = + id.payload.0.iter().map(|e| evaluator.evaluate(e)).collect(); + + let multiplicity = + ::EF::from(evaluator.evaluate(&id.multiplicity)); + // TODO: update this accumulator when the sound challenge is implemented + accumulator += payload[0] + self.logup_challenge + multiplicity; + } + } + + accumulator + } +} + +impl Deref for PowdrComponentWrapper<'_> { + type Target = PowdrComponent; + + fn deref(&self) -> &Self::Target { + self.powdr_component + } +} + +pub struct GkrProofArtifacts { + pub gkr_proof: GkrBatchProof, + pub gkr_artifacts: GkrArtifact, + pub combined_mle: Mle, + pub combine_mle_claim: SecureField, +} + +impl StwoProver +where + MC: MerkleChannel + Send, + C: Channel + Send, + MC::H: DeserializeOwned + Serialize, + PowdrComponent: ComponentProver, + SimdBackend: BackendForChannel, +{ + pub fn gkr_prove( + &self, + witness: &[(String, Vec)], + machine_log_sizes: BTreeMap, + logup_challenge: QM31, + prover_channel: &mut ::C, + ) -> Option { + if !LOGUP_GKR { + return None; + } + + // PhantomBusInteraction exists, which means the pil file is not generated with BusInteraction only, + // then the GKR is not not supported for that pil file yet + for id in &self.analyzed.identities { + if let Identity::PhantomBusInteraction(_) = id { + return None; + } + } + + let has_bus_interaction = self + .analyzed + .identities + .iter() + .any(|id| matches!(id, Identity::BusInteraction(_))); + + if !has_bus_interaction { + return None; + }; + // The payload of the bus can come from all the expressions, therefore inorder to rebuild the payload trace, constant columns,witness columns + // and intermidiate columns are needed. + // get all the fix columns + // TODO: if GKR applies only on main machine, then only the fixed columns of the main machine are needed + let all_fixed_columns: Vec<(String, Vec<_>)> = self + .split + .iter() + .flat_map(|(machine_name, pil)| { + let machine_fixed_col = machine_fixed_columns(&self.fixed, pil); + machine_fixed_col + .iter() + .filter(|(size, _)| size.ilog2() == machine_log_sizes[machine_name]) + .flat_map(|(_, vec)| { + vec.iter() + .map(|(s, w)| (s.clone(), w.to_vec())) + .collect_vec() + }) + .collect_vec() + }) + .collect(); + + // find senders and receivers to build denominator traces + + // GKR toplayer is the input layer of the circuit, it consists of numerator MLE poly and denominator MLE poly + // these MLE polys are from the trace polys in bus payload, multiplicity and selecotr + // numerator MLE poly is 1 for bus send, is from multiplicity poly for bus receive + // denominator MLE poly is from the trace poly in bus payload + // Collect all the top layer inputs of GKR, each of them is a GKR instance for now, later they should be linear combined + let mut gkr_top_layers = Vec::new(); + + // Collect all the MLEs for the numerators of the GKR instances + let mut mle_numerators = Vec::new(); + + // Collect all the MLEs for the denominators of the GKR instances + let mut mle_denominators = Vec::new(); + + let mut all_mle_values = Vec::new(); + + for id in &self.analyzed.identities { + if let Identity::BusInteraction(identity) = id { + for e in &identity.payload.0 { + // For now, only consider payload with polynomial identity + if let AlgebraicExpression::Reference(_) = e { + } else { + break; + }; + + let denominator_trace = witness + .iter() + .chain(all_fixed_columns.iter()) + .find(|(name, _)| { + if let AlgebraicExpression::Reference(r) = e { + name == &r.name + } else { + panic!("cannot find bus payload trace {e:?}"); + } + }) + .unwrap(); + + // create fractions that are to be added by GKR circuit + // numerator is 1 for bus send, is multiplicity for bus receive + // all take 1 for now + // TODO: include multiplicity for bus receive, latch/1 for bus send, 1 needs to be committed as well + + let numerator_values: Vec = match identity.multiplicity { + AlgebraicExpression::Number(n) => { + vec![ + SecureField::from_m31( + into_stwo_field(&n), + 0.into(), + 0.into(), + 0.into() + ); + self.analyzed.degree() as usize + ] + } + _ => panic!("only support multiplicity as Number expression for now"), + }; + + // traces need to be bit-reverse order + let denominator_values = get_bit_reversed_col( + &denominator_trace.1, + self.analyzed.degree() as usize, + logup_challenge, + ); + + // covert to SecureColumn, which is used to crate MLE in secure field + let numerator_secure_column = numerator_values.iter().copied().collect(); + let denominator_secure_column = denominator_values.iter().copied().collect(); + + // create multilinear polynomial for the input layer + let mle_numerator = + Mle::::new(numerator_secure_column); + let mle_denominator = + Mle::::new(denominator_secure_column); + + mle_numerators.push(mle_numerator.clone()); + mle_denominators.push(mle_denominator.clone()); + + let top_layer = Layer::LogUpGeneric { + numerators: mle_numerator, + denominators: mle_denominator, + }; + + gkr_top_layers.push(top_layer); + all_mle_values.push(numerator_values); + all_mle_values.push(denominator_values); + } + } + } + + let (gkr_proof, gkr_artifacts) = prove_batch(prover_channel, gkr_top_layers); + + // check the logop accumulation is zero + if gkr_proof + .output_claims_by_instance + .iter() + .fold(SecureField::zero(), |acc, vec| acc + vec[0] / vec[1]) + != SecureField::zero() + { + panic!("logup accumulation is not zero, prove failed"); + } + + // combine the GKR instances + // TODO: use randomness that is generated based on the claims of the instance, to make the challenge sound + let linear_combine_challenge = SecureField::one(); + + let combined_mle_values: Vec = (0..self.analyzed.degree()) + .map(|index| { + let combined_mle_value: SecureField = all_mle_values + .iter() + .fold(SecureField::zero(), |acc, mle_values| { + acc + linear_combine_challenge * mle_values[index as usize] + }); + combined_mle_value + }) + .collect(); + + let combined_mle_secure_column = combined_mle_values.iter().copied().collect(); + + // create multilinear polynomial for the input layer + let combined_mle = Mle::::new(combined_mle_secure_column); + + // TODO: modify this according to the challenge when the sound challenge is implemented + let combine_mle_claim: SecureField = gkr_artifacts + .claims_to_verify_by_instance + .iter() + .flatten() + .fold(SecureField::zero(), |acc, claim| acc + *claim); + + Some(GkrProofArtifacts { + gkr_proof, + gkr_artifacts, + combined_mle, + combine_mle_claim, + }) + } +} + +fn get_bit_reversed_col( + values: &[Mersenne31Field], + degree: usize, + off_set: QM31, // for challenge if any +) -> Vec { + let mut bit_reversed_col = vec![SecureField::zero(); degree]; + values.iter().enumerate().for_each(|(index, value)| { + bit_reversed_col[bit_reverse_index( + coset_index_to_circle_domain_index(index, degree.ilog2()), + degree.ilog2(), + )] = off_set + SecureField::from_m31(into_stwo_field(value), 0.into(), 0.into(), 0.into()); + }); + + bit_reversed_col +} diff --git a/backend/src/stwo/mod.rs b/backend/src/stwo/mod.rs index 0534042de0..b2020c5802 100644 --- a/backend/src/stwo/mod.rs +++ b/backend/src/stwo/mod.rs @@ -18,6 +18,7 @@ use stwo_prover::core::channel::{Blake2sChannel, Channel, MerkleChannel}; use stwo_prover::core::vcs::blake2_merkle::Blake2sMerkleChannel; mod circuit_builder; +mod logup_gkr; mod proof; mod prover; @@ -42,7 +43,7 @@ impl BackendFactory for RestrictedFactory { assert!(pil.stage_count() <= 2, "stwo supports max 2 stages"); - let mut stwo: Box> = + let mut stwo: Box> = Box::new(StwoProver::new(pil, fixed)?); match (proving_key, verification_key) { @@ -61,7 +62,7 @@ impl BackendFactory for RestrictedFactory { generalize_factory!(Factory <- RestrictedFactory, [M31]); -impl Backend for StwoProver +impl Backend for StwoProver where SimdBackend: BackendForChannel, MC: MerkleChannel, diff --git a/backend/src/stwo/proof.rs b/backend/src/stwo/proof.rs index 0aa0ea23a5..bb3ccef503 100644 --- a/backend/src/stwo/proof.rs +++ b/backend/src/stwo/proof.rs @@ -6,6 +6,7 @@ use stwo_prover::core::backend::Column; use stwo_prover::core::backend::ColumnOps; use stwo_prover::core::channel::MerkleChannel; use stwo_prover::core::fields::m31::BaseField; +use stwo_prover::core::lookups::gkr_verifier::GkrBatchProof; use stwo_prover::core::poly::circle::{CanonicCoset, CircleEvaluation}; use stwo_prover::core::poly::BitReversedOrder; use stwo_prover::core::prover::StarkProof; @@ -133,4 +134,5 @@ where { pub stark_proof: StarkProof, pub machine_log_sizes: BTreeMap, + pub gkr_proof: Option, } diff --git a/backend/src/stwo/prover.rs b/backend/src/stwo/prover.rs index 0f6c4973cf..cfe151b693 100644 --- a/backend/src/stwo/prover.rs +++ b/backend/src/stwo/prover.rs @@ -5,6 +5,14 @@ use powdr_ast::parsed::visitor::AllChildren; use powdr_backend_utils::{machine_fixed_columns, machine_witness_columns}; use powdr_executor::constant_evaluator::VariablySizedColumn; use powdr_executor::witgen::WitgenCallback; +use stwo_prover::core::backend::simd::SimdBackend; +use stwo_prover::core::lookups::gkr_verifier::partially_verify_batch; +use stwo_prover::core::lookups::gkr_verifier::Gate; +use stwo_prover::core::lookups::gkr_verifier::GkrArtifact; +use stwo_prover::core::poly::circle::PolyOps; +use stwo_prover::examples::xor::gkr_lookups::mle_eval::{ + build_trace, MleEvalProverComponent, MleEvalVerifierComponent, +}; use powdr_number::{FieldElement, LargeInt, Mersenne31Field as M31}; @@ -23,6 +31,7 @@ use crate::stwo::circuit_builder::{ gen_stwo_circle_column, get_constant_with_next_list, PowdrComponent, PowdrEval, PREPROCESSED_TRACE_IDX, STAGE0_TRACE_IDX, STAGE1_TRACE_IDX, }; +use crate::stwo::logup_gkr::{PowdrComponentWrapper, MLE_TRACE_IDX}; use crate::stwo::proof::{ Proof, SerializableStarkProvingKey, StarkProvingKey, TableProvingKey, TableProvingKeyCollection, }; @@ -30,7 +39,7 @@ use crate::stwo::proof::{ use stwo_prover::constraint_framework::TraceLocationAllocator; use stwo_prover::core::air::{Component, ComponentProver}; -use stwo_prover::core::backend::{Backend, BackendForChannel, Col, Column}; +use stwo_prover::core::backend::{BackendForChannel, Col, Column}; use stwo_prover::core::channel::{Channel, MerkleChannel}; use stwo_prover::core::fields::m31::BaseField; use stwo_prover::core::fields::qm31::SecureField; @@ -60,28 +69,28 @@ impl fmt::Display for KeyExportError { } } -pub struct StwoProver + Send, MC: MerkleChannel, C: Channel> { +pub struct StwoProver { pub analyzed: Arc>, /// The split analyzed PIL - split: BTreeMap>, + pub split: BTreeMap>, /// The value of the fixed columns pub fixed: Arc)>>, /// Proving key - proving_key: StarkProvingKey, + proving_key: StarkProvingKey, /// TODO: Add verification key. _verifying_key: Option<()>, _channel_marker: PhantomData, _merkle_channel_marker: PhantomData, } -impl StwoProver +impl StwoProver where - B: Backend + Send + BackendForChannel, MC: MerkleChannel + Send, C: Channel + Send, MC::H: DeserializeOwned + Serialize, - PowdrComponent: ComponentProver, + PowdrComponent: ComponentProver, + SimdBackend: BackendForChannel, { pub fn new( analyzed: Arc>, @@ -148,7 +157,7 @@ where }) .collect(); - let preprocessed: BTreeMap> = self + let preprocessed: BTreeMap> = self .split .iter() .filter_map(|(namespace, pil)| { @@ -169,7 +178,7 @@ where let fixed_columns = &fixed_columns[&size]; let log_size = size.ilog2(); let mut constant_trace: ColumnVec< - CircleEvaluation, + CircleEvaluation, > = fixed_columns .iter() .map(|(_, vec)| { @@ -183,7 +192,7 @@ where let constant_with_next_list = get_constant_with_next_list(pil); let constant_shifted_trace: ColumnVec< - CircleEvaluation, + CircleEvaluation, > = fixed_columns .iter() .filter(|(name, _)| constant_with_next_list.contains(name)) @@ -201,13 +210,14 @@ where // get selector columns for the public inputs let publics_selectors: ColumnVec< - CircleEvaluation, + CircleEvaluation, > = pil .get_publics() .into_iter() .map(|(_, _, _, row_id, _)| { // Create a column with a single 1 at the row_id-th (in circle domain bitreverse order) position - let mut col = Col::::zeros(1 << log_size); + let mut col = + Col::::zeros(1 << log_size); col.set( bit_reverse_index( coset_index_to_circle_domain_index( @@ -217,10 +227,14 @@ where ), BaseField::one(), ); - CircleEvaluation::::new( - *domain_map.get(&(log_size as usize)).unwrap(), - col, - ) + CircleEvaluation::< + SimdBackend, + BaseField, + BitReversedOrder, + >::new( + *domain_map.get(&(log_size as usize)).unwrap(), + col, + ) }) .collect(); @@ -352,7 +366,7 @@ where // Get witness columns in circle domain for stage 0 let stage0_witness_cols_circle_domain_eval: ColumnVec< - CircleEvaluation, + CircleEvaluation, > = witness_by_machine .values() .flat_map(|witness_cols| { @@ -370,7 +384,7 @@ where }) .collect_vec(); - let twiddles_max_degree = B::precompute_twiddles( + let twiddles_max_degree = SimdBackend::precompute_twiddles( CanonicCoset::new(domain_degree_range.max.ilog2() + 1 + FRI_LOG_BLOWUP as u32) .circle_domain() .half_coset, @@ -378,7 +392,7 @@ where let prover_channel = &mut ::C::default(); let mut commitment_scheme = - CommitmentSchemeProver::<'_, B, MC>::new(config, &twiddles_max_degree); + CommitmentSchemeProver::<'_, SimdBackend, MC>::new(config, &twiddles_max_degree); // commit to constant columns let mut tree_builder = commitment_scheme.tree_builder(); @@ -387,6 +401,7 @@ where // commit to witness columns of stage 0 let mut tree_builder = commitment_scheme.tree_builder(); + tree_builder.extend_evals(stage0_witness_cols_circle_domain_eval); tree_builder.commit(prover_channel); @@ -472,8 +487,32 @@ where span.exit(); } + // Generate GKR proof, get None if LOGUP_GKR is false + // logup challenge alpha, take dummy challenge and dummy gkr_prover_channel for now + // TODO: implement sound challenge + let alpha = SecureField::from_u32_unchecked(42, 42, 42, 42); + let gkr_prover_channel = &mut ::C::default(); + let gkr_result = self.gkr_prove( + witness, + machine_log_sizes.clone(), + alpha, + gkr_prover_channel, + ); + + if let Some(ref gkr_proof_artifacts) = gkr_result { + let mut tree_builder = commitment_scheme.tree_builder(); + tree_builder.extend_evals(build_trace( + &gkr_proof_artifacts.combined_mle, + &gkr_proof_artifacts.gkr_artifacts.ood_point, + gkr_proof_artifacts.combine_mle_claim, + )); + tree_builder.commit(prover_channel); + } + let tree_span_provider = &mut TraceLocationAllocator::default(); + let mut main_machine_powdr_eval = PowdrEval::default(); + // Build the circuit. The circuit includes constraints of all the machines in both stage 0 and stage 1 let mut constant_cols_offset_acc = 0; let components = self @@ -484,17 +523,19 @@ where |((machine_name, pil), (proof_machine_name, &machine_log_size))| { assert_eq!(machine_name, proof_machine_name); - let component = PowdrComponent::new( - tree_span_provider, - PowdrEval::new( - (*pil).clone(), - constant_cols_offset_acc, - machine_log_size, - stage0_challenges.clone(), - public_values.clone(), - ), - SecureField::zero(), + let powdr_eval = PowdrEval::new( + (*pil).clone(), + constant_cols_offset_acc, + machine_log_size, + stage0_challenges.clone(), + public_values.clone(), ); + if machine_name == "main" { + main_machine_powdr_eval = powdr_eval.clone(); + }; + + let component = + PowdrComponent::new(tree_span_provider, powdr_eval, SecureField::zero()); constant_cols_offset_acc += pil.constant_count() + get_constant_with_next_list(pil).len(); @@ -503,28 +544,70 @@ where ) .collect_vec(); - let components_slice: Vec<&dyn ComponentProver> = components + let mut components_slice: Vec<&dyn ComponentProver> = components .iter() - .map(|component| component as &dyn ComponentProver) + .map(|component| component as &dyn ComponentProver) .collect(); - let proof_result = stwo_prover::core::prover::prove::( - &components_slice, - prover_channel, - commitment_scheme, - ); + if let Some(gkr_proof_artifacts) = gkr_result { + let last_component = components.last().unwrap(); // &FrameworkComponent + let wrapped_component = PowdrComponentWrapper { + powdr_component: last_component, + logup_challenge: alpha, + main_machine_powdr_eval, + }; + + // create component for MLE + let mle_eval_component = MleEvalProverComponent::generate( + tree_span_provider, + &wrapped_component, + &gkr_proof_artifacts.gkr_artifacts.ood_point, + gkr_proof_artifacts.combined_mle.clone(), + gkr_proof_artifacts.combine_mle_claim, + &twiddles_max_degree, + MLE_TRACE_IDX, + ); - let stark_proof = match proof_result { - Ok(value) => value, - Err(e) => return Err(e.to_string()), // Propagate the error instead of panicking - }; + components_slice.push(&mle_eval_component); - let proof: Proof = Proof { - stark_proof, - machine_log_sizes, - }; - prove_span.exit(); - Ok(bincode::serialize(&proof).unwrap()) + let proof_result = stwo_prover::core::prover::prove::( + &components_slice, + prover_channel, + commitment_scheme, + ); + + let stark_proof = match proof_result { + Ok(value) => value, + Err(e) => return Err(e.to_string()), // Propagate the error instead of panicking + }; + + let proof: Proof = Proof { + stark_proof, + machine_log_sizes, + gkr_proof: Some(gkr_proof_artifacts.gkr_proof), + }; + prove_span.exit(); + Ok(bincode::serialize(&proof).unwrap()) + } else { + let proof_result = stwo_prover::core::prover::prove::( + &components_slice, + prover_channel, + commitment_scheme, + ); + + let stark_proof = match proof_result { + Ok(value) => value, + Err(e) => return Err(e.to_string()), // Propagate the error instead of panicking + }; + + let proof: Proof = Proof { + stark_proof, + machine_log_sizes, + gkr_proof: None, + }; + prove_span.exit(); + Ok(bincode::serialize(&proof).unwrap()) + } } pub fn verify(&self, proof: &[u8], instances: &[M31]) -> Result<(), String> { @@ -556,6 +639,7 @@ where // Constraints that are to be proved let tree_span_provider = &mut TraceLocationAllocator::default(); + let mut main_machine_powdr_eval = PowdrEval::default(); let mut constant_cols_offset_acc = 0; let iter = self @@ -565,13 +649,13 @@ where .map( |((machine_name, pil), (proof_machine_name, &machine_log_size))| { assert_eq!(machine_name, proof_machine_name); - (pil, machine_log_size) // Keep only relevant values + (pil, machine_name, machine_log_size) // Keep only relevant values }, ); let constant_col_log_sizes = iter .clone() - .flat_map(|(pil, machine_log_size)| { + .flat_map(|(pil, _machine_name, machine_log_size)| { repeat(machine_log_size).take( pil.constant_count() + get_constant_with_next_list(pil).len() @@ -582,14 +666,14 @@ where let stage0_witness_col_log_sizes = iter .clone() - .flat_map(|(pil, machine_log_size)| { + .flat_map(|(pil, _machine_name, machine_log_size)| { repeat(machine_log_size).take(pil.stage_commitment_count(0)) }) .collect_vec(); let stage1_witness_col_log_sizes = iter .clone() - .flat_map(|(pil, machine_log_size)| { + .flat_map(|(pil, _machine_name, machine_log_size)| { repeat(machine_log_size).take(pil.stage_commitment_count(1)) }) .collect_vec(); @@ -612,18 +696,19 @@ where let components = iter .clone() - .map(|(pil, machine_log_size)| { - let machine_component = PowdrComponent::new( - tree_span_provider, - PowdrEval::new( - (*pil).clone(), - constant_cols_offset_acc, - machine_log_size, - stage0_challenges.clone(), - public_values.clone(), - ), - SecureField::zero(), + .map(|(pil, machine_name, machine_log_size)| { + let powdr_eval = PowdrEval::new( + (*pil).clone(), + constant_cols_offset_acc, + machine_log_size, + stage0_challenges.clone(), + public_values.clone(), ); + if machine_name == "main" { + main_machine_powdr_eval = powdr_eval.clone(); + }; + let machine_component = + PowdrComponent::new(tree_span_provider, powdr_eval, SecureField::zero()); constant_cols_offset_acc += pil.constant_count(); @@ -632,26 +717,88 @@ where }) .collect_vec(); - let components_slice: Vec<&dyn Component> = components + let mut components_slice: Vec<&dyn Component> = components .iter() .map(|component| component as &dyn Component) .collect(); - if self.analyzed.stage_count() > 1 { + let gkr_verifier_channel = &mut ::C::default(); + if let Some(gkr_proof) = &proof.gkr_proof { + //let gkr_trace_size=proof.machine_log_sizes.get("mle").unwrap(); + + let GkrArtifact { + ood_point, + claims_to_verify_by_instance, + n_variables_by_instance, + } = partially_verify_batch(vec![Gate::LogUp; 2], gkr_proof, gkr_verifier_channel) + .unwrap(); + + // check the logop accumulation is zero + if gkr_proof + .output_claims_by_instance + .iter() + .fold(SecureField::zero(), |acc, vec| acc + vec[0] / vec[1]) + != SecureField::zero() + { + return Err("logup accumulation is not zero".to_string()); + } + + gkr_proof.output_claims_by_instance.len(); + + // TODO: modify this according to the challenge when the sound challenge is implemented + let combine_mle_claim: SecureField = claims_to_verify_by_instance + .iter() + .flatten() + .fold(SecureField::zero(), |acc, claim| acc + *claim); + commitment_scheme.commit( - proof.stark_proof.commitments[STAGE1_TRACE_IDX], - &stage1_witness_col_log_sizes, + proof.stark_proof.commitments[MLE_TRACE_IDX], + // 8 is number of the extra columns for GKR + &[n_variables_by_instance[0] as u32; 8], verifier_channel, ); - } - stwo_prover::core::prover::verify( - &components_slice, - verifier_channel, - commitment_scheme, - proof.stark_proof, - ) - .map_err(|e| e.to_string()) + let alpha = SecureField::from_u32_unchecked(42, 42, 42, 42); + let last_component = components.last().unwrap(); // &FrameworkComponent + let wrapped_component = PowdrComponentWrapper { + powdr_component: last_component, + logup_challenge: alpha, + main_machine_powdr_eval, + }; + let mle_eval_component = MleEvalVerifierComponent::new( + tree_span_provider, + &wrapped_component, + &ood_point, + combine_mle_claim, + MLE_TRACE_IDX, + ); + + components_slice.push(&mle_eval_component); + + stwo_prover::core::prover::verify( + &components_slice, + verifier_channel, + commitment_scheme, + proof.stark_proof, + ) + .map_err(|e| e.to_string()) + } else { + if self.analyzed.stage_count() > 1 { + commitment_scheme.commit( + proof.stark_proof.commitments[STAGE1_TRACE_IDX], + &stage1_witness_col_log_sizes, + verifier_channel, + ); + } + + stwo_prover::core::prover::verify( + &components_slice, + verifier_channel, + commitment_scheme, + proof.stark_proof, + ) + .map_err(|e| e.to_string()) + } } } diff --git a/pipeline/tests/pil.rs b/pipeline/tests/pil.rs index 6160aaa9a8..8afde7501c 100644 --- a/pipeline/tests/pil.rs +++ b/pipeline/tests/pil.rs @@ -269,6 +269,19 @@ fn stwo_fixed_columns() { test_stwo(f, Default::default()); } +#[test] +fn stwo_gkr_lookup() { + let f = "pil/gkr_lookup.pil"; + test_stwo(f, Default::default()); +} + +#[test] +#[should_panic(expected = "logup accumulation is not zero, prove failed")] +fn stwo_gkr_lookup_invalid() { + let f = "pil/gkr_lookup_invalid.pil"; + test_stwo(f, Default::default()); +} + #[test] fn stwo_stage1_publics() { let f = "pil/stage1_publics.pil"; diff --git a/test_data/pil/gkr_lookup.pil b/test_data/pil/gkr_lookup.pil new file mode 100644 index 0000000000..13fc74878d --- /dev/null +++ b/test_data/pil/gkr_lookup.pil @@ -0,0 +1,9 @@ +namespace main(32); + col fixed x = [1_fe, 2_fe, 3_fe, 4_fe, 5_fe, 6_fe, 7_fe, 8_fe, 9_fe, 10_fe, 11_fe, 12_fe, 13_fe, 14_fe, 15_fe, 16_fe, 17_fe, 18_fe, 19_fe, 20_fe, 21_fe, 22_fe, 23_fe, 24_fe, 25_fe, 26_fe, 27_fe, 28_fe, 29_fe, 30_fe, 31_fe, 32_fe]; + col witness y; + col witness z; + main::y = main::x; + main::z = main::x; + col fixed is_first(i) { if i == 0_int { 1_int } else { 0_int } }; + Constr::BusInteraction(1, 42, [main::y], 1); + Constr::BusInteraction(-1, 42, [main::z], 1); diff --git a/test_data/pil/gkr_lookup_invalid.pil b/test_data/pil/gkr_lookup_invalid.pil new file mode 100644 index 0000000000..c1d6556d36 --- /dev/null +++ b/test_data/pil/gkr_lookup_invalid.pil @@ -0,0 +1,9 @@ +namespace main(32); + col fixed x(i) {i}; + col witness y; + col witness z; + main::y = main::x; + main::z = main::x+1; + + Constr::BusInteraction(1, 42, [main::y], 1); + Constr::BusInteraction(-1, 42, [main::z], 1);