Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

process_group: wait for futher_thread join before creating new one #68

Merged
merged 1 commit into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion torchft/process_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,8 +609,14 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
if self._rx is not None:
self._rx.close()
if self._future_queue is not None:
# wait for the future thread to exit and then close the queue
self._future_queue.put(_QUEUE_CLOSE)
assert self._future_queue is not None
assert self._future_thread is not None
self._future_thread.join(timeout=10.0)
# pyre-ignore[16]: optional value is checked above
if self._future_thread.is_alive():
raise RuntimeError("future thread did not exit")
# pyre-ignore[16]: optional value is checked above
self._future_queue.close()

ctx = mp.get_context("spawn")
Expand Down
36 changes: 36 additions & 0 deletions torchft/process_group_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,42 @@ def test_baby_gloo_timeout(self) -> None:
with self.assertRaisesRegex(TimeoutError, "timed out after 0.01 seconds"):
a.configure(store_addr, 0, 2)

def test_reconfigure_baby_process_group(self) -> None:
store = TCPStore(
host_name="localhost", port=0, is_master=True, wait_for_workers=False
)
store_addr = f"localhost:{store.port}/prefix"

a = ProcessGroupBabyGloo()
a.configure(store_addr, 0, 1)
future_thread_1 = a._future_thread
future_queue_1 = a._future_queue
p_1 = a._p

store_addr = f"localhost:{store.port}/prefix2"
a.configure(store_addr, 0, 1)
future_thread_2 = a._future_thread
future_queue_2 = a._future_queue
p_2 = a._p

self.assertNotEqual(future_thread_1, future_thread_2)
self.assertNotEqual(future_queue_1, future_queue_2)
self.assertNotEqual(p_1, p_2)

assert future_thread_1 is not None
self.assertFalse(future_thread_1.is_alive())
assert future_queue_1 is not None
self.assertTrue(future_queue_1._closed) # pyre-ignore[16]: no attribute _closed
assert p_1 is not None
self.assertFalse(p_1.is_alive())

assert future_thread_2 is not None
self.assertTrue(future_thread_2.is_alive())
assert future_queue_2 is not None
self.assertFalse(future_queue_2._closed)
assert p_2 is not None
self.assertTrue(p_2.is_alive())

def test_dummy(self) -> None:
pg = ProcessGroupDummy(0, 1)
m = nn.Linear(3, 4)
Expand Down
Loading