Skip to content

Commit eb4d032

Browse files
petarvujovic98Brechtpdsmtmfft
authored
feat(core,host): initial aggregation API (#375)
* initial proof aggregation implementation * aggregation improvements + risc0 aggregation * sp1 aggregation fixes * sp1 aggregation elf * uuid support for risc0 aggregation * risc0 aggregation circuit compile fixes * fix sgx proof aggregation * fmt * feat(core,host): initial aggregation API * fix(core,host,sgx): fix compiler and clippy errors * fix(core,lib,provers): revert merge bugs and add sp1 stubs * fix(core): remove double member * fix(sp1): fix dependency naming * refactor(risc0): clean up aggregation file * fix(sp1): enable verification for proof aggregation * feat(host): migrate to v3 API * feat(sp1): run cargo fmt * feat(core): make `l1_inclusion_block_number` optional * fixproof req input into prove state manager Signed-off-by: smtmfft <[email protected]> * feat(core,host,lib,tasks): add aggregation tasks and API * fix(core): fix typo * fix v3 error return Signed-off-by: smtmfft <[email protected]> * feat(sp1): implement aggregate function * fix sgx aggregation for back compatibility Signed-off-by: smtmfft <[email protected]> * fix(lib): fix typo * fix risc0 aggregation Signed-off-by: smtmfft <[email protected]> * fix(host,sp1): handle statuses * enable sp1 aggregation Signed-off-by: smtmfft <[email protected]> * feat(host): error out on empty proof array request * fix(host): return proper status report * feat(host,tasks): adding details to error statuses * fix sp1 aggregation Signed-off-by: smtmfft <[email protected]> * update prove-block script Signed-off-by: smtmfft <[email protected]> * fix(fmt): run cargo fmt * fix(clippy): fix clippy issues * chore(repo): cleanup captured vars in format calls * fix(sp1): convert to proper types * chore(sp1): remove the unneccessary --------- Signed-off-by: smtmfft <[email protected]> Co-authored-by: Brecht Devos <[email protected]> Co-authored-by: smtmfft <[email protected]>
1 parent 7e10837 commit eb4d032

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+2195
-191
lines changed

Cargo.lock

+1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

core/src/interfaces.rs

+166-3
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@ use alloy_primitives::{Address, B256};
33
use clap::{Args, ValueEnum};
44
use raiko_lib::{
55
consts::VerifierType,
6-
input::{BlobProofType, GuestInput, GuestOutput},
6+
input::{
7+
AggregationGuestInput, AggregationGuestOutput, BlobProofType, GuestInput, GuestOutput,
8+
},
79
prover::{IdStore, IdWrite, Proof, ProofKey, Prover, ProverError},
810
};
911
use serde::{Deserialize, Serialize};
1012
use serde_json::Value;
1113
use serde_with::{serde_as, DisplayFromStr};
12-
use std::{collections::HashMap, path::Path, str::FromStr};
14+
use std::{collections::HashMap, fmt::Display, path::Path, str::FromStr};
1315
use utoipa::ToSchema;
1416

1517
#[derive(Debug, thiserror::Error, ToSchema)]
@@ -203,6 +205,47 @@ impl ProofType {
203205
}
204206
}
205207

208+
/// Run the prover driver depending on the proof type.
209+
pub async fn aggregate_proofs(
210+
&self,
211+
input: AggregationGuestInput,
212+
output: &AggregationGuestOutput,
213+
config: &Value,
214+
store: Option<&mut dyn IdWrite>,
215+
) -> RaikoResult<Proof> {
216+
let proof = match self {
217+
ProofType::Native => NativeProver::aggregate(input.clone(), output, config, store)
218+
.await
219+
.map_err(<ProverError as Into<RaikoError>>::into),
220+
ProofType::Sp1 => {
221+
#[cfg(feature = "sp1")]
222+
return sp1_driver::Sp1Prover::aggregate(input.clone(), output, config, store)
223+
.await
224+
.map_err(|e| e.into());
225+
#[cfg(not(feature = "sp1"))]
226+
Err(RaikoError::FeatureNotSupportedError(*self))
227+
}
228+
ProofType::Risc0 => {
229+
#[cfg(feature = "risc0")]
230+
return risc0_driver::Risc0Prover::aggregate(input.clone(), output, config, store)
231+
.await
232+
.map_err(|e| e.into());
233+
#[cfg(not(feature = "risc0"))]
234+
Err(RaikoError::FeatureNotSupportedError(*self))
235+
}
236+
ProofType::Sgx => {
237+
#[cfg(feature = "sgx")]
238+
return sgx_prover::SgxProver::aggregate(input.clone(), output, config, store)
239+
.await
240+
.map_err(|e| e.into());
241+
#[cfg(not(feature = "sgx"))]
242+
Err(RaikoError::FeatureNotSupportedError(*self))
243+
}
244+
}?;
245+
246+
Ok(proof)
247+
}
248+
206249
pub async fn cancel_proof(
207250
&self,
208251
proof_key: ProofKey,
@@ -302,7 +345,7 @@ pub struct ProofRequestOpt {
302345
pub prover_args: ProverSpecificOpts,
303346
}
304347

305-
#[derive(Default, Clone, Serialize, Deserialize, Debug, ToSchema, Args)]
348+
#[derive(Default, Clone, Serialize, Deserialize, Debug, ToSchema, Args, PartialEq, Eq, Hash)]
306349
pub struct ProverSpecificOpts {
307350
/// Native prover specific options.
308351
pub native: Option<Value>,
@@ -398,3 +441,123 @@ impl TryFrom<ProofRequestOpt> for ProofRequest {
398441
})
399442
}
400443
}
444+
445+
#[derive(Default, Clone, Serialize, Deserialize, Debug, ToSchema)]
446+
#[serde(default)]
447+
/// A request for proof aggregation of multiple proofs.
448+
pub struct AggregationRequest {
449+
/// The block numbers and l1 inclusion block numbers for the blocks to aggregate proofs for.
450+
pub block_numbers: Vec<(u64, Option<u64>)>,
451+
/// The network to generate the proof for.
452+
pub network: Option<String>,
453+
/// The L1 network to generate the proof for.
454+
pub l1_network: Option<String>,
455+
// Graffiti.
456+
pub graffiti: Option<String>,
457+
/// The protocol instance data.
458+
pub prover: Option<String>,
459+
/// The proof type.
460+
pub proof_type: Option<String>,
461+
/// Blob proof type.
462+
pub blob_proof_type: Option<String>,
463+
#[serde(flatten)]
464+
/// Any additional prover params in JSON format.
465+
pub prover_args: ProverSpecificOpts,
466+
}
467+
468+
impl AggregationRequest {
469+
/// Merge proof request options into aggregation request options.
470+
pub fn merge(&mut self, opts: &ProofRequestOpt) -> RaikoResult<()> {
471+
let this = serde_json::to_value(&self)?;
472+
let mut opts = serde_json::to_value(opts)?;
473+
merge(&mut opts, &this);
474+
*self = serde_json::from_value(opts)?;
475+
Ok(())
476+
}
477+
}
478+
479+
impl From<AggregationRequest> for Vec<ProofRequestOpt> {
480+
fn from(value: AggregationRequest) -> Self {
481+
value
482+
.block_numbers
483+
.iter()
484+
.map(
485+
|&(block_number, l1_inclusion_block_number)| ProofRequestOpt {
486+
block_number: Some(block_number),
487+
l1_inclusion_block_number,
488+
network: value.network.clone(),
489+
l1_network: value.l1_network.clone(),
490+
graffiti: value.graffiti.clone(),
491+
prover: value.prover.clone(),
492+
proof_type: value.proof_type.clone(),
493+
blob_proof_type: value.blob_proof_type.clone(),
494+
prover_args: value.prover_args.clone(),
495+
},
496+
)
497+
.collect()
498+
}
499+
}
500+
501+
impl From<ProofRequestOpt> for AggregationRequest {
502+
fn from(value: ProofRequestOpt) -> Self {
503+
let block_numbers = if let Some(block_number) = value.block_number {
504+
vec![(block_number, value.l1_inclusion_block_number)]
505+
} else {
506+
vec![]
507+
};
508+
509+
Self {
510+
block_numbers,
511+
network: value.network,
512+
l1_network: value.l1_network,
513+
graffiti: value.graffiti,
514+
prover: value.prover,
515+
proof_type: value.proof_type,
516+
blob_proof_type: value.blob_proof_type,
517+
prover_args: value.prover_args,
518+
}
519+
}
520+
}
521+
522+
#[derive(Default, Clone, Serialize, Deserialize, Debug, ToSchema, PartialEq, Eq, Hash)]
523+
#[serde(default)]
524+
/// A request for proof aggregation of multiple proofs.
525+
pub struct AggregationOnlyRequest {
526+
/// The block numbers and l1 inclusion block numbers for the blocks to aggregate proofs for.
527+
pub proofs: Vec<Proof>,
528+
/// The proof type.
529+
pub proof_type: Option<String>,
530+
#[serde(flatten)]
531+
/// Any additional prover params in JSON format.
532+
pub prover_args: ProverSpecificOpts,
533+
}
534+
535+
impl Display for AggregationOnlyRequest {
536+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
537+
f.write_str(&format!(
538+
"AggregationOnlyRequest {{ {:?}, {:?} }}",
539+
self.proof_type, self.prover_args
540+
))
541+
}
542+
}
543+
544+
impl From<(AggregationRequest, Vec<Proof>)> for AggregationOnlyRequest {
545+
fn from((request, proofs): (AggregationRequest, Vec<Proof>)) -> Self {
546+
Self {
547+
proofs,
548+
proof_type: request.proof_type,
549+
prover_args: request.prover_args,
550+
}
551+
}
552+
}
553+
554+
impl AggregationOnlyRequest {
555+
/// Merge proof request options into aggregation request options.
556+
pub fn merge(&mut self, opts: &ProofRequestOpt) -> RaikoResult<()> {
557+
let this = serde_json::to_value(&self)?;
558+
let mut opts = serde_json::to_value(opts)?;
559+
merge(&mut opts, &this);
560+
*self = serde_json::from_value(opts)?;
561+
Ok(())
562+
}
563+
}

core/src/lib.rs

+59-12
Original file line numberDiff line numberDiff line change
@@ -226,8 +226,9 @@ mod tests {
226226
use clap::ValueEnum;
227227
use raiko_lib::{
228228
consts::{Network, SupportedChainSpecs},
229-
input::BlobProofType,
229+
input::{AggregationGuestInput, AggregationGuestOutput, BlobProofType},
230230
primitives::B256,
231+
prover::Proof,
231232
};
232233
use serde_json::{json, Value};
233234
use std::{collections::HashMap, env};
@@ -242,7 +243,7 @@ mod tests {
242243
ci == "1"
243244
}
244245

245-
fn test_proof_params() -> HashMap<String, Value> {
246+
fn test_proof_params(enable_aggregation: bool) -> HashMap<String, Value> {
246247
let mut prover_args = HashMap::new();
247248
prover_args.insert(
248249
"native".to_string(),
@@ -256,7 +257,7 @@ mod tests {
256257
"sp1".to_string(),
257258
json! {
258259
{
259-
"recursion": "core",
260+
"recursion": if enable_aggregation { "compressed" } else { "plonk" },
260261
"prover": "mock",
261262
"verify": true
262263
}
@@ -278,8 +279,8 @@ mod tests {
278279
json! {
279280
{
280281
"instance_id": 121,
281-
"setup": true,
282-
"bootstrap": true,
282+
"setup": enable_aggregation,
283+
"bootstrap": enable_aggregation,
283284
"prove": true,
284285
}
285286
},
@@ -291,7 +292,7 @@ mod tests {
291292
l1_chain_spec: ChainSpec,
292293
taiko_chain_spec: ChainSpec,
293294
proof_request: ProofRequest,
294-
) {
295+
) -> Proof {
295296
let provider =
296297
RpcBlockDataProvider::new(&taiko_chain_spec.rpc, proof_request.block_number - 1)
297298
.expect("Could not create RpcBlockDataProvider");
@@ -301,10 +302,10 @@ mod tests {
301302
.await
302303
.expect("input generation failed");
303304
let output = raiko.get_output(&input).expect("output generation failed");
304-
let _proof = raiko
305+
raiko
305306
.prove(input, &output, None)
306307
.await
307-
.expect("proof generation failed");
308+
.expect("proof generation failed")
308309
}
309310

310311
#[ignore]
@@ -332,7 +333,7 @@ mod tests {
332333
l1_network,
333334
proof_type,
334335
blob_proof_type: BlobProofType::ProofOfEquivalence,
335-
prover_args: test_proof_params(),
336+
prover_args: test_proof_params(false),
336337
};
337338
prove_block(l1_chain_spec, taiko_chain_spec, proof_request).await;
338339
}
@@ -361,7 +362,7 @@ mod tests {
361362
l1_network,
362363
proof_type,
363364
blob_proof_type: BlobProofType::ProofOfEquivalence,
364-
prover_args: test_proof_params(),
365+
prover_args: test_proof_params(false),
365366
};
366367
prove_block(l1_chain_spec, taiko_chain_spec, proof_request).await;
367368
}
@@ -399,7 +400,7 @@ mod tests {
399400
l1_network,
400401
proof_type,
401402
blob_proof_type: BlobProofType::ProofOfEquivalence,
402-
prover_args: test_proof_params(),
403+
prover_args: test_proof_params(false),
403404
};
404405
prove_block(l1_chain_spec, taiko_chain_spec, proof_request).await;
405406
}
@@ -432,9 +433,55 @@ mod tests {
432433
l1_network,
433434
proof_type,
434435
blob_proof_type: BlobProofType::ProofOfEquivalence,
435-
prover_args: test_proof_params(),
436+
prover_args: test_proof_params(false),
436437
};
437438
prove_block(l1_chain_spec, taiko_chain_spec, proof_request).await;
438439
}
439440
}
441+
442+
#[tokio::test(flavor = "multi_thread")]
443+
async fn test_prove_block_taiko_a7_aggregated() {
444+
let proof_type = get_proof_type_from_env();
445+
let l1_network = Network::Holesky.to_string();
446+
let network = Network::TaikoA7.to_string();
447+
// Give the CI an simpler block to test because it doesn't have enough memory.
448+
// Unfortunately that also means that kzg is not getting fully verified by CI.
449+
let block_number = if is_ci() { 105987 } else { 101368 };
450+
let taiko_chain_spec = SupportedChainSpecs::default()
451+
.get_chain_spec(&network)
452+
.unwrap();
453+
let l1_chain_spec = SupportedChainSpecs::default()
454+
.get_chain_spec(&l1_network)
455+
.unwrap();
456+
457+
let proof_request = ProofRequest {
458+
block_number,
459+
l1_inclusion_block_number: 0,
460+
network,
461+
graffiti: B256::ZERO,
462+
prover: Address::ZERO,
463+
l1_network,
464+
proof_type,
465+
blob_proof_type: BlobProofType::ProofOfEquivalence,
466+
prover_args: test_proof_params(true),
467+
};
468+
let proof = prove_block(l1_chain_spec, taiko_chain_spec, proof_request).await;
469+
470+
let input = AggregationGuestInput {
471+
proofs: vec![proof.clone(), proof],
472+
};
473+
474+
let output = AggregationGuestOutput { hash: B256::ZERO };
475+
476+
let aggregated_proof = proof_type
477+
.aggregate_proofs(
478+
input,
479+
&output,
480+
&serde_json::to_value(&test_proof_params(false)).unwrap(),
481+
None,
482+
)
483+
.await
484+
.expect("proof aggregation failed");
485+
println!("aggregated proof: {aggregated_proof:?}");
486+
}
440487
}

core/src/preflight/util.rs

+1-4
Original file line numberDiff line numberDiff line change
@@ -136,11 +136,8 @@ pub async fn prepare_taiko_chain_input(
136136
RaikoError::Preflight("No L1 inclusion block hash for the requested block".to_owned())
137137
})?;
138138
info!(
139-
"L1 inclusion block number: {:?}, hash: {:?}. L1 state block number: {:?}, hash: {:?}",
140-
l1_inclusion_block_number,
141-
l1_inclusion_block_hash,
139+
"L1 inclusion block number: {l1_inclusion_block_number:?}, hash: {l1_inclusion_block_hash:?}. L1 state block number: {:?}, hash: {l1_state_block_hash:?}",
142140
l1_state_header.number,
143-
l1_state_block_hash
144141
);
145142

146143
// Fetch the tx data from either calldata or blobdata

core/src/prover.rs

+14
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,28 @@ impl Prover for NativeProver {
5858
}
5959

6060
Ok(Proof {
61+
input: None,
6162
proof: None,
6263
quote: None,
64+
uuid: None,
65+
kzg_proof: None,
6366
})
6467
}
6568

6669
async fn cancel(_proof_key: ProofKey, _read: Box<&mut dyn IdStore>) -> ProverResult<()> {
6770
Ok(())
6871
}
72+
73+
async fn aggregate(
74+
_input: raiko_lib::input::AggregationGuestInput,
75+
_output: &raiko_lib::input::AggregationGuestOutput,
76+
_config: &ProverConfig,
77+
_store: Option<&mut dyn IdWrite>,
78+
) -> ProverResult<Proof> {
79+
Ok(Proof {
80+
..Default::default()
81+
})
82+
}
6983
}
7084

7185
#[ignore = "Only used to test serialized data"]

0 commit comments

Comments
 (0)