@@ -12,11 +12,12 @@ use std::env;
12
12
use std:: sync:: Arc ;
13
13
14
14
use anyhow:: Result ;
15
- use pyo3:: exceptions:: PyRuntimeError ;
15
+ use pyo3:: exceptions:: { PyRuntimeError , PyTimeoutError } ;
16
16
use structopt:: StructOpt ;
17
17
use tokio:: runtime:: Runtime ;
18
18
use tokio:: task:: JoinHandle ;
19
19
use tonic:: transport:: Channel ;
20
+ use tonic:: Status ;
20
21
21
22
pub mod torchftpb {
22
23
tonic:: include_proto!( "torchft" ) ;
@@ -102,14 +103,16 @@ impl ManagerClient {
102
103
} )
103
104
}
104
105
106
+ #[ pyo3( signature = ( room_id, rank, step, checkpoint_server_addr, timeout=None ) ) ]
105
107
fn quorum (
106
108
& mut self ,
107
109
py : Python < ' _ > ,
108
110
room_id : String ,
109
111
rank : i64 ,
110
112
step : i64 ,
111
113
checkpoint_server_addr : String ,
112
- ) -> PyResult < ( i64 , i64 , i64 , String , String , i64 , Option < i64 > , i64 , bool ) > {
114
+ timeout : Option < Duration > ,
115
+ ) -> Result < ( i64 , i64 , i64 , String , String , i64 , Option < i64 > , i64 , bool ) , StatusError > {
113
116
py. allow_threads ( move || {
114
117
let mut request = tonic:: Request :: new ( ManagerQuorumRequest {
115
118
room_id : room_id,
@@ -119,12 +122,9 @@ impl ManagerClient {
119
122
} ) ;
120
123
// This notifies the server about the timeout but doesn't affect the
121
124
// endpoint timeout which we set on client creation.
122
- request. set_timeout ( self . timeout ) ;
125
+ request. set_timeout ( timeout . unwrap_or ( self . timeout ) ) ;
123
126
124
- let response = self
125
- . runtime
126
- . block_on ( self . client . quorum ( request) )
127
- . map_err ( |e| PyRuntimeError :: new_err ( e. to_string ( ) ) ) ?;
127
+ let response = self . runtime . block_on ( self . client . quorum ( request) ) ?;
128
128
let resp = response. into_inner ( ) ;
129
129
Ok ( (
130
130
resp. quorum_id ,
@@ -140,29 +140,36 @@ impl ManagerClient {
140
140
} )
141
141
}
142
142
143
- fn checkpoint_address ( & mut self , py : Python < ' _ > , rank : i64 ) -> PyResult < String > {
143
+ #[ pyo3( signature = ( rank, timeout=None ) ) ]
144
+ fn checkpoint_address (
145
+ & mut self ,
146
+ py : Python < ' _ > ,
147
+ rank : i64 ,
148
+ timeout : Option < Duration > ,
149
+ ) -> Result < String , StatusError > {
144
150
py. allow_threads ( move || {
145
151
let mut request = tonic:: Request :: new ( CheckpointAddressRequest { rank : rank } ) ;
146
152
// This notifies the server about the timeout but doesn't affect the
147
153
// endpoint timeout which we set on client creation.
148
- request. set_timeout ( self . timeout ) ;
154
+ request. set_timeout ( timeout . unwrap_or ( self . timeout ) ) ;
149
155
150
156
let response = self
151
157
. runtime
152
- . block_on ( self . client . checkpoint_address ( request) )
153
- . map_err ( |e| PyRuntimeError :: new_err ( e. to_string ( ) ) ) ?;
158
+ . block_on ( self . client . checkpoint_address ( request) ) ?;
154
159
let resp = response. into_inner ( ) ;
155
160
Ok ( resp. checkpoint_server_address )
156
161
} )
157
162
}
158
163
164
+ #[ pyo3( signature = ( rank, step, should_commit, timeout=None ) ) ]
159
165
fn should_commit (
160
166
& mut self ,
161
167
py : Python < ' _ > ,
162
168
rank : i64 ,
163
169
step : i64 ,
164
170
should_commit : bool ,
165
- ) -> PyResult < bool > {
171
+ timeout : Option < Duration > ,
172
+ ) -> Result < bool , StatusError > {
166
173
py. allow_threads ( move || {
167
174
let mut request = tonic:: Request :: new ( ShouldCommitRequest {
168
175
rank : rank,
@@ -171,12 +178,9 @@ impl ManagerClient {
171
178
} ) ;
172
179
// This notifies the server about the timeout but doesn't affect the
173
180
// endpoint timeout which we set on client creation.
174
- request. set_timeout ( self . timeout ) ;
181
+ request. set_timeout ( timeout . unwrap_or ( self . timeout ) ) ;
175
182
176
- let response = self
177
- . runtime
178
- . block_on ( self . client . should_commit ( request) )
179
- . map_err ( |e| PyRuntimeError :: new_err ( e. to_string ( ) ) ) ?;
183
+ let response = self . runtime . block_on ( self . client . should_commit ( request) ) ?;
180
184
let resp = response. into_inner ( ) ;
181
185
Ok ( resp. should_commit )
182
186
} )
@@ -225,16 +229,25 @@ struct Lighthouse {
225
229
#[ pymethods]
226
230
impl Lighthouse {
227
231
#[ new]
228
- fn new ( py : Python < ' _ > , bind : String , min_replicas : u64 ) -> PyResult < Self > {
232
+ fn new (
233
+ py : Python < ' _ > ,
234
+ bind : String ,
235
+ min_replicas : u64 ,
236
+ join_timeout_ms : Option < u64 > ,
237
+ quorum_tick_ms : Option < u64 > ,
238
+ ) -> PyResult < Self > {
239
+ let join_timeout_ms = join_timeout_ms. unwrap_or ( 100 ) ;
240
+ let quorum_tick_ms = quorum_tick_ms. unwrap_or ( 100 ) ;
241
+
229
242
py. allow_threads ( move || {
230
243
let rt = Runtime :: new ( ) ?;
231
244
232
245
let lighthouse = rt
233
246
. block_on ( lighthouse:: Lighthouse :: new ( lighthouse:: LighthouseOpt {
234
247
bind : bind,
235
248
min_replicas : min_replicas,
236
- join_timeout_ms : 100 ,
237
- quorum_tick_ms : 100 ,
249
+ join_timeout_ms : join_timeout_ms ,
250
+ quorum_tick_ms : quorum_tick_ms ,
238
251
} ) )
239
252
. map_err ( |e| PyRuntimeError :: new_err ( e. to_string ( ) ) ) ?;
240
253
@@ -257,6 +270,26 @@ impl Lighthouse {
257
270
}
258
271
}
259
272
273
+ struct StatusError ( Status ) ;
274
+
275
+ impl From < StatusError > for PyErr {
276
+ fn from ( error : StatusError ) -> Self {
277
+ let code = error. 0 . code ( ) ;
278
+ match code {
279
+ tonic:: Code :: Cancelled | tonic:: Code :: DeadlineExceeded => {
280
+ PyTimeoutError :: new_err ( error. 0 . to_string ( ) )
281
+ }
282
+ _ => PyRuntimeError :: new_err ( error. 0 . to_string ( ) ) ,
283
+ }
284
+ }
285
+ }
286
+
287
+ impl From < Status > for StatusError {
288
+ fn from ( other : Status ) -> Self {
289
+ Self ( other)
290
+ }
291
+ }
292
+
260
293
#[ pymodule]
261
294
fn torchft ( m : & Bound < ' _ , PyModule > ) -> PyResult < ( ) > {
262
295
// setup logging on import
0 commit comments