From e36c99f236484a07114868a0a4990c95ba95cc8e Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Thu, 16 Jan 2025 14:45:35 -0800 Subject: [PATCH] use torchx for manual many replica (20+) tests --- .pyre_configuration | 4 ++ .torchxconfig | 3 ++ pyproject.toml | 3 +- src/lib.rs | 13 ++++-- src/lighthouse.rs | 28 ++++++++----- templates/status.html | 1 + torchft/manager_integ_test.py | 3 -- torchft/torchx.py | 76 +++++++++++++++++++++++++++++++++++ train_ddp.py | 2 + 9 files changed, 114 insertions(+), 19 deletions(-) create mode 100644 .torchxconfig create mode 100644 torchft/torchx.py diff --git a/.pyre_configuration b/.pyre_configuration index 0f2e67f..04ce7f0 100644 --- a/.pyre_configuration +++ b/.pyre_configuration @@ -6,5 +6,9 @@ "import_root": ".", "source": "torchft" } + ], + "search_path": [ + {"site-package": "torchx"}, + {"site-package": "parameterized"} ] } diff --git a/.torchxconfig b/.torchxconfig new file mode 100644 index 0000000..9e45c1d --- /dev/null +++ b/.torchxconfig @@ -0,0 +1,3 @@ +[cli:run] +component=torchft/torchx.py:hsdp +scheduler=local_cwd diff --git a/pyproject.toml b/pyproject.toml index 597ba27..b76e204 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,8 @@ dev = [ "pyre-check", "parameterized", "expecttest", - "numpy" + "numpy", + "torchx" ] [tool.maturin] diff --git a/src/lib.rs b/src/lib.rs index bfbae26..8d6db1b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -302,11 +302,16 @@ impl From for StatusError { #[pymodule] fn torchft(m: &Bound<'_, PyModule>) -> PyResult<()> { // setup logging on import - stderrlog::new() - .verbosity(2) + let mut log = stderrlog::new(); + log.verbosity(2) .show_module_names(true) - .timestamp(stderrlog::Timestamp::Millisecond) - .init() + .timestamp(stderrlog::Timestamp::Millisecond); + + if env::var("CLICOLOR_FORCE").is_ok() { + log.color(stderrlog::ColorChoice::AlwaysAnsi); + } + + log.init() .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; m.add_class::()?; diff --git a/src/lighthouse.rs b/src/lighthouse.rs index 10b51ad..e6be595 100644 --- a/src/lighthouse.rs +++ b/src/lighthouse.rs @@ -377,21 +377,26 @@ impl Lighthouse { let (_, quorum_status) = quorum_compute(Instant::now(), &state, &self.opt); - let max_step = { - if let Some(quorum) = state.prev_quorum.clone() { - quorum - .participants - .iter() - .map(|p| p.step) - .max() - .unwrap_or(-1) - } else { - -1 - } + let max_step = if let Some(quorum) = &state.prev_quorum { + quorum + .participants + .iter() + .map(|p| p.step) + .max() + .unwrap_or(-1) + } else { + -1 + }; + + let num_participants = if let Some(quorum) = &state.prev_quorum { + quorum.participants.len() as i64 + } else { + -1 }; StatusTemplate { quorum_id: state.quorum_id, + num_participants: num_participants, prev_quorum: state.prev_quorum.clone(), quorum_status: quorum_status, max_step: max_step, @@ -527,6 +532,7 @@ struct StatusTemplate { prev_quorum: Option, quorum_id: i64, quorum_status: String, + num_participants: i64, max_step: i64, heartbeats: HashMap, diff --git a/templates/status.html b/templates/status.html index 83ca845..f602137 100644 --- a/templates/status.html +++ b/templates/status.html @@ -6,6 +6,7 @@

Previous Quorum

{% if let Some(prev_quorum) = prev_quorum %} Previous quorum id: {{prev_quorum.quorum_id}}
+Num participants: {{num_participants}}
Quorum age: {{SystemTime::try_from(prev_quorum.created.unwrap()).unwrap().elapsed().unwrap().as_secs_f64()}}s diff --git a/torchft/manager_integ_test.py b/torchft/manager_integ_test.py index fb72496..d6e7bde 100644 --- a/torchft/manager_integ_test.py +++ b/torchft/manager_integ_test.py @@ -10,8 +10,6 @@ import torch import torch.distributed as dist - -# pyre-fixme[21]: missing module from parameterized import parameterized from torch import nn, optim @@ -292,7 +290,6 @@ def test_ddp_healthy(self) -> None: for state_dict in state_dicts: torch.testing.assert_close(state_dict, state_dicts[0]) - # pyre-fixme[56]: couldn't infer type of decorator @parameterized.expand( [ ( diff --git a/torchft/torchx.py b/torchft/torchx.py new file mode 100644 index 0000000..16c8fad --- /dev/null +++ b/torchft/torchx.py @@ -0,0 +1,76 @@ +""" +This is a file for TorchX components used for testing torchft. +""" + +import os +from typing import Dict, Optional + +import torchx.specs as specs + + +def hsdp( + *script_args: str, + replicas: int = 2, + workers_per_replica: int = 1, + max_restarts: int = 10, + script: str = "train_ddp.py", + env: Optional[Dict[str, str]] = None, + image: str = "", + h: Optional[str] = None, + cpu: int = 2, + gpu: int = 0, + memMB: int = 1024, +) -> specs.AppDef: + assert replicas > 0, "replicas must be > 0" + assert workers_per_replica > 0, "workers_per_replica must be > 0" + + env = env or {} + + # Enable logging for PyTorch, torchelastic and Rust. + env.setdefault("TORCH_CPP_LOG_LEVEL", "INFO") + env.setdefault("LOGLEVEL", "INFO") + env.setdefault("RUST_BACKTRACE", "1") + + # Enable colored logging for torchft Rust logger. + env.setdefault("CLICOLOR_FORCE", "1") + + # Set lighthouse address for replicas + # This must be run externally + env.setdefault( + "TORCHFT_LIGHTHOUSE", + os.environ.get("TORCHFT_LIGHTHOUSE", f"http://localhost:29510"), + ) + + # Disable CUDA for CPU-only jobs + env.setdefault("CUDA_VISIBLE_DEVICES", "") + + roles = [] + for replica_id in range(replicas): + cmd = [ + f"--master_port={29600+replica_id}", + "--nnodes=1", + f"--nproc_per_node={workers_per_replica}", + f"--max_restarts={max_restarts}", + ] + if script: + cmd += [script] + cmd += list(script_args) + + roles.append( + specs.Role( + name=f"replica_{replica_id}", + image=image, + min_replicas=workers_per_replica, + num_replicas=workers_per_replica, + resource=specs.resource(cpu=cpu, gpu=gpu, memMB=memMB, h=h), + max_retries=0, + env=env, + entrypoint="torchrun", + args=cmd, + ) + ) + + return specs.AppDef( + name="torchft", + roles=roles, + ) diff --git a/train_ddp.py b/train_ddp.py index c15d7e7..9ad9cc8 100644 --- a/train_ddp.py +++ b/train_ddp.py @@ -13,6 +13,7 @@ import torchvision import torchvision.transforms as transforms from torch import nn, optim +from torch.distributed.elastic.multiprocessing.errors import record from torchdata.stateful_dataloader import StatefulDataLoader from torchft import ( @@ -27,6 +28,7 @@ logging.basicConfig(level=logging.INFO) +@record def main() -> None: REPLICA_GROUP_ID = int(os.environ.get("REPLICA_GROUP_ID", 0)) NUM_REPLICA_GROUPS = int(os.environ.get("NUM_REPLICA_GROUPS", 2))