From 31219a16422487107c02a2e1f6b572f5bf4ade7e Mon Sep 17 00:00:00 2001 From: Falko Schindler Date: Fri, 31 Jan 2025 19:24:27 +0100 Subject: [PATCH] use socket ID to identify connections on auto-index client --- nicegui/air.py | 4 ++-- nicegui/client.py | 14 ++++++++------ nicegui/nicegui.py | 4 ++-- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/nicegui/air.py b/nicegui/air.py index ec1e26288..38a8e6a4f 100644 --- a/nicegui/air.py +++ b/nicegui/air.py @@ -135,7 +135,7 @@ def _handle_handshake(data: Dict[str, Any]) -> bool: core.app.storage.copy_tab(data['old_tab_id'], data['tab_id']) client.tab_id = data['tab_id'] client.on_air = True - client.handle_handshake(data.get('next_message_id')) + client.handle_handshake(data['sid'], data.get('next_message_id')) return True @self.relay.on('client_disconnect') @@ -144,7 +144,7 @@ def _handle_client_disconnect(data: Dict[str, Any]) -> None: client_id = data['client_id'] if client_id not in Client.instances: return - Client.instances[client_id].handle_disconnect() + Client.instances[client_id].handle_disconnect(data['sid']) @self.relay.on('connect') async def _handle_connect() -> None: diff --git a/nicegui/client.py b/nicegui/client.py index bbf11ddea..b53a3bfe1 100644 --- a/nicegui/client.py +++ b/nicegui/client.py @@ -4,6 +4,7 @@ import inspect import time import uuid +from collections import defaultdict from contextlib import contextmanager from pathlib import Path from typing import TYPE_CHECKING, Any, Awaitable, Callable, ClassVar, Dict, Iterable, Iterator, List, Optional, Union @@ -59,7 +60,7 @@ def __init__(self, page: page, *, request: Optional[Request]) -> None: self.environ: Optional[Dict[str, Any]] = None self.shared = request is None self.on_air = False - self._num_connections = 0 + self._num_connections = defaultdict(int) self._delete_task: Optional[asyncio.Task] = None self._deleted = False self.tab_id: Optional[str] = None @@ -235,10 +236,10 @@ def on_disconnect(self, handler: Union[Callable[..., Any], Awaitable]) -> None: """Add a callback to be invoked when the client disconnects.""" self.disconnect_handlers.append(handler) - def handle_handshake(self, next_message_id: Optional[int]) -> None: + def handle_handshake(self, socket_id: str, next_message_id: Optional[int]) -> None: """Cancel pending disconnect task and invoke connect handlers.""" self._cancel_delete_task() - self._num_connections += 1 + self._num_connections[socket_id] += 1 if next_message_id is not None: self.outbox.try_rewind(next_message_id) storage.request_contextvar.set(self.request) @@ -247,10 +248,10 @@ def handle_handshake(self, next_message_id: Optional[int]) -> None: for t in core.app._connect_handlers: # pylint: disable=protected-access self.safe_invoke(t) - def handle_disconnect(self) -> None: + def handle_disconnect(self, socket_id: str) -> None: """Wait for the browser to reconnect; invoke disconnect handlers if it doesn't.""" self._cancel_delete_task() - self._num_connections -= 1 + self._num_connections[socket_id] -= 1 for t in self.disconnect_handlers: self.safe_invoke(t) for t in core.app._disconnect_handlers: # pylint: disable=protected-access @@ -258,7 +259,8 @@ def handle_disconnect(self) -> None: if not self.shared: async def delete_content() -> None: await asyncio.sleep(self.page.resolve_reconnect_timeout()) - if self._num_connections == 0: + if self._num_connections[socket_id] == 0: + self._num_connections.pop(socket_id) self.delete() self._delete_task = background_tasks.create(delete_content()) diff --git a/nicegui/nicegui.py b/nicegui/nicegui.py index fa896b1bd..77efafd02 100644 --- a/nicegui/nicegui.py +++ b/nicegui/nicegui.py @@ -171,7 +171,7 @@ async def _on_handshake(sid: str, data: Dict[str, Any]) -> bool: else: client.environ = sio.get_environ(sid) await sio.enter_room(sid, client.id) - client.handle_handshake(data.get('next_message_id')) + client.handle_handshake(sid, data.get('next_message_id')) assert client.tab_id is not None await core.app.storage._create_tab_storage(client.tab_id) # pylint: disable=protected-access return True @@ -184,7 +184,7 @@ def _on_disconnect(sid: str) -> None: client_id = query['client_id'][0] client = Client.instances.get(client_id) if client: - client.handle_disconnect() + client.handle_disconnect(sid) @sio.on('event')