Skip to content

Commit 7b93da7

Browse files
authored
manager_integ_tests: added Python integration test with lighthouse (#27)
1 parent ddbc3c9 commit 7b93da7

File tree

5 files changed

+176
-5
lines changed

5 files changed

+176
-5
lines changed

.lintrunner.toml

+2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
[[linter]]
22
code = 'BLACK-ISORT'
33
include_patterns = [
4+
'*.py',
45
'**/*.py',
56
]
67
exclude_patterns = []
@@ -46,6 +47,7 @@ command = [
4647
[[linter]]
4748
code = 'PYRE'
4849
include_patterns = [
50+
'*.py',
4951
'**/*.py',
5052
'**/*.pyi',
5153
]

src/lib.rs

+55-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ pub mod manager;
99

1010
use core::time::Duration;
1111
use std::env;
12+
use std::sync::Arc;
1213

1314
use anyhow::Result;
1415
use pyo3::exceptions::PyRuntimeError;
@@ -28,6 +29,8 @@ use pyo3::prelude::*;
2829
#[pyclass]
2930
struct Manager {
3031
handle: JoinHandle<Result<()>>,
32+
manager: Arc<manager::Manager>,
33+
_runtime: Runtime,
3134
}
3235

3336
#[pymethods]
@@ -55,10 +58,18 @@ impl Manager {
5558
))
5659
.unwrap();
5760
let handle = runtime.spawn(manager.clone().run());
58-
Self { handle: handle }
61+
Self {
62+
handle: handle,
63+
manager: manager,
64+
_runtime: runtime,
65+
}
5966
})
6067
}
6168

69+
fn address(&self) -> PyResult<String> {
70+
Ok(self.manager.address().to_string())
71+
}
72+
6273
fn shutdown(&self, py: Python<'_>) {
6374
py.allow_threads(move || {
6475
self.handle.abort();
@@ -200,6 +211,48 @@ async fn lighthouse_main_async(opt: lighthouse::LighthouseOpt) -> Result<()> {
200211
Ok(())
201212
}
202213

214+
#[pyclass]
215+
struct Lighthouse {
216+
lighthouse: Arc<lighthouse::Lighthouse>,
217+
handle: JoinHandle<Result<()>>,
218+
_runtime: Runtime,
219+
}
220+
221+
#[pymethods]
222+
impl Lighthouse {
223+
#[new]
224+
fn new(py: Python<'_>, bind: String, min_replicas: u64) -> PyResult<Self> {
225+
py.allow_threads(move || {
226+
let rt = Runtime::new().unwrap();
227+
228+
let lighthouse = rt
229+
.block_on(lighthouse::Lighthouse::new(lighthouse::LighthouseOpt {
230+
bind: bind,
231+
min_replicas: min_replicas,
232+
join_timeout_ms: 100,
233+
quorum_tick_ms: 100,
234+
}))
235+
.unwrap();
236+
237+
Ok(Self {
238+
handle: rt.spawn(lighthouse.clone().run()),
239+
lighthouse: lighthouse,
240+
_runtime: rt,
241+
})
242+
})
243+
}
244+
245+
fn address(&self) -> PyResult<String> {
246+
Ok(self.lighthouse.address().to_string())
247+
}
248+
249+
fn shutdown(&self, py: Python<'_>) {
250+
py.allow_threads(move || {
251+
self.handle.abort();
252+
})
253+
}
254+
}
255+
203256
#[pymodule]
204257
fn torchft(m: &Bound<'_, PyModule>) -> PyResult<()> {
205258
// setup logging on import
@@ -212,6 +265,7 @@ fn torchft(m: &Bound<'_, PyModule>) -> PyResult<()> {
212265

213266
m.add_class::<Manager>()?;
214267
m.add_class::<ManagerClient>()?;
268+
m.add_class::<Lighthouse>()?;
215269
m.add_function(wrap_pyfunction!(lighthouse_main, m)?)?;
216270

217271
Ok(())

torchft/manager.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,6 @@
3939
from torch.optim import Optimizer
4040

4141
from torchft.checkpointing import CheckpointServer
42-
43-
# pyre-fixme[21]: can't find rust module
4442
from torchft.torchft import Manager as _Manager, ManagerClient
4543

4644
if TYPE_CHECKING:
@@ -121,7 +119,7 @@ def __init__(
121119

122120
store_addr = store_addr or os.environ["MASTER_ADDR"]
123121
store_port = store_port or int(os.environ["MASTER_PORT"])
124-
self._rank: int = rank or int(os.environ["RANK"])
122+
self._rank: int = rank if rank is not None else int(os.environ["RANK"])
125123
rank = self._rank
126124
world_size = world_size or int(os.environ["WORLD_SIZE"])
127125
self._min_replica_size = min_replica_size
@@ -151,7 +149,6 @@ def __init__(
151149

152150
if replica_id is None:
153151
replica_id = str(uuid.uuid4())
154-
# pyre-fixme[16]: can't find rust module
155152
self._manager = _Manager(
156153
replica_id=replica_id,
157154
lighthouse_addr=lighthouse_addr,

torchft/manager_integ_test.py

+100
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
from concurrent.futures import ThreadPoolExecutor, as_completed
2+
from unittest import TestCase
3+
4+
import torch
5+
import torch.distributed as dist
6+
from torch import nn, optim
7+
8+
from torchft.ddp import DistributedDataParallel
9+
from torchft.manager import Manager
10+
from torchft.optim import OptimizerWrapper
11+
from torchft.process_group import ProcessGroupGloo
12+
from torchft.torchft import Lighthouse
13+
14+
15+
class MyModel(nn.Module):
16+
def __init__(self):
17+
super().__init__()
18+
self.model = nn.Sequential(
19+
nn.Linear(3, 4),
20+
nn.Sigmoid(),
21+
)
22+
23+
def forward(self, x):
24+
return self.model(x)
25+
26+
27+
def train_loop(replica_id: int, lighthouse_address: str) -> None:
28+
store = dist.TCPStore(
29+
host_name="localhost",
30+
port=0,
31+
is_master=True,
32+
wait_for_workers=False,
33+
)
34+
35+
def load_state_dict(state_dict):
36+
m.load_state_dict(state_dict["model"])
37+
optimizer.load_state_dict(state_dict["optim"])
38+
39+
def state_dict():
40+
return {
41+
"model": m.state_dict(),
42+
"optim": optimizer.state_dict(),
43+
}
44+
45+
pg = ProcessGroupGloo()
46+
manager = Manager(
47+
pg=pg,
48+
min_replica_size=2,
49+
load_state_dict=load_state_dict,
50+
state_dict=state_dict,
51+
replica_id=str(replica_id),
52+
store_addr="localhost",
53+
store_port=store.port,
54+
rank=0,
55+
world_size=1,
56+
lighthouse_addr=lighthouse_address,
57+
port=19530 + replica_id,
58+
)
59+
m = DistributedDataParallel(manager, MyModel())
60+
optimizer = OptimizerWrapper(manager, optim.Adam(m.parameters()))
61+
criterion = nn.CrossEntropyLoss()
62+
63+
while True:
64+
inputs = torch.rand(2, 3)
65+
labels = torch.randint(4, (2,))
66+
67+
optimizer.zero_grad()
68+
out = m(inputs)
69+
loss = criterion(out, labels)
70+
71+
loss.backward()
72+
optimizer.step()
73+
74+
# TODO: assert weights are equal across replicas
75+
76+
if manager.current_step() >= 5:
77+
break
78+
79+
manager.shutdown()
80+
81+
82+
class ManagerIntegTest(TestCase):
83+
def test_ddp(self):
84+
lighthouse = Lighthouse(
85+
bind="[::]:0",
86+
min_replicas=2,
87+
)
88+
num_replicas = 2
89+
futures = []
90+
91+
with ThreadPoolExecutor(max_workers=num_replicas) as executor:
92+
for replica_id in range(num_replicas):
93+
futures.append(
94+
executor.submit(train_loop, replica_id, lighthouse.address())
95+
)
96+
97+
for fut in as_completed(futures):
98+
fut.result()
99+
100+
lighthouse.shutdown()

torchft/torchft.pyi

+18
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,21 @@ class ManagerClient:
88
) -> Tuple[int, int, int, str, str, int, Optional[int], int, bool]: ...
99
def checkpoint_address(self, rank: int) -> str: ...
1010
def should_commit(self, rank: int, step: int, should_commit: bool) -> bool: ...
11+
12+
class Manager:
13+
def __init__(
14+
self,
15+
replica_id: str,
16+
lighthouse_addr: str,
17+
address: str,
18+
bind: str,
19+
store_addr: str,
20+
world_size: int,
21+
) -> None: ...
22+
def address(self) -> str: ...
23+
def shutdown(self) -> None: ...
24+
25+
class Lighthouse:
26+
def __init__(self, bind: str, min_replicas: int) -> None: ...
27+
def address(self) -> str: ...
28+
def shutdown(self) -> None: ...

0 commit comments

Comments
 (0)