Skip to content

Commit 235406c

Browse files
committed
Ensure WebSocket is always marked as closed after getting ConnectionClosed error
1 parent 92c3da0 commit 235406c

File tree

1 file changed

+11
-16
lines changed

1 file changed

+11
-16
lines changed

replit_river/session.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,15 @@ async def is_websocket_open(self) -> bool:
116116
async with self._ws_lock:
117117
return await self._ws_wrapper.is_open()
118118

119-
async def _begin_close_session_countdown(self) -> None:
120-
"""Begin the countdown to close session, this should be called when
121-
websocket is closed.
119+
async def _handle_connection_closed(self) -> None:
122120
"""
121+
Handle the WebSocket connection being closing.
122+
This will trigger connection retries, if enabled, and starts a
123+
session reconnection timer.
124+
If the timer expires before reconnection, the session will be closed.
125+
"""
126+
# ensure websocket is closed and initiate connection retries if applicable.
127+
await self.close_websocket(self._ws_wrapper, not self._is_server)
123128
# calculate the value now before establishing it so that there are no
124129
# await points between the check and the assignment to avoid a TOCTOU
125130
# race.
@@ -145,12 +150,7 @@ async def serve(self) -> None:
145150
try:
146151
await self._handle_messages_from_ws(tg)
147152
except ConnectionClosed:
148-
if self._retry_connection_callback:
149-
self._task_manager.create_task(
150-
self._retry_connection_callback()
151-
)
152-
153-
await self._begin_close_session_countdown()
153+
await self._handle_connection_closed()
154154
logger.debug("ConnectionClosed while serving", exc_info=True)
155155
except FailedSendingMessageException:
156156
# Expected error if the connection is closed.
@@ -310,10 +310,7 @@ async def _heartbeat(
310310
"%r closing websocket because of heartbeat misses",
311311
self.session_id,
312312
)
313-
await self.close_websocket(
314-
self._ws_wrapper, should_retry=not self._is_server
315-
)
316-
await self._begin_close_session_countdown()
313+
await self._handle_connection_closed()
317314
continue
318315
except FailedSendingMessageException:
319316
# this is expected during websocket closed period
@@ -344,9 +341,7 @@ async def _send_transport_message(
344341
websocket: websockets.WebSocketCommonProtocol,
345342
) -> None:
346343
try:
347-
await send_transport_message(
348-
msg, websocket, self._begin_close_session_countdown
349-
)
344+
await send_transport_message(msg, websocket, self._handle_connection_closed)
350345
except WebsocketClosedException as e:
351346
raise e
352347
except FailedSendingMessageException as e:

0 commit comments

Comments
 (0)