Skip to content

Commit e8a30a4

Browse files
chrisguidryclaude
andauthored
Use separate connection pool for result storage (#276)
The py-key-value library's RedisStore requires `decode_responses=True`, but Docket's internal Redis usage expects bytes for performance. This creates two connection pools - one for each mode. Also fixes a bug where the `result_storage` parameter was accepted but silently ignored, and works around a py-key-value issue where usernames in Redis URLs weren't being passed to the client (github.com/strawgate/py-key-value/issues/254). 🤖 Generated with [Claude Code](https://claude.ai/code) Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
1 parent e7d855b commit e8a30a4

File tree

3 files changed

+89
-16
lines changed

3 files changed

+89
-16
lines changed

src/docket/_redis.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,24 +34,29 @@ def get_memory_server(url: str) -> "FakeServer | None":
3434
return _memory_servers.get(url)
3535

3636

37-
async def connection_pool_from_url(url: str) -> ConnectionPool:
37+
async def connection_pool_from_url(
38+
url: str, decode_responses: bool = False
39+
) -> ConnectionPool:
3840
"""Create a Redis connection pool from a URL.
3941
4042
Handles both real Redis (redis://) and in-memory fakeredis (memory://).
4143
This is the only place in the codebase that imports fakeredis.
4244
4345
Args:
4446
url: Redis URL (redis://...) or memory:// for in-memory backend
47+
decode_responses: If True, decode Redis responses from bytes to strings
4548
4649
Returns:
4750
A ConnectionPool ready for use with Redis clients
4851
"""
4952
if url.startswith("memory://"):
50-
return await _memory_connection_pool(url)
51-
return ConnectionPool.from_url(url)
53+
return await _memory_connection_pool(url, decode_responses)
54+
return ConnectionPool.from_url(url, decode_responses=decode_responses)
5255

5356

54-
async def _memory_connection_pool(url: str) -> ConnectionPool:
57+
async def _memory_connection_pool(
58+
url: str, decode_responses: bool = False
59+
) -> ConnectionPool:
5560
"""Create a connection pool for a memory:// URL using fakeredis."""
5661
global _memory_servers
5762

@@ -63,16 +68,28 @@ async def _memory_connection_pool(url: str) -> ConnectionPool:
6368
# Fast path: server already exists
6469
server = _memory_servers.get(url)
6570
if server is not None:
66-
return ConnectionPool(connection_class=FakeConnection, server=server)
71+
return ConnectionPool(
72+
connection_class=FakeConnection,
73+
server=server,
74+
decode_responses=decode_responses,
75+
)
6776

6877
async with _memory_servers_lock:
6978
server = _memory_servers.get(url)
7079
if server is not None: # pragma: no cover
71-
return ConnectionPool(connection_class=FakeConnection, server=server)
80+
return ConnectionPool(
81+
connection_class=FakeConnection,
82+
server=server,
83+
decode_responses=decode_responses,
84+
)
7285

7386
server = FakeServer()
7487
_memory_servers[url] = server
75-
return ConnectionPool(connection_class=FakeConnection, server=server)
88+
return ConnectionPool(
89+
connection_class=FakeConnection,
90+
server=server,
91+
decode_responses=decode_responses,
92+
)
7693

7794

7895
# ------------------------------------------------------------------------------

src/docket/docket.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
)
4747
from key_value.aio.protocols.key_value import AsyncKeyValue
4848
from key_value.aio.stores.redis import RedisStore
49-
from key_value.aio.stores.memory import MemoryStore
5049

5150
from .instrumentation import (
5251
TASKS_ADDED,
@@ -192,14 +191,7 @@ def __init__(
192191
self.execution_ttl = execution_ttl
193192
self.enable_internal_instrumentation = enable_internal_instrumentation
194193
self._cancel_task_script = None
195-
196-
self.result_storage: AsyncKeyValue
197-
if url.startswith("memory://"):
198-
self.result_storage = MemoryStore()
199-
else:
200-
self.result_storage = RedisStore(
201-
url=url, default_collection=self.results_collection
202-
)
194+
self._user_result_storage = result_storage
203195

204196
from .tasks import standard_tasks
205197

@@ -240,6 +232,23 @@ async def __aenter__(self) -> Self:
240232
# Connect the strike list to Redis and start monitoring
241233
await self.strike_list.connect()
242234

235+
# Initialize result storage
236+
# We use a separate connection pool for result storage because RedisStore
237+
# requires decode_responses=True while Docket's internal Redis usage
238+
# expects bytes (decode_responses=False). We also pass a client to
239+
# RedisStore to work around a py-key-value bug that ignores username
240+
# in URLs (github.com/strawgate/py-key-value/issues/254).
241+
if self._user_result_storage is not None:
242+
self.result_storage: AsyncKeyValue = self._user_result_storage
243+
else:
244+
self._result_storage_pool = await connection_pool_from_url(
245+
self.url, decode_responses=True
246+
)
247+
result_client = Redis(connection_pool=self._result_storage_pool)
248+
self.result_storage = RedisStore(
249+
client=result_client, default_collection=self.results_collection
250+
)
251+
243252
if isinstance(self.result_storage, BaseContextManagerStore):
244253
await self.result_storage.__aenter__()
245254
else:
@@ -255,6 +264,11 @@ async def __aexit__(
255264
if isinstance(self.result_storage, BaseContextManagerStore):
256265
await self.result_storage.__aexit__(exc_type, exc_value, traceback)
257266

267+
# Close the result storage pool if we created it
268+
if hasattr(self, "_result_storage_pool"):
269+
await asyncio.shield(self._result_storage_pool.disconnect())
270+
del self._result_storage_pool
271+
258272
# Close the strike list (stops monitoring and disconnects)
259273
await self.strike_list.close()
260274
del self.strike_list

tests/test_results.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -498,3 +498,45 @@ async def waits_forever() -> int:
498498

499499
event.set()
500500
await worker_task
501+
502+
503+
async def test_result_storage_uses_provided_or_default(redis_url: str):
504+
"""Test that result_storage uses your store if provided, RedisStore if not."""
505+
from unittest.mock import AsyncMock, MagicMock
506+
from urllib.parse import urlparse
507+
508+
from key_value.aio.protocols.key_value import AsyncKeyValue
509+
from key_value.aio.stores.redis import RedisStore
510+
511+
# If you give us one, it's yours
512+
custom_storage = MagicMock(spec=AsyncKeyValue)
513+
custom_storage.setup = AsyncMock()
514+
async with Docket(
515+
name="test-custom-storage",
516+
url=redis_url,
517+
result_storage=custom_storage,
518+
) as docket:
519+
assert docket.result_storage is custom_storage
520+
521+
# If you don't, it's a RedisStore pointing at the same server
522+
async with Docket(name="test-default-storage", url=redis_url) as docket:
523+
assert isinstance(docket.result_storage, RedisStore)
524+
525+
# Verify it's connected to the same Redis
526+
result_client = docket.result_storage._client # type: ignore[attr-defined]
527+
pool_kwargs: dict[str, Any] = result_client.connection_pool.connection_kwargs # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
528+
529+
if redis_url.startswith("memory://"): # pragma: no cover
530+
# For memory://, just verify it has a server (fakeredis)
531+
assert "server" in pool_kwargs
532+
else:
533+
# For real Redis, verify host/port/db match
534+
parsed = urlparse(redis_url)
535+
assert pool_kwargs.get("host") == (parsed.hostname or "localhost")
536+
assert pool_kwargs.get("port") == (parsed.port or 6379)
537+
expected_db = (
538+
int(parsed.path.lstrip("/"))
539+
if parsed.path and parsed.path != "/"
540+
else 0
541+
)
542+
assert pool_kwargs.get("db") == expected_db

0 commit comments

Comments
 (0)