Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TimeoutManager: delete cuda events on main thread #142

Merged
merged 1 commit into from
Mar 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 49 additions & 2 deletions torchft/futures.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import asyncio
import queue
import sys
import threading
from contextlib import contextmanager
from datetime import timedelta
Expand Down Expand Up @@ -36,8 +38,13 @@ def cancel(self) -> None:

class _TimeoutManager:
"""
This class manages timeouts for futures. It uses a background thread with an
event loop to schedule the timeouts.
This class manages timeouts for code blocks, futures and CUDA events. It
uses a background thread with an event loop to schedule the timeouts and
call the callback function when the timeout is reached.

Generally there is a single instance of this class that is used for all
timeouts. The callbacks should not block otherwise other timeouts may not
be processed.
"""

def __init__(self) -> None:
Expand All @@ -46,6 +53,10 @@ def __init__(self) -> None:
self._event_loop_thread: Optional[threading.Thread] = None
self._next_timer_id = 0

# This queue is used to delete events on the main thread as cudaEventDestroy
# can block if the CUDA queue is full.
self._del_queue: queue.SimpleQueue[object] = queue.SimpleQueue()

def _maybe_start_event_loop(self) -> asyncio.AbstractEventLoop:
"""
Start the event loop if it has not already been started.
Expand Down Expand Up @@ -82,6 +93,8 @@ def register(self, fut: Future[T], timeout: timedelta) -> Future[T]:
if isinstance(fut, Mock):
return fut

self._clear_del_queue()

loop = self._maybe_start_event_loop()

# pyre-fixme[29]: Future is not a function
Expand Down Expand Up @@ -114,6 +127,8 @@ def callback(fut: Future[T]) -> None:
return timed_fut

def stream_timeout(self, callback: Callable[[], None], timeout: timedelta) -> None:
self._clear_del_queue()

loop = self._maybe_start_event_loop()

event: torch.cuda.Event = torch.cuda.Event()
Expand All @@ -123,6 +138,11 @@ def handler() -> None:
if not event.query():
callback()

# cudaEventDestroy can block so we never want to delete in the event
# loop. Put it on the del queue so we can delete it in the main
# thread.
self._del_queue.put(event)

loop.call_soon_threadsafe(
self._register_callback, loop, handler, timeout, _TimerHandle()
)
Expand All @@ -145,6 +165,8 @@ def _register_callback(
def context_timeout(
self, callback: Callable[[], None], timeout: timedelta
) -> Generator[None, None, None]:
self._clear_del_queue()

loop = self._maybe_start_event_loop()
handle = _TimerHandle()

Expand All @@ -156,6 +178,31 @@ def context_timeout(

handle.cancel()

def _clear_del_queue(self) -> int:
"""
Clear the queue of futures to be deleted.

Returns the number of items deleted.
"""
count = 0
while True:
try:
# get and immediately discard item
item = self._del_queue.get_nowait()
refcount = sys.getrefcount(item)
assert (
# 1 from item, 1 from getrefcount
refcount
== 2
), f"items in del_queue reference should not have other references, found {refcount=}"
del item

count += 1
except queue.Empty:
break

return count


_TIMEOUT_MANAGER = _TimeoutManager()

Expand Down
48 changes: 46 additions & 2 deletions torchft/futures_test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
import threading
from datetime import timedelta
from unittest import TestCase
from unittest import TestCase, skipUnless

import torch
from torch.futures import Future

from torchft.futures import future_timeout, future_wait
from torchft.futures import (
_TIMEOUT_MANAGER,
context_timeout,
future_timeout,
future_wait,
stream_timeout,
)


class FuturesTest(TestCase):
Expand Down Expand Up @@ -45,3 +53,39 @@ def test_future_timeout_exception(self) -> None:
fut.set_exception(RuntimeError("test"))
with self.assertRaisesRegex(RuntimeError, "test"):
timed_fut.wait()

def test_context_timeout(self) -> None:
barrier: threading.Barrier = threading.Barrier(2)

def callback() -> None:
barrier.wait()

with context_timeout(callback, timedelta(seconds=0.01)):
# block until timeout fires
barrier.wait()

def fail() -> None:
self.fail("timeout should be cancelled")

with context_timeout(fail, timedelta(seconds=10)):
pass

# pyre-fixme[56]: Pyre was not able to infer the type of decorator
@skipUnless(torch.cuda.is_available(), "CUDA is required for this test")
def test_stream_timeout(self) -> None:
torch.cuda.synchronize()

def callback() -> None:
self.fail()

stream_timeout(callback, timeout=timedelta(seconds=0.01))

# make sure event completes
torch.cuda.synchronize()

# make sure that event is deleted on the deletion queue
item = _TIMEOUT_MANAGER._del_queue.get(timeout=10.0)
_TIMEOUT_MANAGER._del_queue.put(item)
del item

self.assertEqual(_TIMEOUT_MANAGER._clear_del_queue(), 1)