Skip to content

Commit 4c662fe

Browse files
committed
tokio: limit number of threads and set names
1 parent f0a4061 commit 4c662fe

File tree

2 files changed

+38
-5
lines changed

2 files changed

+38
-5
lines changed

src/lib.rs

+37-4
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@ use anyhow::Result;
1414
use atty::Stream;
1515
use core::time::Duration;
1616
use pyo3::exceptions::{PyRuntimeError, PyTimeoutError};
17+
use std::cmp;
1718
use std::env;
1819
use std::sync::Arc;
20+
use std::thread::available_parallelism;
1921
use structopt::StructOpt;
2022
use tokio::runtime::Runtime;
2123
use tokio::task::JoinHandle;
@@ -34,6 +36,21 @@ use crate::torchftpb::manager_service_client::ManagerServiceClient;
3436
use crate::torchftpb::{CheckpointMetadataRequest, ManagerQuorumRequest, ShouldCommitRequest};
3537
use pyo3::prelude::*;
3638

39+
// Get the number of threads to use for the tokio runtime
40+
fn num_threads() -> usize {
41+
let default_threads = 4;
42+
let num_cpus = available_parallelism()
43+
.and_then(|p| Ok(p.get()))
44+
.unwrap_or(default_threads);
45+
46+
let num_threads = env::var("TOKIO_WORKER_THREADS")
47+
.ok()
48+
.and_then(|s| s.parse().ok())
49+
.unwrap_or(cmp::min(default_threads, num_cpus));
50+
51+
num_threads
52+
}
53+
3754
/// ManagerServer is a GRPC server for the manager service.
3855
/// There should be one manager server per replica group (typically running on
3956
/// the rank 0 host). The individual ranks within a replica group should use
@@ -71,7 +88,11 @@ impl ManagerServer {
7188
connect_timeout: Duration,
7289
) -> PyResult<Self> {
7390
py.allow_threads(move || {
74-
let runtime = Runtime::new()?;
91+
let runtime = tokio::runtime::Builder::new_multi_thread()
92+
.worker_threads(num_threads())
93+
.thread_name("torchft-manager")
94+
.enable_all()
95+
.build()?;
7596
let manager = runtime
7697
.block_on(manager::Manager::new(
7798
replica_id,
@@ -127,7 +148,11 @@ impl ManagerClient {
127148
#[new]
128149
fn new(py: Python<'_>, addr: String, connect_timeout: Duration) -> PyResult<Self> {
129150
py.allow_threads(move || {
130-
let runtime = Runtime::new()?;
151+
let runtime = tokio::runtime::Builder::new_multi_thread()
152+
.worker_threads(num_threads())
153+
.thread_name("torchft-mgrclnt")
154+
.enable_all()
155+
.build()?;
131156
let client = runtime
132157
.block_on(manager::manager_client_new(addr, connect_timeout))
133158
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
@@ -294,7 +319,11 @@ fn lighthouse_main(py: Python<'_>) -> PyResult<()> {
294319
let mut args = env::args();
295320
args.next(); // discard binary arg
296321
let opt = lighthouse::LighthouseOpt::from_iter(args);
297-
let rt = Runtime::new()?;
322+
let rt = tokio::runtime::Builder::new_multi_thread()
323+
.thread_name("torchft-lighths")
324+
.worker_threads(num_threads())
325+
.enable_all()
326+
.build()?;
298327
rt.block_on(lighthouse_main_async(opt))
299328
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
300329
Ok(())
@@ -345,7 +374,11 @@ impl LighthouseServer {
345374
let heartbeat_timeout_ms = heartbeat_timeout_ms.unwrap_or(5000);
346375

347376
py.allow_threads(move || {
348-
let rt = Runtime::new()?;
377+
let rt = tokio::runtime::Builder::new_multi_thread()
378+
.worker_threads(num_threads())
379+
.thread_name("torchft-lighths")
380+
.enable_all()
381+
.build()?;
349382

350383
let lighthouse = rt
351384
.block_on(lighthouse::Lighthouse::new(lighthouse::LighthouseOpt {

src/manager.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ use std::{println as info, println as warn};
4242
macro_rules! info_with_replica {
4343
($replica_id:expr, $($arg:tt)*) => {{
4444
let parts: Vec<&str> = $replica_id.splitn(2, ':').collect();
45-
let formatted_message = if parts.len() == 2 {
45+
if parts.len() == 2 {
4646
// If there are two parts, use the replica name
4747
info!("[Replica {}] {}", parts[0], format!($($arg)*))
4848
} else {

0 commit comments

Comments
 (0)