Skip to content

Commit f526ae7

Browse files
committed
process_group: wait for futher_thread join before creating new one
1 parent bed29d2 commit f526ae7

File tree

2 files changed

+43
-1
lines changed

2 files changed

+43
-1
lines changed

torchft/process_group.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -609,8 +609,14 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
609609
if self._rx is not None:
610610
self._rx.close()
611611
if self._future_queue is not None:
612+
# wait for the future thread to exit and then close the queue
612613
self._future_queue.put(_QUEUE_CLOSE)
613-
assert self._future_queue is not None
614+
assert self._future_thread is not None
615+
self._future_thread.join(timeout=10.0)
616+
# pyre-ignore[16]: optional value is checked above
617+
if self._future_thread.is_alive():
618+
raise RuntimeError("future thread did not exit")
619+
# pyre-ignore[16]: optional value is checked above
614620
self._future_queue.close()
615621

616622
ctx = mp.get_context("spawn")

torchft/process_group_test.py

+36
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,42 @@ def test_baby_gloo_timeout(self) -> None:
210210
with self.assertRaisesRegex(TimeoutError, "timed out after 0.01 seconds"):
211211
a.configure(store_addr, 0, 2)
212212

213+
def test_reconfigure_baby_process_group(self) -> None:
214+
store = TCPStore(
215+
host_name="localhost", port=0, is_master=True, wait_for_workers=False
216+
)
217+
store_addr = f"localhost:{store.port}/prefix"
218+
219+
a = ProcessGroupBabyGloo()
220+
a.configure(store_addr, 0, 1)
221+
futher_thread_1 = a._future_thread
222+
futher_queue_1 = a._future_queue
223+
p_1 = a._p
224+
225+
store_addr = f"localhost:{store.port}/prefix2"
226+
a.configure(store_addr, 0, 1)
227+
futher_thread_2 = a._future_thread
228+
futher_queue_2 = a._future_queue
229+
p_2 = a._p
230+
231+
self.assertNotEqual(futher_thread_1, futher_thread_2)
232+
self.assertNotEqual(futher_queue_1, futher_queue_2)
233+
self.assertNotEqual(p_1, p_2)
234+
235+
# pyre-ignore[16]: optional
236+
self.assertFalse(futher_thread_1.is_alive())
237+
# pyre-ignore[16]: optional
238+
self.assertTrue(futher_queue_1._closed)
239+
# pyre-ignore[16]: optional
240+
self.assertFalse(p_1.is_alive())
241+
242+
# pyre-ignore[16]: optional
243+
self.assertTrue(futher_thread_2.is_alive())
244+
# pyre-ignore[16]: optional
245+
self.assertFalse(futher_queue_2._closed)
246+
# pyre-ignore[16]: optional
247+
self.assertTrue(p_2.is_alive())
248+
213249
def test_dummy(self) -> None:
214250
pg = ProcessGroupDummy(0, 1)
215251
m = nn.Linear(3, 4)

0 commit comments

Comments
 (0)