Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions chaos/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,13 @@ async def spawn_worker() -> Process:
"driver: Redis connection error (%s), retrying in 5s...", e
)
await asyncio.sleep(5)
except redis.exceptions.ResponseError as e:
if "NOGROUP" in str(e):
# Consumer group not created yet, workers haven't started
logger.debug("driver: Consumer group not yet created, waiting...")
await asyncio.sleep(1)
else:
raise

# Now apply some chaos to the system:

Expand Down
62 changes: 42 additions & 20 deletions src/docket/docket.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,19 +234,6 @@ async def __aenter__(self) -> Self:

self._monitor_strikes_task = asyncio.create_task(self._monitor_strikes())

# Ensure that the stream and worker group exist
try:
async with self.redis() as r:
await r.xgroup_create(
groupname=self.worker_group_name,
name=self.stream_key,
id="0-0",
mkstream=True,
)
except redis.exceptions.RedisError as e:
if "BUSYGROUP" not in repr(e):
raise

if isinstance(self.result_storage, BaseContextManagerStore):
await self.result_storage.__aenter__()
else:
Expand Down Expand Up @@ -625,6 +612,25 @@ def parked_task_key(self, key: str) -> str:
def stream_id_key(self, key: str) -> str:
return f"{self.name}:stream-id:{key}"

async def _ensure_stream_and_group(self) -> None:
"""Create stream and consumer group if they don't exist (idempotent).

This is safe to call from multiple workers racing to initialize - the
BUSYGROUP error is silently ignored since it just means another worker
created the group first.
"""
try:
async with self.redis() as r:
await r.xgroup_create(
groupname=self.worker_group_name,
name=self.stream_key,
id="0-0",
mkstream=True,
)
except redis.exceptions.ResponseError as e:
if "BUSYGROUP" not in str(e):
raise # pragma: no cover

async def _cancel(self, redis: Redis, key: str) -> None:
"""Cancel a task atomically.

Expand Down Expand Up @@ -814,6 +820,12 @@ async def snapshot(self) -> DocketSnapshot:
Returns:
A snapshot of the Docket.
"""
# For memory:// URLs (fakeredis), ensure the group exists upfront. This
# avoids a fakeredis bug where xpending_range raises TypeError instead
# of NOGROUP when the consumer group doesn't exist.
if self.url.startswith("memory://"):
await self._ensure_stream_and_group()

running: list[RunningExecution] = []
future: list[Execution] = []

Expand Down Expand Up @@ -842,13 +854,23 @@ async def snapshot(self) -> DocketSnapshot:
scheduled_task_keys: list[bytes]

now = datetime.now(timezone.utc)
(
total_stream_messages,
total_schedule_messages,
pending_messages,
stream_messages,
scheduled_task_keys,
) = await pipeline.execute()
try:
(
total_stream_messages,
total_schedule_messages,
pending_messages,
stream_messages,
scheduled_task_keys,
) = await pipeline.execute()
except redis.exceptions.ResponseError as e:
# Check for NOGROUP error. Also check for XPENDING because
# redis-py 7.0 has a bug where pipeline errors lose the
# original NOGROUP message (shows "{exception.args}" instead).
error_str = str(e)
if "NOGROUP" in error_str or "XPENDING" in error_str:
await self._ensure_stream_and_group()
return await self.snapshot()
raise # pragma: no cover

for task_key in scheduled_task_keys:
pipeline.hgetall(self.parked_task_key(task_key.decode()))
Expand Down
48 changes: 30 additions & 18 deletions src/docket/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from opentelemetry import trace
from opentelemetry.trace import Status, StatusCode, Tracer
from redis.asyncio import Redis
from redis.exceptions import ConnectionError, LockError
from redis.exceptions import ConnectionError, LockError, ResponseError
from typing_extensions import Self

from .dependencies import (
Expand Down Expand Up @@ -326,14 +326,20 @@ async def check_for_work() -> bool:

async def get_redeliveries(redis: Redis) -> RedisReadGroupResponse:
logger.debug("Getting redeliveries", extra=log_context)
_, redeliveries, *_ = await redis.xautoclaim(
name=self.docket.stream_key,
groupname=self.docket.worker_group_name,
consumername=self.name,
min_idle_time=int(self.redelivery_timeout.total_seconds() * 1000),
start_id="0-0",
count=available_slots,
)
try:
_, redeliveries, *_ = await redis.xautoclaim(
name=self.docket.stream_key,
groupname=self.docket.worker_group_name,
consumername=self.name,
min_idle_time=int(self.redelivery_timeout.total_seconds() * 1000),
start_id="0-0",
count=available_slots,
)
except ResponseError as e:
if "NOGROUP" in str(e):
await self.docket._ensure_stream_and_group()
return await get_redeliveries(redis)
raise # pragma: no cover
return [(b"__redelivery__", redeliveries)]

async def get_new_deliveries(redis: Redis) -> RedisReadGroupResponse:
Expand All @@ -342,15 +348,21 @@ async def get_new_deliveries(redis: Redis) -> RedisReadGroupResponse:
# This is necessary because fakeredis's async blocking operations don't
# properly yield control to the asyncio event loop
is_memory = self.docket.url.startswith("memory://")
result = await redis.xreadgroup(
groupname=self.docket.worker_group_name,
consumername=self.name,
streams={self.docket.stream_key: ">"},
block=0
if is_memory
else int(self.minimum_check_interval.total_seconds() * 1000),
count=available_slots,
)
try:
result = await redis.xreadgroup(
groupname=self.docket.worker_group_name,
consumername=self.name,
streams={self.docket.stream_key: ">"},
block=0
if is_memory
else int(self.minimum_check_interval.total_seconds() * 1000),
count=available_slots,
)
except ResponseError as e:
if "NOGROUP" in str(e):
await self.docket._ensure_stream_and_group()
return await get_new_deliveries(redis)
raise # pragma: no cover
if is_memory and not result:
await asyncio.sleep(self.minimum_check_interval.total_seconds())
return result
Expand Down
112 changes: 108 additions & 4 deletions tests/test_docket.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,20 @@
from tests._key_leak_checker import KeyCountChecker


async def test_docket_aenter_propagates_connection_errors():
"""The docket should propagate Redis connection errors"""

async def test_docket_propagates_connection_errors_on_operation():
"""Connection errors should propagate when operations are attempted."""
docket = Docket(name="test-docket", url="redis://nonexistent-host:12345/0")

# __aenter__ succeeds because it doesn't actually connect to Redis
# (connection is lazy - happens when operations are performed)
await docket.__aenter__()

# But actual operations should fail with connection errors
async def some_task(): ...

docket.register(some_task)
with pytest.raises(redis.exceptions.RedisError):
await docket.__aenter__()
await docket.add(some_task)()

await docket.__aexit__(None, None, None)

Expand Down Expand Up @@ -644,3 +652,99 @@ async def my_task() -> None: ...
async with docket:
assert "my_task" in docket.tasks
assert "trace" in docket.tasks


async def test_stream_not_created_on_docket_init(redis_url: str):
"""Stream and consumer group should NOT be created when Docket is initialized.

Issue #206: Lazy stream/consumer group bootstrap.
"""
from uuid import uuid4

docket = Docket(name=f"fresh-docket-{uuid4()}", url=redis_url)
async with docket:
async with docket.redis() as redis:
stream_exists = await redis.exists(docket.stream_key)
assert not stream_exists, "Stream should not exist on Docket init"


async def test_ensure_stream_and_group_is_idempotent(redis_url: str):
"""Calling _ensure_stream_and_group multiple times should not raise errors.

Issue #206: Lazy stream/consumer group bootstrap.
"""
from uuid import uuid4

docket = Docket(name=f"fresh-docket-{uuid4()}", url=redis_url)
async with docket:
await docket._ensure_stream_and_group() # pyright: ignore[reportPrivateUsage]
await docket._ensure_stream_and_group() # pyright: ignore[reportPrivateUsage]
await docket._ensure_stream_and_group() # pyright: ignore[reportPrivateUsage]

async with docket.redis() as redis:
groups = await redis.xinfo_groups(docket.stream_key)
assert len(groups) == 1
assert groups[0]["name"] == docket.worker_group_name.encode()


async def test_docket_without_worker_does_not_create_group(redis_url: str):
"""A Docket used only for adding tasks should not create consumer group.

Issue #206: Lazy stream/consumer group bootstrap.
"""
from uuid import uuid4

docket = Docket(name=f"fresh-docket-{uuid4()}", url=redis_url)

async def dummy_task(): ...

async with docket:
docket.register(dummy_task)

for _ in range(5):
await docket.add(dummy_task)()

async with docket.redis() as redis:
assert await redis.exists(docket.stream_key)
groups = await redis.xinfo_groups(docket.stream_key)
assert len(groups) == 0, "Consumer group should not exist without worker"


@pytest.mark.parametrize("redis_url", ["real"], indirect=True)
async def test_snapshot_handles_nogroup_with_real_redis(redis_url: str):
"""Snapshot should handle NOGROUP error and create group automatically.

Issue #206: Lazy stream/consumer group bootstrap.

This test uses real Redis (not memory://) to verify the NOGROUP error
handling path in snapshot(), since the memory:// backend proactively
creates the group to work around a fakeredis bug.
"""
from uuid import uuid4

docket = Docket(name=f"fresh-docket-{uuid4()}", url=redis_url)

async def dummy_task(): ...

async with docket:
docket.register(dummy_task)

# Add a task to create the stream (but not the consumer group)
await docket.add(dummy_task)()

# Verify stream exists but group doesn't
async with docket.redis() as redis:
assert await redis.exists(docket.stream_key)
groups = await redis.xinfo_groups(docket.stream_key)
assert len(groups) == 0

# Calling snapshot() should trigger NOGROUP and handle it
snapshot = await docket.snapshot()

# Snapshot should succeed after creating the group
assert snapshot.total_tasks == 1

# Group should now exist
async with docket.redis() as redis:
groups = await redis.xinfo_groups(docket.stream_key)
assert len(groups) == 1
Loading