Skip to content

Commit 9abc936

Browse files
committed
MonitoredQueue: fail fast when subprocess exits
1 parent 927f8b4 commit 9abc936

4 files changed

+214
-71
lines changed

torchft/multiprocessing.py

+91
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import queue
2+
import time
3+
from datetime import timedelta
4+
from typing import Union
5+
6+
import torch.multiprocessing as mp
7+
8+
9+
class _MonitoredQueue:
10+
def __init__(
11+
self,
12+
p: mp.Process,
13+
q: mp.Queue,
14+
poll_interval: timedelta = timedelta(seconds=1),
15+
) -> None:
16+
"""
17+
Args:
18+
p: process to monitor
19+
q: queue to monitor
20+
poll_interval: interval to poll the Process health when calling get/put
21+
"""
22+
self._p = p
23+
self._q = q
24+
self._poll_interval_s: float = poll_interval.total_seconds()
25+
26+
def get(self, timeout: Union[float, timedelta]) -> object:
27+
"""
28+
Get an item from the queue. If the process is not alive, raise RuntimeError.
29+
If the queue is empty, wait for up to timeout seconds for an item to be
30+
available. If no item is available after timeout seconds, raise TimeoutError.
31+
32+
Args:
33+
timeout: timeout in seconds
34+
"""
35+
36+
if isinstance(timeout, timedelta):
37+
timeout = timeout.total_seconds()
38+
39+
start = time.perf_counter()
40+
while True:
41+
elapsed = time.perf_counter() - start
42+
if elapsed > timeout:
43+
raise TimeoutError(f"queue.get() timed out after {timeout} seconds")
44+
if not self._p.is_alive():
45+
raise RuntimeError(f"process is not alive {self._p.exitcode}")
46+
47+
try:
48+
v = self._q.get(timeout=self._poll_interval_s)
49+
break
50+
except queue.Empty:
51+
continue
52+
53+
if isinstance(v, Exception):
54+
raise v
55+
return v
56+
57+
def put(self, obj: object, timeout: Union[float, timedelta]) -> None:
58+
"""
59+
Put an item into the queue. If the process is not alive, raise RuntimeError.
60+
If the queue is full, wait for up to timeout seconds for an item to be
61+
available. If queue is full after timeout seconds, raise TimeoutError.
62+
63+
If an exception is put into the queue, it will be raised when calling get().
64+
65+
Args:
66+
obj: object to put into the queue
67+
timeout: timeout in seconds
68+
"""
69+
if isinstance(timeout, timedelta):
70+
timeout = timeout.total_seconds()
71+
72+
start = time.perf_counter()
73+
while True:
74+
elapsed = time.perf_counter() - start
75+
if elapsed > timeout:
76+
raise TimeoutError(f"queue.put() timed out after {timeout} seconds")
77+
if not self._p.is_alive():
78+
raise RuntimeError(f"process is not alive {self._p.exitcode}")
79+
80+
try:
81+
self._q.put(obj, timeout=self._poll_interval_s)
82+
break
83+
except queue.Full:
84+
continue
85+
86+
def close(self) -> None:
87+
self._q.close()
88+
89+
def closed(self) -> bool:
90+
# pyre-ignore[16]: no attribute _closed
91+
return self._q._closed

torchft/multiprocessing_test.py

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from unittest import TestCase
2+
3+
import torch.multiprocessing as mp
4+
5+
from torchft.multiprocessing import _MonitoredQueue
6+
7+
8+
def queue_get(q: mp.Queue) -> None:
9+
q.get()
10+
11+
12+
def queue_put(q: mp.Queue) -> None:
13+
q.put(1)
14+
15+
16+
class MultiprocessingTest(TestCase):
17+
def test_monitored_queue_put(self) -> None:
18+
ctx = mp.get_context("fork")
19+
q = ctx.Queue(maxsize=1)
20+
p = ctx.Process(target=queue_get, args=(q,), daemon=True)
21+
p.start()
22+
23+
mq = _MonitoredQueue(p, q)
24+
mq.put(1, timeout=10)
25+
mq.put(1, timeout=10)
26+
with self.assertRaisesRegex(RuntimeError, "process is not alive 0"):
27+
mq.put(1, timeout=10)
28+
29+
with self.assertRaisesRegex(TimeoutError, "timed out after 0.0 seconds"):
30+
mq.put(1, timeout=0.0)
31+
32+
mq.close()
33+
34+
def test_monitored_queue_get(self) -> None:
35+
ctx = mp.get_context("fork")
36+
q = ctx.Queue(maxsize=1)
37+
p = ctx.Process(target=queue_put, args=(q,), daemon=True)
38+
p.start()
39+
40+
mq = _MonitoredQueue(p, q)
41+
self.assertEqual(mq.get(timeout=10), 1)
42+
with self.assertRaisesRegex(RuntimeError, "process is not alive 0"):
43+
mq.get(timeout=10)
44+
45+
with self.assertRaisesRegex(TimeoutError, "timed out after 0.0 seconds"):
46+
mq.get(timeout=0.0)
47+
48+
mq.close()

torchft/process_group.py

+73-69
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
"""
1818

1919
import logging
20-
import queue
20+
import sys
2121
import threading
2222
from contextlib import contextmanager, nullcontext
2323
from dataclasses import dataclass
@@ -63,6 +63,8 @@
6363
from torch.futures import Future
6464
from torch.utils._pytree import tree_any
6565

66+
from torchft.multiprocessing import _MonitoredQueue
67+
6668
if TYPE_CHECKING:
6769
from torchft.manager import Manager
6870

@@ -77,28 +79,6 @@
7779
T = TypeVar("T")
7880

7981

80-
def _get(q: mp.Queue, timeout: Union[float, timedelta]) -> object:
81-
"""
82-
Gets an item from a queue with a timeout. If the timeout is exceeded then
83-
a TimeoutError is raised.
84-
85-
If an exception is returned from the queue then it is raised.
86-
87-
Args:
88-
q: queue to get from
89-
timeout: timeout in seconds
90-
"""
91-
if isinstance(timeout, timedelta):
92-
timeout = timeout.total_seconds()
93-
try:
94-
v = q.get(timeout=timeout)
95-
except queue.Empty as e:
96-
raise TimeoutError(f"queue.get() timed out after {timeout} seconds") from e
97-
if isinstance(v, Exception):
98-
raise v
99-
return v
100-
101-
10282
def create_store_client(store_addr: str) -> Store:
10383
"""
10484
Creates a PrefixStore(TCPStore(...)) client from an address in the format:
@@ -573,8 +553,8 @@ class _BabyWork(Work):
573553
def __init__(
574554
self,
575555
pg: "ProcessGroupBaby",
576-
tx: mp.Queue,
577-
rx: mp.Queue,
556+
tx: _MonitoredQueue,
557+
rx: _MonitoredQueue,
578558
op_id: int,
579559
timeout: float,
580560
) -> None:
@@ -592,7 +572,7 @@ def wait(self, timeout: Optional[timedelta] = None) -> bool:
592572
self._tx.put(("wait", self._op_id), timeout=self._timeout)
593573
op_id, event = cast(
594574
Tuple[int, Optional[torch.cuda.Event]],
595-
_get(self._rx, timeout or self._timeout),
575+
self._rx.get(timeout or self._timeout),
596576
)
597577
assert op_id == self._op_id
598578
if event is not None:
@@ -649,9 +629,9 @@ def __init__(self, timeout: Union[float, timedelta] = 60.0) -> None:
649629
self._world_size = -1
650630

651631
self._p: Optional[mp.Process] = None
652-
self._tx: Optional[mp.Queue] = None
653-
self._rx: Optional[mp.Queue] = None
654-
self._future_queue: Optional[mp.Queue] = None
632+
self._tx: Optional[_MonitoredQueue] = None
633+
self._rx: Optional[_MonitoredQueue] = None
634+
self._future_queue: Optional[_MonitoredQueue] = None
655635
self._future_thread: Optional[threading.Thread] = None
656636
self._futures: Dict[int, Future[object]] = {}
657637
self._futures_lock = threading.Lock()
@@ -661,60 +641,80 @@ def __init__(self, timeout: Union[float, timedelta] = 60.0) -> None:
661641

662642
self._timeout: float = timeout
663643

664-
def configure(self, store_addr: str, rank: int, world_size: int) -> None:
665-
if self._p is not None:
666-
self._p.kill()
644+
def shutdown(self) -> None:
645+
"""
646+
Shutdown the process group. This will kill the underlying process and
647+
close all queues.
667648
668-
self._world_size = world_size
649+
This is a no-op if the process group is already shutdown.
650+
651+
ProcessGroup can be reconfigured after shutdown.
652+
"""
669653

670654
if self._tx is not None:
671655
self._tx.close()
672656
if self._rx is not None:
673657
self._rx.close()
674-
if self._future_queue is not None:
658+
659+
future_queue = self._future_queue
660+
if future_queue is not None:
675661
# wait for the future thread to exit and then close the queue
676-
self._future_queue.put(_QUEUE_CLOSE)
677-
assert self._future_thread is not None
678-
self._future_thread.join(timeout=10.0)
679-
# pyre-ignore[16]: optional value is checked above
680-
if self._future_thread.is_alive():
662+
future_queue.put(_QUEUE_CLOSE, timeout=timedelta(seconds=10.0))
663+
664+
future_thread = self._future_thread
665+
assert future_thread is not None
666+
future_thread.join(timeout=10.0)
667+
if future_thread.is_alive():
681668
raise RuntimeError("future thread did not exit")
682-
# pyre-ignore[16]: optional value is checked above
683-
self._future_queue.close()
669+
670+
future_queue.close()
671+
672+
# Kill after closing queues to avoid log spam.
673+
if self._p is not None:
674+
self._p.kill()
675+
676+
def configure(self, store_addr: str, rank: int, world_size: int) -> None:
677+
self._world_size = world_size
678+
679+
self.shutdown()
684680

685681
ctx = mp.get_context("spawn")
686-
self._tx = ctx.Queue()
687-
self._rx = rx = ctx.Queue()
682+
tx = ctx.Queue()
683+
rx = ctx.Queue()
684+
future_queue = ctx.Queue()
685+
686+
self._p = p = ctx.Process(
687+
target=self._worker,
688+
args=(
689+
store_addr,
690+
rank,
691+
world_size,
692+
tx,
693+
rx,
694+
future_queue,
695+
),
696+
daemon=True,
697+
)
698+
p.start()
699+
700+
self._tx = tx = _MonitoredQueue(p, tx)
701+
self._rx = rx = _MonitoredQueue(p, rx)
702+
self._future_queue = future_queue = _MonitoredQueue(p, future_queue)
688703

689704
# futures need thread to fire callbacks
690-
self._future_queue = ctx.Queue()
691705
# this lock needs to be held when manipulating _futures
692706
self._futures_lock = threading.Lock()
693707
self._futures = {}
694708
self._future_thread = threading.Thread(
695709
target=self._future_handler,
696-
args=(self._future_queue,),
710+
args=(future_queue,),
697711
daemon=True,
698712
)
699713
self._future_thread.start()
700714

701-
self._p = ctx.Process(
702-
target=self._worker,
703-
args=(
704-
store_addr,
705-
rank,
706-
world_size,
707-
self._tx,
708-
self._rx,
709-
self._future_queue,
710-
),
711-
daemon=True,
712-
)
713-
self._p.start()
714-
715715
# fetch the status of the PG init
716-
# if an exception was returned _get will throw
717-
assert _get(rx, self._timeout) is None
716+
# if an exception was returned get will throw
717+
assert rx.get(self._timeout) is None
718718

719719
@classmethod
720720
def _create_pg(cls, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
@@ -739,7 +739,7 @@ def _worker(
739739
try:
740740
pg = cls._create_pg(store, rank, world_size)
741741
except Exception as e:
742-
logger.exception(f"got exception in worker: {e}")
742+
print(f"got exception in worker: {e}", file=sys.stderr)
743743
tx.put(e)
744744
return
745745
tx.put(None)
@@ -829,17 +829,21 @@ def callback(fut: Future[object]) -> None:
829829
raise ValueError(f"unknown cmd: {cmd}")
830830

831831
except Exception as e:
832-
logger.exception("worker errored")
832+
print(f"worker errored: {e}", file=sys.stderr)
833833
tx.put(e)
834834
raise
835835

836-
def _future_handler(self, future_queue: mp.Queue) -> None:
836+
def _future_handler(self, future_queue: _MonitoredQueue) -> None:
837837
try:
838838
while True:
839-
cmd = future_queue.get()
839+
try:
840+
# timeout doesn't really matter here
841+
cmd = future_queue.get(timeout=timedelta(seconds=10.0))
842+
except TimeoutError:
843+
continue
840844
if cmd == _QUEUE_CLOSE:
841845
break
842-
op_id, mode, data = cmd
846+
op_id, mode, data = cast(Tuple[int, str, object], cmd)
843847
with self._futures_lock:
844848
fut = self._futures[op_id]
845849
del self._futures[op_id]
@@ -862,7 +866,7 @@ def _get_future(self, op_id: int) -> Future[object]:
862866
self._tx.put(("future", op_id), timeout=self._timeout)
863867

864868
assert self._rx is not None
865-
assert _get(self._rx, self._timeout) == op_id
869+
assert self._rx.get(self._timeout) == op_id
866870
# TODO: return correct tensor instead of None
867871
return fut
868872

@@ -899,7 +903,7 @@ def _run_func(self, func: str, *args: object, **kwargs: object) -> Work:
899903
timeout=self._timeout,
900904
)
901905

902-
op_id = _get(rx, self._timeout)
906+
op_id = rx.get(self._timeout)
903907
assert isinstance(op_id, int), f"invalid return {op_id}"
904908

905909
return _BabyWork(pg=self, tx=tx, rx=rx, op_id=op_id, timeout=self._timeout)
@@ -968,7 +972,7 @@ def num_active_work(self) -> int:
968972
self._tx.put(("num_active_work",), timeout=self._timeout)
969973

970974
assert self._rx is not None
971-
return cast(int, _get(self._rx, self._timeout))
975+
return cast(int, self._rx.get(self._timeout))
972976

973977

974978
@dataclass

0 commit comments

Comments
 (0)