Skip to content

Commit 21a6338

Browse files
committed
First working version of SP1 Distributed Prover
1 parent 48ea079 commit 21a6338

24 files changed

+1212
-117
lines changed

Cargo.lock

+168-65
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

+22-3
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,19 @@ risc0-build = { version = "0.21.0" }
5454
risc0-binfmt = { version = "0.21.0" }
5555

5656
# SP1
57-
sp1-sdk = { git = "https://github.com/succinctlabs/sp1.git", branch = "main" }
58-
sp1-zkvm = { git = "https://github.com/succinctlabs/sp1.git", branch = "main" }
59-
sp1-helper = { git = "https://github.com/succinctlabs/sp1.git", branch = "main" }
57+
sp1-sdk = { git = "https://github.com/succinctlabs/sp1.git", rev = "14eb569d41d24721ffbd407d6060e202482d659c" }
58+
sp1-zkvm = { git = "https://github.com/succinctlabs/sp1.git", rev = "14eb569d41d24721ffbd407d6060e202482d659c" }
59+
sp1-helper = { git = "https://github.com/succinctlabs/sp1.git", rev = "14eb569d41d24721ffbd407d6060e202482d659c" }
60+
sp1-core = { git = "https://github.com/succinctlabs/sp1.git", rev = "14eb569d41d24721ffbd407d6060e202482d659c" }
61+
62+
63+
# Plonky3
64+
p3-field = { git = "https://github.com/Plonky3/Plonky3.git", rev = "88ea2b866e41329817e4761429b4a5a2a9751c07" }
65+
p3-challenger = { git = "https://github.com/Plonky3/Plonky3.git", rev = "88ea2b866e41329817e4761429b4a5a2a9751c07" }
66+
p3-poseidon2 = { git = "https://github.com/Plonky3/Plonky3.git", rev = "88ea2b866e41329817e4761429b4a5a2a9751c07" }
67+
p3-baby-bear = { git = "https://github.com/Plonky3/Plonky3.git", rev = "88ea2b866e41329817e4761429b4a5a2a9751c07" }
68+
p3-symmetric = { git = "https://github.com/Plonky3/Plonky3.git", rev = "88ea2b866e41329817e4761429b4a5a2a9751c07" }
69+
6070

6171
# alloy
6272
alloy-rlp = { version = "0.3.4", default-features = false }
@@ -149,6 +159,7 @@ secp256k1 = { version = "0.29", default-features = false, features = [
149159
"global-context",
150160
"recovery",
151161
] }
162+
async-channel = "2.3.1"
152163

153164
# macro
154165
syn = { version = "1.0", features = ["full"] }
@@ -188,3 +199,11 @@ revm-primitives = { git = "https://github.com/taikoxyz/revm.git", branch = "v36-
188199
revm-precompile = { git = "https://github.com/taikoxyz/revm.git", branch = "v36-taiko" }
189200
secp256k1 = { git = "https://github.com/CeciliaZ030/rust-secp256k1", branch = "sp1-patch" }
190201
blst = { git = "https://github.com/CeciliaZ030/blst.git", branch = "v0.3.12-serialize" }
202+
203+
# Patch Plonky3 for Serialize and Deserialize of DuplexChallenger
204+
[patch."https://github.com/Plonky3/Plonky3.git"]
205+
p3-field = { git = "https://github.com/Champii/Plonky3.git", branch = "serde_patch" }
206+
p3-challenger = { git = "https://github.com/Champii/Plonky3.git", branch = "serde_patch" }
207+
p3-poseidon2 = { git = "https://github.com/Champii/Plonky3.git", branch = "serde_patch" }
208+
p3-baby-bear = { git = "https://github.com/Champii/Plonky3.git", branch = "serde_patch" }
209+
p3-symmetric = { git = "https://github.com/Champii/Plonky3.git", branch = "serde_patch" }

core/src/interfaces.rs

+14
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,10 @@ pub enum ProofType {
104104
///
105105
/// Uses the SP1 prover to build the block.
106106
Sp1,
107+
/// # Sp1Distributed
108+
///
109+
/// Uses the SP1 prover to build the block in a distributed way.
110+
Sp1Distributed,
107111
/// # Sgx
108112
///
109113
/// Builds the block on a SGX supported CPU to create a proof.
@@ -119,6 +123,7 @@ impl std::fmt::Display for ProofType {
119123
f.write_str(match self {
120124
ProofType::Native => "native",
121125
ProofType::Sp1 => "sp1",
126+
ProofType::Sp1Distributed => "sp1_distributed",
122127
ProofType::Sgx => "sgx",
123128
ProofType::Risc0 => "risc0",
124129
})
@@ -132,6 +137,7 @@ impl FromStr for ProofType {
132137
match s.trim().to_lowercase().as_str() {
133138
"native" => Ok(ProofType::Native),
134139
"sp1" => Ok(ProofType::Sp1),
140+
"sp1_distributed" => Ok(ProofType::Sp1Distributed),
135141
"sgx" => Ok(ProofType::Sgx),
136142
"risc0" => Ok(ProofType::Risc0),
137143
_ => Err(RaikoError::InvalidProofType(s.to_string())),
@@ -159,6 +165,14 @@ impl ProofType {
159165
#[cfg(not(feature = "sp1"))]
160166
Err(RaikoError::FeatureNotSupportedError(*self))
161167
}
168+
ProofType::Sp1Distributed => {
169+
#[cfg(feature = "sp1")]
170+
return sp1_driver::Sp1DistributedProver::run(input, output, config)
171+
.await
172+
.map_err(|e| e.into());
173+
#[cfg(not(feature = "sp1"))]
174+
Err(RaikoError::FeatureNotSupportedError(self.clone()))
175+
}
162176
ProofType::Risc0 => {
163177
#[cfg(feature = "risc0")]
164178
return risc0_driver::Risc0Prover::run(input.clone(), output, config)

host/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ ethers-core = { workspace = true }
8282

8383
[features]
8484
default = []
85-
sp1 = ["raiko-core/sp1"]
85+
sp1 = ["raiko-core/sp1", "sp1-driver"]
8686
risc0 = ["raiko-core/risc0"]
8787
sgx = ["raiko-core/sgx"]
8888

host/src/lib.rs

+15
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ fn default_address() -> String {
3939
"0.0.0.0:8080".to_string()
4040
}
4141

42+
fn default_worker_address() -> String {
43+
"0.0.0.0:8081".to_string()
44+
}
45+
4246
fn default_concurrency_limit() -> usize {
4347
16
4448
}
@@ -69,6 +73,17 @@ pub struct Cli {
6973
/// [default: 0.0.0.0:8080]
7074
address: String,
7175

76+
#[arg(long, require_equals = true, default_value = "0.0.0.0:8081")]
77+
#[serde(default = "default_worker_address")]
78+
/// Distributed SP1 worker listening address
79+
/// [default: 0.0.0.0:8081]
80+
worker_address: String,
81+
82+
#[arg(long, default_value = None)]
83+
/// Distributed SP1 worker orchestrator address
84+
/// [default: None]
85+
orchestrator_address: Option<String>,
86+
7287
#[arg(long, require_equals = true, default_value = "16")]
7388
#[serde(default = "default_concurrency_limit")]
7489
/// Limit the max number of in-flight requests

host/src/server/mod.rs

+5
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,14 @@ use tokio::net::TcpListener;
55
use tracing::info;
66

77
pub mod api;
8+
#[cfg(feature = "sp1")]
9+
pub mod worker;
810

911
/// Starts the proverd server.
1012
pub async fn serve(state: ProverState) -> anyhow::Result<()> {
13+
#[cfg(feature = "sp1")]
14+
worker::serve(state.clone()).await;
15+
1116
let addr = SocketAddr::from_str(&state.opts.address)
1217
.map_err(|_| HostError::InvalidAddress(state.opts.address.clone()))?;
1318
let listener = TcpListener::bind(addr).await?;

host/src/server/worker.rs

+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
use crate::ProverState;
2+
use raiko_lib::prover::{ProverError, WorkerError};
3+
use sp1_driver::{PartialProofRequest, WorkerProtocol, WorkerSocket};
4+
use tokio::net::TcpListener;
5+
use tracing::{error, info, warn};
6+
7+
async fn handle_worker_socket(mut socket: WorkerSocket) -> Result<(), ProverError> {
8+
let protocol = socket.receive().await?;
9+
10+
info!("Received request from orchestrator: {}", protocol);
11+
12+
match protocol {
13+
WorkerProtocol::Ping => {
14+
socket.send(WorkerProtocol::Pong).await?;
15+
}
16+
WorkerProtocol::PartialProofRequest(data) => {
17+
process_partial_proof_request(socket, data).await?;
18+
}
19+
_ => Err(WorkerError::InvalidRequest)?,
20+
}
21+
22+
Ok(())
23+
}
24+
25+
async fn process_partial_proof_request(
26+
mut socket: WorkerSocket,
27+
data: PartialProofRequest,
28+
) -> Result<(), ProverError> {
29+
let partial_proof = sp1_driver::Sp1DistributedProver::run_as_worker(data).await?;
30+
31+
socket
32+
.send(WorkerProtocol::PartialProofResponse(partial_proof))
33+
.await?;
34+
35+
Ok(())
36+
}
37+
38+
async fn listen_worker(state: ProverState) {
39+
info!(
40+
"Listening as a SP1 worker on: {}",
41+
state.opts.worker_address
42+
);
43+
44+
let listener = TcpListener::bind(state.opts.worker_address).await.unwrap();
45+
46+
loop {
47+
let Ok((socket, addr)) = listener.accept().await else {
48+
error!("Error while accepting connection from orchestrator: Closing socket");
49+
50+
return;
51+
};
52+
53+
if let Some(orchestrator_address) = &state.opts.orchestrator_address {
54+
if addr.ip().to_string() != *orchestrator_address {
55+
warn!("Unauthorized orchestrator connection from: {}", addr);
56+
57+
continue;
58+
}
59+
}
60+
61+
// We purposely don't spawn the task here, as we want to block to limit the number
62+
// of concurrent connections to one.
63+
if let Err(e) = handle_worker_socket(WorkerSocket::new(socket)).await {
64+
error!("Error while handling worker socket: {:?}", e);
65+
}
66+
}
67+
}
68+
69+
pub async fn serve(state: ProverState) {
70+
if state.opts.orchestrator_address.is_some() {
71+
tokio::spawn(listen_worker(state));
72+
}
73+
}

lib/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -71,4 +71,4 @@ std = [
7171
sgx = []
7272
sp1 = []
7373
risc0 = []
74-
sp1-cycle-tracker = []
74+
sp1-cycle-tracker = []

lib/src/prover.rs

+20
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ pub enum ProverError {
1111
FileIo(#[from] std::io::Error),
1212
#[error("ProverError::Param `{0}`")]
1313
Param(#[from] serde_json::Error),
14+
#[error("ProverError::Worker `{0}`")]
15+
Worker(#[from] WorkerError),
1416
}
1517

1618
impl From<String> for ProverError {
@@ -37,3 +39,21 @@ pub fn to_proof(proof: ProverResult<impl Serialize>) -> ProverResult<Proof> {
3739
serde_json::to_value(res).map_err(|err| ProverError::GuestError(err.to_string()))
3840
})
3941
}
42+
43+
#[derive(ThisError, Debug)]
44+
pub enum WorkerError {
45+
#[error("All workers failed")]
46+
AllWorkersFailed,
47+
#[error("Worker IO error: {0}")]
48+
IO(#[from] std::io::Error),
49+
#[error("Worker Serde error: {0}")]
50+
Serde(#[from] bincode::Error),
51+
#[error("Worker invalid magic number")]
52+
InvalidMagicNumber,
53+
#[error("Worker invalid request")]
54+
InvalidRequest,
55+
#[error("Worker invalid response")]
56+
InvalidResponse,
57+
#[error("Worker payload too big")]
58+
PayloadTooBig,
59+
}

provers/sp1/driver/Cargo.toml

+15
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,32 @@ alloy-sol-types = { workspace = true }
1616
serde = { workspace = true , optional = true}
1717
serde_json = { workspace = true , optional = true }
1818
sp1-sdk = { workspace = true, optional = true }
19+
sp1-core = { workspace = true, optional = true }
1920
anyhow = { workspace = true, optional = true }
2021
once_cell = { workspace = true, optional = true }
2122
sha3 = { workspace = true, optional = true, default-features = false}
2223

24+
log = { workspace = true }
25+
tokio = { workspace = true }
26+
async-channel = { workspace = true }
27+
tracing = { workspace = true }
28+
tempfile = { workspace = true }
29+
bincode = { workspace = true }
30+
31+
p3-field = { workspace = true }
32+
p3-challenger = { workspace = true }
33+
p3-poseidon2 = { workspace = true }
34+
p3-baby-bear = { workspace = true }
35+
p3-symmetric = { workspace = true }
36+
2337

2438
[features]
2539
enable = [
2640
"serde",
2741
"serde_json",
2842
"raiko-lib",
2943
"sp1-sdk",
44+
"sp1-core",
3045
"anyhow",
3146
"alloy-primitives",
3247
"once_cell",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
mod orchestrator;
2+
mod partial_proof_request;
3+
mod prover;
4+
mod sp1_specifics;
5+
mod worker;
6+
7+
pub use partial_proof_request::PartialProofRequest;
8+
pub use prover::Sp1DistributedProver;
9+
pub use worker::{WorkerEnvelope, WorkerProtocol, WorkerSocket};
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
mod worker_client;
2+
3+
use raiko_lib::prover::WorkerError;
4+
use sp1_core::{runtime::ExecutionState, stark::ShardProof, utils::BabyBearPoseidon2};
5+
use worker_client::WorkerClient;
6+
7+
use super::partial_proof_request::PartialProofRequest;
8+
9+
pub async fn distribute_work(
10+
ip_list: Vec<String>,
11+
checkpoints: Vec<ExecutionState>,
12+
partial_proof_request: PartialProofRequest,
13+
) -> Result<Vec<ShardProof<BabyBearPoseidon2>>, WorkerError> {
14+
let mut nb_workers = ip_list.len();
15+
16+
let (queue_tx, queue_rx) = async_channel::bounded(nb_workers);
17+
let (answer_tx, answer_rx) = async_channel::bounded(nb_workers);
18+
19+
// Spawn the workers
20+
for (i, url) in ip_list.iter().enumerate() {
21+
let worker = WorkerClient::new(
22+
i,
23+
url.clone(),
24+
queue_rx.clone(),
25+
answer_tx.clone(),
26+
partial_proof_request.clone(),
27+
);
28+
29+
tokio::spawn(async move {
30+
worker.run().await;
31+
});
32+
}
33+
34+
// Send the checkpoints to the workers
35+
for (i, checkpoint) in checkpoints.iter().enumerate() {
36+
queue_tx.send((i, checkpoint.clone())).await.unwrap();
37+
}
38+
39+
let mut proofs = Vec::new();
40+
41+
// Get the partial proofs from the workers
42+
loop {
43+
let (checkpoint_id, partial_proof_result) = answer_rx.recv().await.unwrap();
44+
45+
match partial_proof_result {
46+
Ok(partial_proof) => {
47+
proofs.push((checkpoint_id as usize, partial_proof));
48+
}
49+
Err(_e) => {
50+
// Decrease the number of workers
51+
nb_workers -= 1;
52+
53+
if nb_workers == 0 {
54+
return Err(WorkerError::AllWorkersFailed);
55+
}
56+
57+
// Push back the work for it to be done by another worker
58+
queue_tx
59+
.send((checkpoint_id, checkpoints[checkpoint_id as usize].clone()))
60+
.await
61+
.unwrap();
62+
}
63+
}
64+
65+
if proofs.len() == checkpoints.len() {
66+
break;
67+
}
68+
}
69+
70+
proofs.sort_by_key(|(checkpoint_id, _)| *checkpoint_id);
71+
72+
let proofs = proofs
73+
.into_iter()
74+
.map(|(_, proof)| proof)
75+
.flatten()
76+
.collect();
77+
78+
Ok(proofs)
79+
}

0 commit comments

Comments
 (0)