Skip to content

Commit

Permalink
use torchx for manual many replica (20+) tests
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k committed Jan 16, 2025
1 parent 3ee2360 commit e36c99f
Show file tree
Hide file tree
Showing 9 changed files with 114 additions and 19 deletions.
4 changes: 4 additions & 0 deletions .pyre_configuration
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,9 @@
"import_root": ".",
"source": "torchft"
}
],
"search_path": [
{"site-package": "torchx"},
{"site-package": "parameterized"}
]
}
3 changes: 3 additions & 0 deletions .torchxconfig
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[cli:run]
component=torchft/torchx.py:hsdp
scheduler=local_cwd
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ dev = [
"pyre-check",
"parameterized",
"expecttest",
"numpy"
"numpy",
"torchx"
]

[tool.maturin]
Expand Down
13 changes: 9 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,11 +302,16 @@ impl From<Status> for StatusError {
#[pymodule]
fn torchft(m: &Bound<'_, PyModule>) -> PyResult<()> {
// setup logging on import
stderrlog::new()
.verbosity(2)
let mut log = stderrlog::new();
log.verbosity(2)
.show_module_names(true)
.timestamp(stderrlog::Timestamp::Millisecond)
.init()
.timestamp(stderrlog::Timestamp::Millisecond);

if env::var("CLICOLOR_FORCE").is_ok() {
log.color(stderrlog::ColorChoice::AlwaysAnsi);
}

log.init()
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;

m.add_class::<Manager>()?;
Expand Down
28 changes: 17 additions & 11 deletions src/lighthouse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -377,21 +377,26 @@ impl Lighthouse {

let (_, quorum_status) = quorum_compute(Instant::now(), &state, &self.opt);

let max_step = {
if let Some(quorum) = state.prev_quorum.clone() {
quorum
.participants
.iter()
.map(|p| p.step)
.max()
.unwrap_or(-1)
} else {
-1
}
let max_step = if let Some(quorum) = &state.prev_quorum {
quorum
.participants
.iter()
.map(|p| p.step)
.max()
.unwrap_or(-1)
} else {
-1
};

let num_participants = if let Some(quorum) = &state.prev_quorum {
quorum.participants.len() as i64
} else {
-1
};

StatusTemplate {
quorum_id: state.quorum_id,
num_participants: num_participants,
prev_quorum: state.prev_quorum.clone(),
quorum_status: quorum_status,
max_step: max_step,
Expand Down Expand Up @@ -527,6 +532,7 @@ struct StatusTemplate {
prev_quorum: Option<Quorum>,
quorum_id: i64,
quorum_status: String,
num_participants: i64,
max_step: i64,
heartbeats: HashMap<String, Instant>,

Expand Down
1 change: 1 addition & 0 deletions templates/status.html
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ <h3>Previous Quorum</h3>
{% if let Some(prev_quorum) = prev_quorum %}

Previous quorum id: {{prev_quorum.quorum_id}} <br>
Num participants: {{num_participants}} <br>
Quorum age:
{{SystemTime::try_from(prev_quorum.created.unwrap()).unwrap().elapsed().unwrap().as_secs_f64()}}s

Expand Down
3 changes: 0 additions & 3 deletions torchft/manager_integ_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@

import torch
import torch.distributed as dist

# pyre-fixme[21]: missing module
from parameterized import parameterized
from torch import nn, optim

Expand Down Expand Up @@ -292,7 +290,6 @@ def test_ddp_healthy(self) -> None:
for state_dict in state_dicts:
torch.testing.assert_close(state_dict, state_dicts[0])

# pyre-fixme[56]: couldn't infer type of decorator
@parameterized.expand(
[
(
Expand Down
76 changes: 76 additions & 0 deletions torchft/torchx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""
This is a file for TorchX components used for testing torchft.
"""

import os
from typing import Dict, Optional

import torchx.specs as specs


def hsdp(
*script_args: str,
replicas: int = 2,
workers_per_replica: int = 1,
max_restarts: int = 10,
script: str = "train_ddp.py",
env: Optional[Dict[str, str]] = None,
image: str = "",
h: Optional[str] = None,
cpu: int = 2,
gpu: int = 0,
memMB: int = 1024,
) -> specs.AppDef:
assert replicas > 0, "replicas must be > 0"
assert workers_per_replica > 0, "workers_per_replica must be > 0"

env = env or {}

# Enable logging for PyTorch, torchelastic and Rust.
env.setdefault("TORCH_CPP_LOG_LEVEL", "INFO")
env.setdefault("LOGLEVEL", "INFO")
env.setdefault("RUST_BACKTRACE", "1")

# Enable colored logging for torchft Rust logger.
env.setdefault("CLICOLOR_FORCE", "1")

# Set lighthouse address for replicas
# This must be run externally
env.setdefault(
"TORCHFT_LIGHTHOUSE",
os.environ.get("TORCHFT_LIGHTHOUSE", f"http://localhost:29510"),
)

# Disable CUDA for CPU-only jobs
env.setdefault("CUDA_VISIBLE_DEVICES", "")

roles = []
for replica_id in range(replicas):
cmd = [
f"--master_port={29600+replica_id}",
"--nnodes=1",
f"--nproc_per_node={workers_per_replica}",
f"--max_restarts={max_restarts}",
]
if script:
cmd += [script]
cmd += list(script_args)

roles.append(
specs.Role(
name=f"replica_{replica_id}",
image=image,
min_replicas=workers_per_replica,
num_replicas=workers_per_replica,
resource=specs.resource(cpu=cpu, gpu=gpu, memMB=memMB, h=h),
max_retries=0,
env=env,
entrypoint="torchrun",
args=cmd,
)
)

return specs.AppDef(
name="torchft",
roles=roles,
)
2 changes: 2 additions & 0 deletions train_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import torchvision
import torchvision.transforms as transforms
from torch import nn, optim
from torch.distributed.elastic.multiprocessing.errors import record
from torchdata.stateful_dataloader import StatefulDataLoader

from torchft import (
Expand All @@ -27,6 +28,7 @@
logging.basicConfig(level=logging.INFO)


@record
def main() -> None:
REPLICA_GROUP_ID = int(os.environ.get("REPLICA_GROUP_ID", 0))
NUM_REPLICA_GROUPS = int(os.environ.get("NUM_REPLICA_GROUPS", 2))
Expand Down

0 comments on commit e36c99f

Please sign in to comment.