1
1
use raiko_lib:: prover:: WorkerError ;
2
+ use sp1_core:: { stark:: ShardProof , utils:: BabyBearPoseidon2 } ;
2
3
use tokio:: io:: { AsyncReadExt , AsyncWriteExt , BufWriter } ;
3
4
4
- use crate :: { WorkerEnvelope , WorkerProtocol } ;
5
+ use crate :: { PartialProofRequest , WorkerEnvelope , WorkerProtocol } ;
6
+
7
+ // 64MB
8
+ const PAYLOAD_MAX_SIZE : usize = 1 << 26 ;
5
9
6
10
pub struct WorkerSocket {
7
11
pub socket : tokio:: net:: TcpStream ,
@@ -23,6 +27,10 @@ impl WorkerSocket {
23
27
24
28
let data = bincode:: serialize ( & envelope) ?;
25
29
30
+ if data. len ( ) > PAYLOAD_MAX_SIZE {
31
+ return Err ( WorkerError :: PayloadTooBig ) ;
32
+ }
33
+
26
34
self . socket . write_u64 ( data. len ( ) as u64 ) . await ?;
27
35
self . socket . write_all ( & data) . await ?;
28
36
@@ -42,10 +50,13 @@ impl WorkerSocket {
42
50
}
43
51
44
52
// TODO: Add a timeout
45
- pub async fn read_data ( & mut self ) -> Result < Vec < u8 > , std:: io:: Error > {
46
- // TODO: limit the size of the data
53
+ pub async fn read_data ( & mut self ) -> Result < Vec < u8 > , WorkerError > {
47
54
let size = self . socket . read_u64 ( ) . await ? as usize ;
48
55
56
+ if size > PAYLOAD_MAX_SIZE {
57
+ return Err ( WorkerError :: PayloadTooBig ) ;
58
+ }
59
+
49
60
let mut data = Vec :: new ( ) ;
50
61
51
62
let mut buf_data = BufWriter :: new ( & mut data) ;
@@ -72,9 +83,36 @@ impl WorkerSocket {
72
83
Err ( e) => {
73
84
log:: error!( "failed to read from socket; err = {:?}" , e) ;
74
85
75
- return Err ( e) ;
86
+ return Err ( e. into ( ) ) ;
76
87
}
77
88
} ;
78
89
}
79
90
}
91
+
92
+ pub async fn ping ( & mut self ) -> Result < ( ) , WorkerError > {
93
+ self . send ( WorkerProtocol :: Ping ) . await ?;
94
+
95
+ let response = self . receive ( ) . await ?;
96
+
97
+ match response {
98
+ WorkerProtocol :: Pong => Ok ( ( ) ) ,
99
+ _ => Err ( WorkerError :: InvalidResponse ) ,
100
+ }
101
+ }
102
+
103
+ pub async fn partial_proof_request (
104
+ & mut self ,
105
+ request : PartialProofRequest ,
106
+ ) -> Result < Vec < ShardProof < BabyBearPoseidon2 > > , WorkerError > {
107
+ self . send ( WorkerProtocol :: PartialProofRequest ( request) )
108
+ . await ?;
109
+
110
+ let response = self . receive ( ) . await ?;
111
+
112
+ if let WorkerProtocol :: PartialProofResponse ( partial_proofs) = response {
113
+ Ok ( partial_proofs)
114
+ } else {
115
+ Err ( WorkerError :: InvalidResponse )
116
+ }
117
+ }
80
118
}
0 commit comments