Skip to content

Commit bc76d35

Browse files
authored
Exit early if trying to close the peer connections twice (#234)
* Exit early if trying to close the peer connections twice Previously, the repeated close() calls were hanging indefinitely for PublisherPeerConnection and SubscriberPeerConnection. Now, we set `_closed=True` guard after the connection is closed for the first time and exit early on the repeated call. * Fix makefile * Set _closed=True in peer connections before super is called
1 parent f3972a6 commit bc76d35

File tree

5 files changed

+77
-5
lines changed

5 files changed

+77
-5
lines changed

Makefile

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ VIDEO_PATHS := \
1212
getstream/video \
1313
tests/rtc \
1414
tests/test_audio_stream_track.py \
15-
tests/test_connection_manager.py \
1615
tests/test_connection_utils.py \
1716
tests/test_signaling.py \
1817
tests/test_video_examples.py \

getstream/video/rtc/pc.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def __init__(
5050
)
5151
super().__init__(configuration)
5252
self.manager = manager
53+
self._closed = False
5354
self._connected_event = asyncio.Event()
5455

5556
for transceiver in self.getTransceivers():
@@ -113,6 +114,14 @@ async def wait_for_connected(self, timeout: float = 15.0):
113114
logger.error(f"Publisher connection timed out after {timeout}s")
114115
raise TimeoutError(f"Connection timed out after {timeout} seconds")
115116

117+
async def close(self):
118+
# Using self._closed guard here
119+
# to avoid closing RTCPeerConnectionTwice by accident (it freezes on second time)
120+
if self._closed:
121+
return
122+
self._closed = True
123+
await super().close()
124+
116125
async def restartIce(self):
117126
"""Restart ICE connection for reconnection scenarios."""
118127
logger.info("Restarting ICE connection for publisher")
@@ -138,6 +147,7 @@ def __init__(
138147
)
139148
super().__init__(configuration)
140149
self.connection = connection
150+
self._closed = False
141151
self._drain_video_frames = drain_video_frames
142152

143153
self.track_map = {} # track_id -> (MediaRelay, original_track)
@@ -245,6 +255,31 @@ def get_video_frame_tracker(self) -> Optional[Any]:
245255
return next(iter(self.video_frame_trackers.values()))
246256
return None
247257

258+
async def close(self):
259+
# Using self._closed guard here
260+
# to avoid closing RTCPeerConnectionTwice by accident (it freezes on second time)
261+
if self._closed:
262+
return
263+
264+
# Clean up video drains
265+
for blackhole, drain_task, drain_proxy in list(self._video_drains.values()):
266+
drain_task.cancel()
267+
drain_proxy.stop()
268+
await blackhole.stop()
269+
self._video_drains.clear()
270+
271+
# Cancel background tasks
272+
for task in list(self._background_tasks):
273+
task.cancel()
274+
self._background_tasks.clear()
275+
276+
# Clear track maps
277+
self.track_map.clear()
278+
self.video_frame_trackers.clear()
279+
280+
self._closed = True
281+
await super().close()
282+
248283
async def restartIce(self):
249284
"""Restart ICE connection for reconnection scenarios."""
250285
logger.info("Restarting ICE connection for subscriber")

getstream/video/rtc/reconnection.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,10 @@ async def _reconnect_migrate(self):
284284
current_publisher = self.connection_manager.publisher_pc
285285
current_subscriber = self.connection_manager.subscriber_pc
286286

287+
# Clear old references so _connect_internal creates fresh PCs
288+
self.connection_manager.publisher_pc = None
289+
self.connection_manager.subscriber_pc = None
290+
287291
self.connection_manager.connection_state = ConnectionState.MIGRATING
288292

289293
if current_publisher and hasattr(current_publisher, "removeListener"):
Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,23 @@
1+
import asyncio
12
import contextlib
3+
import uuid
4+
from unittest.mock import AsyncMock, MagicMock, patch
25

36
import pytest
4-
from unittest.mock import AsyncMock, patch, MagicMock
7+
from dotenv import load_dotenv
58

9+
from getstream import AsyncStream
10+
from getstream.video import rtc
611
from getstream.video.rtc.connection_manager import ConnectionManager
7-
from getstream.video.rtc.connection_utils import SfuJoinError, SfuConnectionError
12+
from getstream.video.rtc.connection_utils import (
13+
ConnectionState,
14+
SfuConnectionError,
15+
SfuJoinError,
16+
)
817
from getstream.video.rtc.pb.stream.video.sfu.models import models_pb2
918

19+
load_dotenv()
20+
1021

1122
@contextlib.contextmanager
1223
def patched_dependencies():
@@ -45,8 +56,30 @@ def connection_manager(request):
4556
yield cm
4657

4758

48-
class TestConnectRetry:
49-
"""Tests for connect() retry logic when SFU is full."""
59+
@pytest.fixture
60+
def client():
61+
return AsyncStream(timeout=10.0)
62+
63+
64+
class TestConnectionManager:
65+
@pytest.mark.asyncio
66+
@pytest.mark.integration
67+
async def test_leave_twice_does_not_hang(self, client: AsyncStream):
68+
"""Integration test: join a real call and leave twice without hanging."""
69+
call_id = str(uuid.uuid4())
70+
call = client.video.call("default", call_id)
71+
72+
async with await rtc.join(call, "test-user") as connection:
73+
assert connection.connection_state == ConnectionState.JOINED
74+
75+
await asyncio.sleep(2)
76+
77+
await asyncio.wait_for(connection.leave(), timeout=10.0)
78+
assert connection.connection_state == ConnectionState.LEFT
79+
80+
# Second leave must not hang
81+
await asyncio.wait_for(connection.leave(), timeout=10.0)
82+
assert connection.connection_state == ConnectionState.LEFT
5083

5184
@pytest.mark.asyncio
5285
@pytest.mark.parametrize("connection_manager", [2], indirect=True)

tests/rtc/test_subscriber_drain.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def subscriber_pc():
1313
"""Create a SubscriberPeerConnection bypassing heavy parent inits."""
1414
pc = SubscriberPeerConnection.__new__(SubscriberPeerConnection)
1515
pc.connection = Mock()
16+
pc._closed = False
1617
pc._drain_video_frames = True
1718
pc.track_map = {}
1819
pc.video_frame_trackers = {}

0 commit comments

Comments
 (0)