diff --git a/pyproject.toml b/pyproject.toml index eeb56a4..7992416 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ classifiers = [ ] dependencies = [ "anyio >=3.6.2,<5", - "pycrdt >=0.12.13,<0.13.0", + "pycrdt >=0.12.16,<0.13.0", "pycrdt-store >=0.1.0,<0.2.0", ] @@ -48,6 +48,7 @@ test = [ "hypercorn >=0.16.0", "trio >=0.25.0", "sniffio", + "channels", ] docs = [ "mkdocs", diff --git a/src/pycrdt/websocket/__init__.py b/src/pycrdt/websocket/__init__.py index d67afb6..3fd423c 100644 --- a/src/pycrdt/websocket/__init__.py +++ b/src/pycrdt/websocket/__init__.py @@ -1,5 +1,4 @@ from .asgi_server import ASGIServer as ASGIServer -from .websocket_provider import WebsocketProvider as WebsocketProvider from .websocket_server import WebsocketServer as WebsocketServer from .websocket_server import exception_logger as exception_logger from .yroom import YRoom as YRoom diff --git a/src/pycrdt/websocket/django_channels_consumer.py b/src/pycrdt/websocket/django_channels_consumer.py index c553da7..30311a9 100644 --- a/src/pycrdt/websocket/django_channels_consumer.py +++ b/src/pycrdt/websocket/django_channels_consumer.py @@ -3,9 +3,10 @@ from logging import getLogger from typing import TypedDict -from channels.generic.websocket import AsyncWebsocketConsumer # type: ignore[import-not-found] +from channels.generic.websocket import AsyncWebsocketConsumer # type: ignore[import-untyped] from pycrdt import ( + Channel, Doc, YMessageType, YSyncMessageType, @@ -13,12 +14,10 @@ handle_sync_message, ) -from .websocket import Websocket - logger = getLogger(__name__) -class _WebsocketShim(Websocket): +class _WebsocketShim(Channel): def __init__(self, path, send_func) -> None: self._path = path self._send_func = send_func diff --git a/src/pycrdt/websocket/websocket.py b/src/pycrdt/websocket/websocket.py index 7be2521..ff44ffb 100644 --- a/src/pycrdt/websocket/websocket.py +++ b/src/pycrdt/websocket/websocket.py @@ -1,34 +1,13 @@ -from typing import Protocol - from anyio import Lock +from pycrdt import Channel -class Websocket(Protocol): - """WebSocket. - - The Websocket instance can receive messages using an async iterator, - until the connection is closed: - ```py - async for message in websocket: - ... - ``` - Or directly by calling `recv()`: - ```py - message = await websocket.recv() - ``` - Sending messages is done with `send()`: - ```py - await websocket.send(message) - ``` - """ - @property - def path(self) -> str: - """The WebSocket path.""" - ... - - def __aiter__(self): - return self +class HttpxWebsocket(Channel): + def __init__(self, websocket, path: str): + self._websocket = websocket + self._path = path + self._send_lock = Lock() async def __anext__(self) -> bytes: try: @@ -38,43 +17,10 @@ async def __anext__(self) -> bytes: return message - async def send(self, message: bytes) -> None: - """Send a message. - - Arguments: - message: The message to send. - """ - ... - - async def recv(self) -> bytes: - """Receive a message. - - Returns: - The received message. - """ - ... - - -class HttpxWebsocket(Websocket): - def __init__(self, websocket, path: str): - self._websocket = websocket - self._path = path - self._send_lock = Lock() - @property def path(self) -> str: return self._path - def __aiter__(self): - return self - - async def __anext__(self) -> bytes: - try: - message = await self.recv() - except Exception: - raise StopAsyncIteration() - return message - async def send(self, message: bytes): async with self._send_lock: await self._websocket.send_bytes(message) diff --git a/src/pycrdt/websocket/websocket_provider.py b/src/pycrdt/websocket/websocket_provider.py deleted file mode 100644 index 98311b4..0000000 --- a/src/pycrdt/websocket/websocket_provider.py +++ /dev/null @@ -1,170 +0,0 @@ -from __future__ import annotations - -from contextlib import AsyncExitStack -from functools import partial -from logging import Logger, getLogger - -from anyio import ( - TASK_STATUS_IGNORED, - Event, - Lock, - create_memory_object_stream, - create_task_group, -) -from anyio.abc import TaskGroup, TaskStatus -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream - -from pycrdt import ( - Doc, - Subscription, - YMessageType, - YSyncMessageType, - create_sync_message, - create_update_message, - handle_sync_message, -) - -from .websocket import Websocket -from .yutils import put_updates - - -class WebsocketProvider: - """WebSocket provider.""" - - _ydoc: Doc - _update_send_stream: MemoryObjectSendStream - _update_receive_stream: MemoryObjectReceiveStream - _subscription: Subscription - _started: Event | None = None - _task_group: TaskGroup | None = None - __start_lock: Lock | None = None - - def __init__(self, ydoc: Doc, websocket: Websocket, log: Logger | None = None) -> None: - """Initialize the object. - - The WebsocketProvider instance should preferably be used as an async context manager: - ```py - async with websocket_provider: - ... - ``` - However, a lower-level API can also be used: - ```py - task = asyncio.create_task(websocket_provider.start()) - await websocket_provider.started.wait() - ... - await websocket_provider.stop() - ``` - - Arguments: - ydoc: The YDoc to connect through the WebSocket. - websocket: The WebSocket through which to connect the YDoc. - log: An optional logger. - """ - self._ydoc = ydoc - self._websocket = websocket - self.log = log or getLogger(__name__) - self._update_send_stream, self._update_receive_stream = create_memory_object_stream( - max_buffer_size=65536 - ) - - @property - def started(self) -> Event: - """An async event that is set when the WebSocket provider has started.""" - if self._started is None: - self._started = Event() - return self._started - - @property - def _start_lock(self) -> Lock: - if self.__start_lock is None: - self.__start_lock = Lock() - return self.__start_lock - - async def _run(self): - sync_message = create_sync_message(self._ydoc) - self.log.debug( - "Sending %s message to endpoint: %s", - YSyncMessageType.SYNC_STEP1.name, - self._websocket.path, - ) - await self._websocket.send(sync_message) - self._task_group.start_soon(self._send) - async for message in self._websocket: - if message[0] == YMessageType.SYNC: - self.log.debug( - "Received %s message from endpoint: %s", - YSyncMessageType(message[1]).name, - self._websocket.path, - ) - reply = handle_sync_message(message[1:], self._ydoc) - if reply is not None: - self.log.debug( - "Sending %s message to endpoint: %s", - YSyncMessageType.SYNC_STEP2.name, - self._websocket.path, - ) - await self._websocket.send(reply) - - async def _send(self): - async with self._update_receive_stream: - async for update in self._update_receive_stream: - message = create_update_message(update) - try: - await self._websocket.send(message) - except Exception: - pass - - async def __aenter__(self) -> WebsocketProvider: - async with self._start_lock: - if self._task_group is not None: - raise RuntimeError("WebsocketProvider already running") - - async with AsyncExitStack() as exit_stack: - tg = create_task_group() - self._task_group = await exit_stack.enter_async_context(tg) - self._exit_stack = exit_stack.pop_all() - await tg.start(partial(self.start, from_context_manager=True)) - - return self - - async def __aexit__(self, exc_type, exc_value, exc_tb): - await self.stop() - return await self._exit_stack.__aexit__(exc_type, exc_value, exc_tb) - - async def start( - self, - *, - task_status: TaskStatus[None] = TASK_STATUS_IGNORED, - from_context_manager: bool = False, - ): - """Start the WebSocket provider. - - Arguments: - task_status: The status to set when the task has started. - """ - self._subscription = self._ydoc.observe(partial(put_updates, self._update_send_stream)) - - if from_context_manager: - task_status.started() - self.started.set() - assert self._task_group is not None - self._task_group.start_soon(self._run) - return - - async with self._start_lock: - if self._task_group is not None: - raise RuntimeError("WebsocketProvider already running") - - async with create_task_group() as self._task_group: - task_status.started() - self.started.set() - self._task_group.start_soon(self._run) - - async def stop(self): - """Stop the WebSocket provider.""" - if self._task_group is None: - raise RuntimeError("WebsocketProvider not running") - - self._task_group.cancel_scope.cancel() - self._task_group = None - self._ydoc.unobserve(self._subscription) diff --git a/src/pycrdt/websocket/websocket_server.py b/src/pycrdt/websocket/websocket_server.py index 1846346..acbdeb5 100644 --- a/src/pycrdt/websocket/websocket_server.py +++ b/src/pycrdt/websocket/websocket_server.py @@ -8,8 +8,9 @@ from anyio import TASK_STATUS_IGNORED, Event, Lock, create_task_group from anyio.abc import TaskGroup, TaskStatus -from .websocket import Websocket -from .yroom import YRoom +from pycrdt import Channel + +from .yroom import ProviderFactory, YRoom class WebsocketServer: @@ -28,6 +29,7 @@ def __init__( auto_clean_rooms: bool = True, exception_handler: Callable[[Exception, Logger], bool] | None = None, log: Logger | None = None, + provider_factory: ProviderFactory | None = None, ) -> None: """Initialize the object. @@ -50,11 +52,14 @@ def __init__( exception_handler: An optional callback to call when an exception is raised, that returns True if the exception was handled. log: An optional logger. + provider_factory: An optional provider factory used to synchronize the rooms with + external documents. """ self.rooms_ready = rooms_ready self.auto_clean_rooms = auto_clean_rooms self.exception_handler = exception_handler self.log = log or getLogger(__name__) + self.provider_factory = provider_factory self.rooms = {} self._stopped = Event() @@ -81,7 +86,14 @@ async def get_room(self, name: str) -> YRoom: The room with the given name, or a new one if no room with that name was found. """ if name not in self.rooms.keys(): - self.rooms[name] = YRoom(ready=self.rooms_ready, log=self.log) + provider_factory = ( + partial(self.provider_factory, path=name) + if self.provider_factory is not None + else None + ) + self.rooms[name] = YRoom( + ready=self.rooms_ready, log=self.log, provider_factory=provider_factory + ) room = self.rooms[name] await self.start_room(room) return room @@ -144,7 +156,7 @@ async def delete_room(self, *, name: str | None = None, room: YRoom | None = Non room = self.rooms.pop(name) await room.stop() - async def serve(self, websocket: Websocket) -> None: + async def serve(self, websocket: Channel) -> None: """Serve a client through a WebSocket. Arguments: diff --git a/src/pycrdt/websocket/yroom.py b/src/pycrdt/websocket/yroom.py index 24e28e0..37c892b 100644 --- a/src/pycrdt/websocket/yroom.py +++ b/src/pycrdt/websocket/yroom.py @@ -5,7 +5,7 @@ from functools import partial from inspect import isawaitable from logging import Logger, getLogger -from typing import Any, Callable +from typing import Any, Callable, Protocol from anyio import ( TASK_STATUS_IGNORED, @@ -19,7 +19,9 @@ from pycrdt import ( Awareness, + Channel, Doc, + Provider, Subscription, YMessageType, YSyncMessageType, @@ -32,12 +34,20 @@ ) from pycrdt.store import BaseYStore -from .websocket import Websocket from .yutils import put_updates +class ProviderFactory(Protocol): + def __call__( + self, + doc: Doc | None = None, + log: Logger | None = None, + path: str | None = None, + ) -> Provider: ... + + class YRoom: - clients: set[Websocket] + clients: set[Channel] ydoc: Doc ystore: BaseYStore | None ready_event: Event @@ -57,6 +67,7 @@ def __init__( exception_handler: Callable[[Exception, Logger], bool] | None = None, log: Logger | None = None, ydoc: Doc | None = None, + provider_factory: ProviderFactory | None = None, ): """Initialize the object. @@ -76,6 +87,8 @@ def __init__( Arguments: ready: Whether the internal YDoc is ready to be synchronized right away. ystore: An optional store in which to persist document updates. + provider_factory: An optional provider factory used to synchronize the room with + an external document. exception_handler: An optional callback to call when an exception is raised, that returns True if the exception was handled. log: An optional logger. @@ -85,6 +98,7 @@ def __init__( self.ready_event = Event() self.ready = ready self.ystore = ystore + self.provider_factory = provider_factory self.log = log or getLogger(__name__) self.awareness = Awareness(self.ydoc) self.awareness.observe(self.send_server_awareness) @@ -92,6 +106,7 @@ def __init__( self._on_message = None self.exception_handler = exception_handler self._stopped = Event() + self._provider_stop_event = Event() @property def _start_lock(self) -> Lock: @@ -207,16 +222,7 @@ async def start( task_status: The status to set when the task has started. """ if from_context_manager: - task_status.started() - self.started.set() - self._update_send_stream, self._update_receive_stream = create_memory_object_stream( - max_buffer_size=65536 - ) - assert self._task_group is not None - self._task_group.start_soon(self._stopped.wait) - self._task_group.start_soon(self._watch_ready) - self._task_group.start_soon(self._broadcast_updates) - self._task_group.start_soon(self.awareness.start) + await self._start(task_status) return async with self._start_lock: @@ -226,21 +232,40 @@ async def start( while True: try: async with create_task_group() as self._task_group: - if not self.started.is_set(): - task_status.started() - self.started.set() - self._update_send_stream, self._update_receive_stream = ( - create_memory_object_stream(max_buffer_size=65536) - ) - self._task_group.start_soon(self._stopped.wait) - self._task_group.start_soon(self._watch_ready) - self._task_group.start_soon(self._broadcast_updates) - self._task_group.start_soon(self.awareness.start) + await self._start(task_status) return except Exception as exception: await self.awareness.stop() + self._provider_stop_event.set() + self._provider_stop_event = Event() self._handle_exception(exception) + async def _start( + self, + task_status: TaskStatus[None], + ): + if not self.started.is_set(): + task_status.started() + self.started.set() + self._update_send_stream, self._update_receive_stream = create_memory_object_stream( + max_buffer_size=65536 + ) + assert self._task_group is not None + self._task_group.start_soon(self._stopped.wait) + self._task_group.start_soon(self._watch_ready) + self._task_group.start_soon(self._broadcast_updates) + await self._task_group.start(self.awareness.start) + self._task_group.start_soon(self._run_provider) + + async def _run_provider(self): + if self.provider_factory is not None: + provider_factory = self.provider_factory(doc=self.ydoc, log=self.log) + try: + async with provider_factory: + await self._provider_stop_event.wait() + except Exception as exception: + self._handle_exception(exception) + async def stop(self) -> None: """Stop the room.""" if self._task_group is None: @@ -252,23 +277,23 @@ async def stop(self) -> None: if self._subscription is not None: self.ydoc.unobserve(self._subscription) - async def serve(self, websocket: Websocket): + async def serve(self, channel: Channel): """Serve a client. Arguments: - websocket: The WebSocket through which to serve the client. + channel: The WebSocket through which to serve the client. """ try: async with create_task_group() as tg: - self.clients.add(websocket) + self.clients.add(channel) sync_message = create_sync_message(self.ydoc) self.log.debug( "Sending %s message to endpoint: %s", YSyncMessageType.SYNC_STEP1.name, - websocket.path, + channel.path, ) - await websocket.send(sync_message) - async for message in websocket: + await channel.send(sync_message) + async for message in channel: # filter messages (e.g. awareness) skip = False if self.on_message: @@ -284,23 +309,23 @@ async def serve(self, websocket: Websocket): self.log.debug( "Received %s message from endpoint: %s", YSyncMessageType(message[1]).name, - websocket.path, + channel.path, ) reply = handle_sync_message(message[1:], self.ydoc) if reply is not None: self.log.debug( "Sending %s message to endpoint: %s", YSyncMessageType.SYNC_STEP2.name, - websocket.path, + channel.path, ) - tg.start_soon(websocket.send, reply) + tg.start_soon(channel.send, reply) elif message_type == YMessageType.AWARENESS: # forward awareness messages from this client to all clients, # including itself, because it's used to keep the connection alive self.log.debug( "Received %s message from endpoint: %s", YMessageType.AWARENESS.name, - websocket.path, + channel.path, ) # Check if the message is a client awareness disconnect. @@ -310,13 +335,13 @@ async def serve(self, websocket: Websocket): # disconnection from the client. This avoid an error when trying # to send the message to the disconnected client. for client in self.clients: - if disconnection and client == websocket: + if disconnection and client == channel: continue self.log.debug( "Sending Y awareness from client with endpoint " "%s to client with endpoint: %s", - websocket.path, + channel.path, client.path, ) tg.start_soon(client.send, message) @@ -326,7 +351,7 @@ async def serve(self, websocket: Websocket): self._handle_exception(exception) finally: # remove this client - self.clients.remove(websocket) + self.clients.remove(channel) def send_server_awareness(self, type: str, changes: tuple[dict[str, Any], Any]) -> None: """ diff --git a/tests/conftest.py b/tests/conftest.py index 55aefcb..527c3ec 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,17 +1,15 @@ import subprocess -from contextlib import asynccontextmanager from functools import partial -from socket import socket import pytest -from anyio import Event, create_task_group -from httpx_ws import aconnect_ws -from hypercorn import Config -from sniffio import current_async_library -from utils import StartStopContextManager, Websocket, connected_websockets, ensure_server_running +from utils import ( + StartStopContextManager, + create_yws_provider, + create_yws_server, + get_unused_tcp_port, +) -from pycrdt import Doc -from pycrdt.websocket import ASGIServer, WebsocketProvider, WebsocketServer, YRoom +from pycrdt.websocket import YRoom @pytest.fixture(params=("websocket_server_context_manager", "websocket_server_start_stop")) @@ -41,51 +39,23 @@ def ystore_api(request): @pytest.fixture async def yws_server(request, unused_tcp_port, websocket_server_api): - async with create_task_group() as tg: - try: - kwargs = request.param - except AttributeError: - kwargs = {} - websocket_server = WebsocketServer(**kwargs) - app = ASGIServer(websocket_server) - config = Config() - config.bind = [f"localhost:{unused_tcp_port}"] - shutdown_event = Event() - if websocket_server_api == "websocket_server_start_stop": - websocket_server = StartStopContextManager(websocket_server, tg) - if current_async_library() == "trio": - from hypercorn.trio import serve - else: - from hypercorn.asyncio import serve - async with websocket_server as websocket_server: - tg.start_soon( - partial(serve, app, config, shutdown_trigger=shutdown_event.wait, mode="asgi") - ) - await ensure_server_running("localhost", unused_tcp_port) - pytest.port = unused_tcp_port - yield unused_tcp_port, websocket_server - shutdown_event.set() + try: + kwargs = request.param + except AttributeError: + kwargs = {} + async with create_yws_server(unused_tcp_port, websocket_server_api, **kwargs) as server: + yield server @pytest.fixture def yws_provider_factory(room_name, websocket_provider_api, websocket_provider_connect): - @asynccontextmanager - async def factory(): - ydoc = Doc() - if websocket_provider_connect == "real_websocket": - server_websocket = None - connect = aconnect_ws(f"http://localhost:{pytest.port}/{room_name}") - else: - server_websocket, connect = connected_websockets() - async with connect as websocket: - async with create_task_group() as tg: - websocket_provider = WebsocketProvider(ydoc, Websocket(websocket, room_name)) - if websocket_provider_api == "websocket_provider_start_stop": - websocket_provider = StartStopContextManager(websocket_provider, tg) - async with websocket_provider as websocket_provider: - yield ydoc, server_websocket - - return factory + return partial( + create_yws_provider, + pytest.port, + room_name, + websocket_provider_api, + websocket_provider_connect, + ) @pytest.fixture @@ -98,21 +68,20 @@ async def yws_provider(yws_provider_factory): @pytest.fixture async def yws_providers(request, yws_provider_factory): number = request.param - yield [yws_provider_factory() for idx in range(number)] + yield [yws_provider_factory() for _ in range(number)] @pytest.fixture async def yroom(request, yroom_api): - async with create_task_group() as tg: - try: - kwargs = request.param - except AttributeError: - kwargs = {} - room = YRoom(**kwargs) - if yroom_api == "yroom_start_stop": - room = StartStopContextManager(room, tg) - async with room as room: - yield room + try: + kwargs = request.param + except AttributeError: + kwargs = {} + room = YRoom(**kwargs) + if yroom_api == "yroom_start_stop": + room = StartStopContextManager(room) + async with room as room: + yield room @pytest.fixture @@ -130,6 +99,4 @@ def room_name(): @pytest.fixture def unused_tcp_port() -> int: - with socket() as sock: - sock.bind(("localhost", 0)) - return sock.getsockname()[1] + return get_unused_tcp_port() diff --git a/tests/test_server.py b/tests/test_server.py index 48c9872..c6e4e5a 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,6 +1,8 @@ import pytest -from anyio import sleep +from anyio import fail_after, sleep +from utils import create_yws_provider, create_yws_server, get_unused_tcp_port +from pycrdt import Text from pycrdt.websocket import exception_logger pytestmark = pytest.mark.anyio @@ -16,3 +18,43 @@ async def raise_error(): server._task_group.start_soon(raise_error) await sleep(0.1) + + +async def test_server_provider(): + # the server that synchronizes the parallel servers + sync_port = get_unused_tcp_port() + sync_server = create_yws_server(sync_port) + + # the provider factory that synchronizes each parallel server with + # the sync_server + def provider_factory(path, doc, log): + return create_yws_provider(sync_port, path, ydoc=doc, log=log) + + # the parallel servers + port1 = get_unused_tcp_port() + server1 = create_yws_server(port1, provider_factory=provider_factory) + port2 = get_unused_tcp_port() + server2 = create_yws_server(port2, provider_factory=provider_factory) + + # the clients connecting to the parallel servers + client1 = create_yws_provider(port1, "myroom") + client2 = create_yws_provider(port2, "myroom") + + async with ( + sync_server as sync_server, + server1 as server1, + server2 as server2, + client1 as client1, + client2 as client2, + ): + doc1, _ = client1 + doc2, _ = client2 + text1 = doc1.get("text", type=Text) + text2 = doc2.get("text", type=Text) + text1 += "Hello" + with fail_after(1): + async with text2.events() as events: + async for event in events: + break + + assert str(text2) == "Hello" diff --git a/tests/utils.py b/tests/utils.py index 182491a..55fe8d1 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,8 +1,24 @@ from __future__ import annotations -from anyio import Lock, connect_tcp, create_memory_object_stream - -from pycrdt import Array, Doc +from contextlib import AsyncExitStack, asynccontextmanager +from functools import partial +from socket import socket + +import pytest +from anyio import ( + Event, + Lock, + connect_tcp, + create_memory_object_stream, + create_task_group, + get_cancelled_exc_class, +) +from httpx_ws import aconnect_ws +from hypercorn import Config +from sniffio import current_async_library + +from pycrdt import Array, Doc, Provider +from pycrdt.websocket import ASGIServer, WebsocketServer class YDocTest: @@ -21,17 +37,20 @@ def update(self): class StartStopContextManager: - def __init__(self, service, task_group): + def __init__(self, service): self._service = service - self._task_group = task_group async def __aenter__(self): - await self._task_group.start(self._service.start) + async with AsyncExitStack() as exit_stack: + self._task_group = await exit_stack.enter_async_context(create_task_group()) + await self._task_group.start(self._service.start) + self._exit_stack = exit_stack.pop_all() await self._service.started.wait() return self._service async def __aexit__(self, exc_type, exc_value, exc_tb): - await self._service.stop() + self._task_group.start_soon(self._service.stop) + return await self._exit_stack.__aexit__(exc_type, exc_value, exc_tb) class Websocket: @@ -110,3 +129,61 @@ async def ensure_server_running(host: str, port: int) -> None: pass else: break + + +@asynccontextmanager +async def create_yws_provider( + port, + room_name, + websocket_provider_api="websocket_provider_context_manager", + websocket_provider_connect="real_websocket", + ydoc=None, + log=None, +): + ydoc = Doc() if ydoc is None else ydoc + if websocket_provider_connect == "real_websocket": + server_websocket = None + connect = aconnect_ws(f"http://localhost:{port}/{room_name}") + else: + server_websocket, connect = connected_websockets() + try: + async with connect as websocket: + websocket_provider = Provider(ydoc, Websocket(websocket, room_name), log) + if websocket_provider_api == "websocket_provider_start_stop": + websocket_provider = StartStopContextManager(websocket_provider) + async with websocket_provider as websocket_provider: + yield ydoc, server_websocket + except get_cancelled_exc_class(): + pass + + +@asynccontextmanager +async def create_yws_server( + port, websocket_server_api="websocket_server_context_manager", **kwargs +): + async with create_task_group() as tg: + websocket_server = WebsocketServer(**kwargs) + app = ASGIServer(websocket_server) + config = Config() + config.bind = [f"localhost:{port}"] + shutdown_event = Event() + if websocket_server_api == "websocket_server_start_stop": + websocket_server = StartStopContextManager(websocket_server) + if current_async_library() == "trio": + from hypercorn.trio import serve + else: + from hypercorn.asyncio import serve + async with websocket_server as websocket_server: + tg.start_soon( + partial(serve, app, config, shutdown_trigger=shutdown_event.wait, mode="asgi") + ) + await ensure_server_running("localhost", port) + pytest.port = port + yield port, websocket_server + shutdown_event.set() + + +def get_unused_tcp_port(): + with socket() as sock: + sock.bind(("localhost", 0)) + return sock.getsockname()[1]