@@ -14,8 +14,10 @@ use anyhow::Result;
14
14
use atty:: Stream ;
15
15
use core:: time:: Duration ;
16
16
use pyo3:: exceptions:: { PyRuntimeError , PyTimeoutError } ;
17
+ use std:: cmp;
17
18
use std:: env;
18
19
use std:: sync:: Arc ;
20
+ use std:: thread:: available_parallelism;
19
21
use structopt:: StructOpt ;
20
22
use tokio:: runtime:: Runtime ;
21
23
use tokio:: task:: JoinHandle ;
@@ -34,6 +36,17 @@ use crate::torchftpb::manager_service_client::ManagerServiceClient;
34
36
use crate :: torchftpb:: { CheckpointMetadataRequest , ManagerQuorumRequest , ShouldCommitRequest } ;
35
37
use pyo3:: prelude:: * ;
36
38
39
+ // Get the number of threads to use for the tokio runtime
40
+ fn num_threads ( ) -> usize {
41
+ let default_threads = cmp:: min ( 4 , available_parallelism ( ) . unwrap ( ) . get ( ) ) ;
42
+ let num_threads = env:: var ( "TOKIO_WORKER_THREADS" )
43
+ . ok ( )
44
+ . and_then ( |s| s. parse ( ) . ok ( ) )
45
+ . unwrap_or ( default_threads) ;
46
+
47
+ num_threads
48
+ }
49
+
37
50
/// ManagerServer is a GRPC server for the manager service.
38
51
/// There should be one manager server per replica group (typically running on
39
52
/// the rank 0 host). The individual ranks within a replica group should use
@@ -71,7 +84,11 @@ impl ManagerServer {
71
84
connect_timeout : Duration ,
72
85
) -> PyResult < Self > {
73
86
py. allow_threads ( move || {
74
- let runtime = Runtime :: new ( ) ?;
87
+ let runtime = tokio:: runtime:: Builder :: new_multi_thread ( )
88
+ . worker_threads ( num_threads ( ) )
89
+ . thread_name ( "torchft-manager" )
90
+ . enable_all ( )
91
+ . build ( ) ?;
75
92
let manager = runtime
76
93
. block_on ( manager:: Manager :: new (
77
94
replica_id,
@@ -127,7 +144,11 @@ impl ManagerClient {
127
144
#[ new]
128
145
fn new ( py : Python < ' _ > , addr : String , connect_timeout : Duration ) -> PyResult < Self > {
129
146
py. allow_threads ( move || {
130
- let runtime = Runtime :: new ( ) ?;
147
+ let runtime = tokio:: runtime:: Builder :: new_multi_thread ( )
148
+ . worker_threads ( num_threads ( ) )
149
+ . thread_name ( "torchft-mgrclnt" )
150
+ . enable_all ( )
151
+ . build ( ) ?;
131
152
let client = runtime
132
153
. block_on ( manager:: manager_client_new ( addr, connect_timeout) )
133
154
. map_err ( |e| PyRuntimeError :: new_err ( e. to_string ( ) ) ) ?;
@@ -294,7 +315,11 @@ fn lighthouse_main(py: Python<'_>) -> PyResult<()> {
294
315
let mut args = env:: args ( ) ;
295
316
args. next ( ) ; // discard binary arg
296
317
let opt = lighthouse:: LighthouseOpt :: from_iter ( args) ;
297
- let rt = Runtime :: new ( ) ?;
318
+ let rt = tokio:: runtime:: Builder :: new_multi_thread ( )
319
+ . thread_name ( "torchft-lighths" )
320
+ . worker_threads ( num_threads ( ) )
321
+ . enable_all ( )
322
+ . build ( ) ?;
298
323
rt. block_on ( lighthouse_main_async ( opt) )
299
324
. map_err ( |e| PyRuntimeError :: new_err ( e. to_string ( ) ) ) ?;
300
325
Ok ( ( ) )
@@ -345,7 +370,11 @@ impl LighthouseServer {
345
370
let heartbeat_timeout_ms = heartbeat_timeout_ms. unwrap_or ( 5000 ) ;
346
371
347
372
py. allow_threads ( move || {
348
- let rt = Runtime :: new ( ) ?;
373
+ let rt = tokio:: runtime:: Builder :: new_multi_thread ( )
374
+ . worker_threads ( num_threads ( ) )
375
+ . thread_name ( "torchft-lighths" )
376
+ . enable_all ( )
377
+ . build ( ) ?;
349
378
350
379
let lighthouse = rt
351
380
. block_on ( lighthouse:: Lighthouse :: new ( lighthouse:: LighthouseOpt {
0 commit comments