Skip to content

feat(raiko): proof aggregation #347

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Prev Previous commit
Next Next commit
fmt
Brechtpd committed Aug 25, 2024
commit ee6fe0a562925995bfdc91a57860daf0529a3325
6 changes: 4 additions & 2 deletions core/src/interfaces.rs
Original file line number Diff line number Diff line change
@@ -3,7 +3,9 @@ use alloy_primitives::{Address, B256};
use clap::{Args, ValueEnum};
use raiko_lib::{
consts::VerifierType,
input::{AggregationGuestInput, AggregationGuestOutput, BlobProofType, GuestInput, GuestOutput},
input::{
AggregationGuestInput, AggregationGuestOutput, BlobProofType, GuestInput, GuestOutput,
},
primitives::eip4844::{calc_kzg_proof, commitment_to_version_hash, kzg_proof_to_bytes},
prover::{IdStore, IdWrite, Proof, ProofKey, Prover, ProverError},
};
@@ -460,4 +462,4 @@ pub struct AggregationRequest {
pub proof_type: ProofType,
/// Additional prover params.
pub prover_args: HashMap<String, Value>,
}
}
17 changes: 11 additions & 6 deletions core/src/lib.rs
Original file line number Diff line number Diff line change
@@ -222,7 +222,8 @@ mod tests {
use raiko_lib::{
consts::{Network, SupportedChainSpecs},
input::{AggregationGuestInput, AggregationGuestOutput, BlobProofType},
primitives::B256, prover::Proof,
primitives::B256,
prover::Proof,
};
use serde_json::{json, Value};
use std::{collections::HashMap, env};
@@ -431,13 +432,17 @@ mod tests {
proofs: vec![proof.clone(), proof],
};

let output = AggregationGuestOutput {
hash: B256::ZERO,
};
let output = AggregationGuestOutput { hash: B256::ZERO };

let aggregated_proof = proof_type
.aggregate_proofs(input, &output, &serde_json::to_value(&test_proof_params(false)).unwrap(), None)
.await.expect("proof aggregation failed");
.aggregate_proofs(
input,
&output,
&serde_json::to_value(&test_proof_params(false)).unwrap(),
None,
)
.await
.expect("proof aggregation failed");
println!("aggregated proof: {:?}", aggregated_proof);
}
}
4 changes: 3 additions & 1 deletion lib/src/input.rs
Original file line number Diff line number Diff line change
@@ -13,7 +13,9 @@ use serde_with::serde_as;

#[cfg(not(feature = "std"))]
use crate::no_std::*;
use crate::{consts::ChainSpec, primitives::mpt::MptNode, prover::Proof, utils::zlib_compress_data};
use crate::{
consts::ChainSpec, primitives::mpt::MptNode, prover::Proof, utils::zlib_compress_data,
};

/// Represents the state of an account's storage.
/// The storage trie together with the used storage slots allow us to reconstruct all the
5 changes: 1 addition & 4 deletions lib/src/protocol_instance.rs
Original file line number Diff line number Diff line change
@@ -232,10 +232,7 @@ pub fn aggregation_output_combine(public_inputs: Vec<B256>) -> Vec<u8> {
}

pub fn aggregation_output(program: B256, public_inputs: Vec<B256>) -> Vec<u8> {
aggregation_output_combine([
vec![program],
public_inputs,
].concat())
aggregation_output_combine([vec![program], public_inputs].concat())
}

#[cfg(test)]
2 changes: 1 addition & 1 deletion lib/src/prover.rs
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@ use reth_primitives::{ChainId, B256};
use serde::{Deserialize, Serialize};
use utoipa::ToSchema;

use crate::input::{GuestInput, GuestOutput, AggregationGuestInput, AggregationGuestOutput};
use crate::input::{AggregationGuestInput, AggregationGuestOutput, GuestInput, GuestOutput};

#[derive(thiserror::Error, Debug)]
pub enum ProverError {
5 changes: 4 additions & 1 deletion provers/risc0/builder/src/main.rs
Original file line number Diff line number Diff line change
@@ -5,7 +5,10 @@ use std::path::PathBuf;

fn main() {
let pipeline = Risc0Pipeline::new("provers/risc0/guest", "release");
pipeline.bins(&["risc0-guest", "risc0-aggregation"], "provers/risc0/driver/src/methods");
pipeline.bins(
&["risc0-guest", "risc0-aggregation"],
"provers/risc0/driver/src/methods",
);
#[cfg(feature = "test")]
pipeline.tests(&["risc0-guest"], "provers/risc0/driver/src/methods");
#[cfg(feature = "bench")]
12 changes: 7 additions & 5 deletions provers/risc0/driver/src/bonsai.rs
Original file line number Diff line number Diff line change
@@ -10,7 +10,8 @@ use raiko_lib::{
prover::{IdWrite, ProofKey, ProverError, ProverResult},
};
use risc0_zkvm::{
compute_image_id, is_dev_mode, serde::to_vec, sha::Digest, Assumption, AssumptionReceipt, ExecutorEnv, ExecutorImpl, Receipt
compute_image_id, is_dev_mode, serde::to_vec, sha::Digest, Assumption, AssumptionReceipt,
ExecutorEnv, ExecutorImpl, Receipt,
};
use serde::{de::DeserializeOwned, Serialize};
use std::{
@@ -286,9 +287,10 @@ pub async fn bonsai_stark_to_snark(
input: B256,
) -> ProverResult<Risc0Response> {
let image_id = Digest::from(RISC0_GUEST_ID);
let (snark_uuid, snark_receipt) = stark2snark(image_id, stark_uuid.clone(), stark_receipt.clone())
.await
.map_err(|err| format!("Failed to convert STARK to SNARK: {err:?}"))?;
let (snark_uuid, snark_receipt) =
stark2snark(image_id, stark_uuid.clone(), stark_receipt.clone())
.await
.map_err(|err| format!("Failed to convert STARK to SNARK: {err:?}"))?;

info!("Validating SNARK uuid: {snark_uuid}");

@@ -297,7 +299,7 @@ pub async fn bonsai_stark_to_snark(
.map_err(|err| format!("Failed to verify SNARK: {err:?}"))?;

let snark_proof = format!("0x{}", hex::encode(enc_proof));
Ok(Risc0Response {
Ok(Risc0Response {
proof: snark_proof,
receipt: serde_json::to_string(&stark_receipt).unwrap(),
uuid: stark_uuid,
35 changes: 18 additions & 17 deletions provers/risc0/driver/src/lib.rs
Original file line number Diff line number Diff line change
@@ -3,15 +3,18 @@
#[cfg(feature = "bonsai-auto-scaling")]
use crate::bonsai::auto_scaling::shutdown_bonsai;
use crate::{
methods::risc0_guest::{RISC0_GUEST_ELF, RISC0_GUEST_ID},
methods::risc0_aggregation::{RISC0_AGGREGATION_ELF, RISC0_AGGREGATION_ID},
methods::risc0_guest::{RISC0_GUEST_ELF, RISC0_GUEST_ID},
snarks::verify_groth16_snark,
};
use alloy_primitives::{hex::ToHexExt, B256};
use bonsai::{cancel_proof, maybe_prove};
use log::warn;
use raiko_lib::{
input::{AggregationGuestInput, AggregationGuestOutput, GuestInput, GuestOutput, ZkAggregationGuestInput},
input::{
AggregationGuestInput, AggregationGuestOutput, GuestInput, GuestOutput,
ZkAggregationGuestInput,
},
prover::{IdStore, IdWrite, Proof, ProofKey, Prover, ProverConfig, ProverError, ProverResult},
};
use risc0_zkvm::{serde::to_vec, sha::Digest, Receipt};
@@ -103,7 +106,7 @@ impl Prover for Risc0Prover {
proof: stark_receipt.journal.encode_hex_with_prefix(),
receipt: serde_json::to_string(&receipt).unwrap(),
uuid,
input: output.hash
input: output.hash,
}
.into())
}
@@ -132,28 +135,26 @@ impl Prover for Risc0Prover {
) -> ProverResult<Proof> {
let mut id_store = id_store;
let config = Risc0Param::deserialize(config.get("risc0").unwrap()).unwrap();
let proof_key = (
0,
output.hash.clone(),
RISC0_PROVER_CODE,
);
let proof_key = (0, output.hash.clone(), RISC0_PROVER_CODE);

// Extract the block proof receipts
let assumptions: Vec<Receipt> = input.proofs
let assumptions: Vec<Receipt> = input
.proofs
.iter()
.map(|proof| {
let receipt: Receipt = serde_json::from_str(&proof.quote.clone().unwrap()).expect("Failed to deserialize");
let receipt: Receipt = serde_json::from_str(&proof.quote.clone().unwrap())
.expect("Failed to deserialize");
receipt
})
.collect::<Vec<_>>();
let block_inputs: Vec<B256> = input.proofs
let block_inputs: Vec<B256> = input
.proofs
.iter()
.map(|proof| {
proof.input.unwrap()
})
.map(|proof| proof.input.unwrap())
.collect::<Vec<_>>();
// For bonsai
let assumptions_uuids: Vec<String> = input.proofs
let assumptions_uuids: Vec<String> = input
.proofs
.iter()
.map(|proof| proof.uuid.clone().unwrap())
.collect::<Vec<_>>();
@@ -179,7 +180,7 @@ impl Prover for Risc0Prover {

let receipt = result.clone().unwrap().1.clone();
let uuid = result.clone().unwrap().0;

let proof_gen_result = if result.is_some() {
if config.snark && config.bonsai {
let (stark_uuid, stark_receipt) = result.clone().unwrap();
@@ -194,7 +195,7 @@ impl Prover for Risc0Prover {
proof: stark_receipt.journal.encode_hex_with_prefix(),
receipt: serde_json::to_string(&receipt).unwrap(),
uuid,
input: output.hash
input: output.hash,
}
.into())
}
2 changes: 1 addition & 1 deletion provers/risc0/driver/src/methods/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
pub mod risc0_guest;
pub mod risc0_aggregation;
pub mod risc0_guest;

// To build the following `$ cargo run --features test,bench --bin risc0-builder`
// or `$ $TARGET=risc0 make test`
8 changes: 5 additions & 3 deletions provers/risc0/driver/src/methods/risc0_aggregation.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@

pub const RISC0_AGGREGATION_ELF: &[u8] = include_bytes!("../../../guest/target/riscv32im-risc0-zkvm-elf/release/risc0-aggregation");
pub const RISC0_AGGREGATION_ID: [u32; 8] = [834745027, 3860709824, 1052791454, 925104520, 3609882255, 551703375, 2495735124, 1897996989];
pub const RISC0_AGGREGATION_ELF: &[u8] =
include_bytes!("../../../guest/target/riscv32im-risc0-zkvm-elf/release/risc0-aggregation");
pub const RISC0_AGGREGATION_ID: [u32; 8] = [
834745027, 3860709824, 1052791454, 925104520, 3609882255, 551703375, 2495735124, 1897996989,
];
8 changes: 5 additions & 3 deletions provers/risc0/driver/src/methods/risc0_guest.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@

pub const RISC0_GUEST_ELF: &[u8] = include_bytes!("../../../guest/target/riscv32im-risc0-zkvm-elf/release/risc0-guest");
pub const RISC0_GUEST_ID: [u32; 8] = [395633193, 936490633, 451059026, 3955219165, 3837062005, 3412945197, 1606123515, 1234626647];
pub const RISC0_GUEST_ELF: &[u8] =
include_bytes!("../../../guest/target/riscv32im-risc0-zkvm-elf/release/risc0-guest");
pub const RISC0_GUEST_ID: [u32; 8] = [
395633193, 936490633, 451059026, 3955219165, 3837062005, 3412945197, 1606123515, 1234626647,
];
1 change: 0 additions & 1 deletion provers/risc0/driver/src/snarks.rs
Original file line number Diff line number Diff line change
@@ -32,7 +32,6 @@ use tracing::{error as tracing_err, info as tracing_info};

use crate::bonsai::save_receipt;


sol!(
/// A Groth16 seal over the claimed receipt claim.
struct Seal {
33 changes: 18 additions & 15 deletions provers/sgx/guest/src/one_shot.rs
Original file line number Diff line number Diff line change
@@ -8,7 +8,10 @@ use std::{
use anyhow::{anyhow, bail, Context, Error, Result};
use base64_serde::base64_serde_type;
use raiko_lib::{
builder::calculate_block_header, consts::VerifierType, input::{AggregationGuestInput, GuestInput, RawAggregationGuestInput}, primitives::{keccak, Address, B256},
builder::calculate_block_header,
consts::VerifierType,
input::{AggregationGuestInput, GuestInput, RawAggregationGuestInput},
primitives::{keccak, Address, B256},
protocol_instance::{aggregation_output, aggregation_output_combine, ProtocolInstance},
};
use secp256k1::{Keypair, SecretKey};
@@ -189,11 +192,10 @@ pub async fn aggregate(global_opts: GlobalOpts, args: OneShotArgs) -> Result<()>

// Verify the proofs
for proof in input.proofs.iter() {
// TODO: verify protocol instance data so we can trust the old/new instance data
assert_eq!(
recover_signer_unchecked(
&proof.proof.clone()[44..].try_into().unwrap(),
&proof.input,
).unwrap(),
recover_signer_unchecked(&proof.proof.clone()[44..].try_into().unwrap(), &proof.input,)
.unwrap(),
cur_instance,
);
cur_instance = Address::from_slice(&proof.proof.clone()[24..44]);
@@ -203,22 +205,23 @@ pub async fn aggregate(global_opts: GlobalOpts, args: OneShotArgs) -> Result<()>
assert_eq!(cur_instance, new_instance);

// Calculate the aggregation hash
let aggregation_hash = keccak::keccak(
aggregation_output_combine(
[
vec![
let aggregation_hash = keccak::keccak(aggregation_output_combine(
[
vec![
B256::left_padding_from(&old_instance.to_vec()),
B256::left_padding_from(&new_instance.to_vec()),
],
input.proofs.iter().map(|proof| proof.input).collect::<Vec<_>>(),
].concat(),
input
.proofs
.iter()
.map(|proof| proof.input)
.collect::<Vec<_>>(),
]
.concat(),
));

// Sign the public aggregation hash
let sig = sign_message(
&prev_privkey,
aggregation_hash.into(),
)?;
let sig = sign_message(&prev_privkey, aggregation_hash.into())?;

// Create the proof for the onchain SGX verifier
const SGX_PROOF_LEN: usize = 89;
16 changes: 12 additions & 4 deletions provers/sgx/prover/src/lib.rs
Original file line number Diff line number Diff line change
@@ -10,7 +10,12 @@ use std::{

use once_cell::sync::Lazy;
use raiko_lib::{
input::{AggregationGuestInput, AggregationGuestOutput, GuestInput, GuestOutput, RawAggregationGuestInput, RawProof}, primitives::B256, prover::{IdStore, IdWrite, Proof, ProofKey, Prover, ProverConfig, ProverError, ProverResult}
input::{
AggregationGuestInput, AggregationGuestOutput, GuestInput, GuestOutput,
RawAggregationGuestInput, RawProof,
},
primitives::B256,
prover::{IdStore, IdWrite, Proof, ProofKey, Prover, ProverConfig, ProverError, ProverResult},
};
use serde::{Deserialize, Serialize};
use serde_json::Value;
@@ -393,11 +398,14 @@ async fn aggregate(
) -> ProverResult<SgxResponse, ProverError> {
// Extract the useful parts of the proof here so the guest doesn't have to do it
let raw_input = RawAggregationGuestInput {
proofs: input.proofs.iter().map(|proof| RawProof {
proofs: input
.proofs
.iter()
.map(|proof| RawProof {
input: proof.clone().input.unwrap(),
proof: hex::decode(&proof.clone().proof.unwrap()[2..]).unwrap(),
}
).collect(),
})
.collect(),
};

tokio::task::spawn_blocking(move || {
32 changes: 18 additions & 14 deletions provers/sp1/driver/src/lib.rs
Original file line number Diff line number Diff line change
@@ -3,7 +3,10 @@

use once_cell::sync::Lazy;
use raiko_lib::{
input::{AggregationGuestInput, AggregationGuestOutput, GuestInput, GuestOutput, ZkAggregationGuestInput},
input::{
AggregationGuestInput, AggregationGuestOutput, GuestInput, GuestOutput,
ZkAggregationGuestInput,
},
prover::{IdStore, IdWrite, Proof, ProofKey, Prover, ProverConfig, ProverError, ProverResult},
Measurement,
};
@@ -13,7 +16,8 @@ use serde_with::serde_as;
use sp1_sdk::{
action,
network::client::NetworkClient,
proto::network::{ProofMode, UnclaimReason}, SP1Proof,
proto::network::{ProofMode, UnclaimReason},
SP1Proof,
};
use sp1_sdk::{HashableKey, ProverClient, SP1Stdin, SP1VerifyingKey};
use std::env;
@@ -152,7 +156,7 @@ impl Prover for Sp1Prover {
};

if param.verify {
if matches!(param.recursion, RecursionMode::Plonk ) {
if matches!(param.recursion, RecursionMode::Plonk) {
let time = Measurement::start("verify", false);
verify_sol(vk, prove_result)?;
time.stop_with("==> Verification complete");
@@ -174,9 +178,15 @@ impl Prover for Sp1Prover {
let mode = param.prover.clone().unwrap_or_else(get_env_mock);

// Extract the block proofs
let proofs: Vec<sp1_sdk::SP1ProofWithPublicValues> = input.proofs
let proofs: Vec<sp1_sdk::SP1ProofWithPublicValues> = input
.proofs
.iter()
.map(|input| serde_json::from_str::<sp1_sdk::SP1ProofWithPublicValues>(&input.proof.clone().unwrap()).unwrap())
.map(|input| {
serde_json::from_str::<sp1_sdk::SP1ProofWithPublicValues>(
&input.proof.clone().unwrap(),
)
.unwrap()
})
.collect::<Vec<_>>();

// Generate the proof for the given program.
@@ -236,16 +246,10 @@ impl Prover for Sp1Prover {
.map_err(|_| ProverError::GuestError("Sp1: requesting proof failed".to_owned()))?;
if let Some(id_store) = id_store {
id_store
.store_id(
(123456, output.hash, SP1_PROVER_CODE),
proof_id.clone(),
)
.store_id((123456, output.hash, SP1_PROVER_CODE), proof_id.clone())
.await?;
}
info!(
"Sp1 Prover: aggregation proof id {:?}",
proof_id
);
info!("Sp1 Prover: aggregation proof id {:?}", proof_id);
network_prover
.wait_proof::<sp1_sdk::SP1ProofWithPublicValues>(&proof_id)
.await
@@ -259,7 +263,7 @@ impl Prover for Sp1Prover {
};

if param.verify {
if matches!(param.recursion, RecursionMode::Plonk ) {
if matches!(param.recursion, RecursionMode::Plonk) {
let time = Measurement::start("verify", false);
verify_sol(vk, prove_result)?;
time.stop_with("==> Verification complete");