Skip to content

Commit e177f9c

Browse files
authored
process_group: wait for futher_thread join before creating new one (#68)
1 parent beb94f0 commit e177f9c

File tree

2 files changed

+43
-1
lines changed

2 files changed

+43
-1
lines changed

torchft/process_group.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -609,8 +609,14 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
609609
if self._rx is not None:
610610
self._rx.close()
611611
if self._future_queue is not None:
612+
# wait for the future thread to exit and then close the queue
612613
self._future_queue.put(_QUEUE_CLOSE)
613-
assert self._future_queue is not None
614+
assert self._future_thread is not None
615+
self._future_thread.join(timeout=10.0)
616+
# pyre-ignore[16]: optional value is checked above
617+
if self._future_thread.is_alive():
618+
raise RuntimeError("future thread did not exit")
619+
# pyre-ignore[16]: optional value is checked above
614620
self._future_queue.close()
615621

616622
ctx = mp.get_context("spawn")

torchft/process_group_test.py

+36
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,42 @@ def test_baby_gloo_timeout(self) -> None:
210210
with self.assertRaisesRegex(TimeoutError, "timed out after 0.01 seconds"):
211211
a.configure(store_addr, 0, 2)
212212

213+
def test_reconfigure_baby_process_group(self) -> None:
214+
store = TCPStore(
215+
host_name="localhost", port=0, is_master=True, wait_for_workers=False
216+
)
217+
store_addr = f"localhost:{store.port}/prefix"
218+
219+
a = ProcessGroupBabyGloo()
220+
a.configure(store_addr, 0, 1)
221+
future_thread_1 = a._future_thread
222+
future_queue_1 = a._future_queue
223+
p_1 = a._p
224+
225+
store_addr = f"localhost:{store.port}/prefix2"
226+
a.configure(store_addr, 0, 1)
227+
future_thread_2 = a._future_thread
228+
future_queue_2 = a._future_queue
229+
p_2 = a._p
230+
231+
self.assertNotEqual(future_thread_1, future_thread_2)
232+
self.assertNotEqual(future_queue_1, future_queue_2)
233+
self.assertNotEqual(p_1, p_2)
234+
235+
assert future_thread_1 is not None
236+
self.assertFalse(future_thread_1.is_alive())
237+
assert future_queue_1 is not None
238+
self.assertTrue(future_queue_1._closed) # pyre-ignore[16]: no attribute _closed
239+
assert p_1 is not None
240+
self.assertFalse(p_1.is_alive())
241+
242+
assert future_thread_2 is not None
243+
self.assertTrue(future_thread_2.is_alive())
244+
assert future_queue_2 is not None
245+
self.assertFalse(future_queue_2._closed)
246+
assert p_2 is not None
247+
self.assertTrue(p_2.is_alive())
248+
213249
def test_dummy(self) -> None:
214250
pg = ProcessGroupDummy(0, 1)
215251
m = nn.Linear(3, 4)

0 commit comments

Comments
 (0)