Skip to content

Commit 8fd028c

Browse files
committed
tokio: limit number of threads and set names
1 parent f0a4061 commit 8fd028c

File tree

2 files changed

+34
-5
lines changed

2 files changed

+34
-5
lines changed

src/lib.rs

+33-4
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@ use anyhow::Result;
1414
use atty::Stream;
1515
use core::time::Duration;
1616
use pyo3::exceptions::{PyRuntimeError, PyTimeoutError};
17+
use std::cmp;
1718
use std::env;
1819
use std::sync::Arc;
20+
use std::thread::available_parallelism;
1921
use structopt::StructOpt;
2022
use tokio::runtime::Runtime;
2123
use tokio::task::JoinHandle;
@@ -34,6 +36,17 @@ use crate::torchftpb::manager_service_client::ManagerServiceClient;
3436
use crate::torchftpb::{CheckpointMetadataRequest, ManagerQuorumRequest, ShouldCommitRequest};
3537
use pyo3::prelude::*;
3638

39+
// Get the number of threads to use for the tokio runtime
40+
fn num_threads() -> usize {
41+
let default_threads = cmp::min(4, available_parallelism().unwrap().get());
42+
let num_threads = env::var("TOKIO_WORKER_THREADS")
43+
.ok()
44+
.and_then(|s| s.parse().ok())
45+
.unwrap_or(default_threads);
46+
47+
num_threads
48+
}
49+
3750
/// ManagerServer is a GRPC server for the manager service.
3851
/// There should be one manager server per replica group (typically running on
3952
/// the rank 0 host). The individual ranks within a replica group should use
@@ -71,7 +84,11 @@ impl ManagerServer {
7184
connect_timeout: Duration,
7285
) -> PyResult<Self> {
7386
py.allow_threads(move || {
74-
let runtime = Runtime::new()?;
87+
let runtime = tokio::runtime::Builder::new_multi_thread()
88+
.worker_threads(num_threads())
89+
.thread_name("torchft-manager")
90+
.enable_all()
91+
.build()?;
7592
let manager = runtime
7693
.block_on(manager::Manager::new(
7794
replica_id,
@@ -127,7 +144,11 @@ impl ManagerClient {
127144
#[new]
128145
fn new(py: Python<'_>, addr: String, connect_timeout: Duration) -> PyResult<Self> {
129146
py.allow_threads(move || {
130-
let runtime = Runtime::new()?;
147+
let runtime = tokio::runtime::Builder::new_multi_thread()
148+
.worker_threads(num_threads())
149+
.thread_name("torchft-mgrclnt")
150+
.enable_all()
151+
.build()?;
131152
let client = runtime
132153
.block_on(manager::manager_client_new(addr, connect_timeout))
133154
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
@@ -294,7 +315,11 @@ fn lighthouse_main(py: Python<'_>) -> PyResult<()> {
294315
let mut args = env::args();
295316
args.next(); // discard binary arg
296317
let opt = lighthouse::LighthouseOpt::from_iter(args);
297-
let rt = Runtime::new()?;
318+
let rt = tokio::runtime::Builder::new_multi_thread()
319+
.thread_name("torchft-lighths")
320+
.worker_threads(num_threads())
321+
.enable_all()
322+
.build()?;
298323
rt.block_on(lighthouse_main_async(opt))
299324
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
300325
Ok(())
@@ -345,7 +370,11 @@ impl LighthouseServer {
345370
let heartbeat_timeout_ms = heartbeat_timeout_ms.unwrap_or(5000);
346371

347372
py.allow_threads(move || {
348-
let rt = Runtime::new()?;
373+
let rt = tokio::runtime::Builder::new_multi_thread()
374+
.worker_threads(num_threads())
375+
.thread_name("torchft-lighths")
376+
.enable_all()
377+
.build()?;
349378

350379
let lighthouse = rt
351380
.block_on(lighthouse::Lighthouse::new(lighthouse::LighthouseOpt {

src/manager.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ use std::{println as info, println as warn};
4242
macro_rules! info_with_replica {
4343
($replica_id:expr, $($arg:tt)*) => {{
4444
let parts: Vec<&str> = $replica_id.splitn(2, ':').collect();
45-
let formatted_message = if parts.len() == 2 {
45+
if parts.len() == 2 {
4646
// If there are two parts, use the replica name
4747
info!("[Replica {}] {}", parts[0], format!($($arg)*))
4848
} else {

0 commit comments

Comments
 (0)