Skip to content

Commit 94fd92c

Browse files
committed
First working version of SP1 Distributed Prover
1 parent 4c8c44e commit 94fd92c

24 files changed

+1179
-118
lines changed

Diff for: Cargo.lock

+168-65
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Diff for: Cargo.toml

+23-4
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,19 @@ risc0-build = { version = "0.21.0" }
5252
risc0-binfmt = { version = "0.21.0" }
5353

5454
# SP1
55-
sp1-sdk = { git = "https://github.com/succinctlabs/sp1.git", branch = "main" }
56-
sp1-zkvm = { git = "https://github.com/succinctlabs/sp1.git", branch = "main" }
57-
sp1-helper = { git = "https://github.com/succinctlabs/sp1.git", branch = "main" }
55+
sp1-sdk = { git = "https://github.com/succinctlabs/sp1.git", rev = "14eb569d41d24721ffbd407d6060e202482d659c" }
56+
sp1-zkvm = { git = "https://github.com/succinctlabs/sp1.git", rev = "14eb569d41d24721ffbd407d6060e202482d659c" }
57+
sp1-helper = { git = "https://github.com/succinctlabs/sp1.git", rev = "14eb569d41d24721ffbd407d6060e202482d659c" }
58+
sp1-core = { git = "https://github.com/succinctlabs/sp1.git", rev = "14eb569d41d24721ffbd407d6060e202482d659c" }
59+
60+
61+
# Plonky3
62+
p3-field = { git = "https://github.com/Plonky3/Plonky3.git", rev = "88ea2b866e41329817e4761429b4a5a2a9751c07" }
63+
p3-challenger = { git = "https://github.com/Plonky3/Plonky3.git", rev = "88ea2b866e41329817e4761429b4a5a2a9751c07" }
64+
p3-poseidon2 = { git = "https://github.com/Plonky3/Plonky3.git", rev = "88ea2b866e41329817e4761429b4a5a2a9751c07" }
65+
p3-baby-bear = { git = "https://github.com/Plonky3/Plonky3.git", rev = "88ea2b866e41329817e4761429b4a5a2a9751c07" }
66+
p3-symmetric = { git = "https://github.com/Plonky3/Plonky3.git", rev = "88ea2b866e41329817e4761429b4a5a2a9751c07" }
67+
5868

5969
# alloy
6070
alloy-rlp = { version = "0.3.4", default-features = false }
@@ -145,6 +155,7 @@ secp256k1 = { version = "0.29", default-features = false, features = [
145155
"global-context",
146156
"recovery",
147157
] }
158+
async-channel = "2.3.1"
148159

149160
# macro
150161
syn = { version = "1.0", features = ["full"] }
@@ -180,4 +191,12 @@ revm = { git = "https://github.com/taikoxyz/revm.git", branch = "v36-taiko" }
180191
revm-primitives = { git = "https://github.com/taikoxyz/revm.git", branch = "v36-taiko" }
181192
revm-precompile = { git = "https://github.com/taikoxyz/revm.git", branch = "v36-taiko" }
182193
secp256k1 = { git = "https://github.com/CeciliaZ030/rust-secp256k1", branch = "sp1-patch" }
183-
blst = { git = "https://github.com/CeciliaZ030/blst.git", branch = "v0.3.12-serialize" }
194+
blst = { git = "https://github.com/CeciliaZ030/blst.git", branch = "v0.3.12-serialize" }
195+
196+
# Patch Plonky3 for Serialize and Deserialize of DuplexChallenger
197+
[patch."https://github.com/Plonky3/Plonky3.git"]
198+
p3-field = { git = "https://github.com/Champii/Plonky3.git", branch = "serde_patch" }
199+
p3-challenger = { git = "https://github.com/Champii/Plonky3.git", branch = "serde_patch" }
200+
p3-poseidon2 = { git = "https://github.com/Champii/Plonky3.git", branch = "serde_patch" }
201+
p3-baby-bear = { git = "https://github.com/Champii/Plonky3.git", branch = "serde_patch" }
202+
p3-symmetric = { git = "https://github.com/Champii/Plonky3.git", branch = "serde_patch" }

Diff for: core/src/interfaces.rs

+14
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,10 @@ pub enum ProofType {
9191
///
9292
/// Uses the SP1 prover to build the block.
9393
Sp1,
94+
/// # Sp1Distributed
95+
///
96+
/// Uses the SP1 prover to build the block in a distributed way.
97+
Sp1Distributed,
9498
/// # Sgx
9599
///
96100
/// Builds the block on a SGX supported CPU to create a proof.
@@ -106,6 +110,7 @@ impl std::fmt::Display for ProofType {
106110
f.write_str(match self {
107111
ProofType::Native => "native",
108112
ProofType::Sp1 => "sp1",
113+
ProofType::Sp1Distributed => "sp1_distributed",
109114
ProofType::Sgx => "sgx",
110115
ProofType::Risc0 => "risc0",
111116
})
@@ -119,6 +124,7 @@ impl FromStr for ProofType {
119124
match s.trim().to_lowercase().as_str() {
120125
"native" => Ok(ProofType::Native),
121126
"sp1" => Ok(ProofType::Sp1),
127+
"sp1_distributed" => Ok(ProofType::Sp1Distributed),
122128
"sgx" => Ok(ProofType::Sgx),
123129
"risc0" => Ok(ProofType::Risc0),
124130
_ => Err(RaikoError::InvalidProofType(s.to_string())),
@@ -146,6 +152,14 @@ impl ProofType {
146152
#[cfg(not(feature = "sp1"))]
147153
Err(RaikoError::FeatureNotSupportedError(self.clone()))
148154
}
155+
ProofType::Sp1Distributed => {
156+
#[cfg(feature = "sp1")]
157+
return sp1_driver::Sp1DistributedProver::run(input, output, config)
158+
.await
159+
.map_err(|e| e.into());
160+
#[cfg(not(feature = "sp1"))]
161+
Err(RaikoError::FeatureNotSupportedError(self.clone()))
162+
}
149163
ProofType::Risc0 => {
150164
#[cfg(feature = "risc0")]
151165
return risc0_driver::Risc0Prover::run(input.clone(), output, config)

Diff for: host/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ ethers-core = { workspace = true }
8181

8282
[features]
8383
default = []
84-
sp1 = ["raiko-core/sp1"]
84+
sp1 = ["raiko-core/sp1", "sp1-driver"]
8585
risc0 = ["raiko-core/risc0"]
8686
sgx = ["raiko-core/sgx"]
8787

Diff for: host/src/lib.rs

+15
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ fn default_address() -> String {
2121
"0.0.0.0:8080".to_string()
2222
}
2323

24+
fn default_worker_address() -> String {
25+
"0.0.0.0:8081".to_string()
26+
}
27+
2428
fn default_concurrency_limit() -> usize {
2529
16
2630
}
@@ -51,6 +55,17 @@ pub struct Cli {
5155
/// [default: 0.0.0.0:8080]
5256
address: String,
5357

58+
#[arg(long, require_equals = true, default_value = "0.0.0.0:8081")]
59+
#[serde(default = "default_worker_address")]
60+
/// Distributed SP1 worker listening address
61+
/// [default: 0.0.0.0:8081]
62+
worker_address: String,
63+
64+
#[arg(long, default_value = None)]
65+
/// Distributed SP1 worker orchestrator address
66+
/// [default: None]
67+
orchestrator_address: Option<String>,
68+
5469
#[arg(long, require_equals = true, default_value = "16")]
5570
#[serde(default = "default_concurrency_limit")]
5671
/// Limit the max number of in-flight requests

Diff for: host/src/server/mod.rs

+3
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,12 @@ use tokio::net::TcpListener;
55
use tracing::info;
66

77
pub mod api;
8+
pub mod worker;
89

910
/// Starts the proverd server.
1011
pub async fn serve(state: ProverState) -> anyhow::Result<()> {
12+
worker::serve(state.clone()).await;
13+
1114
let addr = SocketAddr::from_str(&state.opts.address)
1215
.map_err(|_| HostError::InvalidAddress(state.opts.address.clone()))?;
1316
let listener = TcpListener::bind(addr).await?;

Diff for: host/src/server/worker.rs

+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
use crate::ProverState;
2+
use raiko_lib::prover::{ProverError, WorkerError};
3+
use sp1_driver::{PartialProofRequest, WorkerProtocol, WorkerSocket};
4+
use tokio::net::TcpListener;
5+
use tracing::{error, info, warn};
6+
7+
async fn handle_worker_socket(mut socket: WorkerSocket) -> Result<(), ProverError> {
8+
let protocol = socket.receive().await?;
9+
10+
info!("Received request: {}", protocol);
11+
12+
match protocol {
13+
WorkerProtocol::Ping => {
14+
socket.send(WorkerProtocol::Ping).await?;
15+
}
16+
WorkerProtocol::PartialProofRequest(data) => {
17+
process_partial_proof_request(socket, data).await?;
18+
}
19+
_ => Err(WorkerError::InvalidRequest)?,
20+
}
21+
22+
Ok(())
23+
}
24+
25+
async fn process_partial_proof_request(
26+
mut socket: WorkerSocket,
27+
data: PartialProofRequest,
28+
) -> Result<(), ProverError> {
29+
let result = sp1_driver::Sp1DistributedProver::run_as_worker(data).await;
30+
31+
match result {
32+
Ok(data) => Ok(socket
33+
.send(WorkerProtocol::PartialProofResponse(data))
34+
.await?),
35+
Err(e) => {
36+
error!("Error while processing worker request: {:?}", e);
37+
38+
Err(e)
39+
}
40+
}
41+
}
42+
43+
async fn listen_worker(state: ProverState) {
44+
info!(
45+
"Listening as a SP1 worker on: {}",
46+
state.opts.worker_address
47+
);
48+
49+
let listener = TcpListener::bind(state.opts.worker_address).await.unwrap();
50+
51+
loop {
52+
let Ok((socket, addr)) = listener.accept().await else {
53+
error!("Error while accepting connection from orchestrator: Closing socket");
54+
55+
return;
56+
};
57+
58+
if let Some(orchestrator_address) = &state.opts.orchestrator_address {
59+
if addr.ip().to_string() != *orchestrator_address {
60+
warn!("Unauthorized orchestrator connection from: {}", addr);
61+
62+
continue;
63+
}
64+
}
65+
66+
info!("Receiving connection from orchestrator: {}", addr);
67+
68+
// We purposely don't spawn the task here, as we want to block to limit the number
69+
// of concurrent connections to one.
70+
71+
if let Err(e) = handle_worker_socket(WorkerSocket::new(socket)).await {
72+
error!("Error while handling worker socket: {:?}", e);
73+
}
74+
}
75+
}
76+
77+
pub async fn serve(state: ProverState) {
78+
if state.opts.orchestrator_address.is_some() {
79+
tokio::spawn(listen_worker(state));
80+
}
81+
}

Diff for: lib/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -71,4 +71,4 @@ std = [
7171
sgx = []
7272
sp1 = []
7373
risc0 = []
74-
sp1-cycle-tracker = []
74+
sp1-cycle-tracker = []

Diff for: lib/src/prover.rs

+18
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ pub enum ProverError {
1111
FileIo(#[from] std::io::Error),
1212
#[error("ProverError::Param `{0}`")]
1313
Param(#[from] serde_json::Error),
14+
#[error("ProverError::Worker `{0}`")]
15+
Worker(#[from] WorkerError),
1416
}
1517

1618
impl From<String> for ProverError {
@@ -37,3 +39,19 @@ pub fn to_proof(proof: ProverResult<impl Serialize>) -> ProverResult<Proof> {
3739
serde_json::to_value(res).map_err(|err| ProverError::GuestError(err.to_string()))
3840
})
3941
}
42+
43+
#[derive(ThisError, Debug)]
44+
pub enum WorkerError {
45+
#[error("All workers failed")]
46+
AllWorkersFailed,
47+
#[error("Worker IO error: {0}")]
48+
IO(#[from] std::io::Error),
49+
#[error("Worker Serde error: {0}")]
50+
Serde(#[from] bincode::Error),
51+
#[error("Worker invalid magic number")]
52+
InvalidMagicNumber,
53+
#[error("Worker invalid request")]
54+
InvalidRequest,
55+
#[error("Worker invalid response")]
56+
InvalidResponse,
57+
}

Diff for: provers/sp1/driver/Cargo.toml

+15
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,32 @@ alloy-sol-types = { workspace = true }
1616
serde = { workspace = true , optional = true}
1717
serde_json = { workspace = true , optional = true }
1818
sp1-sdk = { workspace = true, optional = true }
19+
sp1-core = { workspace = true, optional = true }
1920
anyhow = { workspace = true, optional = true }
2021
once_cell = { workspace = true, optional = true }
2122
sha3 = { workspace = true, optional = true, default-features = false}
2223

24+
log = { workspace = true }
25+
tokio = { workspace = true }
26+
async-channel = { workspace = true }
27+
tracing = { workspace = true }
28+
tempfile = { workspace = true }
29+
bincode = { workspace = true }
30+
31+
p3-field = { workspace = true }
32+
p3-challenger = { workspace = true }
33+
p3-poseidon2 = { workspace = true }
34+
p3-baby-bear = { workspace = true }
35+
p3-symmetric = { workspace = true }
36+
2337

2438
[features]
2539
enable = [
2640
"serde",
2741
"serde_json",
2842
"raiko-lib",
2943
"sp1-sdk",
44+
"sp1-core",
3045
"anyhow",
3146
"alloy-primitives",
3247
"once_cell",

Diff for: provers/sp1/driver/src/distributed/mod.rs

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
mod orchestrator;
2+
mod partial_proof_request;
3+
mod prover;
4+
mod sp1_specifics;
5+
mod worker;
6+
7+
pub use partial_proof_request::PartialProofRequest;
8+
pub use prover::Sp1DistributedProver;
9+
pub use worker::{WorkerEnvelope, WorkerProtocol, WorkerSocket};
+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
mod worker_client;
2+
3+
use raiko_lib::prover::WorkerError;
4+
use sp1_core::{runtime::ExecutionState, stark::ShardProof, utils::BabyBearPoseidon2};
5+
use worker_client::WorkerClient;
6+
7+
use super::partial_proof_request::PartialProofRequest;
8+
9+
pub async fn distribute_work(
10+
ip_list: Vec<String>,
11+
checkpoints: Vec<ExecutionState>,
12+
partial_proof_request: PartialProofRequest,
13+
) -> Result<Vec<ShardProof<BabyBearPoseidon2>>, WorkerError> {
14+
let mut nb_workers = ip_list.len();
15+
16+
let (queue_tx, queue_rx) = async_channel::bounded(nb_workers);
17+
let (answer_tx, answer_rx) = async_channel::bounded(nb_workers);
18+
19+
// Spawn the workers
20+
for (i, url) in ip_list.iter().enumerate() {
21+
let worker = WorkerClient::new(
22+
i,
23+
url.clone(),
24+
queue_rx.clone(),
25+
answer_tx.clone(),
26+
partial_proof_request.clone(),
27+
);
28+
29+
tokio::spawn(async move {
30+
worker.run().await;
31+
});
32+
}
33+
34+
// Send the checkpoints to the workers
35+
for (i, checkpoint) in checkpoints.iter().enumerate() {
36+
queue_tx.send((i, checkpoint.clone())).await.unwrap();
37+
}
38+
39+
let mut proofs = Vec::new();
40+
41+
// Get the partial proofs from the workers
42+
loop {
43+
let (checkpoint_id, partial_proof_result) = answer_rx.recv().await.unwrap();
44+
45+
match partial_proof_result {
46+
Ok(partial_proof) => {
47+
proofs.push((checkpoint_id as usize, partial_proof));
48+
}
49+
Err(_e) => {
50+
// Decrease the number of workers
51+
nb_workers -= 1;
52+
53+
if nb_workers == 0 {
54+
return Err(WorkerError::AllWorkersFailed);
55+
}
56+
57+
// Push back the work for it to be done by another worker
58+
queue_tx
59+
.send((checkpoint_id, checkpoints[checkpoint_id as usize].clone()))
60+
.await
61+
.unwrap();
62+
}
63+
}
64+
65+
if proofs.len() == checkpoints.len() {
66+
break;
67+
}
68+
}
69+
70+
proofs.sort_by_key(|(checkpoint_id, _)| *checkpoint_id);
71+
72+
let proofs = proofs
73+
.into_iter()
74+
.map(|(_, proof)| proof)
75+
.flatten()
76+
.collect();
77+
78+
Ok(proofs)
79+
}

0 commit comments

Comments
 (0)