Skip to content

Commit c1ab7d9

Browse files
committed
TimeoutManager: delete cuda events on main thread
1 parent 73a6f78 commit c1ab7d9

File tree

2 files changed

+95
-4
lines changed

2 files changed

+95
-4
lines changed

torchft/futures.py

+49-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import asyncio
2+
import queue
3+
import sys
24
import threading
35
from contextlib import contextmanager
46
from datetime import timedelta
@@ -36,8 +38,13 @@ def cancel(self) -> None:
3638

3739
class _TimeoutManager:
3840
"""
39-
This class manages timeouts for futures. It uses a background thread with an
40-
event loop to schedule the timeouts.
41+
This class manages timeouts for code blocks, futures and CUDA events. It
42+
uses a background thread with an event loop to schedule the timeouts and
43+
call the callback function when the timeout is reached.
44+
45+
Generally there is a single instance of this class that is used for all
46+
timeouts. The callbacks should not block otherwise other timeouts may not
47+
be processed.
4148
"""
4249

4350
def __init__(self) -> None:
@@ -46,6 +53,10 @@ def __init__(self) -> None:
4653
self._event_loop_thread: Optional[threading.Thread] = None
4754
self._next_timer_id = 0
4855

56+
# This queue is used to delete events on the main thread as cudaEventDestroy
57+
# can block if the CUDA queue is full.
58+
self._del_queue: queue.SimpleQueue[object] = queue.SimpleQueue()
59+
4960
def _maybe_start_event_loop(self) -> asyncio.AbstractEventLoop:
5061
"""
5162
Start the event loop if it has not already been started.
@@ -82,6 +93,8 @@ def register(self, fut: Future[T], timeout: timedelta) -> Future[T]:
8293
if isinstance(fut, Mock):
8394
return fut
8495

96+
self._clear_del_queue()
97+
8598
loop = self._maybe_start_event_loop()
8699

87100
# pyre-fixme[29]: Future is not a function
@@ -114,6 +127,8 @@ def callback(fut: Future[T]) -> None:
114127
return timed_fut
115128

116129
def stream_timeout(self, callback: Callable[[], None], timeout: timedelta) -> None:
130+
self._clear_del_queue()
131+
117132
loop = self._maybe_start_event_loop()
118133

119134
event: torch.cuda.Event = torch.cuda.Event()
@@ -123,6 +138,11 @@ def handler() -> None:
123138
if not event.query():
124139
callback()
125140

141+
# cudaEventDestroy can block so we never want to delete in the event
142+
# loop. Put it on the del queue so we can delete it in the main
143+
# thread.
144+
self._del_queue.put(event)
145+
126146
loop.call_soon_threadsafe(
127147
self._register_callback, loop, handler, timeout, _TimerHandle()
128148
)
@@ -145,6 +165,8 @@ def _register_callback(
145165
def context_timeout(
146166
self, callback: Callable[[], None], timeout: timedelta
147167
) -> Generator[None, None, None]:
168+
self._clear_del_queue()
169+
148170
loop = self._maybe_start_event_loop()
149171
handle = _TimerHandle()
150172

@@ -156,6 +178,31 @@ def context_timeout(
156178

157179
handle.cancel()
158180

181+
def _clear_del_queue(self) -> int:
182+
"""
183+
Clear the queue of futures to be deleted.
184+
185+
Returns the number of items deleted.
186+
"""
187+
count = 0
188+
while True:
189+
try:
190+
# get and immediately discard item
191+
item = self._del_queue.get_nowait()
192+
refcount = sys.getrefcount(item)
193+
assert (
194+
# 1 from item, 1 from getrefcount
195+
refcount
196+
== 2
197+
), f"items in del_queue reference should not have other references, found {refcount=}"
198+
del item
199+
200+
count += 1
201+
except queue.Empty:
202+
break
203+
204+
return count
205+
159206

160207
_TIMEOUT_MANAGER = _TimeoutManager()
161208

torchft/futures_test.py

+46-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,17 @@
1+
import threading
12
from datetime import timedelta
2-
from unittest import TestCase
3+
from unittest import TestCase, skipUnless
34

5+
import torch
46
from torch.futures import Future
57

6-
from torchft.futures import future_timeout, future_wait
8+
from torchft.futures import (
9+
_TIMEOUT_MANAGER,
10+
context_timeout,
11+
future_timeout,
12+
future_wait,
13+
stream_timeout,
14+
)
715

816

917
class FuturesTest(TestCase):
@@ -45,3 +53,39 @@ def test_future_timeout_exception(self) -> None:
4553
fut.set_exception(RuntimeError("test"))
4654
with self.assertRaisesRegex(RuntimeError, "test"):
4755
timed_fut.wait()
56+
57+
def test_context_timeout(self) -> None:
58+
barrier: threading.Barrier = threading.Barrier(2)
59+
60+
def callback() -> None:
61+
barrier.wait()
62+
63+
with context_timeout(callback, timedelta(seconds=0.01)):
64+
# block until timeout fires
65+
barrier.wait()
66+
67+
def fail() -> None:
68+
self.fail("timeout should be cancelled")
69+
70+
with context_timeout(fail, timedelta(seconds=10)):
71+
pass
72+
73+
# pyre-fixme[56]: Pyre was not able to infer the type of decorator
74+
@skipUnless(torch.cuda.is_available(), "CUDA is required for this test")
75+
def test_stream_timeout(self) -> None:
76+
torch.cuda.synchronize()
77+
78+
def callback() -> None:
79+
self.fail()
80+
81+
stream_timeout(callback, timeout=timedelta(seconds=0.01))
82+
83+
# make sure event completes
84+
torch.cuda.synchronize()
85+
86+
# make sure that event is deleted on the deletion queue
87+
item = _TIMEOUT_MANAGER._del_queue.get(timeout=10.0)
88+
_TIMEOUT_MANAGER._del_queue.put(item)
89+
del item
90+
91+
self.assertEqual(_TIMEOUT_MANAGER._clear_del_queue(), 1)

0 commit comments

Comments
 (0)