Skip to content

Commit a0272b8

Browse files
committed
Add RPC style call for WorkerSocket and limit the payload size to 64MB
1 parent dfd8366 commit a0272b8

File tree

7 files changed

+62
-48
lines changed

7 files changed

+62
-48
lines changed

host/src/server/worker.rs

+7-15
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@ use tracing::{error, info, warn};
77
async fn handle_worker_socket(mut socket: WorkerSocket) -> Result<(), ProverError> {
88
let protocol = socket.receive().await?;
99

10-
info!("Received request: {}", protocol);
10+
info!("Received request from orchestrator: {}", protocol);
1111

1212
match protocol {
1313
WorkerProtocol::Ping => {
14-
socket.send(WorkerProtocol::Ping).await?;
14+
socket.send(WorkerProtocol::Pong).await?;
1515
}
1616
WorkerProtocol::PartialProofRequest(data) => {
1717
process_partial_proof_request(socket, data).await?;
@@ -26,18 +26,13 @@ async fn process_partial_proof_request(
2626
mut socket: WorkerSocket,
2727
data: PartialProofRequest,
2828
) -> Result<(), ProverError> {
29-
let result = sp1_driver::Sp1DistributedProver::run_as_worker(data).await;
29+
let partial_proof = sp1_driver::Sp1DistributedProver::run_as_worker(data).await?;
3030

31-
match result {
32-
Ok(data) => Ok(socket
33-
.send(WorkerProtocol::PartialProofResponse(data))
34-
.await?),
35-
Err(e) => {
36-
error!("Error while processing worker request: {:?}", e);
31+
socket
32+
.send(WorkerProtocol::PartialProofResponse(partial_proof))
33+
.await?;
3734

38-
Err(e)
39-
}
40-
}
35+
Ok(())
4136
}
4237

4338
async fn listen_worker(state: ProverState) {
@@ -63,11 +58,8 @@ async fn listen_worker(state: ProverState) {
6358
}
6459
}
6560

66-
info!("Receiving connection from orchestrator: {}", addr);
67-
6861
// We purposely don't spawn the task here, as we want to block to limit the number
6962
// of concurrent connections to one.
70-
7163
if let Err(e) = handle_worker_socket(WorkerSocket::new(socket)).await {
7264
error!("Error while handling worker socket: {:?}", e);
7365
}

lib/src/prover.rs

+2
Original file line numberDiff line numberDiff line change
@@ -54,4 +54,6 @@ pub enum WorkerError {
5454
InvalidRequest,
5555
#[error("Worker invalid response")]
5656
InvalidResponse,
57+
#[error("Worker payload too big")]
58+
PayloadTooBig,
5759
}

provers/sp1/driver/src/distributed/orchestrator/worker_client.rs

+2-14
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,7 @@ use async_channel::{Receiver, Sender};
22
use raiko_lib::prover::WorkerError;
33
use sp1_core::{runtime::ExecutionState, stark::ShardProof, utils::BabyBearPoseidon2};
44

5-
use crate::{
6-
distributed::partial_proof_request::PartialProofRequest, WorkerProtocol, WorkerSocket,
7-
};
5+
use crate::{distributed::partial_proof_request::PartialProofRequest, WorkerSocket};
86

97
pub struct WorkerClient {
108
/// The id of the worker
@@ -84,16 +82,6 @@ impl WorkerClient {
8482
request.checkpoint_id = i;
8583
request.checkpoint_data = checkpoint;
8684

87-
socket
88-
.send(WorkerProtocol::PartialProofRequest(request))
89-
.await?;
90-
91-
let response = socket.receive().await?;
92-
93-
if let WorkerProtocol::PartialProofResponse(partial_proofs) = response {
94-
Ok(partial_proofs)
95-
} else {
96-
Err(WorkerError::InvalidResponse)
97-
}
85+
socket.partial_proof_request(request).await
9886
}
9987
}

provers/sp1/driver/src/distributed/prover.rs

+4-14
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use crate::{
1212
partial_proof_request::PartialProofRequest,
1313
sp1_specifics::{commit, prove_partial},
1414
},
15-
Sp1Response, WorkerProtocol, WorkerSocket, ELF,
15+
Sp1Response, WorkerSocket, ELF,
1616
};
1717

1818
pub struct Sp1DistributedProver;
@@ -123,23 +123,13 @@ impl Sp1DistributedProver {
123123
continue;
124124
};
125125

126-
if let Err(_) = socket.send(WorkerProtocol::Ping).await {
127-
log::warn!("Sp1 Distributed: Worker at {} is not reachable. Removing from the list for this task", ip);
126+
if let Err(_) = socket.ping().await {
127+
log::warn!("Sp1 Distributed: Worker at {} is not sending good response to Ping. Removing from the list for this task", ip);
128128

129129
continue;
130130
}
131131

132-
let Ok(response) = socket.receive().await else {
133-
log::warn!("Sp1 Distributed: Worker at {} is not a valid SP1 worker. Removing from the list for this task", ip);
134-
135-
continue;
136-
};
137-
138-
if let WorkerProtocol::Ping = response {
139-
reachable_ip_list.push(ip.clone());
140-
} else {
141-
log::warn!("Sp1 Distributed: Worker at {} is not a valid SP1 worker. Removing from the list for this task", ip);
142-
}
132+
reachable_ip_list.push(ip.clone());
143133
}
144134

145135
if reachable_ip_list.is_empty() {

provers/sp1/driver/src/distributed/sp1_specifics.rs

+3-1
Original file line numberDiff line numberDiff line change
@@ -238,11 +238,13 @@ pub fn prove_partial(request_data: &PartialProofRequest) -> Vec<ShardProof<BabyB
238238

239239
log::debug!("Checkpoint sharding took {:?}", now.elapsed());
240240

241+
let nb_shards = checkpoint_shards.len();
242+
241243
let mut proofs = checkpoint_shards
242244
.into_iter()
243245
.enumerate()
244246
.map(|(i, shard)| {
245-
log::info!("Proving shard {}/{}", i + 1, request_data.shard_batch_size);
247+
log::info!("Proving shard {}/{}", i + 1, nb_shards);
246248

247249
let config = machine.config();
248250

provers/sp1/driver/src/distributed/worker/protocol.rs

+2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use crate::PartialProofRequest;
88
#[derive(Debug, Serialize, Deserialize)]
99
pub enum WorkerProtocol {
1010
Ping,
11+
Pong,
1112
PartialProofRequest(PartialProofRequest),
1213
PartialProofResponse(Vec<ShardProof<BabyBearPoseidon2>>),
1314
}
@@ -16,6 +17,7 @@ impl Display for WorkerProtocol {
1617
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1718
match self {
1819
WorkerProtocol::Ping => write!(f, "Ping"),
20+
WorkerProtocol::Pong => write!(f, "Pong"),
1921
WorkerProtocol::PartialProofRequest(_) => write!(f, "PartialProofRequest"),
2022
WorkerProtocol::PartialProofResponse(_) => write!(f, "PartialProofResponse"),
2123
}

provers/sp1/driver/src/distributed/worker/socket.rs

+42-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
use raiko_lib::prover::WorkerError;
2+
use sp1_core::{stark::ShardProof, utils::BabyBearPoseidon2};
23
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufWriter};
34

4-
use crate::{WorkerEnvelope, WorkerProtocol};
5+
use crate::{PartialProofRequest, WorkerEnvelope, WorkerProtocol};
6+
7+
// 64MB
8+
const PAYLOAD_MAX_SIZE: usize = 1 << 26;
59

610
pub struct WorkerSocket {
711
pub socket: tokio::net::TcpStream,
@@ -23,6 +27,10 @@ impl WorkerSocket {
2327

2428
let data = bincode::serialize(&envelope)?;
2529

30+
if data.len() > PAYLOAD_MAX_SIZE {
31+
return Err(WorkerError::PayloadTooBig);
32+
}
33+
2634
self.socket.write_u64(data.len() as u64).await?;
2735
self.socket.write_all(&data).await?;
2836

@@ -42,10 +50,13 @@ impl WorkerSocket {
4250
}
4351

4452
// TODO: Add a timeout
45-
pub async fn read_data(&mut self) -> Result<Vec<u8>, std::io::Error> {
46-
// TODO: limit the size of the data
53+
pub async fn read_data(&mut self) -> Result<Vec<u8>, WorkerError> {
4754
let size = self.socket.read_u64().await? as usize;
4855

56+
if size > PAYLOAD_MAX_SIZE {
57+
return Err(WorkerError::PayloadTooBig);
58+
}
59+
4960
let mut data = Vec::new();
5061

5162
let mut buf_data = BufWriter::new(&mut data);
@@ -72,9 +83,36 @@ impl WorkerSocket {
7283
Err(e) => {
7384
log::error!("failed to read from socket; err = {:?}", e);
7485

75-
return Err(e);
86+
return Err(e.into());
7687
}
7788
};
7889
}
7990
}
91+
92+
pub async fn ping(&mut self) -> Result<(), WorkerError> {
93+
self.send(WorkerProtocol::Ping).await?;
94+
95+
let response = self.receive().await?;
96+
97+
match response {
98+
WorkerProtocol::Pong => Ok(()),
99+
_ => Err(WorkerError::InvalidResponse),
100+
}
101+
}
102+
103+
pub async fn partial_proof_request(
104+
&mut self,
105+
request: PartialProofRequest,
106+
) -> Result<Vec<ShardProof<BabyBearPoseidon2>>, WorkerError> {
107+
self.send(WorkerProtocol::PartialProofRequest(request))
108+
.await?;
109+
110+
let response = self.receive().await?;
111+
112+
if let WorkerProtocol::PartialProofResponse(partial_proofs) = response {
113+
Ok(partial_proofs)
114+
} else {
115+
Err(WorkerError::InvalidResponse)
116+
}
117+
}
80118
}

0 commit comments

Comments
 (0)