Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions src/docket/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,17 +406,14 @@ async def _scheduler_loop(
task[task_data[j]] = task_data[j+1]
end

local message_id = redis.call('XADD', KEYS[2], '*',
redis.call('XADD', KEYS[2], '*',
'key', task['key'],
'when', task['when'],
'function', task['function'],
'args', task['args'],
'kwargs', task['kwargs'],
'attempt', task['attempt']
)
-- Store the message ID in the known task key
local known_key = ARGV[2] .. ":known:" .. key
redis.call('HSET', known_key, 'stream_message_id', message_id)
redis.call('DEL', hash_key)
due_work = due_work + 1
end
Expand Down
253 changes: 240 additions & 13 deletions tests/test_worker.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import asyncio
import logging
import time
from contextlib import asynccontextmanager
from contextvars import ContextVar
from datetime import datetime, timedelta, timezone
from typing import AsyncGenerator, Callable
from typing import AsyncGenerator, Callable, Iterable
from unittest.mock import AsyncMock, patch
from uuid import uuid4

import cloudpickle # type: ignore[import]
import pytest
from redis.asyncio import Redis
from redis.exceptions import ConnectionError
Expand All @@ -18,6 +21,8 @@
Perpetual,
Worker,
)
from docket.dependencies import Timeout
from docket.execution import Execution
from docket.tasks import standard_tasks
from docket.worker import ms

Expand Down Expand Up @@ -175,7 +180,6 @@ async def task_that_sometimes_fails(
nonlocal failure_count

# Record when this task runs
import time

task_executions.append((customer_id, time.time()))

Expand Down Expand Up @@ -556,7 +560,6 @@ async def perpetual_task(

async def test_worker_concurrency_limits_task_queuing_behavior(docket: Docket):
"""Test that concurrency limits control task execution properly"""
from contextvars import ContextVar

# Use contextvar for reliable tracking across async execution
execution_log: ContextVar[list[tuple[str, int]]] = ContextVar("execution_log")
Expand Down Expand Up @@ -1172,7 +1175,6 @@ async def edge_case_task(

async def test_worker_timeout_exceeds_redelivery_timeout(docket: Docket):
"""Test worker handles user timeout longer than redelivery timeout."""
from docket.dependencies import Timeout

task_executed = False

Expand Down Expand Up @@ -1251,8 +1253,6 @@ async def task_missing_concurrency_arg(

async def test_worker_no_concurrency_dependency_in_function(docket: Docket):
"""Test _can_start_task with function that has no concurrency dependency."""
from docket.execution import Execution
from datetime import datetime, timezone

async def task_without_concurrency_dependency():
await asyncio.sleep(0.001)
Expand All @@ -1278,8 +1278,6 @@ async def task_without_concurrency_dependency():

async def test_worker_no_concurrency_dependency_in_release(docket: Docket):
"""Test _release_concurrency_slot with function that has no concurrency dependency."""
from docket.execution import Execution
from datetime import datetime, timezone

async def task_without_concurrency_dependency():
await asyncio.sleep(0.001)
Expand All @@ -1304,8 +1302,6 @@ async def task_without_concurrency_dependency():

async def test_worker_missing_concurrency_argument_in_release(docket: Docket):
"""Test _release_concurrency_slot when concurrency argument is missing."""
from docket.execution import Execution
from datetime import datetime, timezone

async def task_with_missing_arg(
concurrency: ConcurrencyLimit = ConcurrencyLimit(
Expand Down Expand Up @@ -1334,8 +1330,6 @@ async def task_with_missing_arg(

async def test_worker_concurrency_missing_argument_in_can_start(docket: Docket):
"""Test _can_start_task with missing concurrency argument during execution."""
from docket.execution import Execution
from datetime import datetime, timezone

async def task_with_missing_concurrency_arg(
concurrency: ConcurrencyLimit = ConcurrencyLimit(
Expand Down Expand Up @@ -1384,7 +1378,6 @@ async def task_that_will_fail():
task_failed = False

# Mock resolved_dependencies to fail before setting dependencies
from unittest.mock import patch, AsyncMock

await docket.add(task_that_will_fail)()

Expand Down Expand Up @@ -1504,3 +1497,237 @@ async def test_rapid_replace_operations(
# Should only execute the last replacement
the_task.assert_awaited_once_with("arg4", b="b4")
assert the_task.await_count == 1


async def test_wrongtype_error_with_legacy_known_task_key(
docket: Docket,
worker: Worker,
the_task: AsyncMock,
now: Callable[[], datetime],
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test graceful handling when known task keys exist as strings from legacy implementations.

Regression test for issue where worker scheduler would get WRONGTYPE errors when trying to
HSET on known task keys that existed as string values from older docket versions.

The original error occurred when:
1. A legacy docket created known task keys as simple string values (timestamps)
2. The new scheduler tried to HSET stream_message_id on these keys
3. Redis threw WRONGTYPE error because you can't HSET on a string key
4. This caused scheduler loop failures in production

This test reproduces that scenario by manually setting up the legacy state,
then verifies the new code handles it gracefully without errors.
"""
key = f"legacy-task:{uuid4()}"

# Simulate legacy behavior: create the known task key as a string
# This is what older versions of docket would have done
async with docket.redis() as redis:
known_task_key = docket.known_task_key(key)
when = now() + timedelta(seconds=1)

# Set up legacy state: known key as string, task in queue with parked data
await redis.set(known_task_key, str(when.timestamp()))
await redis.zadd(docket.queue_key, {key: when.timestamp()})

await redis.hset( # type: ignore
docket.parked_task_key(key),
mapping={
"key": key,
"when": when.isoformat(),
"function": "trace",
"args": cloudpickle.dumps(["legacy task test"]), # type: ignore[arg-type]
"kwargs": cloudpickle.dumps({}), # type: ignore[arg-type]
"attempt": "1",
},
)

# Capture logs to ensure no errors occur and see task execution
with caplog.at_level(logging.INFO):
await worker.run_until_finished()

# Should not have any ERROR logs now that the issue is fixed
error_logs = [record for record in caplog.records if record.levelname == "ERROR"]
assert len(error_logs) == 0, (
f"Expected no error logs, but got: {[r.message for r in error_logs]}"
)

# The task should execute successfully
# Since we used trace, we should see an INFO log with the message
info_logs = [record for record in caplog.records if record.levelname == "INFO"]
trace_logs = [
record for record in info_logs if "legacy task test" in record.message
]
assert len(trace_logs) > 0, (
f"Expected to see trace log with 'legacy task test', got: {[r.message for r in info_logs]}"
)


async def count_redis_keys_by_type(redis: Redis, prefix: str) -> dict[str, int]:
"""Count Redis keys by type for a given prefix."""
pattern = f"{prefix}*"
keys: Iterable[str] = await redis.keys(pattern) # type: ignore
counts: dict[str, int] = {}

for key in keys:
key_type = await redis.type(key)
key_type_str = (
key_type.decode() if isinstance(key_type, bytes) else str(key_type)
)
counts[key_type_str] = counts.get(key_type_str, 0) + 1

return counts


class KeyCountChecker:
"""Helper to verify Redis key counts remain consistent across operations."""

def __init__(self, docket: Docket, redis: Redis) -> None:
self.docket = docket
self.redis = redis
self.baseline_counts: dict[str, int] = {}

async def capture_baseline(self) -> None:
"""Capture baseline key counts after worker priming."""
self.baseline_counts = await count_redis_keys_by_type(
self.redis, self.docket.name
)
print(f"Baseline key counts: {self.baseline_counts}")

async def verify_keys_increased(self, operation: str) -> None:
"""Verify that key counts increased after scheduling operation."""
current_counts = await count_redis_keys_by_type(self.redis, self.docket.name)
print(f"After {operation} key counts: {current_counts}")

total_current = sum(current_counts.values())
total_baseline = sum(self.baseline_counts.values())
assert total_current > total_baseline, (
f"Expected more keys after {operation}, but got {total_current} vs {total_baseline}"
)

async def verify_keys_returned_to_baseline(self, operation: str) -> None:
"""Verify that key counts returned to baseline after operation completion."""
final_counts = await count_redis_keys_by_type(self.redis, self.docket.name)
print(f"Final key counts: {final_counts}")

# Check each key type matches baseline
all_key_types = set(self.baseline_counts.keys()) | set(final_counts.keys())
for key_type in all_key_types:
baseline_count = self.baseline_counts.get(key_type, 0)
final_count = final_counts.get(key_type, 0)
assert final_count == baseline_count, (
f"Memory leak detected after {operation}: {key_type} keys not cleaned up properly. "
f"Baseline: {baseline_count}, Final: {final_count}"
)


async def test_redis_key_cleanup_successful_task(
docket: Docket, worker: Worker
) -> None:
"""Test that Redis keys are properly cleaned up after successful task execution.

This test systematically counts Redis keys before and after task operations to detect
memory leaks where keys are not properly cleaned up.
"""
# Prime the worker (run once with no tasks to establish baseline)
await worker.run_until_finished()

# Create and register a simple task
task_executed = False

async def successful_task():
nonlocal task_executed
task_executed = True
await asyncio.sleep(0.01) # Small delay to ensure proper execution flow

docket.register(successful_task)

async with docket.redis() as redis:
checker = KeyCountChecker(docket, redis)
await checker.capture_baseline()

# Schedule the task
await docket.add(successful_task)()
await checker.verify_keys_increased("scheduling")

# Execute the task
await worker.run_until_finished()

# Verify task executed successfully
assert task_executed, "Task should have executed successfully"

# Verify cleanup
await checker.verify_keys_returned_to_baseline("successful task execution")


async def test_redis_key_cleanup_failed_task(docket: Docket, worker: Worker) -> None:
"""Test that Redis keys are properly cleaned up after failed task execution."""
# Prime the worker
await worker.run_until_finished()

# Create a task that will fail
task_attempted = False

async def failing_task():
nonlocal task_attempted
task_attempted = True
raise ValueError("Intentional test failure")

docket.register(failing_task)

async with docket.redis() as redis:
checker = KeyCountChecker(docket, redis)
await checker.capture_baseline()

# Schedule the task
await docket.add(failing_task)()
await checker.verify_keys_increased("scheduling")

# Execute the task (should fail)
await worker.run_until_finished()

# Verify task was attempted
assert task_attempted, "Task should have been attempted"

# Verify cleanup despite failure
await checker.verify_keys_returned_to_baseline("failed task execution")


async def test_redis_key_cleanup_cancelled_task(docket: Docket, worker: Worker) -> None:
"""Test that Redis keys are properly cleaned up after task cancellation."""
# Prime the worker
await worker.run_until_finished()

# Create a task that won't be executed
task_executed = False

async def task_to_cancel():
nonlocal task_executed
task_executed = True # pragma: no cover

docket.register(task_to_cancel)

async with docket.redis() as redis:
checker = KeyCountChecker(docket, redis)
await checker.capture_baseline()

# Schedule the task for future execution
future_time = datetime.now(timezone.utc) + timedelta(seconds=10)
execution = await docket.add(task_to_cancel, future_time)()
await checker.verify_keys_increased("scheduling")

# Cancel the task
await docket.cancel(execution.key)

# Run worker to process any cleanup
await worker.run_until_finished()

# Verify task was not executed
assert not task_executed, (
"Task should not have been executed after cancellation"
)

# Verify cleanup after cancellation
await checker.verify_keys_returned_to_baseline("task cancellation")