Skip to content

Commit 75350d4

Browse files
committed
TimeoutManager: delete cuda events on main thread
1 parent ad0ca0a commit 75350d4

File tree

2 files changed

+76
-4
lines changed

2 files changed

+76
-4
lines changed

torchft/futures.py

+34-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import queue
23
import threading
34
from contextlib import contextmanager
45
from datetime import timedelta
@@ -36,8 +37,13 @@ def cancel(self) -> None:
3637

3738
class _TimeoutManager:
3839
"""
39-
This class manages timeouts for futures. It uses a background thread with an
40-
event loop to schedule the timeouts.
40+
This class manages timeouts for code blocks, futures and CUDA events. It
41+
uses a background thread with an event loop to schedule the timeouts and
42+
call the callback function when the timeout is reached.
43+
44+
Generally there is a single instance of this class that is used for all
45+
timeouts. The callbacks should not block otherwise other timeouts may not
46+
be processed.
4147
"""
4248

4349
def __init__(self) -> None:
@@ -46,6 +52,10 @@ def __init__(self) -> None:
4652
self._event_loop_thread: Optional[threading.Thread] = None
4753
self._next_timer_id = 0
4854

55+
# This queue is used to delete events on the main thread as cudaEventDestroy
56+
# can block if the CUDA queue is full.
57+
self._del_queue: queue.SimpleQueue[object] = queue.SimpleQueue()
58+
4959
def _maybe_start_event_loop(self) -> asyncio.AbstractEventLoop:
5060
"""
5161
Start the event loop if it has not already been started.
@@ -82,6 +92,8 @@ def register(self, fut: Future[T], timeout: timedelta) -> Future[T]:
8292
if isinstance(fut, Mock):
8393
return fut
8494

95+
self._clear_del_queue()
96+
8597
loop = self._maybe_start_event_loop()
8698

8799
# pyre-fixme[29]: Future is not a function
@@ -114,6 +126,8 @@ def callback(fut: Future[T]) -> None:
114126
return timed_fut
115127

116128
def stream_timeout(self, callback: Callable[[], None], timeout: timedelta) -> None:
129+
self._clear_del_queue()
130+
117131
loop = self._maybe_start_event_loop()
118132

119133
event: torch.cuda.Event = torch.cuda.Event()
@@ -123,6 +137,11 @@ def handler() -> None:
123137
if not event.query():
124138
callback()
125139

140+
# cudaEventDestroy can block so we never want to delete in the event
141+
# loop. Put it on the del queue so we can delete it in the main
142+
# thread.
143+
self._del_queue.put(event)
144+
126145
loop.call_soon_threadsafe(
127146
self._register_callback, loop, handler, timeout, _TimerHandle()
128147
)
@@ -145,6 +164,8 @@ def _register_callback(
145164
def context_timeout(
146165
self, callback: Callable[[], None], timeout: timedelta
147166
) -> Generator[None, None, None]:
167+
self._clear_del_queue()
168+
148169
loop = self._maybe_start_event_loop()
149170
handle = _TimerHandle()
150171

@@ -156,6 +177,17 @@ def context_timeout(
156177

157178
handle.cancel()
158179

180+
def _clear_del_queue(self) -> None:
181+
"""
182+
Clear the queue of futures to be deleted.
183+
"""
184+
while True:
185+
try:
186+
# get and immediately discard item
187+
self._del_queue.get_nowait()
188+
except queue.Empty:
189+
break
190+
159191

160192
_TIMEOUT_MANAGER = _TimeoutManager()
161193

torchft/futures_test.py

+42-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,17 @@
1+
import threading
12
from datetime import timedelta
2-
from unittest import TestCase
3+
from unittest import TestCase, skipUnless
34

5+
import torch
46
from torch.futures import Future
57

6-
from torchft.futures import future_timeout, future_wait
8+
from torchft.futures import (
9+
_TIMEOUT_MANAGER,
10+
context_timeout,
11+
future_timeout,
12+
future_wait,
13+
stream_timeout,
14+
)
715

816

917
class FuturesTest(TestCase):
@@ -45,3 +53,35 @@ def test_future_timeout_exception(self) -> None:
4553
fut.set_exception(RuntimeError("test"))
4654
with self.assertRaisesRegex(RuntimeError, "test"):
4755
timed_fut.wait()
56+
57+
def test_context_timeout(self) -> None:
58+
barrier: threading.Barrier = threading.Barrier(2)
59+
60+
def callback() -> None:
61+
barrier.wait()
62+
63+
with context_timeout(callback, timedelta(seconds=0.01)):
64+
# block until timeout fires
65+
barrier.wait()
66+
67+
def fail() -> None:
68+
self.fail("timeout should be cancelled")
69+
70+
with context_timeout(fail, timedelta(seconds=10)):
71+
pass
72+
73+
# pyre-fixme[56]: Pyre was not able to infer the type of decorator
74+
@skipUnless(torch.cuda.is_available(), "CUDA is required for this test")
75+
def test_stream_timeout(self) -> None:
76+
torch.cuda.synchronize()
77+
78+
def callback() -> None:
79+
self.fail()
80+
81+
stream_timeout(callback, timeout=timedelta(seconds=0.01))
82+
83+
# make sure event completes
84+
torch.cuda.synchronize()
85+
86+
# make sure that event is deleted on the deletion queue
87+
_TIMEOUT_MANAGER._del_queue.get(timeout=10.0)

0 commit comments

Comments
 (0)