Skip to content

Commit 39a40b2

Browse files
authored
use torchx for manual many replica (20+) tests (#75)
1 parent 3ee2360 commit 39a40b2

10 files changed

+129
-19
lines changed

.pyre_configuration

+4
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,9 @@
66
"import_root": ".",
77
"source": "torchft"
88
}
9+
],
10+
"search_path": [
11+
{"site-package": "torchx"},
12+
{"site-package": "parameterized"}
913
]
1014
}

.torchxconfig

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[cli:run]
2+
component=torchft/torchx.py:hsdp
3+
scheduler=local_cwd

CONTRIBUTING.md

+15
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,21 @@ make livehtml
9696
The docs will be built in the `docs/build/html` directory and served at http://localhost:8000.
9797
The page will be automatically re-built as long as the process is kept running.
9898

99+
### Running Multiple Replica Local Job
100+
101+
We use torchx to run multiple worker local test jobs. You need to run the
102+
lighthouse first and then you can use torchx to launch as many replica groups as
103+
you want. This uses the [train_ddp.py](./train_ddp.py) script.
104+
105+
```sh
106+
$ torchft_lighthouse --min_replicas 2 --join_timeout_ms 10000 &
107+
$ torchx run -- --replicas 10
108+
```
109+
110+
Once the Lighthouse has started you can view the status of all the workers at the Lighthouse dashboard.
111+
112+
Default address is: http://localhost:29510
113+
99114
## Contributor License Agreement ("CLA")
100115

101116
In order to accept your pull request, we need you to submit a CLA. You only need to do this once to work on any of

pyproject.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ dev = [
2727
"pyre-check",
2828
"parameterized",
2929
"expecttest",
30-
"numpy"
30+
"numpy",
31+
"torchx"
3132
]
3233

3334
[tool.maturin]

src/lib.rs

+9-4
Original file line numberDiff line numberDiff line change
@@ -302,11 +302,16 @@ impl From<Status> for StatusError {
302302
#[pymodule]
303303
fn torchft(m: &Bound<'_, PyModule>) -> PyResult<()> {
304304
// setup logging on import
305-
stderrlog::new()
306-
.verbosity(2)
305+
let mut log = stderrlog::new();
306+
log.verbosity(2)
307307
.show_module_names(true)
308-
.timestamp(stderrlog::Timestamp::Millisecond)
309-
.init()
308+
.timestamp(stderrlog::Timestamp::Millisecond);
309+
310+
if env::var("CLICOLOR_FORCE").is_ok() {
311+
log.color(stderrlog::ColorChoice::AlwaysAnsi);
312+
}
313+
314+
log.init()
310315
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
311316

312317
m.add_class::<Manager>()?;

src/lighthouse.rs

+17-11
Original file line numberDiff line numberDiff line change
@@ -377,21 +377,26 @@ impl Lighthouse {
377377

378378
let (_, quorum_status) = quorum_compute(Instant::now(), &state, &self.opt);
379379

380-
let max_step = {
381-
if let Some(quorum) = state.prev_quorum.clone() {
382-
quorum
383-
.participants
384-
.iter()
385-
.map(|p| p.step)
386-
.max()
387-
.unwrap_or(-1)
388-
} else {
389-
-1
390-
}
380+
let max_step = if let Some(quorum) = &state.prev_quorum {
381+
quorum
382+
.participants
383+
.iter()
384+
.map(|p| p.step)
385+
.max()
386+
.unwrap_or(-1)
387+
} else {
388+
-1
389+
};
390+
391+
let num_participants = if let Some(quorum) = &state.prev_quorum {
392+
quorum.participants.len() as i64
393+
} else {
394+
-1
391395
};
392396

393397
StatusTemplate {
394398
quorum_id: state.quorum_id,
399+
num_participants: num_participants,
395400
prev_quorum: state.prev_quorum.clone(),
396401
quorum_status: quorum_status,
397402
max_step: max_step,
@@ -527,6 +532,7 @@ struct StatusTemplate {
527532
prev_quorum: Option<Quorum>,
528533
quorum_id: i64,
529534
quorum_status: String,
535+
num_participants: i64,
530536
max_step: i64,
531537
heartbeats: HashMap<String, Instant>,
532538

templates/status.html

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ <h3>Previous Quorum</h3>
66
{% if let Some(prev_quorum) = prev_quorum %}
77

88
Previous quorum id: {{prev_quorum.quorum_id}} <br>
9+
Num participants: {{num_participants}} <br>
910
Quorum age:
1011
{{SystemTime::try_from(prev_quorum.created.unwrap()).unwrap().elapsed().unwrap().as_secs_f64()}}s
1112

torchft/manager_integ_test.py

-3
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010

1111
import torch
1212
import torch.distributed as dist
13-
14-
# pyre-fixme[21]: missing module
1513
from parameterized import parameterized
1614
from torch import nn, optim
1715

@@ -292,7 +290,6 @@ def test_ddp_healthy(self) -> None:
292290
for state_dict in state_dicts:
293291
torch.testing.assert_close(state_dict, state_dicts[0])
294292

295-
# pyre-fixme[56]: couldn't infer type of decorator
296293
@parameterized.expand(
297294
[
298295
(

torchft/torchx.py

+76
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
"""
2+
This is a file for TorchX components used for testing torchft.
3+
"""
4+
5+
import os
6+
from typing import Dict, Optional
7+
8+
import torchx.specs as specs
9+
10+
11+
def hsdp(
12+
*script_args: str,
13+
replicas: int = 2,
14+
workers_per_replica: int = 1,
15+
max_restarts: int = 10,
16+
script: str = "train_ddp.py",
17+
env: Optional[Dict[str, str]] = None,
18+
image: str = "",
19+
h: Optional[str] = None,
20+
cpu: int = 2,
21+
gpu: int = 0,
22+
memMB: int = 1024,
23+
) -> specs.AppDef:
24+
assert replicas > 0, "replicas must be > 0"
25+
assert workers_per_replica > 0, "workers_per_replica must be > 0"
26+
27+
env = env or {}
28+
29+
# Enable logging for PyTorch, torchelastic and Rust.
30+
env.setdefault("TORCH_CPP_LOG_LEVEL", "INFO")
31+
env.setdefault("LOGLEVEL", "INFO")
32+
env.setdefault("RUST_BACKTRACE", "1")
33+
34+
# Enable colored logging for torchft Rust logger.
35+
env.setdefault("CLICOLOR_FORCE", "1")
36+
37+
# Set lighthouse address for replicas
38+
# This must be run externally
39+
env.setdefault(
40+
"TORCHFT_LIGHTHOUSE",
41+
os.environ.get("TORCHFT_LIGHTHOUSE", f"http://localhost:29510"),
42+
)
43+
44+
# Disable CUDA for CPU-only jobs
45+
env.setdefault("CUDA_VISIBLE_DEVICES", "")
46+
47+
roles = []
48+
for replica_id in range(replicas):
49+
cmd = [
50+
f"--master_port={29600+replica_id}",
51+
"--nnodes=1",
52+
f"--nproc_per_node={workers_per_replica}",
53+
f"--max_restarts={max_restarts}",
54+
]
55+
if script:
56+
cmd += [script]
57+
cmd += list(script_args)
58+
59+
roles.append(
60+
specs.Role(
61+
name=f"replica_{replica_id}",
62+
image=image,
63+
min_replicas=workers_per_replica,
64+
num_replicas=workers_per_replica,
65+
resource=specs.resource(cpu=cpu, gpu=gpu, memMB=memMB, h=h),
66+
max_retries=0,
67+
env=env,
68+
entrypoint="torchrun",
69+
args=cmd,
70+
)
71+
)
72+
73+
return specs.AppDef(
74+
name="torchft",
75+
roles=roles,
76+
)

train_ddp.py

+2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import torchvision
1414
import torchvision.transforms as transforms
1515
from torch import nn, optim
16+
from torch.distributed.elastic.multiprocessing.errors import record
1617
from torchdata.stateful_dataloader import StatefulDataLoader
1718

1819
from torchft import (
@@ -27,6 +28,7 @@
2728
logging.basicConfig(level=logging.INFO)
2829

2930

31+
@record
3032
def main() -> None:
3133
REPLICA_GROUP_ID = int(os.environ.get("REPLICA_GROUP_ID", 0))
3234
NUM_REPLICA_GROUPS = int(os.environ.get("NUM_REPLICA_GROUPS", 2))

0 commit comments

Comments
 (0)