Skip to content

Commit 2680485

Browse files
committed
Add a version number to WorkerEnvelope
1 parent 85fff43 commit 2680485

File tree

6 files changed

+29
-13
lines changed

6 files changed

+29
-13
lines changed

Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ hyper = { version = "0.14.27", features = ["server"] }
149149
reqwest = { version = "0.11.22", features = ["json"] }
150150
url = "2.5.0"
151151
async-trait = "0.1.80"
152+
async-channel = "2.3.1"
152153

153154
# crypto
154155
kzg = { package = "rust-kzg-zkcrypto", git = "https://github.com/brechtpd/rust-kzg.git", branch = "sp1-patch", default-features = false }
@@ -159,7 +160,6 @@ secp256k1 = { version = "0.29", default-features = false, features = [
159160
"global-context",
160161
"recovery",
161162
] }
162-
async-channel = "2.3.1"
163163

164164
# macro
165165
syn = { version = "1.0", features = ["full"] }

host/src/server/worker.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ async fn listen_worker(state: ProverState) {
6060

6161
// We purposely don't spawn the task here, as we want to block to limit the number
6262
// of concurrent connections to one.
63-
if let Err(e) = handle_worker_socket(WorkerSocket::new(socket)).await {
63+
if let Err(e) = handle_worker_socket(WorkerSocket::from_streamm_stream(socket)).await {
6464
error!("Error while handling worker socket: {:?}", e);
6565
}
6666
}

lib/src/prover.rs

+2
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ pub enum WorkerError {
5050
Serde(#[from] bincode::Error),
5151
#[error("Worker invalid magic number")]
5252
InvalidMagicNumber,
53+
#[error("Worker invalid version")]
54+
InvalidVersion,
5355
#[error("Worker invalid request")]
5456
InvalidRequest,
5557
#[error("Worker invalid response")]

provers/sp1/driver/src/distributed/worker/envelope.rs

+19-2
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,34 @@
1+
use raiko_lib::prover::WorkerError;
12
use serde::{Deserialize, Serialize};
23

34
use crate::WorkerProtocol;
45

56
#[derive(Debug, Serialize, Deserialize)]
67
pub struct WorkerEnvelope {
7-
pub magic: u64,
8-
pub data: WorkerProtocol,
8+
magic: u64,
9+
version: u64,
10+
data: WorkerProtocol,
11+
}
12+
13+
impl WorkerEnvelope {
14+
pub fn data(self) -> Result<WorkerProtocol, WorkerError> {
15+
if self.magic != 0xdeadbeef {
16+
return Err(WorkerError::InvalidMagicNumber);
17+
}
18+
19+
if self.version != include!("../../../worker.version") {
20+
return Err(WorkerError::InvalidVersion);
21+
}
22+
23+
Ok(self.data)
24+
}
925
}
1026

1127
impl From<WorkerProtocol> for WorkerEnvelope {
1228
fn from(data: WorkerProtocol) -> Self {
1329
WorkerEnvelope {
1430
magic: 0xdeadbeef,
31+
version: include!("../../../worker.version"),
1532
data,
1633
}
1734
}

provers/sp1/driver/src/distributed/worker/socket.rs

+5-9
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,17 @@ use crate::{PartialProofRequest, WorkerEnvelope, WorkerProtocol};
88
const PAYLOAD_MAX_SIZE: usize = 1 << 26;
99

1010
pub struct WorkerSocket {
11-
pub socket: tokio::net::TcpStream,
11+
socket: tokio::net::TcpStream,
1212
}
1313

1414
impl WorkerSocket {
1515
pub async fn connect(url: &str) -> Result<Self, WorkerError> {
16-
let stream = tokio::net::TcpStream::connect(url).await?;
16+
let socket = tokio::net::TcpStream::connect(url).await?;
1717

18-
Ok(WorkerSocket { socket: stream })
18+
Ok(WorkerSocket::from_streamm_stream(socket))
1919
}
2020

21-
pub fn new(socket: tokio::net::TcpStream) -> Self {
21+
pub fn from_streamm_stream(socket: tokio::net::TcpStream) -> Self {
2222
WorkerSocket { socket }
2323
}
2424

@@ -42,11 +42,7 @@ impl WorkerSocket {
4242

4343
let envelope: WorkerEnvelope = bincode::deserialize(&data)?;
4444

45-
if envelope.magic != 0xdeadbeef {
46-
return Err(WorkerError::InvalidMagicNumber);
47-
}
48-
49-
Ok(envelope.data)
45+
envelope.data()
5046
}
5147

5248
// TODO: Add a timeout

provers/sp1/driver/worker.version

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
1

0 commit comments

Comments
 (0)