Skip to content

Add callback for compressing updates in sqlite store #8

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
May 19, 2025
24 changes: 22 additions & 2 deletions src/pycrdt/store/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ class MySQLiteYStore(SQLiteYStore):
lock: Lock
db_initialized: Event | None
_db: Connection
# Optional callbacks for compressing and decompressing data, default: no compression
_compress: Callable[[bytes], bytes] | None = None
_decompress: Callable[[bytes], bytes] | None = None

def __init__(
self,
Expand Down Expand Up @@ -149,6 +152,14 @@ async def _init_db(self):
assert self.db_initialized is not None
self.db_initialized.set()

def register_compression_callbacks(
self, compress: Callable[[bytes], bytes], decompress: Callable[[bytes], bytes]
) -> None:
if not callable(compress) or not callable(decompress):
raise TypeError("Both compress and decompress must be callable.")
self._compress = compress
self._decompress = decompress

async def read(self) -> AsyncIterator[tuple[bytes, bytes, float]]:
"""Async iterator for reading the store content.

Expand All @@ -168,6 +179,11 @@ async def read(self) -> AsyncIterator[tuple[bytes, bytes, float]]:
(self.path,),
)
for update, metadata, timestamp in await cursor.fetchall():
if self._decompress:
try:
update = self._decompress(update)
except Exception:
pass
found = True
yield update, metadata, timestamp
if not found:
Expand Down Expand Up @@ -209,15 +225,19 @@ async def write(self, data: bytes) -> None:
await cursor.execute("DELETE FROM yupdates WHERE path = ?", (self.path,))
# insert squashed updates
squashed_update = ydoc.get_update()
compressed_update = (
self._compress(squashed_update) if self._compress else squashed_update
)
metadata = await self.get_metadata()
await cursor.execute(
"INSERT INTO yupdates VALUES (?, ?, ?, ?)",
(self.path, squashed_update, metadata, time.time()),
(self.path, compressed_update, metadata, time.time()),
)

# finally, write this update to the DB
metadata = await self.get_metadata()
compressed_data = self._compress(data) if self._compress else data
await cursor.execute(
"INSERT INTO yupdates VALUES (?, ?, ?, ?)",
(self.path, data, metadata, time.time()),
(self.path, compressed_data, metadata, time.time()),
)
34 changes: 34 additions & 0 deletions tests/test_store.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import tempfile
import time
import zlib
from pathlib import Path
from unittest.mock import patch

Expand Down Expand Up @@ -124,3 +125,36 @@ async def test_version(YStore, ystore_api, caplog):
YStore.version = prev_version
async with ystore as ystore:
await ystore.write(b"bar")


@pytest.mark.parametrize("ystore_api", ("ystore_context_manager", "ystore_start_stop"))
async def test_compression_callbacks_zlib(ystore_api):
"""
Verify that registering zlib.compress as a compression callback
correctly round-trips data through the SQLiteYStore.
"""
async with create_task_group() as tg:
store_name = f"compress_test_with_api_{ystore_api}"
ystore = MySQLiteYStore(store_name, metadata_callback=MetadataCallback(), delete=True)
if ystore_api == "ystore_start_stop":
ystore = StartStopContextManager(ystore, tg)

async with ystore as ystore:
# register zlib compression and no-op decompression
ystore.register_compression_callbacks(zlib.compress, lambda x: x)

data = [b"alpha", b"beta", b"gamma"]
# write compressed
for d in data:
await ystore.write(d)

assert Path(MySQLiteYStore.db_path).exists()

# read back and ensure correct decompression
i = 0
async for d_read, m, t in ystore.read():
assert zlib.decompress(d_read) == data[i]
assert m == str(i).encode()
i += 1

assert i == len(data)
Loading