-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
MonitoredQueue: fail fast when subprocess exits
- Loading branch information
Showing
3 changed files
with
208 additions
and
69 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import queue | ||
import time | ||
from datetime import timedelta | ||
from typing import Union | ||
|
||
import torch.multiprocessing as mp | ||
|
||
|
||
class _MonitoredQueue: | ||
def __init__( | ||
self, | ||
p: mp.Process, | ||
q: mp.Queue, | ||
poll_interval: timedelta = timedelta(seconds=1), | ||
) -> None: | ||
""" | ||
Args: | ||
p: process to monitor | ||
q: queue to monitor | ||
poll_interval: interval to poll the Process health when calling get/put | ||
""" | ||
self._p = p | ||
self._q = q | ||
self._poll_interval_s: float = poll_interval.total_seconds() | ||
|
||
def get(self, timeout: Union[float, timedelta]) -> object: | ||
""" | ||
Get an item from the queue. If the process is not alive, raise RuntimeError. | ||
If the queue is empty, wait for up to timeout seconds for an item to be | ||
available. If no item is available after timeout seconds, raise TimeoutError. | ||
Args: | ||
timeout: timeout in seconds | ||
""" | ||
|
||
if isinstance(timeout, timedelta): | ||
timeout = timeout.total_seconds() | ||
|
||
start = time.perf_counter() | ||
while True: | ||
elapsed = time.perf_counter() - start | ||
if elapsed > timeout: | ||
raise TimeoutError(f"queue.get() timed out after {timeout} seconds") | ||
if not self._p.is_alive(): | ||
raise RuntimeError(f"process is not alive {self._p.exitcode}") | ||
|
||
try: | ||
v = self._q.get(timeout=self._poll_interval_s) | ||
break | ||
except queue.Empty: | ||
continue | ||
|
||
if isinstance(v, Exception): | ||
raise v | ||
return v | ||
|
||
def put(self, obj: object, timeout: Union[float, timedelta]) -> None: | ||
""" | ||
Put an item into the queue. If the process is not alive, raise RuntimeError. | ||
If the queue is full, wait for up to timeout seconds for an item to be | ||
available. If queue is full after timeout seconds, raise TimeoutError. | ||
If an exception is put into the queue, it will be raised when calling get(). | ||
Args: | ||
obj: object to put into the queue | ||
timeout: timeout in seconds | ||
""" | ||
if isinstance(timeout, timedelta): | ||
timeout = timeout.total_seconds() | ||
|
||
start = time.perf_counter() | ||
while True: | ||
elapsed = time.perf_counter() - start | ||
if elapsed > timeout: | ||
raise TimeoutError(f"queue.put() timed out after {timeout} seconds") | ||
if not self._p.is_alive(): | ||
raise RuntimeError(f"process is not alive {self._p.exitcode}") | ||
|
||
try: | ||
self._q.put(obj, timeout=self._poll_interval_s) | ||
break | ||
except queue.Full: | ||
continue | ||
|
||
def close(self) -> None: | ||
self._q.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
from unittest import TestCase | ||
|
||
import torch.multiprocessing as mp | ||
|
||
from torchft.multiprocessing import _MonitoredQueue | ||
|
||
|
||
def queue_get(q: mp.Queue) -> None: | ||
q.get() | ||
|
||
|
||
def queue_put(q: mp.Queue) -> None: | ||
q.put(1) | ||
|
||
|
||
class MultiprocessingTest(TestCase): | ||
def test_monitored_queue_put(self) -> None: | ||
ctx = mp.get_context("fork") | ||
q = ctx.Queue(maxsize=1) | ||
p = ctx.Process(target=queue_get, args=(q,), daemon=True) | ||
p.start() | ||
|
||
mq = _MonitoredQueue(p, q) | ||
mq.put(1, timeout=10) | ||
mq.put(1, timeout=10) | ||
with self.assertRaisesRegex(RuntimeError, "process is not alive 0"): | ||
mq.put(1, timeout=10) | ||
|
||
with self.assertRaisesRegex(TimeoutError, "timed out after 0.0 seconds"): | ||
mq.put(1, timeout=0.0) | ||
|
||
mq.close() | ||
|
||
def test_monitored_queue_get(self) -> None: | ||
ctx = mp.get_context("fork") | ||
q = ctx.Queue(maxsize=1) | ||
p = ctx.Process(target=queue_put, args=(q,), daemon=True) | ||
p.start() | ||
|
||
mq = _MonitoredQueue(p, q) | ||
self.assertEqual(mq.get(timeout=10), 1) | ||
with self.assertRaisesRegex(RuntimeError, "process is not alive 0"): | ||
mq.get(timeout=10) | ||
|
||
with self.assertRaisesRegex(TimeoutError, "timed out after 0.0 seconds"): | ||
mq.get(timeout=0.0) | ||
|
||
mq.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters