Skip to content

Commit 87aaf82

Browse files
committed
Guard websocket reconnect lifecycle
1 parent 05cf475 commit 87aaf82

2 files changed

Lines changed: 224 additions & 2 deletions

File tree

custom_components/bhyve/pybhyve/websocket.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import json
44
import logging
5-
from asyncio import AbstractEventLoop, ensure_future
5+
from asyncio import AbstractEventLoop, TimerHandle, ensure_future
66
from collections.abc import Callable
77
from math import ceil
88
from typing import Any
@@ -50,12 +50,18 @@ def __init__(
5050
self._ws: ClientWebSocketResponse | None = None
5151
self._closing: bool = False
5252
self._reconnect_delay: int = RECONNECT_DELAY
53+
self._reconnect_cb: TimerHandle | None = None
5354

5455
def _cancel_heartbeat(self) -> None:
5556
if self._heartbeat_cb is not None:
5657
self._heartbeat_cb.cancel()
5758
self._heartbeat_cb = None
5859

60+
def _cancel_reconnect(self) -> None:
61+
if self._reconnect_cb is not None:
62+
self._reconnect_cb.cancel()
63+
self._reconnect_cb = None
64+
5965
def _reset_heartbeat(self) -> None:
6066
self._cancel_heartbeat()
6167

@@ -90,6 +96,11 @@ def state(self, value: str) -> None:
9096

9197
def start(self) -> None:
9298
"""Start the websocket."""
99+
self._cancel_reconnect()
100+
if self._closing:
101+
_LOGGER.info("Websocket closed intentionally, not starting")
102+
return
103+
93104
if self.state != STATE_RUNNING:
94105
self.state = STATE_STARTING
95106
self._loop.create_task(self.running())
@@ -211,19 +222,26 @@ async def stop(self) -> None:
211222
self._closing = True
212223
self.state = STATE_STOPPED
213224
self._cancel_heartbeat()
225+
self._cancel_reconnect()
214226
if self._ws is not None:
215227
await self._ws.close()
216228

217229
def retry(self) -> None:
218230
"""Retry to connect to Orbit."""
231+
if self._closing:
232+
_LOGGER.info("Websocket closed intentionally, not scheduling reconnect")
233+
return
234+
219235
if self.state != STATE_STARTING:
220236
_LOGGER.info(
221237
"Reconnecting to Orbit in %i; state: %s",
222238
self._reconnect_delay,
223239
self.state,
224240
)
225241
self.state = STATE_STARTING
226-
self._loop.call_later(self._reconnect_delay, self.start)
242+
self._reconnect_cb = self._loop.call_later(
243+
self._reconnect_delay, self.start
244+
)
227245
self._reconnect_delay = min(
228246
self._reconnect_delay * 2,
229247
MAX_RECONNECT_DELAY,

tests/test_websocket.py

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
"""Test Orbit BHyve websocket transport."""
2+
3+
import json
4+
from collections.abc import Callable, Coroutine
5+
from types import TracebackType
6+
from typing import Any
7+
from unittest.mock import AsyncMock, MagicMock
8+
9+
from aiohttp import WSMsgType
10+
11+
from custom_components.bhyve.pybhyve.websocket import (
12+
MAX_RECONNECT_DELAY,
13+
RECONNECT_DELAY,
14+
STATE_STOPPED,
15+
OrbitWebsocket,
16+
)
17+
18+
19+
class FakeTimerHandle:
20+
"""Fake asyncio timer handle."""
21+
22+
def __init__(self, delay: float, callback: Callable[[], None]) -> None:
23+
"""Initialize the fake timer handle."""
24+
self.delay = delay
25+
self.callback = callback
26+
self.cancelled = False
27+
28+
def cancel(self) -> None:
29+
"""Cancel the fake timer handle."""
30+
self.cancelled = True
31+
32+
33+
class FakeLoop:
34+
"""Fake event loop for websocket scheduling tests."""
35+
36+
def __init__(self) -> None:
37+
"""Initialize the fake loop."""
38+
self.call_later_handles: list[FakeTimerHandle] = []
39+
self.call_at_handles: list[FakeTimerHandle] = []
40+
self.created_tasks: list[Coroutine[Any, Any, Any]] = []
41+
42+
def time(self) -> float:
43+
"""Return the fake loop time."""
44+
return 0
45+
46+
def call_at(self, when: float, callback: Callable[[], None]) -> FakeTimerHandle:
47+
"""Record a scheduled call_at callback."""
48+
handle = FakeTimerHandle(when, callback)
49+
self.call_at_handles.append(handle)
50+
return handle
51+
52+
def call_later(self, delay: float, callback: Callable[[], None]) -> FakeTimerHandle:
53+
"""Record a scheduled call_later callback."""
54+
handle = FakeTimerHandle(delay, callback)
55+
self.call_later_handles.append(handle)
56+
return handle
57+
58+
def create_task(self, coro: Coroutine[Any, Any, Any]) -> MagicMock:
59+
"""Record and close a created coroutine task."""
60+
self.created_tasks.append(coro)
61+
coro.close()
62+
return MagicMock()
63+
64+
65+
class FakeMessage:
66+
"""Fake websocket message."""
67+
68+
type = WSMsgType.CLOSE
69+
70+
71+
class FakeWebSocket:
72+
"""Fake aiohttp websocket response."""
73+
74+
def __init__(self) -> None:
75+
"""Initialize the fake websocket response."""
76+
self.closed = False
77+
self.sent: list[str] = []
78+
79+
async def send_str(self, data: str) -> None:
80+
"""Record sent websocket data."""
81+
self.sent.append(data)
82+
83+
async def receive(self) -> FakeMessage:
84+
"""Return a close message to end the receive loop."""
85+
return FakeMessage()
86+
87+
async def pong(self) -> None:
88+
"""Handle websocket pong calls."""
89+
90+
async def close(self) -> None:
91+
"""Close the fake websocket."""
92+
self.closed = True
93+
94+
def exception(self) -> None:
95+
"""Return no websocket exception."""
96+
97+
98+
class FakeWebsocketContext:
99+
"""Fake websocket async context manager."""
100+
101+
def __init__(self, websocket: FakeWebSocket) -> None:
102+
"""Initialize the fake websocket context."""
103+
self.websocket = websocket
104+
105+
async def __aenter__(self) -> FakeWebSocket:
106+
"""Enter the fake websocket context."""
107+
return self.websocket
108+
109+
async def __aexit__(
110+
self,
111+
exc_type: type[BaseException] | None,
112+
exc: BaseException | None,
113+
exc_tb: TracebackType | None,
114+
) -> None:
115+
"""Exit the fake websocket context."""
116+
self.websocket.closed = True
117+
118+
119+
class FakeSession:
120+
"""Fake aiohttp client session."""
121+
122+
def __init__(self) -> None:
123+
"""Initialize the fake session."""
124+
self.websocket = FakeWebSocket()
125+
126+
def ws_connect(self, _url: str) -> FakeWebsocketContext:
127+
"""Return a fake websocket context."""
128+
return FakeWebsocketContext(self.websocket)
129+
130+
131+
def create_websocket(
132+
loop: FakeLoop | None = None,
133+
session: FakeSession | None = None,
134+
) -> OrbitWebsocket:
135+
"""Create an Orbit websocket with test doubles."""
136+
return OrbitWebsocket(
137+
token="token",
138+
loop=loop or FakeLoop(),
139+
session=session or FakeSession(),
140+
url="wss://example.test",
141+
async_callback=AsyncMock(),
142+
)
143+
144+
145+
def test_retry_uses_exponential_backoff_until_max() -> None:
146+
"""Test reconnect retries back off to the maximum delay."""
147+
loop = FakeLoop()
148+
websocket = create_websocket(loop=loop)
149+
150+
for expected_delay in [5, 10, 20, 40, 80, 160, 300, 300]:
151+
websocket.state = STATE_STOPPED
152+
153+
websocket.retry()
154+
155+
handle = loop.call_later_handles[-1]
156+
assert handle.delay == expected_delay
157+
handle.callback()
158+
159+
assert websocket._reconnect_delay == MAX_RECONNECT_DELAY
160+
161+
162+
async def test_successful_connection_resets_reconnect_delay() -> None:
163+
"""Test a successful websocket connection resets the backoff delay."""
164+
loop = FakeLoop()
165+
session = FakeSession()
166+
websocket = create_websocket(loop=loop, session=session)
167+
websocket._reconnect_delay = 80
168+
169+
await websocket.running()
170+
171+
assert loop.call_later_handles[-1].delay == RECONNECT_DELAY
172+
assert websocket._reconnect_delay == RECONNECT_DELAY * 2
173+
assert json.loads(session.websocket.sent[0]) == {
174+
"event": "app_connection",
175+
"orbit_session_token": "token",
176+
}
177+
178+
179+
async def test_stop_cancels_pending_reconnect() -> None:
180+
"""Test stop cancels a pending reconnect timer."""
181+
loop = FakeLoop()
182+
websocket = create_websocket(loop=loop)
183+
184+
websocket.retry()
185+
handle = loop.call_later_handles[-1]
186+
187+
await websocket.stop()
188+
189+
assert handle.cancelled
190+
assert websocket._reconnect_cb is None
191+
192+
193+
async def test_stale_reconnect_callback_does_not_restart_after_stop() -> None:
194+
"""Test a stale reconnect callback cannot restart a stopped websocket."""
195+
loop = FakeLoop()
196+
websocket = create_websocket(loop=loop)
197+
198+
websocket.retry()
199+
handle = loop.call_later_handles[-1]
200+
await websocket.stop()
201+
handle.callback()
202+
203+
assert websocket.state == STATE_STOPPED
204+
assert not loop.created_tasks

0 commit comments

Comments
 (0)