@@ -14,8 +14,10 @@ use anyhow::Result;
14
14
use atty:: Stream ;
15
15
use core:: time:: Duration ;
16
16
use pyo3:: exceptions:: { PyRuntimeError , PyTimeoutError } ;
17
+ use std:: cmp;
17
18
use std:: env;
18
19
use std:: sync:: Arc ;
20
+ use std:: thread:: available_parallelism;
19
21
use structopt:: StructOpt ;
20
22
use tokio:: runtime:: Runtime ;
21
23
use tokio:: task:: JoinHandle ;
@@ -34,6 +36,21 @@ use crate::torchftpb::manager_service_client::ManagerServiceClient;
34
36
use crate :: torchftpb:: { CheckpointMetadataRequest , ManagerQuorumRequest , ShouldCommitRequest } ;
35
37
use pyo3:: prelude:: * ;
36
38
39
+ // Get the number of threads to use for the tokio runtime
40
+ fn num_threads ( ) -> usize {
41
+ let default_threads = 4 ;
42
+ let num_cpus = available_parallelism ( )
43
+ . and_then ( |p| Ok ( p. get ( ) ) )
44
+ . unwrap_or ( default_threads) ;
45
+
46
+ let num_threads = env:: var ( "TOKIO_WORKER_THREADS" )
47
+ . ok ( )
48
+ . and_then ( |s| s. parse ( ) . ok ( ) )
49
+ . unwrap_or ( cmp:: min ( default_threads, num_cpus) ) ;
50
+
51
+ num_threads
52
+ }
53
+
37
54
/// ManagerServer is a GRPC server for the manager service.
38
55
/// There should be one manager server per replica group (typically running on
39
56
/// the rank 0 host). The individual ranks within a replica group should use
@@ -71,7 +88,11 @@ impl ManagerServer {
71
88
connect_timeout : Duration ,
72
89
) -> PyResult < Self > {
73
90
py. allow_threads ( move || {
74
- let runtime = Runtime :: new ( ) ?;
91
+ let runtime = tokio:: runtime:: Builder :: new_multi_thread ( )
92
+ . worker_threads ( num_threads ( ) )
93
+ . thread_name ( "torchft-manager" )
94
+ . enable_all ( )
95
+ . build ( ) ?;
75
96
let manager = runtime
76
97
. block_on ( manager:: Manager :: new (
77
98
replica_id,
@@ -127,7 +148,11 @@ impl ManagerClient {
127
148
#[ new]
128
149
fn new ( py : Python < ' _ > , addr : String , connect_timeout : Duration ) -> PyResult < Self > {
129
150
py. allow_threads ( move || {
130
- let runtime = Runtime :: new ( ) ?;
151
+ let runtime = tokio:: runtime:: Builder :: new_multi_thread ( )
152
+ . worker_threads ( num_threads ( ) )
153
+ . thread_name ( "torchft-mgrclnt" )
154
+ . enable_all ( )
155
+ . build ( ) ?;
131
156
let client = runtime
132
157
. block_on ( manager:: manager_client_new ( addr, connect_timeout) )
133
158
. map_err ( |e| PyRuntimeError :: new_err ( e. to_string ( ) ) ) ?;
@@ -294,7 +319,11 @@ fn lighthouse_main(py: Python<'_>) -> PyResult<()> {
294
319
let mut args = env:: args ( ) ;
295
320
args. next ( ) ; // discard binary arg
296
321
let opt = lighthouse:: LighthouseOpt :: from_iter ( args) ;
297
- let rt = Runtime :: new ( ) ?;
322
+ let rt = tokio:: runtime:: Builder :: new_multi_thread ( )
323
+ . thread_name ( "torchft-lighths" )
324
+ . worker_threads ( num_threads ( ) )
325
+ . enable_all ( )
326
+ . build ( ) ?;
298
327
rt. block_on ( lighthouse_main_async ( opt) )
299
328
. map_err ( |e| PyRuntimeError :: new_err ( e. to_string ( ) ) ) ?;
300
329
Ok ( ( ) )
@@ -345,7 +374,11 @@ impl LighthouseServer {
345
374
let heartbeat_timeout_ms = heartbeat_timeout_ms. unwrap_or ( 5000 ) ;
346
375
347
376
py. allow_threads ( move || {
348
- let rt = Runtime :: new ( ) ?;
377
+ let rt = tokio:: runtime:: Builder :: new_multi_thread ( )
378
+ . worker_threads ( num_threads ( ) )
379
+ . thread_name ( "torchft-lighths" )
380
+ . enable_all ( )
381
+ . build ( ) ?;
349
382
350
383
let lighthouse = rt
351
384
. block_on ( lighthouse:: Lighthouse :: new ( lighthouse:: LighthouseOpt {
0 commit comments