Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 45 additions & 6 deletions src/lmstudio/_ws_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,8 @@ def __init__(
ws_url: str,
auth_details: DictObject,
log_context: LogEventContext | None = None,
max_reconnect_retries: int = 3,
initial_retry_delay: float = 1.0,
) -> None:
self._auth_details = auth_details
self._connection_attempted = asyncio.Event()
Expand All @@ -295,6 +297,10 @@ def __init__(
self._logger = logger = new_logger(type(self).__name__)
logger.update_context(log_context, ws_url=ws_url)
self._mux = MultiplexingManager(logger)
# Reconnection configuration
self._max_reconnect_retries = max_reconnect_retries
self._initial_retry_delay = initial_retry_delay
self._consecutive_failures = 0

async def connect(self) -> bool:
"""Connect websocket from the task manager's event loop."""
Expand Down Expand Up @@ -515,15 +521,48 @@ async def _process_next_message(self) -> bool:
return await self._enqueue_message(message)

async def _receive_messages(self) -> None:
"""Process received messages until task is cancelled."""
"""Process received messages with automatic reconnection on failure."""
while True:
try:
await self._process_next_message()
except (LMStudioWebsocketError, HTTPXWSException):
if self._ws is not None and not self._ws_disconnected.is_set():
# Websocket failed unexpectedly (rather than due to client shutdown)
self._logger.error("Websocket failed, terminating session.")
break
# this Reset failure counter on successful yeah
self._consecutive_failures = 0
except (LMStudioWebsocketError, HTTPXWSException) as exc:
# and it will check if this was an intentional disconnect
if self._ws_disconnected.is_set():
self._logger.debug("Websocket disconnected intentionally")
break

# and this is for Increment failure counter
self._consecutive_failures += 1

# this wiill Check if we should attempt reconnection
if self._consecutive_failures > self._max_reconnect_retries:
self._logger.error(
f"Websocket failed after {self._max_reconnect_retries} reconnection attempts, "
"terminating session.",
consecutive_failures=self._consecutive_failures,
)
break

# Calculate exponential backoff delay
retry_delay = self._initial_retry_delay * (2 ** (self._consecutive_failures - 1))
retry_delay = min(retry_delay, 30.0) # Cap at 30 seconds

self._logger.warning(
f"Websocket error (attempt {self._consecutive_failures}/{self._max_reconnect_retries}), "
f"retrying in {retry_delay:.1f}s: {exc}",
consecutive_failures=self._consecutive_failures,
retry_delay=retry_delay,
error=str(exc),
)

# Wait before attempting to reconnect
await asyncio.sleep(retry_delay)

# there is a note like The actual reconnection happens at a higher level
# This code allows the message loop to continue, giving the
# connection a chance to reestablish itself

async def _enqueue_message(self, message: Any) -> bool:
if message is None:
Expand Down
40 changes: 30 additions & 10 deletions src/lmstudio/json_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,12 +605,22 @@ def acquire_channel_id(self, rx_queue: RxQueue) -> int:
def release_channel_id(self, channel_id: int, rx_queue: RxQueue) -> None:
"""Release a previously acquired streaming channel ID."""
open_channels = self._open_channels
assigned_queue = open_channels.get(channel_id)
if rx_queue is not assigned_queue:
raise LMStudioRuntimeError(
f"Unexpected change to reply queue for channel ({channel_id} in {self!r})"
# this Use pop to safely remove the channel, even if already gone
assigned_queue = open_channels.pop(channel_id, None)

# Make cleanup more forgiving log warnings instead of raising
if assigned_queue is None:
self._logger.warning(
f"Channel {channel_id} already released or never acquired",
channel_id=channel_id,
)
elif rx_queue is not assigned_queue:
# Queue mismatch is suspicious but shouldn't prevent cleanup
self._logger.warning(
f"Channel {channel_id} queue mismatch during release "
f"(expected {rx_queue!r}, found {assigned_queue!r})",
channel_id=channel_id,
)
del open_channels[channel_id]

@contextmanager
def assign_channel_id(self, rx_queue: RxQueue) -> Generator[int, None, None]:
Expand All @@ -636,12 +646,22 @@ def acquire_call_id(self, rx_queue: RxQueue) -> int:
def release_call_id(self, call_id: int, rx_queue: RxQueue) -> None:
"""Release a previously acquired remote call ID."""
pending_calls = self._pending_calls
assigned_queue = pending_calls.get(call_id)
if rx_queue is not assigned_queue:
raise LMStudioRuntimeError(
f"Unexpected change to reply queue for remote call ({call_id} in {self!r})"
# Use pop to safely remove the call, even if already gone
assigned_queue = pending_calls.pop(call_id, None)

# Make cleanup more forgiving log warnings instead of raising
if assigned_queue is None:
self._logger.warning(
f"Remote call {call_id} already released or never acquired",
call_id=call_id,
)
elif rx_queue is not assigned_queue:
# Queue mismatch is suspicious but shouldn't prevent cleanup
self._logger.warning(
f"Remote call {call_id} queue mismatch during release "
f"(expected {rx_queue!r}, found {assigned_queue!r})",
call_id=call_id,
)
del pending_calls[call_id]

@contextmanager
def assign_call_id(self, rx_queue: RxQueue) -> Generator[int, None, None]:
Expand Down
Loading