Skip to content

Commit 9533676

Browse files
authored
MonitoredQueue: fail fast when subprocess exits (#99)
1 parent 4d4d260 commit 9533676

4 files changed

+235
-86
lines changed

torchft/multiprocessing.py

+91
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import queue
2+
import time
3+
from datetime import timedelta
4+
from typing import Union
5+
6+
import torch.multiprocessing as mp
7+
8+
9+
class _MonitoredQueue:
10+
def __init__(
11+
self,
12+
p: mp.Process,
13+
q: mp.Queue,
14+
poll_interval: timedelta = timedelta(seconds=1),
15+
) -> None:
16+
"""
17+
Args:
18+
p: process to monitor
19+
q: queue to monitor
20+
poll_interval: interval to poll the Process health when calling get/put
21+
"""
22+
self._p = p
23+
self._q = q
24+
self._poll_interval_s: float = poll_interval.total_seconds()
25+
26+
def get(self, timeout: Union[float, timedelta]) -> object:
27+
"""
28+
Get an item from the queue. If the process is not alive, raise RuntimeError.
29+
If the queue is empty, wait for up to timeout seconds for an item to be
30+
available. If no item is available after timeout seconds, raise TimeoutError.
31+
32+
Args:
33+
timeout: timeout in seconds
34+
"""
35+
36+
if isinstance(timeout, timedelta):
37+
timeout = timeout.total_seconds()
38+
39+
start = time.perf_counter()
40+
while True:
41+
elapsed = time.perf_counter() - start
42+
if elapsed > timeout:
43+
raise TimeoutError(f"queue.get() timed out after {timeout} seconds")
44+
if not self._p.is_alive():
45+
raise RuntimeError(f"process is not alive {self._p.exitcode}")
46+
47+
try:
48+
v = self._q.get(timeout=self._poll_interval_s)
49+
break
50+
except queue.Empty:
51+
continue
52+
53+
if isinstance(v, Exception):
54+
raise v
55+
return v
56+
57+
def put(self, obj: object, timeout: Union[float, timedelta]) -> None:
58+
"""
59+
Put an item into the queue. If the process is not alive, raise RuntimeError.
60+
If the queue is full, wait for up to timeout seconds for an item to be
61+
available. If queue is full after timeout seconds, raise TimeoutError.
62+
63+
If an exception is put into the queue, it will be raised when calling get().
64+
65+
Args:
66+
obj: object to put into the queue
67+
timeout: timeout in seconds
68+
"""
69+
if isinstance(timeout, timedelta):
70+
timeout = timeout.total_seconds()
71+
72+
start = time.perf_counter()
73+
while True:
74+
elapsed = time.perf_counter() - start
75+
if elapsed > timeout:
76+
raise TimeoutError(f"queue.put() timed out after {timeout} seconds")
77+
if not self._p.is_alive():
78+
raise RuntimeError(f"process is not alive {self._p.exitcode}")
79+
80+
try:
81+
self._q.put(obj, timeout=self._poll_interval_s)
82+
break
83+
except queue.Full:
84+
continue
85+
86+
def close(self) -> None:
87+
self._q.close()
88+
89+
def closed(self) -> bool:
90+
# pyre-ignore[16]: no attribute _closed
91+
return self._q._closed

torchft/multiprocessing_test.py

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from unittest import TestCase
2+
3+
import torch.multiprocessing as mp
4+
5+
from torchft.multiprocessing import _MonitoredQueue
6+
7+
8+
def queue_get(q: mp.Queue) -> None:
9+
q.get()
10+
11+
12+
def queue_put(q: mp.Queue) -> None:
13+
q.put(1)
14+
15+
16+
class MultiprocessingTest(TestCase):
17+
def test_monitored_queue_put(self) -> None:
18+
ctx = mp.get_context("fork")
19+
q = ctx.Queue(maxsize=1)
20+
p = ctx.Process(target=queue_get, args=(q,), daemon=True)
21+
p.start()
22+
23+
mq = _MonitoredQueue(p, q)
24+
mq.put(1, timeout=10)
25+
mq.put(1, timeout=10)
26+
with self.assertRaisesRegex(RuntimeError, "process is not alive 0"):
27+
mq.put(1, timeout=10)
28+
29+
with self.assertRaisesRegex(TimeoutError, "timed out after 0.0 seconds"):
30+
mq.put(1, timeout=0.0)
31+
32+
mq.close()
33+
34+
def test_monitored_queue_get(self) -> None:
35+
ctx = mp.get_context("fork")
36+
q = ctx.Queue(maxsize=1)
37+
p = ctx.Process(target=queue_put, args=(q,), daemon=True)
38+
p.start()
39+
40+
mq = _MonitoredQueue(p, q)
41+
self.assertEqual(mq.get(timeout=10), 1)
42+
with self.assertRaisesRegex(RuntimeError, "process is not alive 0"):
43+
mq.get(timeout=10)
44+
45+
with self.assertRaisesRegex(TimeoutError, "timed out after 0.0 seconds"):
46+
mq.get(timeout=0.0)
47+
48+
mq.close()

0 commit comments

Comments
 (0)