Skip to content

Commit ba3c7a0

Browse files
committed
process_group: wait for futher_thread join before creating new one
1 parent 2f97660 commit ba3c7a0

File tree

2 files changed

+36
-1
lines changed

2 files changed

+36
-1
lines changed

torchft/process_group.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -609,8 +609,12 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
609609
if self._rx is not None:
610610
self._rx.close()
611611
if self._future_queue is not None:
612+
# wait for the future thread to exit and then close the queue
612613
self._future_queue.put(_QUEUE_CLOSE)
613-
assert self._future_queue is not None
614+
assert self._future_thread is not None
615+
self._future_thread.join(timeout=10.0)
616+
if self._future_thread.is_alive():
617+
raise RuntimeError("future thread did not exit")
614618
self._future_queue.close()
615619

616620
ctx = mp.get_context("spawn")

torchft/process_group_test.py

+31
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,37 @@ def test_baby_gloo_timeout(self) -> None:
208208
with self.assertRaisesRegex(TimeoutError, "timed out after 0.01 seconds"):
209209
a.configure(store_addr, 0, 2)
210210

211+
def test_reconfigure_baby_process_group(self) -> None:
212+
store = TCPStore(
213+
host_name="localhost", port=0, is_master=True, wait_for_workers=False
214+
)
215+
store_addr = f"localhost:{store.port}/prefix"
216+
217+
a = ProcessGroupBabyGloo()
218+
a.configure(store_addr, 0, 1)
219+
futher_thread_1 = a._future_thread
220+
futher_queue_1 = a._future_queue
221+
p_1 = a._p
222+
223+
store_addr = f"localhost:{store.port}/prefix2"
224+
a.configure(store_addr, 0, 1)
225+
futher_thread_2 = a._future_thread
226+
futher_queue_2 = a._future_queue
227+
p_2 = a._p
228+
229+
self.assertNotEqual(futher_thread_1, futher_thread_2)
230+
self.assertNotEqual(futher_queue_1, futher_queue_2)
231+
self.assertNotEqual(p_1, p_2)
232+
233+
self.assertFalse(futher_thread_1.is_alive())
234+
self.assertTrue(futher_queue_1._closed)
235+
self.assertFalse(p_1.is_alive())
236+
237+
self.assertTrue(futher_thread_2.is_alive())
238+
self.assertFalse(futher_queue_2._closed)
239+
self.assertTrue(p_2.is_alive())
240+
241+
211242
def test_dummy(self) -> None:
212243
pg = ProcessGroupDummy(0, 1)
213244
m = nn.Linear(3, 4)

0 commit comments

Comments
 (0)