Skip to content

Commit 0828c0b

Browse files
Merge pull request #8 from Darshan808/compress-yupdates
Add callback for compressing updates in sqlite store
2 parents 2d75d02 + f6bf966 commit 0828c0b

File tree

2 files changed

+56
-2
lines changed

2 files changed

+56
-2
lines changed

src/pycrdt/store/sqlite.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ class MySQLiteYStore(SQLiteYStore):
3737
lock: Lock
3838
db_initialized: Event | None
3939
_db: Connection
40+
# Optional callbacks for compressing and decompressing data, default: no compression
41+
_compress: Callable[[bytes], bytes] | None = None
42+
_decompress: Callable[[bytes], bytes] | None = None
4043

4144
def __init__(
4245
self,
@@ -149,6 +152,14 @@ async def _init_db(self):
149152
assert self.db_initialized is not None
150153
self.db_initialized.set()
151154

155+
def register_compression_callbacks(
156+
self, compress: Callable[[bytes], bytes], decompress: Callable[[bytes], bytes]
157+
) -> None:
158+
if not callable(compress) or not callable(decompress):
159+
raise TypeError("Both compress and decompress must be callable.")
160+
self._compress = compress
161+
self._decompress = decompress
162+
152163
async def read(self) -> AsyncIterator[tuple[bytes, bytes, float]]:
153164
"""Async iterator for reading the store content.
154165
@@ -168,6 +179,11 @@ async def read(self) -> AsyncIterator[tuple[bytes, bytes, float]]:
168179
(self.path,),
169180
)
170181
for update, metadata, timestamp in await cursor.fetchall():
182+
if self._decompress:
183+
try:
184+
update = self._decompress(update)
185+
except Exception:
186+
pass
171187
found = True
172188
yield update, metadata, timestamp
173189
if not found:
@@ -209,15 +225,19 @@ async def write(self, data: bytes) -> None:
209225
await cursor.execute("DELETE FROM yupdates WHERE path = ?", (self.path,))
210226
# insert squashed updates
211227
squashed_update = ydoc.get_update()
228+
compressed_update = (
229+
self._compress(squashed_update) if self._compress else squashed_update
230+
)
212231
metadata = await self.get_metadata()
213232
await cursor.execute(
214233
"INSERT INTO yupdates VALUES (?, ?, ?, ?)",
215-
(self.path, squashed_update, metadata, time.time()),
234+
(self.path, compressed_update, metadata, time.time()),
216235
)
217236

218237
# finally, write this update to the DB
219238
metadata = await self.get_metadata()
239+
compressed_data = self._compress(data) if self._compress else data
220240
await cursor.execute(
221241
"INSERT INTO yupdates VALUES (?, ?, ?, ?)",
222-
(self.path, data, metadata, time.time()),
242+
(self.path, compressed_data, metadata, time.time()),
223243
)

tests/test_store.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import tempfile
22
import time
3+
import zlib
34
from pathlib import Path
45
from unittest.mock import patch
56

@@ -124,3 +125,36 @@ async def test_version(YStore, ystore_api, caplog):
124125
YStore.version = prev_version
125126
async with ystore as ystore:
126127
await ystore.write(b"bar")
128+
129+
130+
@pytest.mark.parametrize("ystore_api", ("ystore_context_manager", "ystore_start_stop"))
131+
async def test_compression_callbacks_zlib(ystore_api):
132+
"""
133+
Verify that registering zlib.compress as a compression callback
134+
correctly round-trips data through the SQLiteYStore.
135+
"""
136+
async with create_task_group() as tg:
137+
store_name = f"compress_test_with_api_{ystore_api}"
138+
ystore = MySQLiteYStore(store_name, metadata_callback=MetadataCallback(), delete=True)
139+
if ystore_api == "ystore_start_stop":
140+
ystore = StartStopContextManager(ystore, tg)
141+
142+
async with ystore as ystore:
143+
# register zlib compression and no-op decompression
144+
ystore.register_compression_callbacks(zlib.compress, lambda x: x)
145+
146+
data = [b"alpha", b"beta", b"gamma"]
147+
# write compressed
148+
for d in data:
149+
await ystore.write(d)
150+
151+
assert Path(MySQLiteYStore.db_path).exists()
152+
153+
# read back and ensure correct decompression
154+
i = 0
155+
async for d_read, m, t in ystore.read():
156+
assert zlib.decompress(d_read) == data[i]
157+
assert m == str(i).encode()
158+
i += 1
159+
160+
assert i == len(data)

0 commit comments

Comments
 (0)