Skip to content

Commit a803af3

Browse files
author
Jialin Zhang
committed
change yroom class attribute to instance attribute and stop ystore in stop method
1 parent 1cd727f commit a803af3

File tree

3 files changed

+57
-11
lines changed

3 files changed

+57
-11
lines changed

pycrdt_websocket/yroom.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,11 @@ class YRoom:
3737
_on_message: Callable[[bytes], Awaitable[bool] | bool] | None
3838
_update_send_stream: MemoryObjectSendStream
3939
_update_receive_stream: MemoryObjectReceiveStream
40-
_task_group: TaskGroup | None = None
41-
_started: Event | None = None
40+
_task_group: TaskGroup | None
41+
_started: Event | None
4242
_stopped: Event
43-
__start_lock: Lock | None = None
44-
_subscription: Subscription | None = None
43+
__start_lock: Lock | None
44+
_subscription: Subscription | None
4545

4646
def __init__(
4747
self,
@@ -82,6 +82,10 @@ def __init__(
8282
self._on_message = None
8383
self.exception_handler = exception_handler
8484
self._stopped = Event()
85+
self._task_group = None
86+
self._started = None
87+
self.__start_lock = None
88+
self._subscription = None
8589

8690
@property
8791
def _start_lock(self) -> Lock:
@@ -230,6 +234,11 @@ async def stop(self) -> None:
230234
self._stopped.set()
231235
self._task_group.cancel_scope.cancel()
232236
self._task_group = None
237+
if self.ystore is not None:
238+
try:
239+
await self.ystore.stop()
240+
except RuntimeError:
241+
pass
233242
if self._subscription is not None:
234243
self.ydoc.unobserve(self._subscription)
235244

pycrdt_websocket/ystore.py

+18-6
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class BaseYStore(ABC):
2828
metadata_callback: Callable[[], Awaitable[bytes] | bytes] | None = None
2929
version = 2
3030
_started: Event | None = None
31+
_stopped: Event | None = None
3132
_task_group: TaskGroup | None = None
3233
__start_lock: Lock | None = None
3334

@@ -50,6 +51,12 @@ def started(self) -> Event:
5051
self._started = Event()
5152
return self._started
5253

54+
@property
55+
def stopped(self) -> Event:
56+
if self._stopped is None:
57+
self._stopped = Event()
58+
return self._stopped
59+
5360
@property
5461
def _start_lock(self) -> Lock:
5562
if self.__start_lock is None:
@@ -96,12 +103,14 @@ async def start(
96103
async with create_task_group() as self._task_group:
97104
task_status.started()
98105
self.started.set()
106+
await self.stopped.wait()
99107

100108
async def stop(self) -> None:
101109
"""Stop the store."""
102110
if self._task_group is None:
103111
raise RuntimeError("YStore not running")
104112

113+
self.stopped.set()
105114
self._task_group.cancel_scope.cancel()
106115
self._task_group = None
107116

@@ -309,7 +318,7 @@ class MySQLiteYStore(SQLiteYStore):
309318
document_ttl: int | None = None
310319
path: str
311320
lock: Lock
312-
db_initialized: Event
321+
db_initialized: Event | None
313322
_db: Connection
314323

315324
def __init__(
@@ -329,6 +338,7 @@ def __init__(
329338
self.metadata_callback = metadata_callback
330339
self.log = log or getLogger(__name__)
331340
self.lock = Lock()
341+
self.db_initialized = None
332342

333343
async def start(
334344
self,
@@ -356,10 +366,11 @@ async def start(
356366
self._task_group.start_soon(self._init_db)
357367
task_status.started()
358368
self.started.set()
369+
await self.stopped.wait()
359370

360371
async def stop(self) -> None:
361372
"""Stop the store."""
362-
if hasattr(self, "db_initialized") and self.db_initialized.is_set():
373+
if self.db_initialized is not None and self.db_initialized.is_set():
363374
await self._db.close()
364375
await super().stop()
365376

@@ -405,6 +416,7 @@ async def _init_db(self):
405416
await db.commit()
406417
await db.close()
407418
self._db = await connect(self.db_path)
419+
assert self.db_initialized is not None
408420
self.db_initialized.set()
409421

410422
async def read(self) -> AsyncIterator[tuple[bytes, bytes, float]]:
@@ -413,8 +425,8 @@ async def read(self) -> AsyncIterator[tuple[bytes, bytes, float]]:
413425
Returns:
414426
A tuple of (update, metadata, timestamp) for each update.
415427
"""
416-
if not hasattr(self, "db_initialized"):
417-
raise RuntimeError("ystore is not started")
428+
if self.db_initialized is None:
429+
raise RuntimeError("YStore not started")
418430
await self.db_initialized.wait()
419431
try:
420432
async with self.lock:
@@ -438,8 +450,8 @@ async def write(self, data: bytes) -> None:
438450
Arguments:
439451
data: The update to store.
440452
"""
441-
if not hasattr(self, "db_initialized"):
442-
raise RuntimeError("ystore is not started")
453+
if self.db_initialized is None:
454+
raise RuntimeError("YStore not started")
443455
await self.db_initialized.wait()
444456
async with self.lock:
445457
# first, determine time elapsed since last update

tests/test_ystore.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@
44
from unittest.mock import patch
55

66
import pytest
7-
from anyio import create_task_group
7+
from anyio import create_task_group, sleep
8+
from pycrdt import Map
89
from sqlite_anyio import connect
910
from utils import StartStopContextManager, YDocTest
1011

12+
from pycrdt_websocket.websocket_server import exception_logger
13+
from pycrdt_websocket.yroom import YRoom
1114
from pycrdt_websocket.ystore import SQLiteYStore, TempFileYStore
1215

1316
pytestmark = pytest.mark.anyio
@@ -124,3 +127,25 @@ async def test_version(YStore, ystore_api, caplog):
124127
YStore.version = prev_version
125128
async with ystore as ystore:
126129
await ystore.write(b"bar")
130+
131+
132+
@pytest.mark.parametrize("websocket_server_api", ["websocket_server_start_stop"], indirect=True)
133+
@pytest.mark.parametrize("yws_server", [{"exception_handler": exception_logger}], indirect=True)
134+
@pytest.mark.parametrize("YStore", (MyTempFileYStore, MySQLiteYStore))
135+
async def test_yroom_stop(yws_server, yws_provider, YStore):
136+
port, server = yws_server
137+
ystore = YStore("ystore", metadata_callback=MetadataCallback())
138+
yroom = YRoom(ystore=ystore, exception_handler=exception_logger)
139+
yroom.ydoc, _ = yws_provider
140+
await server.start_room(yroom)
141+
yroom.ydoc["map"] = ymap1 = Map()
142+
ymap1["key"] = "value"
143+
ymap1["key2"] = "value2"
144+
await sleep(1)
145+
assert yroom._task_group is not None
146+
assert not yroom._task_group.cancel_scope.cancel_called
147+
assert ystore._task_group is not None
148+
assert not ystore._task_group.cancel_scope.cancel_called
149+
await yroom.stop()
150+
assert yroom._task_group is None
151+
assert ystore._task_group is None

0 commit comments

Comments
 (0)