Skip to content

Commit a851875

Browse files
committed
Allow room to sync with provider
1 parent 748bfdf commit a851875

10 files changed

+242
-346
lines changed

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ test = [
4848
"hypercorn >=0.16.0",
4949
"trio >=0.25.0",
5050
"sniffio",
51+
"channels",
5152
]
5253
docs = [
5354
"mkdocs",

src/pycrdt/websocket/__init__.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from .asgi_server import ASGIServer as ASGIServer
2-
from .websocket_provider import WebsocketProvider as WebsocketProvider
32
from .websocket_server import WebsocketServer as WebsocketServer
43
from .websocket_server import exception_logger as exception_logger
54
from .yroom import YRoom as YRoom

src/pycrdt/websocket/django_channels_consumer.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,21 @@
33
from logging import getLogger
44
from typing import TypedDict
55

6-
from channels.generic.websocket import AsyncWebsocketConsumer # type: ignore[import-not-found]
6+
from channels.generic.websocket import AsyncWebsocketConsumer # type: ignore[import-untyped]
77

88
from pycrdt import (
9+
Channel,
910
Doc,
1011
YMessageType,
1112
YSyncMessageType,
1213
create_sync_message,
1314
handle_sync_message,
1415
)
1516

16-
from .websocket import Websocket
17-
1817
logger = getLogger(__name__)
1918

2019

21-
class _WebsocketShim(Websocket):
20+
class _WebsocketShim(Channel):
2221
def __init__(self, path, send_func) -> None:
2322
self._path = path
2423
self._send_func = send_func

src/pycrdt/websocket/websocket.py

+6-60
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,13 @@
1-
from typing import Protocol
2-
31
from anyio import Lock
42

3+
from pycrdt import Channel
54

6-
class Websocket(Protocol):
7-
"""WebSocket.
8-
9-
The Websocket instance can receive messages using an async iterator,
10-
until the connection is closed:
11-
```py
12-
async for message in websocket:
13-
...
14-
```
15-
Or directly by calling `recv()`:
16-
```py
17-
message = await websocket.recv()
18-
```
19-
Sending messages is done with `send()`:
20-
```py
21-
await websocket.send(message)
22-
```
23-
"""
245

25-
@property
26-
def path(self) -> str:
27-
"""The WebSocket path."""
28-
...
29-
30-
def __aiter__(self):
31-
return self
6+
class HttpxWebsocket(Channel):
7+
def __init__(self, websocket, path: str):
8+
self._websocket = websocket
9+
self._path = path
10+
self._send_lock = Lock()
3211

3312
async def __anext__(self) -> bytes:
3413
try:
@@ -38,43 +17,10 @@ async def __anext__(self) -> bytes:
3817

3918
return message
4019

41-
async def send(self, message: bytes) -> None:
42-
"""Send a message.
43-
44-
Arguments:
45-
message: The message to send.
46-
"""
47-
...
48-
49-
async def recv(self) -> bytes:
50-
"""Receive a message.
51-
52-
Returns:
53-
The received message.
54-
"""
55-
...
56-
57-
58-
class HttpxWebsocket(Websocket):
59-
def __init__(self, websocket, path: str):
60-
self._websocket = websocket
61-
self._path = path
62-
self._send_lock = Lock()
63-
6420
@property
6521
def path(self) -> str:
6622
return self._path
6723

68-
def __aiter__(self):
69-
return self
70-
71-
async def __anext__(self) -> bytes:
72-
try:
73-
message = await self.recv()
74-
except Exception:
75-
raise StopAsyncIteration()
76-
return message
77-
7824
async def send(self, message: bytes):
7925
async with self._send_lock:
8026
await self._websocket.send_bytes(message)

src/pycrdt/websocket/websocket_provider.py

-170
This file was deleted.

src/pycrdt/websocket/websocket_server.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88
from anyio import TASK_STATUS_IGNORED, Event, Lock, create_task_group
99
from anyio.abc import TaskGroup, TaskStatus
1010

11-
from .websocket import Websocket
12-
from .yroom import YRoom
11+
from pycrdt import Channel
12+
13+
from .yroom import ProviderFactory, YRoom
1314

1415

1516
class WebsocketServer:
@@ -28,6 +29,7 @@ def __init__(
2829
auto_clean_rooms: bool = True,
2930
exception_handler: Callable[[Exception, Logger], bool] | None = None,
3031
log: Logger | None = None,
32+
provider_factory: ProviderFactory | None = None,
3133
) -> None:
3234
"""Initialize the object.
3335
@@ -50,11 +52,14 @@ def __init__(
5052
exception_handler: An optional callback to call when an exception is raised, that
5153
returns True if the exception was handled.
5254
log: An optional logger.
55+
provider_factory: An optional provider factory used to synchronize the rooms with
56+
external documents.
5357
"""
5458
self.rooms_ready = rooms_ready
5559
self.auto_clean_rooms = auto_clean_rooms
5660
self.exception_handler = exception_handler
5761
self.log = log or getLogger(__name__)
62+
self.provider_factory = provider_factory
5863
self.rooms = {}
5964
self._stopped = Event()
6065

@@ -81,7 +86,14 @@ async def get_room(self, name: str) -> YRoom:
8186
The room with the given name, or a new one if no room with that name was found.
8287
"""
8388
if name not in self.rooms.keys():
84-
self.rooms[name] = YRoom(ready=self.rooms_ready, log=self.log)
89+
provider_factory = (
90+
partial(self.provider_factory, path=name)
91+
if self.provider_factory is not None
92+
else None
93+
)
94+
self.rooms[name] = YRoom(
95+
ready=self.rooms_ready, log=self.log, provider_factory=provider_factory
96+
)
8597
room = self.rooms[name]
8698
await self.start_room(room)
8799
return room
@@ -144,7 +156,7 @@ async def delete_room(self, *, name: str | None = None, room: YRoom | None = Non
144156
room = self.rooms.pop(name)
145157
await room.stop()
146158

147-
async def serve(self, websocket: Websocket) -> None:
159+
async def serve(self, websocket: Channel) -> None:
148160
"""Serve a client through a WebSocket.
149161
150162
Arguments:

0 commit comments

Comments
 (0)