Skip to content

Commit

Permalink
process_group: wait for futher_thread join before creating new one
Browse files Browse the repository at this point in the history
  • Loading branch information
dwancn committed Jan 24, 2025
1 parent bed29d2 commit 3893b42
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 1 deletion.
8 changes: 7 additions & 1 deletion torchft/process_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,8 +609,14 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
if self._rx is not None:
self._rx.close()
if self._future_queue is not None:
# wait for the future thread to exit and then close the queue
self._future_queue.put(_QUEUE_CLOSE)
assert self._future_queue is not None
assert self._future_thread is not None
self._future_thread.join(timeout=10.0)
# pyre-ignore[16]: optional value is checked above
if self._future_thread.is_alive():
raise RuntimeError("future thread did not exit")
# pyre-ignore[16]: optional value is checked above
self._future_queue.close()

ctx = mp.get_context("spawn")
Expand Down
36 changes: 36 additions & 0 deletions torchft/process_group_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,42 @@ def test_baby_gloo_timeout(self) -> None:
with self.assertRaisesRegex(TimeoutError, "timed out after 0.01 seconds"):
a.configure(store_addr, 0, 2)

def test_reconfigure_baby_process_group(self) -> None:
store = TCPStore(
host_name="localhost", port=0, is_master=True, wait_for_workers=False
)
store_addr = f"localhost:{store.port}/prefix"

a = ProcessGroupBabyGloo()
a.configure(store_addr, 0, 1)
future_thread_1 = a._future_thread
future_queue_1 = a._future_queue
p_1 = a._p

store_addr = f"localhost:{store.port}/prefix2"
a.configure(store_addr, 0, 1)
future_thread_2 = a._future_thread
future_queue_2 = a._future_queue
p_2 = a._p

self.assertNotEqual(future_thread_1, future_thread_2)
self.assertNotEqual(future_queue_1, future_queue_2)
self.assertNotEqual(p_1, p_2)

assert future_thread_1 is not None
self.assertFalse(future_thread_1.is_alive())
assert future_queue_1 is not None
self.assertTrue(future_queue_1._closed) # pyre-ignore[16]: no attribute _closed
assert p_1 is not None
self.assertFalse(p_1.is_alive())

assert future_thread_2 is not None
self.assertTrue(future_thread_2.is_alive())
assert future_queue_2 is not None
self.assertFalse(future_queue_2._closed)
assert p_2 is not None
self.assertTrue(p_2.is_alive())

def test_dummy(self) -> None:
pg = ProcessGroupDummy(0, 1)
m = nn.Linear(3, 4)
Expand Down

0 comments on commit 3893b42

Please sign in to comment.