diff --git a/Cargo.lock b/Cargo.lock index a8c66e950..550f84c4e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8285,6 +8285,7 @@ dependencies = [ "alloy-transport-http", "anyhow", "bincode", + "hex", "once_cell", "pem 3.0.4", "raiko-lib", diff --git a/core/src/interfaces.rs b/core/src/interfaces.rs index 3099565ed..962baf1d8 100644 --- a/core/src/interfaces.rs +++ b/core/src/interfaces.rs @@ -3,7 +3,9 @@ use alloy_primitives::{Address, B256}; use clap::{Args, ValueEnum}; use raiko_lib::{ consts::VerifierType, - input::{BlobProofType, GuestInput, GuestOutput}, + input::{ + AggregationGuestInput, AggregationGuestOutput, BlobProofType, GuestInput, GuestOutput, + }, prover::{IdStore, IdWrite, Proof, ProofKey, Prover, ProverError}, }; use serde::{Deserialize, Serialize}; @@ -203,6 +205,47 @@ impl ProofType { } } + /// Run the prover driver depending on the proof type. + pub async fn aggregate_proofs( + &self, + input: AggregationGuestInput, + output: &AggregationGuestOutput, + config: &Value, + store: Option<&mut dyn IdWrite>, + ) -> RaikoResult { + let proof = match self { + ProofType::Native => NativeProver::aggregate(input.clone(), output, config, store) + .await + .map_err(>::into), + ProofType::Sp1 => { + #[cfg(feature = "sp1")] + return sp1_driver::Sp1Prover::aggregate(input.clone(), output, config, store) + .await + .map_err(|e| e.into()); + #[cfg(not(feature = "sp1"))] + Err(RaikoError::FeatureNotSupportedError(*self)) + } + ProofType::Risc0 => { + #[cfg(feature = "risc0")] + return risc0_driver::Risc0Prover::aggregate(input.clone(), output, config, store) + .await + .map_err(|e| e.into()); + #[cfg(not(feature = "risc0"))] + Err(RaikoError::FeatureNotSupportedError(*self)) + } + ProofType::Sgx => { + #[cfg(feature = "sgx")] + return sgx_prover::SgxProver::aggregate(input.clone(), output, config, store) + .await + .map_err(|e| e.into()); + #[cfg(not(feature = "sgx"))] + Err(RaikoError::FeatureNotSupportedError(*self)) + } + }?; + + Ok(proof) + } + pub async fn cancel_proof( &self, proof_key: ProofKey, @@ -398,3 +441,15 @@ impl TryFrom for ProofRequest { }) } } + +#[serde_as] +#[derive(Clone, Debug, Serialize, Deserialize)] +/// A request for proof aggregation of multiple proofs. +pub struct AggregationRequest { + /// All the proofs to verify + pub proofs: Vec, + /// The proof type. + pub proof_type: ProofType, + /// Additional prover params. + pub prover_args: HashMap, +} diff --git a/core/src/lib.rs b/core/src/lib.rs index cd026952b..47c8d20cd 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -226,8 +226,9 @@ mod tests { use clap::ValueEnum; use raiko_lib::{ consts::{Network, SupportedChainSpecs}, - input::BlobProofType, + input::{AggregationGuestInput, AggregationGuestOutput, BlobProofType}, primitives::B256, + prover::Proof, }; use serde_json::{json, Value}; use std::{collections::HashMap, env}; @@ -242,7 +243,7 @@ mod tests { ci == "1" } - fn test_proof_params() -> HashMap { + fn test_proof_params(enable_aggregation: bool) -> HashMap { let mut prover_args = HashMap::new(); prover_args.insert( "native".to_string(), @@ -256,7 +257,7 @@ mod tests { "sp1".to_string(), json! { { - "recursion": "core", + "recursion": if enable_aggregation { "compressed" } else { "plonk" }, "prover": "mock", "verify": true } @@ -278,8 +279,8 @@ mod tests { json! { { "instance_id": 121, - "setup": true, - "bootstrap": true, + "setup": enable_aggregation, + "bootstrap": enable_aggregation, "prove": true, } }, @@ -291,7 +292,7 @@ mod tests { l1_chain_spec: ChainSpec, taiko_chain_spec: ChainSpec, proof_request: ProofRequest, - ) { + ) -> Proof { let provider = RpcBlockDataProvider::new(&taiko_chain_spec.rpc, proof_request.block_number - 1) .expect("Could not create RpcBlockDataProvider"); @@ -301,10 +302,10 @@ mod tests { .await .expect("input generation failed"); let output = raiko.get_output(&input).expect("output generation failed"); - let _proof = raiko + raiko .prove(input, &output, None) .await - .expect("proof generation failed"); + .expect("proof generation failed") } #[ignore] @@ -332,7 +333,7 @@ mod tests { l1_network, proof_type, blob_proof_type: BlobProofType::ProofOfEquivalence, - prover_args: test_proof_params(), + prover_args: test_proof_params(false), }; prove_block(l1_chain_spec, taiko_chain_spec, proof_request).await; } @@ -361,7 +362,7 @@ mod tests { l1_network, proof_type, blob_proof_type: BlobProofType::ProofOfEquivalence, - prover_args: test_proof_params(), + prover_args: test_proof_params(false), }; prove_block(l1_chain_spec, taiko_chain_spec, proof_request).await; } @@ -399,7 +400,7 @@ mod tests { l1_network, proof_type, blob_proof_type: BlobProofType::ProofOfEquivalence, - prover_args: test_proof_params(), + prover_args: test_proof_params(false), }; prove_block(l1_chain_spec, taiko_chain_spec, proof_request).await; } @@ -432,9 +433,55 @@ mod tests { l1_network, proof_type, blob_proof_type: BlobProofType::ProofOfEquivalence, - prover_args: test_proof_params(), + prover_args: test_proof_params(false), }; prove_block(l1_chain_spec, taiko_chain_spec, proof_request).await; } } + + #[tokio::test(flavor = "multi_thread")] + async fn test_prove_block_taiko_a7_aggregated() { + let proof_type = get_proof_type_from_env(); + let l1_network = Network::Holesky.to_string(); + let network = Network::TaikoA7.to_string(); + // Give the CI an simpler block to test because it doesn't have enough memory. + // Unfortunately that also means that kzg is not getting fully verified by CI. + let block_number = if is_ci() { 105987 } else { 101368 }; + let taiko_chain_spec = SupportedChainSpecs::default() + .get_chain_spec(&network) + .unwrap(); + let l1_chain_spec = SupportedChainSpecs::default() + .get_chain_spec(&l1_network) + .unwrap(); + + let proof_request = ProofRequest { + l1_inclusion_block_number: 0, + block_number, + network, + graffiti: B256::ZERO, + prover: Address::ZERO, + l1_network, + proof_type, + blob_proof_type: BlobProofType::ProofOfEquivalence, + prover_args: test_proof_params(true), + }; + let proof = prove_block(l1_chain_spec, taiko_chain_spec, proof_request).await; + + let input = AggregationGuestInput { + proofs: vec![proof.clone(), proof], + }; + + let output = AggregationGuestOutput { hash: B256::ZERO }; + + let aggregated_proof = proof_type + .aggregate_proofs( + input, + &output, + &serde_json::to_value(&test_proof_params(false)).unwrap(), + None, + ) + .await + .expect("proof aggregation failed"); + println!("aggregated proof: {:?}", aggregated_proof); + } } diff --git a/core/src/prover.rs b/core/src/prover.rs index 577c5318a..de89d859e 100644 --- a/core/src/prover.rs +++ b/core/src/prover.rs @@ -58,14 +58,28 @@ impl Prover for NativeProver { } Ok(Proof { + input: None, proof: None, quote: None, + uuid: None, + kzg_proof: None, }) } async fn cancel(_proof_key: ProofKey, _read: Box<&mut dyn IdStore>) -> ProverResult<()> { Ok(()) } + + async fn aggregate( + _input: raiko_lib::input::AggregationGuestInput, + _output: &raiko_lib::input::AggregationGuestOutput, + _config: &ProverConfig, + _store: Option<&mut dyn IdWrite>, + ) -> ProverResult { + Ok(Proof { + ..Default::default() + }) + } } #[ignore = "Only used to test serialized data"] diff --git a/host/src/server/api/v2/mod.rs b/host/src/server/api/v2/mod.rs index 65f4894e4..7c32b4ff0 100644 --- a/host/src/server/api/v2/mod.rs +++ b/host/src/server/api/v2/mod.rs @@ -151,6 +151,8 @@ pub fn create_router() -> Router { // Only add the concurrency limit to the proof route. We want to still be able to call // healthchecks and metrics to have insight into the system. .nest("/proof", proof::create_router()) + // TODO: Separate task or try to get it into /proof somehow? Probably separate + .nest("/aggregate", proof::create_router()) .nest("/health", v1::health::create_router()) .nest("/metrics", v1::metrics::create_router()) .merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", docs.clone())) diff --git a/lib/src/input.rs b/lib/src/input.rs index 1b0688b16..44a18ef65 100644 --- a/lib/src/input.rs +++ b/lib/src/input.rs @@ -12,7 +12,9 @@ use serde_with::serde_as; #[cfg(not(feature = "std"))] use crate::no_std::*; -use crate::{consts::ChainSpec, primitives::mpt::MptNode, utils::zlib_compress_data}; +use crate::{ + consts::ChainSpec, primitives::mpt::MptNode, prover::Proof, utils::zlib_compress_data, +}; /// Represents the state of an account's storage. /// The storage trie together with the used storage slots allow us to reconstruct all the @@ -41,6 +43,42 @@ pub struct GuestInput { pub taiko: TaikoGuestInput, } +/// External aggregation input. +#[derive(Debug, Clone, Default, Deserialize, Serialize)] +pub struct AggregationGuestInput { + /// All block proofs to prove + pub proofs: Vec, +} + +/// The raw proof data necessary to verify a proof +#[derive(Debug, Clone, Default, Deserialize, Serialize)] +pub struct RawProof { + /// The actual proof + pub proof: Vec, + /// The resulting hash + pub input: B256, +} + +/// External aggregation input. +#[derive(Debug, Clone, Default, Deserialize, Serialize)] +pub struct RawAggregationGuestInput { + /// All block proofs to prove + pub proofs: Vec, +} + +/// External aggregation input. +#[derive(Debug, Clone, Default, Deserialize, Serialize)] +pub struct AggregationGuestOutput { + /// The resulting hash + pub hash: B256, +} + +#[derive(Clone, Serialize, Deserialize)] +pub struct ZkAggregationGuestInput { + pub image_id: [u32; 8], + pub block_inputs: Vec, +} + impl From<(Block, Header, ChainSpec, TaikoGuestInput)> for GuestInput { fn from( (block, parent_header, chain_spec, taiko): (Block, Header, ChainSpec, TaikoGuestInput), diff --git a/lib/src/protocol_instance.rs b/lib/src/protocol_instance.rs index 5036173f7..786181cce 100644 --- a/lib/src/protocol_instance.rs +++ b/lib/src/protocol_instance.rs @@ -315,6 +315,27 @@ fn bytes_to_bytes32(input: &[u8]) -> [u8; 32] { bytes } +pub fn words_to_bytes_le(words: &[u32; 8]) -> [u8; 32] { + let mut bytes = [0u8; 32]; + for i in 0..8 { + let word_bytes = words[i].to_le_bytes(); + bytes[i * 4..(i + 1) * 4].copy_from_slice(&word_bytes); + } + bytes +} + +pub fn aggregation_output_combine(public_inputs: Vec) -> Vec { + let mut output = Vec::with_capacity(public_inputs.len() * 32); + for public_input in public_inputs.iter() { + output.extend_from_slice(&public_input.0); + } + output +} + +pub fn aggregation_output(program: B256, public_inputs: Vec) -> Vec { + aggregation_output_combine([vec![program], public_inputs].concat()) +} + #[cfg(test)] mod tests { use alloy_primitives::{address, b256}; diff --git a/lib/src/prover.rs b/lib/src/prover.rs index 948f57af4..5a1c7669e 100644 --- a/lib/src/prover.rs +++ b/lib/src/prover.rs @@ -2,7 +2,7 @@ use reth_primitives::{ChainId, B256}; use serde::{Deserialize, Serialize}; use utoipa::ToSchema; -use crate::input::{GuestInput, GuestOutput}; +use crate::input::{AggregationGuestInput, AggregationGuestOutput, GuestInput, GuestOutput}; #[derive(thiserror::Error, Debug)] pub enum ProverError { @@ -26,13 +26,19 @@ pub type ProverResult = core::result::Result; pub type ProverConfig = serde_json::Value; pub type ProofKey = (ChainId, B256, u8); -#[derive(Debug, Serialize, ToSchema, Deserialize, Default)] +#[derive(Clone, Debug, Serialize, ToSchema, Deserialize, Default)] /// The response body of a proof request. pub struct Proof { /// The proof either TEE or ZK. pub proof: Option, + /// The public input + pub input: Option, /// The TEE quote. pub quote: Option, + /// The assumption UUID. + pub uuid: Option, + /// The kzg proof. + pub kzg_proof: Option, } #[async_trait::async_trait] @@ -56,5 +62,12 @@ pub trait Prover { store: Option<&mut dyn IdWrite>, ) -> ProverResult; + async fn aggregate( + input: AggregationGuestInput, + output: &AggregationGuestOutput, + config: &ProverConfig, + store: Option<&mut dyn IdWrite>, + ) -> ProverResult; + async fn cancel(proof_key: ProofKey, read: Box<&mut dyn IdStore>) -> ProverResult<()>; } diff --git a/provers/risc0/builder/src/main.rs b/provers/risc0/builder/src/main.rs index b0de9edb1..523824f40 100644 --- a/provers/risc0/builder/src/main.rs +++ b/provers/risc0/builder/src/main.rs @@ -5,7 +5,10 @@ use std::path::PathBuf; fn main() { let pipeline = Risc0Pipeline::new("provers/risc0/guest", "release"); - pipeline.bins(&["risc0-guest"], "provers/risc0/driver/src/methods"); + pipeline.bins( + &["risc0-guest", "risc0-aggregation"], + "provers/risc0/driver/src/methods", + ); #[cfg(feature = "test")] pipeline.tests(&["risc0-guest"], "provers/risc0/driver/src/methods"); #[cfg(feature = "bench")] diff --git a/provers/risc0/driver/src/bonsai.rs b/provers/risc0/driver/src/bonsai.rs index 65c369338..40ee4f322 100644 --- a/provers/risc0/driver/src/bonsai.rs +++ b/provers/risc0/driver/src/bonsai.rs @@ -3,13 +3,14 @@ use crate::{ snarks::{stark2snark, verify_groth16_snark}, Risc0Response, }; +use alloy_primitives::B256; use log::{debug, error, info, warn}; use raiko_lib::{ primitives::keccak::keccak, prover::{IdWrite, ProofKey, ProverError, ProverResult}, }; use risc0_zkvm::{ - compute_image_id, is_dev_mode, serde::to_vec, sha::Digest, Assumption, ExecutorEnv, + compute_image_id, is_dev_mode, serde::to_vec, sha::Digest, AssumptionReceipt, ExecutorEnv, ExecutorImpl, Receipt, }; use serde::{de::DeserializeOwned, Serialize}; @@ -120,7 +121,7 @@ pub async fn maybe_prove, elf: &[u8], expected_output: &O, - assumptions: (Vec, Vec), + assumptions: (Vec>, Vec), proof_key: ProofKey, id_store: &mut Option<&mut dyn IdWrite>, ) -> Option<(String, Receipt)> { @@ -283,11 +284,13 @@ pub async fn prove_bonsai( pub async fn bonsai_stark_to_snark( stark_uuid: String, stark_receipt: Receipt, + input: B256, ) -> ProverResult { let image_id = Digest::from(RISC0_GUEST_ID); - let (snark_uuid, snark_receipt) = stark2snark(image_id, stark_uuid, stark_receipt) - .await - .map_err(|err| format!("Failed to convert STARK to SNARK: {err:?}"))?; + let (snark_uuid, snark_receipt) = + stark2snark(image_id, stark_uuid.clone(), stark_receipt.clone()) + .await + .map_err(|err| format!("Failed to convert STARK to SNARK: {err:?}"))?; info!("Validating SNARK uuid: {snark_uuid}"); @@ -296,7 +299,12 @@ pub async fn bonsai_stark_to_snark( .map_err(|err| format!("Failed to verify SNARK: {err:?}"))?; let snark_proof = format!("0x{}", hex::encode(enc_proof)); - Ok(Risc0Response { proof: snark_proof }) + Ok(Risc0Response { + proof: snark_proof, + receipt: serde_json::to_string(&stark_receipt).unwrap(), + uuid: stark_uuid, + input, + }) } /// Prove the given ELF locally with the given input and assumptions. The segments are @@ -305,7 +313,7 @@ pub fn prove_locally( segment_limit_po2: u32, encoded_input: Vec, elf: &[u8], - assumptions: Vec, + assumptions: Vec>, profile: bool, ) -> ProverResult { debug!("Proving with segment_limit_po2 = {segment_limit_po2:?}"); diff --git a/provers/risc0/driver/src/lib.rs b/provers/risc0/driver/src/lib.rs index 177ba6742..4cf4e32d7 100644 --- a/provers/risc0/driver/src/lib.rs +++ b/provers/risc0/driver/src/lib.rs @@ -2,15 +2,21 @@ #[cfg(feature = "bonsai-auto-scaling")] use crate::bonsai::auto_scaling::shutdown_bonsai; -use crate::methods::risc0_guest::RISC0_GUEST_ELF; +use crate::{ + methods::risc0_aggregation::RISC0_AGGREGATION_ELF, + methods::risc0_guest::{RISC0_GUEST_ELF, RISC0_GUEST_ID}, +}; use alloy_primitives::{hex::ToHexExt, B256}; -pub use bonsai::*; +use bonsai::{cancel_proof, maybe_prove}; use log::warn; use raiko_lib::{ - input::{GuestInput, GuestOutput}, + input::{ + AggregationGuestInput, AggregationGuestOutput, GuestInput, GuestOutput, + ZkAggregationGuestInput, + }, prover::{IdStore, IdWrite, Proof, ProofKey, Prover, ProverConfig, ProverError, ProverResult}, }; -use risc0_zkvm::serde::to_vec; +use risc0_zkvm::{serde::to_vec, Receipt}; use serde::{Deserialize, Serialize}; use serde_with::serde_as; use std::fmt::Debug; @@ -32,6 +38,9 @@ pub struct Risc0Param { #[derive(Clone, Serialize, Deserialize)] pub struct Risc0Response { pub proof: String, + pub receipt: String, + pub uuid: String, + pub input: B256, } impl From for Proof { @@ -39,6 +48,9 @@ impl From for Proof { Self { proof: Some(value.proof), quote: None, + input: Some(value.input), + uuid: Some(value.uuid), + kzg_proof: None, } } } @@ -70,25 +82,119 @@ impl Prover for Risc0Prover { encoded_input, RISC0_GUEST_ELF, &output.hash, - Default::default(), + (Vec::::new(), Vec::new()), proof_key, &mut id_store, ) .await; + let receipt = result.clone().unwrap().1.clone(); + let uuid = result.clone().unwrap().0; + let proof_gen_result = if result.is_some() { if config.snark && config.bonsai { let (stark_uuid, stark_receipt) = result.clone().unwrap(); - bonsai::bonsai_stark_to_snark(stark_uuid, stark_receipt) + bonsai::bonsai_stark_to_snark(stark_uuid, stark_receipt, output.hash) .await .map(|r0_response| r0_response.into()) .map_err(|e| ProverError::GuestError(e.to_string())) } else { warn!("proof is not in snark mode, please check."); let (_, stark_receipt) = result.clone().unwrap(); + Ok(Risc0Response { + proof: stark_receipt.journal.encode_hex_with_prefix(), + receipt: serde_json::to_string(&receipt).unwrap(), + uuid, + input: output.hash, + } + .into()) + } + } else { + Err(ProverError::GuestError( + "Failed to generate proof".to_string(), + )) + }; + + #[cfg(feature = "bonsai-auto-scaling")] + if config.bonsai { + // shutdown bonsai + shutdown_bonsai() + .await + .map_err(|e| ProverError::GuestError(e.to_string()))?; + } + + proof_gen_result + } + + async fn aggregate( + input: AggregationGuestInput, + output: &AggregationGuestOutput, + config: &ProverConfig, + id_store: Option<&mut dyn IdWrite>, + ) -> ProverResult { + let mut id_store = id_store; + let config = Risc0Param::deserialize(config.get("risc0").unwrap()).unwrap(); + let proof_key = (0, output.hash.clone(), RISC0_PROVER_CODE); + + // Extract the block proof receipts + let assumptions: Vec = input + .proofs + .iter() + .map(|proof| { + let receipt: Receipt = serde_json::from_str(&proof.quote.clone().unwrap()) + .expect("Failed to deserialize"); + receipt + }) + .collect::>(); + let block_inputs: Vec = input + .proofs + .iter() + .map(|proof| proof.input.unwrap()) + .collect::>(); + // For bonsai + let assumptions_uuids: Vec = input + .proofs + .iter() + .map(|proof| proof.uuid.clone().unwrap()) + .collect::>(); + + let input = ZkAggregationGuestInput { + image_id: RISC0_GUEST_ID, + block_inputs, + }; + + debug!("elf code length: {}", RISC0_AGGREGATION_ELF.len()); + let encoded_input = to_vec(&input).expect("Could not serialize proving input!"); + + let result = maybe_prove::( + &config, + encoded_input, + RISC0_AGGREGATION_ELF, + &output.hash, + (assumptions, assumptions_uuids), + proof_key, + &mut id_store, + ) + .await; + + let receipt = result.clone().unwrap().1.clone(); + let uuid = result.clone().unwrap().0; + let proof_gen_result = if result.is_some() { + if config.snark && config.bonsai { + let (stark_uuid, stark_receipt) = result.clone().unwrap(); + bonsai::bonsai_stark_to_snark(stark_uuid, stark_receipt, output.hash) + .await + .map(|r0_response| r0_response.into()) + .map_err(|e| ProverError::GuestError(e.to_string())) + } else { + warn!("proof is not in snark mode, please check."); + let (_, stark_receipt) = result.clone().unwrap(); Ok(Risc0Response { proof: stark_receipt.journal.encode_hex_with_prefix(), + receipt: serde_json::to_string(&receipt).unwrap(), + uuid, + input: output.hash, } .into()) } diff --git a/provers/risc0/driver/src/methods/mod.rs b/provers/risc0/driver/src/methods/mod.rs index 0211d22de..19219d8af 100644 --- a/provers/risc0/driver/src/methods/mod.rs +++ b/provers/risc0/driver/src/methods/mod.rs @@ -1,3 +1,4 @@ +pub mod risc0_aggregation; pub mod risc0_guest; // To build the following `$ cargo run --features test,bench --bin risc0-builder` diff --git a/provers/risc0/driver/src/methods/risc0_aggregation.rs b/provers/risc0/driver/src/methods/risc0_aggregation.rs new file mode 100644 index 000000000..f3b1fe64f --- /dev/null +++ b/provers/risc0/driver/src/methods/risc0_aggregation.rs @@ -0,0 +1,5 @@ +pub const RISC0_AGGREGATION_ELF: &[u8] = + include_bytes!("../../../guest/target/riscv32im-risc0-zkvm-elf/release/risc0-aggregation"); +pub const RISC0_AGGREGATION_ID: [u32; 8] = [ + 834745027, 3860709824, 1052791454, 925104520, 3609882255, 551703375, 2495735124, 1897996989, +]; diff --git a/provers/risc0/driver/src/snarks.rs b/provers/risc0/driver/src/snarks.rs index 5cc00d232..056a1e8cf 100644 --- a/provers/risc0/driver/src/snarks.rs +++ b/provers/risc0/driver/src/snarks.rs @@ -30,7 +30,7 @@ use risc0_zkvm::{ use tracing::{error as tracing_err, info as tracing_info}; -use crate::save_receipt; +use crate::bonsai::save_receipt; sol!( /// A Groth16 seal over the claimed receipt claim. diff --git a/provers/risc0/guest/Cargo.toml b/provers/risc0/guest/Cargo.toml index 28091f3c9..190ac9a60 100644 --- a/provers/risc0/guest/Cargo.toml +++ b/provers/risc0/guest/Cargo.toml @@ -9,6 +9,10 @@ edition = "2021" name = "zk_op" path = "src/zk_op.rs" +[[bin]] +name = "risc0-aggregation" +path = "src/aggregation.rs" + [[bin]] name = "sha256" path = "src/benchmark/sha256.rs" diff --git a/provers/risc0/guest/src/aggregation.rs b/provers/risc0/guest/src/aggregation.rs new file mode 100644 index 000000000..3f65701e1 --- /dev/null +++ b/provers/risc0/guest/src/aggregation.rs @@ -0,0 +1,21 @@ +#![no_main] +harness::entrypoint!(main); +use risc0_zkvm::{serde, guest::env}; +use raiko_lib::protocol_instance::words_to_bytes_le; +use raiko_lib::protocol_instance::aggregation_output; +use raiko_lib::input::ZkAggregationGuestInput; +use raiko_lib::primitives::B256; + +fn main() { + // Read the aggregation input + let input: ZkAggregationGuestInput = env::read(); + + // Verify the proofs. + for block_input in input.block_inputs.iter() { + // Verify that n has a known factorization. + env::verify(input.image_id, &serde::to_vec(&block_input).unwrap()).unwrap(); + } + + // The aggregation output + env::commit(&aggregation_output(B256::from(words_to_bytes_le(&input.image_id)), input.block_inputs)); +} diff --git a/provers/sgx/guest/src/app_args.rs b/provers/sgx/guest/src/app_args.rs index 35020f272..10f8ca18e 100644 --- a/provers/sgx/guest/src/app_args.rs +++ b/provers/sgx/guest/src/app_args.rs @@ -17,6 +17,8 @@ pub struct App { pub enum Command { /// Prove (i.e. sign) a single block and exit. OneShot(OneShotArgs), + /// Aggregate proofs + Aggregate(OneShotArgs), /// Bootstrap the application and then exit. The bootstrapping process generates the /// initial public-private key pair and stores it on the disk in an encrypted /// format using SGX encryption primitives. diff --git a/provers/sgx/guest/src/main.rs b/provers/sgx/guest/src/main.rs index accd54913..c7af5db30 100644 --- a/provers/sgx/guest/src/main.rs +++ b/provers/sgx/guest/src/main.rs @@ -3,6 +3,7 @@ extern crate secp256k1; use anyhow::{anyhow, Result}; use clap::Parser; +use one_shot::aggregate; use crate::{ app_args::{App, Command}, @@ -22,6 +23,10 @@ pub async fn main() -> Result<()> { println!("Starting one shot mode"); one_shot(args.global_opts, one_shot_args).await? } + Command::Aggregate(one_shot_args) => { + println!("Starting one shot mode"); + aggregate(args.global_opts, one_shot_args).await? + } Command::Bootstrap => { println!("Bootstrapping the app"); bootstrap(args.global_opts)? diff --git a/provers/sgx/guest/src/one_shot.rs b/provers/sgx/guest/src/one_shot.rs index 4c4cfee71..18778b457 100644 --- a/provers/sgx/guest/src/one_shot.rs +++ b/provers/sgx/guest/src/one_shot.rs @@ -8,8 +8,11 @@ use std::{ use anyhow::{anyhow, bail, Context, Error, Result}; use base64_serde::base64_serde_type; use raiko_lib::{ - builder::calculate_block_header, consts::VerifierType, input::GuestInput, primitives::Address, - protocol_instance::ProtocolInstance, + builder::calculate_block_header, + consts::VerifierType, + input::{GuestInput, RawAggregationGuestInput}, + primitives::{keccak, Address, B256}, + protocol_instance::{aggregation_output_combine, ProtocolInstance}, }; use secp256k1::{Keypair, SecretKey}; use serde::Serialize; @@ -147,6 +150,85 @@ pub async fn one_shot(global_opts: GlobalOpts, args: OneShotArgs) -> Result<()> let mut proof = Vec::with_capacity(SGX_PROOF_LEN); proof.extend(args.sgx_instance_id.to_be_bytes()); proof.extend(new_instance); + proof.extend(new_instance); + proof.extend(sig); + let proof = hex::encode(proof); + + // Store the public key address in the attestation data + save_attestation_user_report_data(new_instance)?; + + // Print out the proof and updated public info + let quote = get_sgx_quote()?; + let data = serde_json::json!({ + "proof": format!("0x{proof}"), + "quote": hex::encode(quote), + "public_key": format!("0x{new_pubkey}"), + "instance_address": new_instance.to_string(), + "input": pi_hash.to_string(), + }); + println!("{data}"); + + // Print out general SGX information + print_sgx_info() +} + +pub async fn aggregate(global_opts: GlobalOpts, args: OneShotArgs) -> Result<()> { + // Make sure this SGX instance was bootstrapped + let prev_privkey = load_bootstrap(&global_opts.secrets_dir) + .or_else(|_| bail!("Application was not bootstrapped or has a deprecated bootstrap.")) + .unwrap(); + + println!("Global options: {global_opts:?}, OneShot options: {args:?}"); + + let new_pubkey = public_key(&prev_privkey); + let new_instance = public_key_to_address(&new_pubkey); + + let input: RawAggregationGuestInput = + bincode::deserialize_from(std::io::stdin()).expect("unable to deserialize input"); + + // Make sure the chain of old/new public keys is preserved + let old_instance = Address::from_slice(&input.proofs[0].proof.clone()[4..24]); + let mut cur_instance = old_instance; + + // Verify the proofs + for proof in input.proofs.iter() { + // TODO: verify protocol instance data so we can trust the old/new instance data + assert_eq!( + recover_signer_unchecked(&proof.proof.clone()[44..].try_into().unwrap(), &proof.input,) + .unwrap(), + cur_instance, + ); + cur_instance = Address::from_slice(&proof.proof.clone()[24..44]); + } + + // Current public key needs to match latest proof new public key + assert_eq!(cur_instance, new_instance); + + // Calculate the aggregation hash + let aggregation_hash = keccak::keccak(aggregation_output_combine( + [ + vec![ + B256::left_padding_from(old_instance.as_ref()), + B256::left_padding_from(new_instance.as_ref()), + ], + input + .proofs + .iter() + .map(|proof| proof.input) + .collect::>(), + ] + .concat(), + )); + + // Sign the public aggregation hash + let sig = sign_message(&prev_privkey, aggregation_hash.into())?; + + // Create the proof for the onchain SGX verifier + const SGX_PROOF_LEN: usize = 89; + let mut proof = Vec::with_capacity(SGX_PROOF_LEN); + proof.extend(args.sgx_instance_id.to_be_bytes()); + proof.extend(old_instance); + proof.extend(new_instance); proof.extend(sig); let proof = hex::encode(proof); @@ -160,6 +242,7 @@ pub async fn one_shot(global_opts: GlobalOpts, args: OneShotArgs) -> Result<()> "quote": hex::encode(quote), "public_key": format!("0x{new_pubkey}"), "instance_address": new_instance.to_string(), + "input": B256::from(aggregation_hash).to_string(), }); println!("{data}"); diff --git a/provers/sgx/prover/Cargo.toml b/provers/sgx/prover/Cargo.toml index 69c0c3570..0c5f5a6c9 100644 --- a/provers/sgx/prover/Cargo.toml +++ b/provers/sgx/prover/Cargo.toml @@ -24,6 +24,7 @@ alloy-transport-http = { workspace = true } pem = { version = "3.0.4", optional = true } url = { workspace = true } anyhow = { workspace = true } +hex = { workspace = true } [features] default = ["dep:pem"] diff --git a/provers/sgx/prover/src/lib.rs b/provers/sgx/prover/src/lib.rs index 7f7688ac7..a74ee0e06 100644 --- a/provers/sgx/prover/src/lib.rs +++ b/provers/sgx/prover/src/lib.rs @@ -5,12 +5,16 @@ use std::{ fs::{copy, create_dir_all, remove_file}, path::{Path, PathBuf}, process::{Command as StdCommand, Output, Stdio}, - str, + str::{self, FromStr}, }; use once_cell::sync::Lazy; use raiko_lib::{ - input::{GuestInput, GuestOutput}, + input::{ + AggregationGuestInput, AggregationGuestOutput, GuestInput, GuestOutput, + RawAggregationGuestInput, RawProof, + }, + primitives::B256, prover::{IdStore, IdWrite, Proof, ProofKey, Prover, ProverConfig, ProverError, ProverResult}, }; use serde::{Deserialize, Serialize}; @@ -42,13 +46,17 @@ pub struct SgxResponse { /// proof format: 4b(id)+20b(pubkey)+65b(signature) pub proof: String, pub quote: String, + pub input: B256, } impl From for Proof { fn from(value: SgxResponse) -> Self { Self { proof: Some(value.proof), + input: Some(value.input), quote: Some(value.quote), + uuid: None, + kzg_proof: None, } } } @@ -147,6 +155,87 @@ impl Prover for SgxProver { sgx_proof.map(|r| r.into()) } + async fn aggregate( + input: AggregationGuestInput, + _output: &AggregationGuestOutput, + config: &ProverConfig, + _id_store: Option<&mut dyn IdWrite>, + ) -> ProverResult { + let sgx_param = SgxParam::deserialize(config.get("sgx").unwrap()).unwrap(); + + // Support both SGX and the direct backend for testing + let direct_mode = match env::var("SGX_DIRECT") { + Ok(value) => value == "1", + Err(_) => false, + }; + + println!( + "WARNING: running SGX in {} mode!", + if direct_mode { + "direct (a.k.a. simulation)" + } else { + "hardware" + } + ); + + // The working directory + let mut cur_dir = env::current_exe() + .expect("Fail to get current directory") + .parent() + .unwrap() + .to_path_buf(); + + // When running in tests we might be in a child folder + if cur_dir.ends_with("deps") { + cur_dir = cur_dir.parent().unwrap().to_path_buf(); + } + + println!("Current directory: {cur_dir:?}\n"); + // Working paths + PRIVATE_KEY + .get_or_init(|| async { cur_dir.join("secrets").join(PRIV_KEY_FILENAME) }) + .await; + GRAMINE_MANIFEST_TEMPLATE + .get_or_init(|| async { + cur_dir + .join(CONFIG) + .join("sgx-guest.local.manifest.template") + }) + .await; + + // The gramine command (gramine or gramine-direct for testing in non-SGX environment) + let gramine_cmd = || -> StdCommand { + let mut cmd = if direct_mode { + StdCommand::new("gramine-direct") + } else { + let mut cmd = StdCommand::new("sudo"); + cmd.arg("gramine-sgx"); + cmd + }; + cmd.current_dir(&cur_dir).arg(ELF_NAME); + cmd + }; + + // Setup: run this once while setting up your SGX instance + if sgx_param.setup { + setup(&cur_dir, direct_mode).await?; + } + + let mut sgx_proof = if sgx_param.bootstrap { + bootstrap(cur_dir.clone().join("secrets"), gramine_cmd()).await + } else { + // Dummy proof: it's ok when only setup/bootstrap was requested + Ok(SgxResponse::default()) + }; + + if sgx_param.prove { + // overwrite sgx_proof as the bootstrap quote stays the same in bootstrap & prove. + sgx_proof = aggregate(gramine_cmd(), input.clone(), sgx_param.instance_id).await + } + + sgx_proof.map(|r| r.into()) + } + async fn cancel(_proof_key: ProofKey, _read: Box<&mut dyn IdStore>) -> ProverResult<()> { Ok(()) } @@ -303,6 +392,54 @@ async fn prove( .map_err(|e| ProverError::GuestError(e.to_string()))? } +async fn aggregate( + mut gramine_cmd: StdCommand, + input: AggregationGuestInput, + instance_id: u64, +) -> ProverResult { + // Extract the useful parts of the proof here so the guest doesn't have to do it + let raw_input = RawAggregationGuestInput { + proofs: input + .proofs + .iter() + .map(|proof| RawProof { + input: proof.clone().input.unwrap(), + proof: hex::decode(&proof.clone().proof.unwrap()[2..]).unwrap(), + }) + .collect(), + }; + + tokio::task::spawn_blocking(move || { + let mut child = gramine_cmd + .arg("aggregate") + .arg("--sgx-instance-id") + .arg(instance_id.to_string()) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn() + .map_err(|e| format!("Could not spawn gramine cmd: {e}"))?; + let stdin = child.stdin.as_mut().expect("Failed to open stdin"); + let input_success = bincode::serialize_into(stdin, &raw_input); + let output_success = child.wait_with_output(); + + match (input_success, output_success) { + (Ok(_), Ok(output)) => { + handle_output(&output, "SGX prove")?; + Ok(parse_sgx_result(output.stdout)?) + } + (Err(i), output_success) => Err(ProverError::GuestError(format!( + "Can not serialize input for SGX {i}, output is {output_success:?}" + ))), + (Ok(_), Err(output_err)) => Err(ProverError::GuestError( + handle_gramine_error("Could not run SGX guest prover", output_err).to_string(), + )), + } + }) + .await + .map_err(|e| ProverError::GuestError(e.to_string()))? +} + fn parse_sgx_result(output: Vec) -> ProverResult { let mut json_value: Option = None; let output = String::from_utf8(output).map_err(|e| e.to_string())?; @@ -324,6 +461,7 @@ fn parse_sgx_result(output: Vec) -> ProverResult { Ok(SgxResponse { proof: extract_field("proof"), quote: extract_field("quote"), + input: B256::from_str(&extract_field("input")).unwrap(), }) } diff --git a/provers/sp1/builder/src/main.rs b/provers/sp1/builder/src/main.rs index 7db899a13..fe696594e 100644 --- a/provers/sp1/builder/src/main.rs +++ b/provers/sp1/builder/src/main.rs @@ -5,7 +5,7 @@ use std::path::PathBuf; fn main() { let pipeline = Sp1Pipeline::new("provers/sp1/guest", "release"); - pipeline.bins(&["sp1-guest"], "provers/sp1/guest/elf"); + pipeline.bins(&["sp1-guest", "sp1-aggregation"], "provers/sp1/guest/elf"); #[cfg(feature = "test")] pipeline.tests(&["sp1-guest"], "provers/sp1/guest/elf"); #[cfg(feature = "bench")] diff --git a/provers/sp1/driver/src/lib.rs b/provers/sp1/driver/src/lib.rs index da38ea11f..c8f0fe604 100644 --- a/provers/sp1/driver/src/lib.rs +++ b/provers/sp1/driver/src/lib.rs @@ -3,7 +3,7 @@ use once_cell::sync::Lazy; use raiko_lib::{ - input::{GuestInput, GuestOutput}, + input::{AggregationGuestInput, AggregationGuestOutput, GuestInput, GuestOutput}, prover::{IdStore, IdWrite, Proof, ProofKey, Prover, ProverConfig, ProverError, ProverResult}, Measurement, }; @@ -24,6 +24,7 @@ use std::{ use tracing::{debug, info}; pub const ELF: &[u8] = include_bytes!("../../guest/elf/sp1-guest"); +pub const AGGREGATION_ELF: &[u8] = include_bytes!("../../guest/elf/sp1-aggregation"); const SP1_PROVER_CODE: u8 = 1; static FIXTURE_PATH: Lazy = Lazy::new(|| Path::new(env!("CARGO_MANIFEST_DIR")).join("../contracts/src/fixtures/")); @@ -74,15 +75,18 @@ pub enum ProverMode { impl From for Proof { fn from(value: Sp1Response) -> Self { Self { - proof: Some(value.proof), + proof: value.proof, quote: None, + input: None, + uuid: None, + kzg_proof: None, } } } #[derive(Clone, Serialize, Deserialize)] pub struct Sp1Response { - pub proof: String, + pub proof: Option, } pub struct Sp1Prover; @@ -97,6 +101,8 @@ impl Prover for Sp1Prover { let param = Sp1Param::deserialize(config.get("sp1").unwrap()).unwrap(); let mode = param.prover.clone().unwrap_or_else(get_env_mock); + println!("param: {param:?}"); + let mut stdin = SP1Stdin::new(); stdin.write(&input); @@ -186,13 +192,15 @@ impl Prover for Sp1Prover { }; info!( - "Sp1 Prover: block {:?} completed! proof: {:?}", - output.header.number, proof_string + "Sp1 Prover: block {:?} completed! proof: {proof_string:?}", + output.header.number, ); - Ok::<_, ProverError>(Proof { - proof: proof_string, - quote: None, - }) + Ok::<_, ProverError>( + Sp1Response { + proof: proof_string, + } + .into(), + ) } async fn cancel(key: ProofKey, id_store: Box<&mut dyn IdStore>) -> ProverResult<()> { @@ -217,6 +225,15 @@ impl Prover for Sp1Prover { id_store.remove_id(key).await?; Ok(()) } + + async fn aggregate( + _input: AggregationGuestInput, + _output: &AggregationGuestOutput, + _config: &ProverConfig, + _store: Option<&mut dyn IdWrite>, + ) -> ProverResult { + todo!() + } } fn get_env_mock() -> ProverMode { @@ -235,10 +252,10 @@ fn get_env_mock() -> ProverMode { fn init_verifier() -> Result { // In cargo run, Cargo sets the working directory to the root of the workspace let contract_path = &*CONTRACT_PATH; - info!("Contract dir: {:?}", contract_path); + info!("Contract dir: {contract_path:?}"); let artifacts_dir = sp1_sdk::install::try_install_circuit_artifacts(); // Create the destination directory if it doesn't exist - fs::create_dir_all(&contract_path)?; + fs::create_dir_all(contract_path)?; // Read the entries in the source directory for entry in fs::read_dir(artifacts_dir)? { @@ -266,22 +283,21 @@ pub(crate) struct RaikoProofFixture { fn verify_sol(fixture: &RaikoProofFixture) -> ProverResult<()> { assert!(VERIFIER.is_ok()); - debug!("===> Fixture: {:#?}", fixture); + debug!("===> Fixture: {fixture:#?}"); // Save the fixture to a file. let fixture_path = &*FIXTURE_PATH; - info!("Writing fixture to: {:?}", fixture_path); + info!("Writing fixture to: {fixture_path:?}"); if !fixture_path.exists() { - std::fs::create_dir_all(fixture_path).map_err(|e| { - ProverError::GuestError(format!("Failed to create fixture path: {}", e)) - })?; + std::fs::create_dir_all(fixture_path.clone()) + .map_err(|e| ProverError::GuestError(format!("Failed to create fixture path: {e}")))?; } std::fs::write( fixture_path.join("fixture.json"), serde_json::to_string_pretty(&fixture).unwrap(), ) - .map_err(|e| ProverError::GuestError(format!("Failed to write fixture: {}", e)))?; + .map_err(|e| ProverError::GuestError(format!("Failed to write fixture: {e}")))?; let child = std::process::Command::new("forge") .arg("test") @@ -289,7 +305,7 @@ fn verify_sol(fixture: &RaikoProofFixture) -> ProverResult<()> { .stdout(std::process::Stdio::inherit()) // Inherit the parent process' stdout .spawn(); info!("Verification started {:?}", child); - child.map_err(|e| ProverError::GuestError(format!("Failed to run forge: {}", e)))?; + child.map_err(|e| ProverError::GuestError(format!("Failed to run forge: {e}")))?; Ok(()) } @@ -314,11 +330,11 @@ mod test { prover: Some(ProverMode::Network), verify: true, }; - let serialized = serde_json::to_value(¶m).unwrap(); + let serialized = serde_json::to_value(param).unwrap(); assert_eq!(json, serialized); let deserialized: Sp1Param = serde_json::from_value(serialized).unwrap(); - println!("{:?} {:?}", json, deserialized); + println!("{json:?} {deserialized:?}"); } #[test] diff --git a/provers/sp1/guest/Cargo.lock b/provers/sp1/guest/Cargo.lock index 3f00879a3..eab71cc46 100644 --- a/provers/sp1/guest/Cargo.lock +++ b/provers/sp1/guest/Cargo.lock @@ -3599,6 +3599,8 @@ dependencies = [ "lazy_static", "libm", "once_cell", + "p3-baby-bear", + "p3-field", "rand", "serde", "sha2", diff --git a/provers/sp1/guest/Cargo.toml b/provers/sp1/guest/Cargo.toml index efa74446b..21856d7ff 100644 --- a/provers/sp1/guest/Cargo.toml +++ b/provers/sp1/guest/Cargo.toml @@ -9,6 +9,10 @@ edition = "2021" name = "zk_op" path = "src/zk_op.rs" +[[bin]] +name = "sp1-aggregation" +path = "src/aggregation.rs" + [[bin]] name = "sha256" path = "src/benchmark/sha256.rs" diff --git a/provers/sp1/guest/elf/sp1-aggregation b/provers/sp1/guest/elf/sp1-aggregation new file mode 100755 index 000000000..0dbad9cd0 Binary files /dev/null and b/provers/sp1/guest/elf/sp1-aggregation differ diff --git a/provers/sp1/guest/src/aggregation.rs b/provers/sp1/guest/src/aggregation.rs new file mode 100644 index 000000000..b69a50bc2 --- /dev/null +++ b/provers/sp1/guest/src/aggregation.rs @@ -0,0 +1,25 @@ +//! Aggregates multiple block proofs + +#![no_main] +sp1_zkvm::entrypoint!(main); + +use sha2::Sha256; +use sha2::Digest; + +use raiko_lib::protocol_instance::words_to_bytes_le; +use raiko_lib::protocol_instance::aggregation_output; +use raiko_lib::input::ZkAggregationGuestInput; +use raiko_lib::primitives::B256; + +pub fn main() { + // Read the aggregation input + let input = sp1_zkvm::io::read::(); + + // Verify the block proofs. + for block_input in input.block_inputs.iter() { + sp1_zkvm::lib::verify::verify_sp1_proof(&input.image_id, &Sha256::digest(block_input).into()); + } + + // The aggregation output + sp1_zkvm::io::commit_slice(&aggregation_output(B256::from(words_to_bytes_le(&input.image_id)), input.block_inputs)); +} \ No newline at end of file diff --git a/provers/sp1/guest/src/zk_op.rs b/provers/sp1/guest/src/zk_op.rs index e6ed28be4..9fad10a90 100644 --- a/provers/sp1/guest/src/zk_op.rs +++ b/provers/sp1/guest/src/zk_op.rs @@ -6,7 +6,7 @@ use secp256k1::{ ecdsa::{RecoverableSignature, RecoveryId}, Message, }; -use sha2_v0_10_8 as sp1_sha2; +use sha2 as sp1_sha2; use sp1_core::utils::ec::{weierstrass::bn254::Bn254, AffinePoint};