Skip to content

Commit 5dd6f38

Browse files
authored
feat: expose lighthouse join timeout (#61)
* feat: expose lighthouse join timeout * lint * add join timeout test * expose quorum tick ms * blacken
1 parent 2ae42a0 commit 5dd6f38

File tree

3 files changed

+106
-4
lines changed

3 files changed

+106
-4
lines changed

src/lib.rs

+12-3
Original file line numberDiff line numberDiff line change
@@ -225,16 +225,25 @@ struct Lighthouse {
225225
#[pymethods]
226226
impl Lighthouse {
227227
#[new]
228-
fn new(py: Python<'_>, bind: String, min_replicas: u64) -> PyResult<Self> {
228+
fn new(
229+
py: Python<'_>,
230+
bind: String,
231+
min_replicas: u64,
232+
join_timeout_ms: Option<u64>,
233+
quorum_tick_ms: Option<u64>,
234+
) -> PyResult<Self> {
235+
let join_timeout_ms = join_timeout_ms.unwrap_or(100);
236+
let quorum_tick_ms = quorum_tick_ms.unwrap_or(100);
237+
229238
py.allow_threads(move || {
230239
let rt = Runtime::new()?;
231240

232241
let lighthouse = rt
233242
.block_on(lighthouse::Lighthouse::new(lighthouse::LighthouseOpt {
234243
bind: bind,
235244
min_replicas: min_replicas,
236-
join_timeout_ms: 100,
237-
quorum_tick_ms: 100,
245+
join_timeout_ms: join_timeout_ms,
246+
quorum_tick_ms: quorum_tick_ms,
238247
}))
239248
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
240249

torchft/lighthouse_test.py

+93
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import time
2+
from unittest import TestCase
3+
4+
import torch.distributed as dist
5+
6+
from torchft import Manager, ProcessGroupGloo
7+
from torchft.torchft import Lighthouse
8+
9+
10+
class TestLighthouse(TestCase):
11+
def test_join_timeout_behavior(self) -> None:
12+
"""Test that join_timeout_ms affects joining behavior"""
13+
# To test, we create a lighthouse with 100ms and 400ms join timeouts
14+
# and measure the time taken to validate the quorum.
15+
lighthouse = Lighthouse(
16+
bind="[::]:0",
17+
min_replicas=1,
18+
join_timeout_ms=100,
19+
)
20+
21+
# Create a manager that tries to join
22+
try:
23+
store = dist.TCPStore(
24+
host_name="localhost",
25+
port=0,
26+
is_master=True,
27+
wait_for_workers=False,
28+
)
29+
pg = ProcessGroupGloo()
30+
manager = Manager(
31+
pg=pg,
32+
min_replica_size=1,
33+
load_state_dict=lambda x: None,
34+
state_dict=lambda: None,
35+
replica_id=f"lighthouse_test",
36+
store_addr="localhost",
37+
store_port=store.port,
38+
rank=0,
39+
world_size=1,
40+
use_async_quorum=False,
41+
lighthouse_addr=lighthouse.address(),
42+
)
43+
44+
start_time = time.time()
45+
manager.start_quorum()
46+
time_taken = time.time() - start_time
47+
assert time_taken < 0.4, f"Time taken to join: {time_taken} > 0.4s"
48+
49+
finally:
50+
# Cleanup
51+
lighthouse.shutdown()
52+
if "manager" in locals():
53+
manager.shutdown()
54+
55+
lighthouse = Lighthouse(
56+
bind="[::]:0",
57+
min_replicas=1,
58+
join_timeout_ms=400,
59+
)
60+
61+
# Create a manager that tries to join
62+
try:
63+
store = dist.TCPStore(
64+
host_name="localhost",
65+
port=0,
66+
is_master=True,
67+
wait_for_workers=False,
68+
)
69+
pg = ProcessGroupGloo()
70+
manager = Manager(
71+
pg=pg,
72+
min_replica_size=1,
73+
load_state_dict=lambda x: None,
74+
state_dict=lambda: None,
75+
replica_id=f"lighthouse_test",
76+
store_addr="localhost",
77+
store_port=store.port,
78+
rank=0,
79+
world_size=1,
80+
use_async_quorum=False,
81+
lighthouse_addr=lighthouse.address(),
82+
)
83+
84+
start_time = time.time()
85+
manager.start_quorum()
86+
time_taken = time.time() - start_time
87+
assert time_taken > 0.4, f"Time taken to join: {time_taken} < 0.4s"
88+
89+
finally:
90+
# Cleanup
91+
lighthouse.shutdown()
92+
if "manager" in locals():
93+
manager.shutdown()

torchft/torchft.pyi

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,6 @@ class Manager:
2323
def shutdown(self) -> None: ...
2424

2525
class Lighthouse:
26-
def __init__(self, bind: str, min_replicas: int) -> None: ...
26+
def __init__(self, bind: str, min_replicas: int, join_timeout_ms: Optional[int] = None, quorum_tick_ms: Optional[int] = None) -> None: ...
2727
def address(self) -> str: ...
2828
def shutdown(self) -> None: ...

0 commit comments

Comments
 (0)