diff --git a/docs/reference/Django_Channels.md b/docs/reference/Django_Channels.md new file mode 100644 index 0000000..5aee0ca --- /dev/null +++ b/docs/reference/Django_Channels.md @@ -0,0 +1,11 @@ +## Consumer + +::: pycrdt_websocket.django_channels.yjs_consumer.YjsConsumer + +## Storage + +### BaseYRoomStorage +::: pycrdt_websocket.django_channels.yroom_storage.BaseYRoomStorage + +### RedisYRoomStorage +::: pycrdt_websocket.django_channels.yroom_storage.RedisYRoomStorage diff --git a/docs/reference/Django_Channels_consumer.md b/docs/reference/Django_Channels_consumer.md deleted file mode 100644 index 8548b4e..0000000 --- a/docs/reference/Django_Channels_consumer.md +++ /dev/null @@ -1 +0,0 @@ -::: pycrdt_websocket.django_channels_consumer.YjsConsumer diff --git a/mkdocs.yml b/mkdocs.yml index efa55a9..872d860 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -39,7 +39,7 @@ nav: - reference/WebSocket_provider.md - reference/WebSocket_server.md - reference/ASGI_server.md - - reference/Django_Channels_consumer.md + - reference/Django_Channels.md - reference/WebSocket.md - reference/Room.md - reference/Store.md diff --git a/pycrdt_websocket/django_channels/__init__.py b/pycrdt_websocket/django_channels/__init__.py new file mode 100644 index 0000000..7be8da3 --- /dev/null +++ b/pycrdt_websocket/django_channels/__init__.py @@ -0,0 +1,2 @@ +from .storage.base_yroom_storage import BaseYRoomStorage as BaseYRoomStorage +from .yjs_consumer import YjsConsumer as YjsConsumer diff --git a/pycrdt_websocket/django_channels/storage/base_yroom_storage.py b/pycrdt_websocket/django_channels/storage/base_yroom_storage.py new file mode 100644 index 0000000..7a4c272 --- /dev/null +++ b/pycrdt_websocket/django_channels/storage/base_yroom_storage.py @@ -0,0 +1,101 @@ +from abc import ABC, abstractmethod +from typing import Optional + +from pycrdt import Doc + + +class BaseYRoomStorage(ABC): + """Base class for YRoom storage. + This class is responsible for storing, retrieving, updating and persisting the Ypy document. + Each Django Channels Consumer should have its own YRoomStorage instance, although all consumers + and rooms with the same room name will be connected to the same document in the end. + Updates to the document should be sent to the shared storage, instead of each + consumer having its own version of the YDoc. + + A full example of a Redis as temporary storage and Postgres as persistent storage is: + ```py + from typing import Optional + from django.db import models + from ypy_websocket.django_channels.yroom_storage import RedisYRoomStorage + + class YDocSnapshotManager(models.Manager): + async def aget_snapshot(self, name) -> Optional[bytes]: + try: + instance: YDocSnapshot = await self.aget(name=name) + result = instance.data + if not isinstance(result, bytes): + # Postgres on psycopg2 returns memoryview + return bytes(result) + except YDocSnapshot.DoesNotExist: + return None + else: + return result + + async def asave_snapshot(self, name, data): + return await self.aupdate_or_create(name=name, defaults={"data": data}) + + class YDocSnapshot(models.Model): + name = models.CharField(max_length=255, primary_key=True) + data = models.BinaryField() + objects = YDocSnapshotManager() + + class CustomRoomStorage(RedisYRoomStorage): + async def load_snapshot(self) -> Optional[bytes]: + return await YDocSnapshot.objects.aget_snapshot(self.room_name) + + async def save_snapshot(self): + current_snapshot = await self.redis.get(self.redis_key) + if not current_snapshot: + return + await YDocSnapshot.objects.asave_snapshot( + self.room_name, + current_snapshot, + ) + ``` + """ + + def __init__(self, room_name: str) -> None: + self.room_name = room_name + + @abstractmethod + async def get_document(self) -> Doc: + """Gets the document from the storage. + Ideally it should be retrieved first from temporary storage (e.g. Redis) and then from + persistent storage (e.g. a database). + Returns: + The document with the latest changes. + """ + ... + + @abstractmethod + async def update_document(self, update: bytes) -> None: + """Updates the document in the storage. + Updates could be received by Yjs client (e.g. from a WebSocket) or from the server + (e.g. from a Django Celery job). + Args: + update: The update to apply to the document. + """ + ... + + @abstractmethod + async def load_snapshot(self) -> Optional[bytes]: + """Gets the document encoded as update from the database. Override this method to + implement a persistent storage. + Defaults to None. + Returns: + The latest document snapshot. + """ + ... + + @abstractmethod + async def save_snapshot(self) -> None: + """Saves the document encoded as update to the database.""" + ... + + async def close(self) -> None: + """Closes the storage connection. + + Useful for cleaning up resources like closing a database + connection or saving the document before exiting. + """ + pass diff --git a/pycrdt_websocket/django_channels/storage/redis_yroom_storage.py b/pycrdt_websocket/django_channels/storage/redis_yroom_storage.py new file mode 100644 index 0000000..3b02919 --- /dev/null +++ b/pycrdt_websocket/django_channels/storage/redis_yroom_storage.py @@ -0,0 +1,110 @@ +import time +from typing import Optional + +import redis.asyncio as redis +from pycrdt import Doc + +from .base_yroom_storage import BaseYRoomStorage + + +class RedisYRoomStorage(BaseYRoomStorage): + """A YRoom storage that uses Redis as main storage, without + persistent storage. + Args: + room_name: The name of the room. + """ + + def __init__( + self, + room_name: str, + save_throttle_interval: int | None = None, + redis_expiration_seconds: int | None = 60 * 10, # 10 minutes, + ): + super().__init__(room_name) + + self.save_throttle_interval = save_throttle_interval + self.last_saved_at = time.time() + + self.redis_key = f"document:{self.room_name}" + self.redis = self.make_redis() + self.redis_expiration_seconds = redis_expiration_seconds + + async def get_document(self) -> Doc: + snapshot = await self.redis.get(self.redis_key) + + if not snapshot: + snapshot = await self.load_snapshot() + + document = Doc() + + if snapshot: + document.apply_update(snapshot) + + return document + + async def update_document(self, update: bytes): + await self.redis.watch(self.redis_key) + + try: + current_document = await self.get_document() + updated_snapshot = self._apply_update_to_document(current_document, update) + + async with self.redis.pipeline() as pipe: + while True: + try: + pipe.multi() + pipe.set( + name=self.redis_key, + value=updated_snapshot, + ex=self.redis_expiration_seconds, + ) + + await pipe.execute() + + break + except redis.WatchError: + current_document = await self.get_document() + updated_snapshot = self._apply_update_to_document( + current_document, + update, + ) + + continue + finally: + await self.redis.unwatch() + + await self.throttled_save_snapshot() + + async def load_snapshot(self) -> Optional[bytes]: + return None + + async def save_snapshot(self) -> None: + return None + + async def throttled_save_snapshot(self) -> None: + """Saves the document encoded as update to the database, throttled.""" + + if ( + not self.save_throttle_interval + or time.time() - self.last_saved_at <= self.save_throttle_interval + ): + return + + await self.save_snapshot() + + self.last_saved_at = time.time() + + def make_redis(self): + """Makes a Redis client. + Defaults to a local client""" + + return redis.Redis(host="localhost", port=6379, db=0) + + async def close(self): + await self.save_snapshot() + await self.redis.close() + + def _apply_update_to_document(self, document: Doc, update: bytes) -> bytes: + document.apply_update(update) + + return document.get_update() diff --git a/pycrdt_websocket/django_channels_consumer.py b/pycrdt_websocket/django_channels/yjs_consumer.py similarity index 67% rename from pycrdt_websocket/django_channels_consumer.py rename to pycrdt_websocket/django_channels/yjs_consumer.py index 9f917b4..ff750df 100644 --- a/pycrdt_websocket/django_channels_consumer.py +++ b/pycrdt_websocket/django_channels/yjs_consumer.py @@ -6,8 +6,17 @@ from channels.generic.websocket import AsyncWebsocketConsumer # type: ignore[import-not-found] from pycrdt import Doc -from .websocket import Websocket -from .yutils import YMessageType, process_sync_message, sync +from pycrdt_websocket.django_channels.storage.base_yroom_storage import BaseYRoomStorage + +from ..websocket import Websocket +from ..yutils import ( + EMPTY_UPDATE, + YMessageType, + YSyncMessageType, + process_sync_message, + read_message, + sync, +) logger = getLogger(__name__) @@ -70,63 +79,82 @@ class YjsConsumer(AsyncWebsocketConsumer): In particular, - Override `make_room_name` to customize the room name. - - Override `make_ydoc` to initialize the YDoc. This is useful to initialize it with data - from your database, or to add observers to it). + - Override `make_room_storage` to initialize the room storage. Create your own storage class + by subclassing `BaseYRoomStorage` and implementing the methods. - Override `connect` to do custom validation (like auth) on connect, but be sure to call `await super().connect()` in the end. - Call `group_send_message` to send a message to an entire group/room. - Call `send_message` to send a message to a single client, although this is not recommended. - A full example of a custom consumer showcasing all of these options is: + A full example of a custom consumer showcasing all of these options is below. The example also + includes an example function `propagate_document_update_from_external` that demonstrates how to + send a message to all connected clients from an external source (like a Celery job). + ```py from pycrdt import Doc from asgiref.sync import async_to_sync from channels.layers import get_channel_layer from pycrdt_websocket.django_channels_consumer import YjsConsumer from pycrdt_websocket.yutils import create_update_message + from pycrdt_websocket.django_channels.storage.redis_yroom_storage import RedisYRoomStorage class DocConsumer(YjsConsumer): + def make_room_storage(self) -> BaseYRoomStorage: + # Modify the room storage here + + return RedisYRoomStorage(room_name=self.room_name) + def make_room_name(self) -> str: - # modify the room name here - return self.scope["url_route"]["kwargs"]["room"] + # Modify the room name here - async def make_ydoc(self) -> Doc: - doc = Doc() - # fill doc with data from DB here - doc.observe(self.on_update_event) - return doc + return self.scope["url_route"]["kwargs"]["room"] async def connect(self): user = self.scope["user"] + if user is None or user.is_anonymous: await self.close() return - await super().connect() - def on_update_event(self, event): - # process event here - ... + await super().connect() - async def doc_update(self, update_wrapper): + async def propagate_document_update(self, update_wrapper): update = update_wrapper["update"] - self.ydoc.apply_update(update) - await self.group_send_message(create_update_message(update)) + await self.send(create_update_message(update)) - def send_doc_update(room_name, update): - layer = get_channel_layer() - async_to_sync(layer.group_send)(room_name, {"type": "doc_update", "update": update}) - ``` + async def propagate_document_update_from_external(room_name, update): + channel_layer = get_channel_layer() + + await channel_layer.group_send( + room_name, + {"type": "propagate_document_update", "update": update}, + ) + ``` """ def __init__(self): super().__init__() self.room_name = None self.ydoc = None + self.room_storage = None self._websocket_shim = None + def make_room_storage(self) -> BaseYRoomStorage | None: + """Make the room storage for a new channel to persist the YDoc permanently. + + Defaults to not using any (just broadcast updates between consumers). + + Example: + self.room_storage = YourCustomRedisYRoomStorage( + room_name=self.room_name, + save_throttle_interval=5 + ) + """ + return None + def make_room_name(self) -> str: """Make the room name for a new channel. @@ -137,15 +165,10 @@ def make_room_name(self) -> str: """ return self.scope["url_route"]["kwargs"]["room"] - async def make_ydoc(self) -> Doc: - """Make the YDoc for a new channel. - - Override to customize the YDoc when a channel is created - (useful to initialize it with data from your database, or to add observers to it). + async def _make_ydoc(self) -> Doc: + if self.room_storage: + return await self.room_storage.get_document() - Returns: - The YDoc for a new channel. Defaults to a new empty YDoc. - """ return Doc() def _make_websocket_shim(self, path: str) -> _WebsocketShim: @@ -153,7 +176,9 @@ def _make_websocket_shim(self, path: str) -> _WebsocketShim: async def connect(self) -> None: self.room_name = self.make_room_name() - self.ydoc = await self.make_ydoc() + self.room_storage = self.make_room_storage() + + self.ydoc = await self._make_ydoc() self._websocket_shim = self._make_websocket_shim(self.scope["path"]) await self.channel_layer.group_add(self.room_name, self.channel_name) @@ -162,14 +187,32 @@ async def connect(self) -> None: await sync(self.ydoc, self._websocket_shim, logger) async def disconnect(self, code) -> None: + if self.room_storage: + await self.room_storage.close() + + if not self.room_name: + return + await self.channel_layer.group_discard(self.room_name, self.channel_name) async def receive(self, text_data=None, bytes_data=None): if bytes_data is None: return + await self.group_send_message(bytes_data) + if bytes_data[0] != YMessageType.SYNC: return + + # If it's an update message, apply it to the storage document + if self.room_storage and bytes_data[1] == YSyncMessageType.SYNC_UPDATE: + update = read_message(bytes_data[2:]) + + if update != EMPTY_UPDATE: + await self.room_storage.update_document(update) + + return + await process_sync_message(bytes_data[1:], self.ydoc, self._websocket_shim, logger) class WrappedMessage(TypedDict): diff --git a/pycrdt_websocket/yutils.py b/pycrdt_websocket/yutils.py index 2d363b4..4f609c8 100644 --- a/pycrdt_websocket/yutils.py +++ b/pycrdt_websocket/yutils.py @@ -19,6 +19,10 @@ class YSyncMessageType(IntEnum): SYNC_UPDATE = 2 +# Empty updates (see https://github.com/y-crdt/ypy/issues/98) +EMPTY_UPDATE = b"\x00\x00" + + def write_var_uint(num: int) -> bytes: res = [] while num > 127: @@ -128,7 +132,7 @@ async def process_sync_message(message: bytes, ydoc: Doc, websocket, log) -> Non YSyncMessageType.SYNC_UPDATE, ): update = read_message(msg) - if update != b"\x00\x00": + if update != EMPTY_UPDATE: ydoc.apply_update(update) diff --git a/pyproject.toml b/pyproject.toml index ebc5bad..33a7133 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ test = [ "hypercorn >=0.16.0", "trio >=0.25.0", "sniffio", + "types-redis", ] docs = [ "mkdocs", @@ -51,6 +52,9 @@ docs = [ django = [ "channels", ] +redis = [ + "redis", +] [project.urls] Homepage = "https://github.com/jupyter-server/pycrdt-websocket"