Skip to content

Commit 8857c86

Browse files
Copilotenzofrnt
andcommitted
Fix asyncio future cancellation in SSE stream to prevent shutdown hangs
Co-authored-by: enzofrnt <[email protected]>
1 parent af8855d commit 8857c86

File tree

2 files changed

+103
-48
lines changed

2 files changed

+103
-48
lines changed

django_eventstream/views.py

Lines changed: 53 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -223,59 +223,64 @@ async def stream(event_request, listener):
223223

224224
while True:
225225
f = asyncio.ensure_future(listener.aevent.wait())
226-
while True:
227-
done, _ = await asyncio.wait([f], timeout=20)
228-
if f in done:
229-
break
230-
body = "event: keep-alive\ndata:\n\n"
231-
yield body
226+
try:
227+
while True:
228+
done, _ = await asyncio.wait([f], timeout=20)
229+
if f in done:
230+
break
231+
body = "event: keep-alive\ndata:\n\n"
232+
yield body
233+
234+
lm.lock.acquire()
235+
236+
channel_items = listener.channel_items
237+
overflow = listener.overflow
238+
error_data = listener.error
239+
240+
listener.aevent.clear()
241+
listener.channel_items = {}
242+
listener.overflow = False
243+
244+
lm.lock.release()
245+
246+
body = ""
247+
for channel, items in channel_items.items():
248+
for item in items:
249+
if channel in last_ids:
250+
if item.id is not None:
251+
last_ids[channel] = item.id
252+
else:
253+
del last_ids[channel]
254+
if last_ids:
255+
event_id = make_id(last_ids)
256+
else:
257+
event_id = None
258+
body += sse_encode_event(
259+
item.type, item.data, event_id=event_id
260+
)
232261

233-
lm.lock.acquire()
262+
more = True
234263

235-
channel_items = listener.channel_items
236-
overflow = listener.overflow
237-
error_data = listener.error
264+
if error_data:
265+
condition = error_data["condition"]
266+
text = error_data["text"]
267+
extra = error_data.get("extra")
268+
body += sse_encode_error(condition, text, extra=extra)
269+
more = False
238270

239-
listener.aevent.clear()
240-
listener.channel_items = {}
241-
listener.overflow = False
271+
if body or not more:
272+
yield body
242273

243-
lm.lock.release()
274+
if not more:
275+
break
244276

245-
body = ""
246-
for channel, items in channel_items.items():
247-
for item in items:
248-
if channel in last_ids:
249-
if item.id is not None:
250-
last_ids[channel] = item.id
251-
else:
252-
del last_ids[channel]
253-
if last_ids:
254-
event_id = make_id(last_ids)
255-
else:
256-
event_id = None
257-
body += sse_encode_event(
258-
item.type, item.data, event_id=event_id
259-
)
260-
261-
more = True
262-
263-
if error_data:
264-
condition = error_data["condition"]
265-
text = error_data["text"]
266-
extra = error_data.get("extra")
267-
body += sse_encode_error(condition, text, extra=extra)
268-
more = False
269-
270-
if body or not more:
271-
yield body
272-
273-
if not more:
274-
break
275-
276-
if overflow:
277-
# check db
278-
break
277+
if overflow:
278+
# check db
279+
break
280+
finally:
281+
# Always cancel the future to prevent it from lingering
282+
if not f.done():
283+
f.cancel()
279284

280285
event_request.channel_last_ids = last_ids
281286
finally:

tests/test_stream.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,56 @@ def __assert_all_events_are_retrieved_only_once(self):
5858
CHANNEL_NAME, EVENTS_LIMIT, limit=EVENTS_LIMIT + 1
5959
)
6060

61+
@patch("django_eventstream.eventstream.get_storage")
62+
async def test_stream_cancellation_during_wait(self, mock_get_storage):
63+
"""Test that stream properly cleans up when cancelled during event wait."""
64+
mock_get_storage.return_value = self.storage
65+
66+
# Create a real listener (not mocked) to test actual wait behavior
67+
listener = Listener()
68+
69+
request = EventRequest()
70+
request.is_next = False
71+
request.is_recover = False
72+
request.channels = [CHANNEL_NAME]
73+
74+
# Get current ID using sync_to_async
75+
get_current_id = sync_to_async(self.storage.get_current_id)
76+
current_id = await get_current_id(CHANNEL_NAME)
77+
request.channel_last_ids = {CHANNEL_NAME: str(current_id)}
78+
79+
# Start streaming - this will wait for events since we're caught up
80+
stream_task = asyncio.create_task(
81+
self.__collect_response(stream(request, listener))
82+
)
83+
84+
# Give it time to enter the wait loop
85+
await asyncio.sleep(0.5)
86+
87+
# Cancel the stream
88+
stream_task.cancel()
89+
90+
try:
91+
await stream_task
92+
raise ValueError("stream completed unexpectedly")
93+
except asyncio.CancelledError:
94+
pass
95+
96+
# Verify no tasks are left running
97+
pending_tasks = [task for task in asyncio.all_tasks()
98+
if not task.done() and task != asyncio.current_task()]
99+
100+
# Allow brief time for cleanup
101+
await asyncio.sleep(0.1)
102+
103+
# Check again after cleanup time
104+
pending_tasks_after = [task for task in asyncio.all_tasks()
105+
if not task.done() and task != asyncio.current_task()]
106+
107+
# The number of pending tasks should not increase after cancellation
108+
self.assertLessEqual(len(pending_tasks_after), len(pending_tasks),
109+
"Stream cancellation should not leave lingering tasks")
110+
61111
async def __initialise_test(self, mock_get_storage):
62112
mock_get_storage.return_value = self.storage
63113

0 commit comments

Comments
 (0)