Skip to content

Commit

Permalink
MonitoredQueue: fail fast when subprocess exits
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k committed Feb 5, 2025
1 parent 927f8b4 commit 9bc6269
Show file tree
Hide file tree
Showing 3 changed files with 208 additions and 69 deletions.
87 changes: 87 additions & 0 deletions torchft/multiprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import queue
import time
from datetime import timedelta
from typing import Union

import torch.multiprocessing as mp


class _MonitoredQueue:
def __init__(
self,
p: mp.Process,
q: mp.Queue,
poll_interval: timedelta = timedelta(seconds=1),
) -> None:
"""
Args:
p: process to monitor
q: queue to monitor
poll_interval: interval to poll the Process health when calling get/put
"""
self._p = p
self._q = q
self._poll_interval_s: float = poll_interval.total_seconds()

def get(self, timeout: Union[float, timedelta]) -> object:
"""
Get an item from the queue. If the process is not alive, raise RuntimeError.
If the queue is empty, wait for up to timeout seconds for an item to be
available. If no item is available after timeout seconds, raise TimeoutError.
Args:
timeout: timeout in seconds
"""

if isinstance(timeout, timedelta):
timeout = timeout.total_seconds()

start = time.perf_counter()
while True:
elapsed = time.perf_counter() - start
if elapsed > timeout:
raise TimeoutError(f"queue.get() timed out after {timeout} seconds")
if not self._p.is_alive():
raise RuntimeError(f"process is not alive {self._p.exitcode}")

try:
v = self._q.get(timeout=self._poll_interval_s)
break
except queue.Empty:
continue

if isinstance(v, Exception):
raise v
return v

def put(self, obj: object, timeout: Union[float, timedelta]) -> None:
"""
Put an item into the queue. If the process is not alive, raise RuntimeError.
If the queue is full, wait for up to timeout seconds for an item to be
available. If queue is full after timeout seconds, raise TimeoutError.
If an exception is put into the queue, it will be raised when calling get().
Args:
obj: object to put into the queue
timeout: timeout in seconds
"""
if isinstance(timeout, timedelta):
timeout = timeout.total_seconds()

start = time.perf_counter()
while True:
elapsed = time.perf_counter() - start
if elapsed > timeout:
raise TimeoutError(f"queue.put() timed out after {timeout} seconds")
if not self._p.is_alive():
raise RuntimeError(f"process is not alive {self._p.exitcode}")

try:
self._q.put(obj, timeout=self._poll_interval_s)
break
except queue.Full:
continue

def close(self) -> None:
self._q.close()
48 changes: 48 additions & 0 deletions torchft/multiprocessing_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from unittest import TestCase

import torch.multiprocessing as mp

from torchft.multiprocessing import _MonitoredQueue


def queue_get(q: mp.Queue) -> None:
q.get()


def queue_put(q: mp.Queue) -> None:
q.put(1)


class MultiprocessingTest(TestCase):
def test_monitored_queue_put(self) -> None:
ctx = mp.get_context("fork")
q = ctx.Queue(maxsize=1)
p = ctx.Process(target=queue_get, args=(q,), daemon=True)
p.start()

mq = _MonitoredQueue(p, q)
mq.put(1, timeout=10)
mq.put(1, timeout=10)
with self.assertRaisesRegex(RuntimeError, "process is not alive 0"):
mq.put(1, timeout=10)

with self.assertRaisesRegex(TimeoutError, "timed out after 0.0 seconds"):
mq.put(1, timeout=0.0)

mq.close()

def test_monitored_queue_get(self) -> None:
ctx = mp.get_context("fork")
q = ctx.Queue(maxsize=1)
p = ctx.Process(target=queue_put, args=(q,), daemon=True)
p.start()

mq = _MonitoredQueue(p, q)
self.assertEqual(mq.get(timeout=10), 1)
with self.assertRaisesRegex(RuntimeError, "process is not alive 0"):
mq.get(timeout=10)

with self.assertRaisesRegex(TimeoutError, "timed out after 0.0 seconds"):
mq.get(timeout=0.0)

mq.close()
142 changes: 73 additions & 69 deletions torchft/process_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""

import logging
import queue
import sys
import threading
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
Expand Down Expand Up @@ -63,6 +63,8 @@
from torch.futures import Future
from torch.utils._pytree import tree_any

from torchft.multiprocessing import _MonitoredQueue

if TYPE_CHECKING:
from torchft.manager import Manager

Expand All @@ -77,28 +79,6 @@
T = TypeVar("T")


def _get(q: mp.Queue, timeout: Union[float, timedelta]) -> object:
"""
Gets an item from a queue with a timeout. If the timeout is exceeded then
a TimeoutError is raised.
If an exception is returned from the queue then it is raised.
Args:
q: queue to get from
timeout: timeout in seconds
"""
if isinstance(timeout, timedelta):
timeout = timeout.total_seconds()
try:
v = q.get(timeout=timeout)
except queue.Empty as e:
raise TimeoutError(f"queue.get() timed out after {timeout} seconds") from e
if isinstance(v, Exception):
raise v
return v


def create_store_client(store_addr: str) -> Store:
"""
Creates a PrefixStore(TCPStore(...)) client from an address in the format:
Expand Down Expand Up @@ -573,8 +553,8 @@ class _BabyWork(Work):
def __init__(
self,
pg: "ProcessGroupBaby",
tx: mp.Queue,
rx: mp.Queue,
tx: _MonitoredQueue,
rx: _MonitoredQueue,
op_id: int,
timeout: float,
) -> None:
Expand All @@ -592,7 +572,7 @@ def wait(self, timeout: Optional[timedelta] = None) -> bool:
self._tx.put(("wait", self._op_id), timeout=self._timeout)
op_id, event = cast(
Tuple[int, Optional[torch.cuda.Event]],
_get(self._rx, timeout or self._timeout),
self._rx.get(timeout or self._timeout),
)
assert op_id == self._op_id
if event is not None:
Expand Down Expand Up @@ -649,9 +629,9 @@ def __init__(self, timeout: Union[float, timedelta] = 60.0) -> None:
self._world_size = -1

self._p: Optional[mp.Process] = None
self._tx: Optional[mp.Queue] = None
self._rx: Optional[mp.Queue] = None
self._future_queue: Optional[mp.Queue] = None
self._tx: Optional[_MonitoredQueue] = None
self._rx: Optional[_MonitoredQueue] = None
self._future_queue: Optional[_MonitoredQueue] = None
self._future_thread: Optional[threading.Thread] = None
self._futures: Dict[int, Future[object]] = {}
self._futures_lock = threading.Lock()
Expand All @@ -661,60 +641,80 @@ def __init__(self, timeout: Union[float, timedelta] = 60.0) -> None:

self._timeout: float = timeout

def configure(self, store_addr: str, rank: int, world_size: int) -> None:
if self._p is not None:
self._p.kill()
def shutdown(self) -> None:
"""
Shutdown the process group. This will kill the underlying process and
close all queues.
self._world_size = world_size
This is a no-op if the process group is already shutdown.
ProcessGroup can be reconfigured after shutdown.
"""

if self._tx is not None:
self._tx.close()
if self._rx is not None:
self._rx.close()
if self._future_queue is not None:

future_queue = self._future_queue
if future_queue is not None:
# wait for the future thread to exit and then close the queue
self._future_queue.put(_QUEUE_CLOSE)
assert self._future_thread is not None
self._future_thread.join(timeout=10.0)
# pyre-ignore[16]: optional value is checked above
if self._future_thread.is_alive():
future_queue.put(_QUEUE_CLOSE, timeout=timedelta(seconds=10.0))

future_thread = self._future_thread
assert future_thread is not None
future_thread.join(timeout=10.0)
if future_thread.is_alive():
raise RuntimeError("future thread did not exit")
# pyre-ignore[16]: optional value is checked above
self._future_queue.close()

future_queue.close()

# Kill after closing queues to avoid log spam.
if self._p is not None:
self._p.kill()

def configure(self, store_addr: str, rank: int, world_size: int) -> None:
self._world_size = world_size

self.shutdown()

ctx = mp.get_context("spawn")
self._tx = ctx.Queue()
self._rx = rx = ctx.Queue()
tx = ctx.Queue()
rx = ctx.Queue()
future_queue = ctx.Queue()

self._p = p = ctx.Process(
target=self._worker,
args=(
store_addr,
rank,
world_size,
tx,
rx,
future_queue,
),
daemon=True,
)
p.start()

self._tx = tx = _MonitoredQueue(p, tx)
self._rx = rx = _MonitoredQueue(p, rx)
self._future_queue = future_queue = _MonitoredQueue(p, future_queue)

# futures need thread to fire callbacks
self._future_queue = ctx.Queue()
# this lock needs to be held when manipulating _futures
self._futures_lock = threading.Lock()
self._futures = {}
self._future_thread = threading.Thread(
target=self._future_handler,
args=(self._future_queue,),
args=(future_queue,),
daemon=True,
)
self._future_thread.start()

self._p = ctx.Process(
target=self._worker,
args=(
store_addr,
rank,
world_size,
self._tx,
self._rx,
self._future_queue,
),
daemon=True,
)
self._p.start()

# fetch the status of the PG init
# if an exception was returned _get will throw
assert _get(rx, self._timeout) is None
# if an exception was returned get will throw
assert rx.get(self._timeout) is None

@classmethod
def _create_pg(cls, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
Expand All @@ -739,7 +739,7 @@ def _worker(
try:
pg = cls._create_pg(store, rank, world_size)
except Exception as e:
logger.exception(f"got exception in worker: {e}")
print(f"got exception in worker: {e}", file=sys.stderr)
tx.put(e)
return
tx.put(None)
Expand Down Expand Up @@ -829,17 +829,21 @@ def callback(fut: Future[object]) -> None:
raise ValueError(f"unknown cmd: {cmd}")

except Exception as e:
logger.exception("worker errored")
print(f"worker errored: {e}", file=sys.stderr)
tx.put(e)
raise

def _future_handler(self, future_queue: mp.Queue) -> None:
def _future_handler(self, future_queue: _MonitoredQueue) -> None:
try:
while True:
cmd = future_queue.get()
try:
# timeout doesn't really matter here
cmd = future_queue.get(timeout=timedelta(seconds=10.0))
except TimeoutError:
continue
if cmd == _QUEUE_CLOSE:
break
op_id, mode, data = cmd
op_id, mode, data = cast(Tuple[int, str, object], cmd)
with self._futures_lock:
fut = self._futures[op_id]
del self._futures[op_id]
Expand All @@ -862,7 +866,7 @@ def _get_future(self, op_id: int) -> Future[object]:
self._tx.put(("future", op_id), timeout=self._timeout)

assert self._rx is not None
assert _get(self._rx, self._timeout) == op_id
assert self._rx.get(self._timeout) == op_id
# TODO: return correct tensor instead of None
return fut

Expand Down Expand Up @@ -899,7 +903,7 @@ def _run_func(self, func: str, *args: object, **kwargs: object) -> Work:
timeout=self._timeout,
)

op_id = _get(rx, self._timeout)
op_id = rx.get(self._timeout)
assert isinstance(op_id, int), f"invalid return {op_id}"

return _BabyWork(pg=self, tx=tx, rx=rx, op_id=op_id, timeout=self._timeout)
Expand Down Expand Up @@ -968,7 +972,7 @@ def num_active_work(self) -> int:
self._tx.put(("num_active_work",), timeout=self._timeout)

assert self._rx is not None
return cast(int, _get(self._rx, self._timeout))
return cast(int, self._rx.get(self._timeout))


@dataclass
Expand Down

0 comments on commit 9bc6269

Please sign in to comment.