From 064c1412d38c4623e21dd664dc234e032e50402b Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Wed, 8 Jan 2025 08:56:54 +0000 Subject: [PATCH 1/5] feat: expose lighthouse join timeout --- src/lib.rs | 6 ++++-- torchft/torchft.pyi | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index e4d84a0..34c7cc4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -225,7 +225,9 @@ struct Lighthouse { #[pymethods] impl Lighthouse { #[new] - fn new(py: Python<'_>, bind: String, min_replicas: u64) -> PyResult { + fn new(py: Python<'_>, bind: String, min_replicas: u64, join_timeout_ms: Option) -> PyResult { + let join_timeout_ms = join_timeout_ms.unwrap_or(100); + py.allow_threads(move || { let rt = Runtime::new()?; @@ -233,7 +235,7 @@ impl Lighthouse { .block_on(lighthouse::Lighthouse::new(lighthouse::LighthouseOpt { bind: bind, min_replicas: min_replicas, - join_timeout_ms: 100, + join_timeout_ms: join_timeout_ms, quorum_tick_ms: 100, })) .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; diff --git a/torchft/torchft.pyi b/torchft/torchft.pyi index aee2947..4249ac3 100644 --- a/torchft/torchft.pyi +++ b/torchft/torchft.pyi @@ -23,6 +23,6 @@ class Manager: def shutdown(self) -> None: ... class Lighthouse: - def __init__(self, bind: str, min_replicas: int) -> None: ... + def __init__(self, bind: str, min_replicas: int, join_timeout_ms: Optional[int]) -> None: ... def address(self) -> str: ... def shutdown(self) -> None: ... From 91624dd888115c8e7574712d3e4cdfe2e2610b85 Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Wed, 8 Jan 2025 09:18:08 +0000 Subject: [PATCH 2/5] lint --- src/lib.rs | 7 ++++++- torchft/torchft.pyi | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 34c7cc4..5f64e09 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -225,7 +225,12 @@ struct Lighthouse { #[pymethods] impl Lighthouse { #[new] - fn new(py: Python<'_>, bind: String, min_replicas: u64, join_timeout_ms: Option) -> PyResult { + fn new( + py: Python<'_>, + bind: String, + min_replicas: u64, + join_timeout_ms: Option, + ) -> PyResult { let join_timeout_ms = join_timeout_ms.unwrap_or(100); py.allow_threads(move || { diff --git a/torchft/torchft.pyi b/torchft/torchft.pyi index 4249ac3..f71113e 100644 --- a/torchft/torchft.pyi +++ b/torchft/torchft.pyi @@ -23,6 +23,6 @@ class Manager: def shutdown(self) -> None: ... class Lighthouse: - def __init__(self, bind: str, min_replicas: int, join_timeout_ms: Optional[int]) -> None: ... + def __init__(self, bind: str, min_replicas: int, join_timeout_ms: Optional[int] = None) -> None: ... def address(self) -> str: ... def shutdown(self) -> None: ... From a41eb1c24b9ac5f9396bbb8aa85511334405754b Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Wed, 8 Jan 2025 23:54:48 +0000 Subject: [PATCH 3/5] add join timeout test --- torchft/lighthouse_test.py | 91 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 91 insertions(+) create mode 100644 torchft/lighthouse_test.py diff --git a/torchft/lighthouse_test.py b/torchft/lighthouse_test.py new file mode 100644 index 0000000..c1e83b3 --- /dev/null +++ b/torchft/lighthouse_test.py @@ -0,0 +1,91 @@ +import torch.distributed as dist +from unittest import TestCase +from torchft.torchft import Lighthouse +from torchft import Manager, ProcessGroupGloo +import time + +class TestLighthouse(TestCase): + def test_join_timeout_behavior(self) -> None: + """Test that join_timeout_ms affects joining behavior""" + # To test, we create a lighthouse with 100ms and 400ms join timeouts + # and measure the time taken to validate the quorum. + lighthouse = Lighthouse( + bind="[::]:0", + min_replicas=1, + join_timeout_ms=100, + ) + + # Create a manager that tries to join + try: + store = dist.TCPStore( + host_name="localhost", + port=0, + is_master=True, + wait_for_workers=False, + ) + pg = ProcessGroupGloo() + manager = Manager( + pg=pg, + min_replica_size=1, + load_state_dict=lambda x: None, + state_dict=lambda: None, + replica_id=f"lighthouse_test", + store_addr="localhost", + store_port=store.port, + rank=0, + world_size=1, + use_async_quorum=False, + lighthouse_addr=lighthouse.address(), + ) + + start_time = time.time() + manager.start_quorum() + time_taken = time.time() - start_time + assert time_taken < 0.4, f"Time taken to join: {time_taken} > 0.4s" + + finally: + # Cleanup + lighthouse.shutdown() + if 'manager' in locals(): + manager.shutdown() + + lighthouse = Lighthouse( + bind="[::]:0", + min_replicas=1, + join_timeout_ms=400, + ) + + # Create a manager that tries to join + try: + store = dist.TCPStore( + host_name="localhost", + port=0, + is_master=True, + wait_for_workers=False, + ) + pg = ProcessGroupGloo() + manager = Manager( + pg=pg, + min_replica_size=1, + load_state_dict=lambda x: None, + state_dict=lambda: None, + replica_id=f"lighthouse_test", + store_addr="localhost", + store_port=store.port, + rank=0, + world_size=1, + use_async_quorum=False, + lighthouse_addr=lighthouse.address(), + ) + + start_time = time.time() + manager.start_quorum() + time_taken = time.time() - start_time + assert time_taken > 0.4, f"Time taken to join: {time_taken} < 0.4s" + + finally: + # Cleanup + lighthouse.shutdown() + if 'manager' in locals(): + manager.shutdown() + \ No newline at end of file From 13e72f124c0f003b491bbb6a7d4776053932be2f Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Wed, 8 Jan 2025 23:55:32 +0000 Subject: [PATCH 4/5] expose quorum tick ms --- src/lib.rs | 4 +++- torchft/torchft.pyi | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 5f64e09..6fc8d10 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -230,8 +230,10 @@ impl Lighthouse { bind: String, min_replicas: u64, join_timeout_ms: Option, + quorum_tick_ms: Option, ) -> PyResult { let join_timeout_ms = join_timeout_ms.unwrap_or(100); + let quorum_tick_ms = quorum_tick_ms.unwrap_or(100); py.allow_threads(move || { let rt = Runtime::new()?; @@ -241,7 +243,7 @@ impl Lighthouse { bind: bind, min_replicas: min_replicas, join_timeout_ms: join_timeout_ms, - quorum_tick_ms: 100, + quorum_tick_ms: quorum_tick_ms, })) .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; diff --git a/torchft/torchft.pyi b/torchft/torchft.pyi index f71113e..644a4ea 100644 --- a/torchft/torchft.pyi +++ b/torchft/torchft.pyi @@ -23,6 +23,6 @@ class Manager: def shutdown(self) -> None: ... class Lighthouse: - def __init__(self, bind: str, min_replicas: int, join_timeout_ms: Optional[int] = None) -> None: ... + def __init__(self, bind: str, min_replicas: int, join_timeout_ms: Optional[int] = None, quorum_tick_ms: Optional[int] = None) -> None: ... def address(self) -> str: ... def shutdown(self) -> None: ... From 653f192cb7d33b7fd59e3567de9172ab1c3b7428 Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Thu, 9 Jan 2025 03:15:57 +0000 Subject: [PATCH 5/5] blacken --- torchft/lighthouse_test.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/torchft/lighthouse_test.py b/torchft/lighthouse_test.py index c1e83b3..f6efc32 100644 --- a/torchft/lighthouse_test.py +++ b/torchft/lighthouse_test.py @@ -1,8 +1,11 @@ -import torch.distributed as dist +import time from unittest import TestCase -from torchft.torchft import Lighthouse + +import torch.distributed as dist + from torchft import Manager, ProcessGroupGloo -import time +from torchft.torchft import Lighthouse + class TestLighthouse(TestCase): def test_join_timeout_behavior(self) -> None: @@ -14,7 +17,7 @@ def test_join_timeout_behavior(self) -> None: min_replicas=1, join_timeout_ms=100, ) - + # Create a manager that tries to join try: store = dist.TCPStore( @@ -37,7 +40,7 @@ def test_join_timeout_behavior(self) -> None: use_async_quorum=False, lighthouse_addr=lighthouse.address(), ) - + start_time = time.time() manager.start_quorum() time_taken = time.time() - start_time @@ -46,15 +49,15 @@ def test_join_timeout_behavior(self) -> None: finally: # Cleanup lighthouse.shutdown() - if 'manager' in locals(): + if "manager" in locals(): manager.shutdown() - + lighthouse = Lighthouse( bind="[::]:0", min_replicas=1, join_timeout_ms=400, ) - + # Create a manager that tries to join try: store = dist.TCPStore( @@ -77,7 +80,7 @@ def test_join_timeout_behavior(self) -> None: use_async_quorum=False, lighthouse_addr=lighthouse.address(), ) - + start_time = time.time() manager.start_quorum() time_taken = time.time() - start_time @@ -86,6 +89,5 @@ def test_join_timeout_behavior(self) -> None: finally: # Cleanup lighthouse.shutdown() - if 'manager' in locals(): + if "manager" in locals(): manager.shutdown() - \ No newline at end of file