Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MonitoredQueue: fail fast when subprocess exits #99

Merged
merged 1 commit into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions torchft/multiprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import queue
import time
from datetime import timedelta
from typing import Union

import torch.multiprocessing as mp


class _MonitoredQueue:
def __init__(
self,
p: mp.Process,
q: mp.Queue,
poll_interval: timedelta = timedelta(seconds=1),
) -> None:
"""
Args:
p: process to monitor
q: queue to monitor
poll_interval: interval to poll the Process health when calling get/put
"""
self._p = p
self._q = q
self._poll_interval_s: float = poll_interval.total_seconds()

def get(self, timeout: Union[float, timedelta]) -> object:
"""
Get an item from the queue. If the process is not alive, raise RuntimeError.
If the queue is empty, wait for up to timeout seconds for an item to be
available. If no item is available after timeout seconds, raise TimeoutError.

Args:
timeout: timeout in seconds
"""

if isinstance(timeout, timedelta):
timeout = timeout.total_seconds()

start = time.perf_counter()
while True:
elapsed = time.perf_counter() - start
if elapsed > timeout:
raise TimeoutError(f"queue.get() timed out after {timeout} seconds")
if not self._p.is_alive():
raise RuntimeError(f"process is not alive {self._p.exitcode}")

try:
v = self._q.get(timeout=self._poll_interval_s)
break
except queue.Empty:
continue

if isinstance(v, Exception):
raise v
return v

def put(self, obj: object, timeout: Union[float, timedelta]) -> None:
"""
Put an item into the queue. If the process is not alive, raise RuntimeError.
If the queue is full, wait for up to timeout seconds for an item to be
available. If queue is full after timeout seconds, raise TimeoutError.

If an exception is put into the queue, it will be raised when calling get().

Args:
obj: object to put into the queue
timeout: timeout in seconds
"""
if isinstance(timeout, timedelta):
timeout = timeout.total_seconds()

start = time.perf_counter()
while True:
elapsed = time.perf_counter() - start
if elapsed > timeout:
raise TimeoutError(f"queue.put() timed out after {timeout} seconds")
if not self._p.is_alive():
raise RuntimeError(f"process is not alive {self._p.exitcode}")

try:
self._q.put(obj, timeout=self._poll_interval_s)
break
except queue.Full:
continue

def close(self) -> None:
self._q.close()

def closed(self) -> bool:
# pyre-ignore[16]: no attribute _closed
return self._q._closed
48 changes: 48 additions & 0 deletions torchft/multiprocessing_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from unittest import TestCase

import torch.multiprocessing as mp

from torchft.multiprocessing import _MonitoredQueue


def queue_get(q: mp.Queue) -> None:
q.get()


def queue_put(q: mp.Queue) -> None:
q.put(1)


class MultiprocessingTest(TestCase):
def test_monitored_queue_put(self) -> None:
ctx = mp.get_context("fork")
q = ctx.Queue(maxsize=1)
p = ctx.Process(target=queue_get, args=(q,), daemon=True)
p.start()

mq = _MonitoredQueue(p, q)
mq.put(1, timeout=10)
mq.put(1, timeout=10)
with self.assertRaisesRegex(RuntimeError, "process is not alive 0"):
mq.put(1, timeout=10)

with self.assertRaisesRegex(TimeoutError, "timed out after 0.0 seconds"):
mq.put(1, timeout=0.0)

mq.close()

def test_monitored_queue_get(self) -> None:
ctx = mp.get_context("fork")
q = ctx.Queue(maxsize=1)
p = ctx.Process(target=queue_put, args=(q,), daemon=True)
p.start()

mq = _MonitoredQueue(p, q)
self.assertEqual(mq.get(timeout=10), 1)
with self.assertRaisesRegex(RuntimeError, "process is not alive 0"):
mq.get(timeout=10)

with self.assertRaisesRegex(TimeoutError, "timed out after 0.0 seconds"):
mq.get(timeout=0.0)

mq.close()
Loading