Skip to content

Commit e14f6d6

Browse files
committed
First working version of SP1 Distributed Prover
Optimized prototype Remove unecessary complexity on async send/receive Fix the commitment of shards public values Setting the shard_batch_size to 1 and processing multiple checkpoints in the workers Send only the first checkpoint and reexecute the runtime for the next ones Make the worker computation stateless Share the shard_batch_size and shard_size with workers Make worker able to receive multiple requests Redistribute a request when a worker fails Reducing the size of the shard public values Better request data structure Keep a single instance of program and machine Remove debugs about time duration
1 parent 65c1758 commit e14f6d6

23 files changed

+1310
-141
lines changed

Cargo.lock

+118-65
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

+22-3
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,19 @@ risc0-build = { version = "1.0.1" }
5454
risc0-binfmt = { version = "1.0.1" }
5555

5656
# SP1
57-
sp1-sdk = { git = "https://github.com/succinctlabs/sp1.git", branch = "main" }
58-
sp1-zkvm = { git = "https://github.com/succinctlabs/sp1.git", branch = "main" }
59-
sp1-helper = { git = "https://github.com/succinctlabs/sp1.git", branch = "main" }
57+
sp1-sdk = { git = "https://github.com/succinctlabs/sp1.git", rev = "14eb569d41d24721ffbd407d6060e202482d659c" }
58+
sp1-zkvm = { git = "https://github.com/succinctlabs/sp1.git", rev = "14eb569d41d24721ffbd407d6060e202482d659c" }
59+
sp1-helper = { git = "https://github.com/succinctlabs/sp1.git", rev = "14eb569d41d24721ffbd407d6060e202482d659c" }
60+
sp1-core = { git = "https://github.com/succinctlabs/sp1.git", rev = "14eb569d41d24721ffbd407d6060e202482d659c" }
61+
62+
63+
# Plonky3
64+
p3-field = { git = "https://github.com/Plonky3/Plonky3.git", rev = "88ea2b866e41329817e4761429b4a5a2a9751c07" }
65+
p3-challenger = { git = "https://github.com/Plonky3/Plonky3.git", rev = "88ea2b866e41329817e4761429b4a5a2a9751c07" }
66+
p3-poseidon2 = { git = "https://github.com/Plonky3/Plonky3.git", rev = "88ea2b866e41329817e4761429b4a5a2a9751c07" }
67+
p3-baby-bear = { git = "https://github.com/Plonky3/Plonky3.git", rev = "88ea2b866e41329817e4761429b4a5a2a9751c07" }
68+
p3-symmetric = { git = "https://github.com/Plonky3/Plonky3.git", rev = "88ea2b866e41329817e4761429b4a5a2a9751c07" }
69+
6070

6171
# alloy
6272
alloy-rlp = { version = "0.3.4", default-features = false }
@@ -187,3 +197,12 @@ revm-primitives = { git = "https://github.com/taikoxyz/revm.git", branch = "v36-
187197
revm-precompile = { git = "https://github.com/taikoxyz/revm.git", branch = "v36-taiko" }
188198
secp256k1 = { git = "https://github.com/CeciliaZ030/rust-secp256k1", branch = "sp1-patch" }
189199
blst = { git = "https://github.com/CeciliaZ030/blst.git", branch = "v0.3.12-serialize" }
200+
201+
# Patch Plonky3 for Serialize and Deserialize of DuplexChallenger
202+
[patch."https://github.com/Plonky3/Plonky3.git"]
203+
p3-field = { git = "https://github.com/Champii/Plonky3.git", branch = "serde_patch" }
204+
p3-challenger = { git = "https://github.com/Champii/Plonky3.git", branch = "serde_patch" }
205+
p3-poseidon2 = { git = "https://github.com/Champii/Plonky3.git", branch = "serde_patch" }
206+
p3-baby-bear = { git = "https://github.com/Champii/Plonky3.git", branch = "serde_patch" }
207+
p3-symmetric = { git = "https://github.com/Champii/Plonky3.git", branch = "serde_patch" }
208+

core/src/interfaces.rs

+14
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,10 @@ pub enum ProofType {
104104
///
105105
/// Uses the SP1 prover to build the block.
106106
Sp1,
107+
/// # Sp1Distributed
108+
///
109+
/// Uses the SP1 prover to build the block in a distributed way.
110+
Sp1Distributed,
107111
/// # Sgx
108112
///
109113
/// Builds the block on a SGX supported CPU to create a proof.
@@ -119,6 +123,7 @@ impl std::fmt::Display for ProofType {
119123
f.write_str(match self {
120124
ProofType::Native => "native",
121125
ProofType::Sp1 => "sp1",
126+
ProofType::Sp1Distributed => "sp1_distributed",
122127
ProofType::Sgx => "sgx",
123128
ProofType::Risc0 => "risc0",
124129
})
@@ -132,6 +137,7 @@ impl FromStr for ProofType {
132137
match s.trim().to_lowercase().as_str() {
133138
"native" => Ok(ProofType::Native),
134139
"sp1" => Ok(ProofType::Sp1),
140+
"sp1_distributed" => Ok(ProofType::Sp1Distributed),
135141
"sgx" => Ok(ProofType::Sgx),
136142
"risc0" => Ok(ProofType::Risc0),
137143
_ => Err(RaikoError::InvalidProofType(s.to_string())),
@@ -173,6 +179,14 @@ impl ProofType {
173179
#[cfg(not(feature = "sp1"))]
174180
Err(RaikoError::FeatureNotSupportedError(*self))
175181
}
182+
ProofType::Sp1Distributed => {
183+
#[cfg(feature = "sp1")]
184+
return sp1_driver::Sp1DistributedProver::run(input, output, config)
185+
.await
186+
.map_err(|e| e.into());
187+
#[cfg(not(feature = "sp1"))]
188+
Err(RaikoError::FeatureNotSupportedError(*self))
189+
}
176190
ProofType::Risc0 => {
177191
#[cfg(feature = "risc0")]
178192
return risc0_driver::Risc0Prover::run(input.clone(), output, config)

host/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ ethers-core = { workspace = true }
8181

8282
[features]
8383
default = []
84-
sp1 = ["raiko-core/sp1"]
84+
sp1 = ["raiko-core/sp1", "sp1-driver"]
8585
risc0 = ["raiko-core/risc0"]
8686
sgx = ["raiko-core/sgx"]
8787

host/src/lib.rs

+19
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,21 @@ pub struct Opts {
3434
/// [default: 0.0.0.0:8080]
3535
address: String,
3636

37+
#[arg(long, require_equals = true, default_value = "0.0.0.0:8081")]
38+
#[serde(default = "Opts::default_sp1_worker_address")]
39+
/// Distributed SP1 worker listening address
40+
/// [default: 0.0.0.0:8081]
41+
sp1_worker_address: String,
42+
43+
#[arg(long, default_value = None)]
44+
/// Distributed SP1 worker orchestrator address
45+
///
46+
/// Setting this will enable the worker and restrict it to only accept requests from
47+
/// this orchestrator
48+
///
49+
/// [default: None]
50+
sp1_orchestrator_address: Option<String>,
51+
3752
#[arg(long, require_equals = true, default_value = "16")]
3853
#[serde(default = "Opts::default_concurrency_limit")]
3954
/// Limit the max number of in-flight requests
@@ -87,6 +102,10 @@ impl Opts {
87102
"0.0.0.0:8080".to_string()
88103
}
89104

105+
fn default_sp1_worker_address() -> String {
106+
"0.0.0.0:8081".to_string()
107+
}
108+
90109
fn default_concurrency_limit() -> usize {
91110
16
92111
}

host/src/server/mod.rs

+5
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,14 @@ use tokio::net::TcpListener;
55
use tracing::info;
66

77
pub mod api;
8+
#[cfg(feature = "sp1")]
9+
pub mod worker;
810

911
/// Starts the proverd server.
1012
pub async fn serve(state: ProverState) -> anyhow::Result<()> {
13+
#[cfg(feature = "sp1")]
14+
worker::serve(state.clone()).await;
15+
1116
let addr = SocketAddr::from_str(&state.opts.address)
1217
.map_err(|_| HostError::InvalidAddress(state.opts.address.clone()))?;
1318
let listener = TcpListener::bind(addr).await?;

host/src/server/worker.rs

+136
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
use crate::ProverState;
2+
use raiko_lib::prover::{ProverError, WorkerError};
3+
use sp1_driver::{
4+
sp1_specifics::{Challenger, CoreSC, Machine, Program, ProvingKey, RiscvAir},
5+
RequestData, WorkerProtocol, WorkerRequest, WorkerResponse, WorkerSocket, ELF,
6+
};
7+
use tokio::net::TcpListener;
8+
use tracing::{error, info, warn};
9+
10+
pub async fn serve(state: ProverState) {
11+
if state.opts.sp1_orchestrator_address.is_some() {
12+
tokio::spawn(listen_worker(state));
13+
}
14+
}
15+
16+
async fn listen_worker(state: ProverState) {
17+
info!(
18+
"Listening as a SP1 worker on: {}",
19+
state.opts.sp1_worker_address
20+
);
21+
22+
let listener = TcpListener::bind(state.opts.sp1_worker_address).await.unwrap();
23+
24+
loop {
25+
let Ok((socket, addr)) = listener.accept().await else {
26+
error!("Error while accepting connection from orchestrator: Closing socket");
27+
28+
return;
29+
};
30+
31+
if let Some(orchestrator_address) = &state.opts.sp1_orchestrator_address {
32+
if addr.ip().to_string() != *orchestrator_address {
33+
warn!("Unauthorized orchestrator connection from: {}", addr);
34+
35+
continue;
36+
}
37+
}
38+
39+
// We purposely don't spawn the task here, as we want to block to limit the number
40+
// of concurrent connections to one.
41+
if let Err(e) = handle_worker_socket(WorkerSocket::from_stream(socket)).await {
42+
error!("Error while handling worker socket: {:?}", e);
43+
}
44+
}
45+
}
46+
47+
async fn handle_worker_socket(mut socket: WorkerSocket) -> Result<(), ProverError> {
48+
let program = Program::from(ELF);
49+
let config = CoreSC::default();
50+
51+
let machine = RiscvAir::machine(config.clone());
52+
let (pk, _vk) = machine.setup(&program);
53+
54+
while let Ok(protocol) = socket.receive().await {
55+
match protocol {
56+
WorkerProtocol::Request(request) => match request {
57+
WorkerRequest::Ping => handle_ping(&mut socket).await?,
58+
WorkerRequest::Commit(request_data) => {
59+
handle_commit(&mut socket, &program, &machine, request_data).await?
60+
}
61+
WorkerRequest::Prove {
62+
request_data,
63+
challenger,
64+
} => {
65+
handle_prove(
66+
&mut socket,
67+
&program,
68+
&machine,
69+
&pk,
70+
request_data,
71+
challenger,
72+
)
73+
.await?
74+
}
75+
},
76+
_ => Err(WorkerError::InvalidRequest)?,
77+
}
78+
}
79+
80+
Ok(())
81+
}
82+
83+
async fn handle_ping(socket: &mut WorkerSocket) -> Result<(), WorkerError> {
84+
socket
85+
.send(WorkerProtocol::Response(WorkerResponse::Pong))
86+
.await
87+
}
88+
89+
async fn handle_commit(
90+
socket: &mut WorkerSocket,
91+
program: &Program,
92+
machine: &Machine,
93+
request_data: RequestData,
94+
) -> Result<(), WorkerError> {
95+
let (commitments, shards_public_values) = sp1_driver::sp1_specifics::commit(
96+
program,
97+
machine,
98+
request_data.checkpoint,
99+
request_data.nb_checkpoints,
100+
request_data.public_values,
101+
request_data.shard_batch_size,
102+
request_data.shard_size,
103+
)?;
104+
105+
socket
106+
.send(WorkerProtocol::Response(WorkerResponse::Commitment {
107+
commitments,
108+
shards_public_values,
109+
}))
110+
.await
111+
}
112+
113+
async fn handle_prove(
114+
socket: &mut WorkerSocket,
115+
program: &Program,
116+
machine: &Machine,
117+
pk: &ProvingKey,
118+
request_data: RequestData,
119+
challenger: Challenger,
120+
) -> Result<(), WorkerError> {
121+
let proof = sp1_driver::sp1_specifics::prove(
122+
program,
123+
machine,
124+
pk,
125+
request_data.checkpoint,
126+
request_data.nb_checkpoints,
127+
request_data.public_values,
128+
request_data.shard_batch_size,
129+
request_data.shard_size,
130+
challenger,
131+
)?;
132+
133+
socket
134+
.send(WorkerProtocol::Response(WorkerResponse::Proof(proof)))
135+
.await
136+
}

lib/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -71,4 +71,4 @@ std = [
7171
sgx = []
7272
sp1 = []
7373
risc0 = []
74-
sp1-cycle-tracker = []
74+
sp1-cycle-tracker = []

lib/src/prover.rs

+20
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ pub enum ProverError {
1111
FileIo(#[from] std::io::Error),
1212
#[error("ProverError::Param `{0}`")]
1313
Param(#[from] serde_json::Error),
14+
#[error("ProverError::Worker `{0}`")]
15+
Worker(#[from] WorkerError),
1416
}
1517

1618
impl From<String> for ProverError {
@@ -37,3 +39,21 @@ pub fn to_proof(proof: ProverResult<impl Serialize>) -> ProverResult<Proof> {
3739
serde_json::to_value(res).map_err(|err| ProverError::GuestError(err.to_string()))
3840
})
3941
}
42+
43+
#[derive(ThisError, Debug)]
44+
pub enum WorkerError {
45+
#[error("All workers failed")]
46+
AllWorkersFailed,
47+
#[error("Worker IO error: {0}")]
48+
IO(#[from] std::io::Error),
49+
#[error("Worker Serde error: {0}")]
50+
Serde(#[from] bincode::Error),
51+
#[error("Worker invalid version")]
52+
InvalidVersion,
53+
#[error("Worker invalid request")]
54+
InvalidRequest,
55+
#[error("Worker invalid response")]
56+
InvalidResponse,
57+
#[error("Worker payload too big")]
58+
PayloadTooBig,
59+
}

provers/sp1/driver/Cargo.toml

+24
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,46 @@ alloy-sol-types = { workspace = true }
1616
serde = { workspace = true, optional = true }
1717
serde_json = { workspace = true, optional = true }
1818
sp1-sdk = { workspace = true, optional = true }
19+
sp1-core = { workspace = true, optional = true }
1920
anyhow = { workspace = true, optional = true }
2021
once_cell = { workspace = true, optional = true }
2122
sha3 = { workspace = true, optional = true, default-features = false }
2223
tracing = { workspace = true, optional = true }
2324

25+
log = { workspace = true, optional = true }
26+
tokio = { workspace = true, optional = true }
27+
tempfile = { workspace = true, optional = true }
28+
bincode = { workspace = true, optional = true }
29+
30+
p3-field = { workspace = true, optional = true }
31+
p3-challenger = { workspace = true, optional = true }
32+
p3-poseidon2 = { workspace = true, optional = true }
33+
p3-baby-bear = { workspace = true, optional = true }
34+
p3-symmetric = { workspace = true, optional = true }
35+
2436

2537
[features]
2638
enable = [
2739
"serde",
2840
"serde_json",
2941
"raiko-lib",
3042
"sp1-sdk",
43+
"sp1-core",
3144
"anyhow",
3245
"alloy-primitives",
3346
"once_cell",
3447
"sha3",
3548
"tracing",
49+
50+
"log",
51+
"tokio",
52+
"tempfile",
53+
"bincode",
54+
55+
"p3-field",
56+
"p3-challenger",
57+
"p3-poseidon2",
58+
"p3-baby-bear",
59+
"p3-symmetric",
3660
]
3761
neon = ["sp1-sdk?/neon"]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
mod prover;
2+
pub mod sp1_specifics;
3+
mod worker;
4+
5+
pub use prover::Sp1DistributedProver;
6+
pub use worker::{
7+
RequestData, WorkerEnvelope, WorkerPool, WorkerProtocol, WorkerRequest, WorkerResponse,
8+
WorkerSocket,
9+
};

0 commit comments

Comments
 (0)