diff --git a/src/docket/worker.py b/src/docket/worker.py index 76ad9edc..2c739482 100644 --- a/src/docket/worker.py +++ b/src/docket/worker.py @@ -406,7 +406,7 @@ async def _scheduler_loop( task[task_data[j]] = task_data[j+1] end - local message_id = redis.call('XADD', KEYS[2], '*', + redis.call('XADD', KEYS[2], '*', 'key', task['key'], 'when', task['when'], 'function', task['function'], @@ -414,9 +414,6 @@ async def _scheduler_loop( 'kwargs', task['kwargs'], 'attempt', task['attempt'] ) - -- Store the message ID in the known task key - local known_key = ARGV[2] .. ":known:" .. key - redis.call('HSET', known_key, 'stream_message_id', message_id) redis.call('DEL', hash_key) due_work = due_work + 1 end diff --git a/tests/test_worker.py b/tests/test_worker.py index 758f17d4..e56f59e5 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -1,11 +1,14 @@ import asyncio import logging +import time from contextlib import asynccontextmanager +from contextvars import ContextVar from datetime import datetime, timedelta, timezone -from typing import AsyncGenerator, Callable +from typing import AsyncGenerator, Callable, Iterable from unittest.mock import AsyncMock, patch from uuid import uuid4 +import cloudpickle # type: ignore[import] import pytest from redis.asyncio import Redis from redis.exceptions import ConnectionError @@ -18,6 +21,8 @@ Perpetual, Worker, ) +from docket.dependencies import Timeout +from docket.execution import Execution from docket.tasks import standard_tasks from docket.worker import ms @@ -175,7 +180,6 @@ async def task_that_sometimes_fails( nonlocal failure_count # Record when this task runs - import time task_executions.append((customer_id, time.time())) @@ -556,7 +560,6 @@ async def perpetual_task( async def test_worker_concurrency_limits_task_queuing_behavior(docket: Docket): """Test that concurrency limits control task execution properly""" - from contextvars import ContextVar # Use contextvar for reliable tracking across async execution execution_log: ContextVar[list[tuple[str, int]]] = ContextVar("execution_log") @@ -1172,7 +1175,6 @@ async def edge_case_task( async def test_worker_timeout_exceeds_redelivery_timeout(docket: Docket): """Test worker handles user timeout longer than redelivery timeout.""" - from docket.dependencies import Timeout task_executed = False @@ -1251,8 +1253,6 @@ async def task_missing_concurrency_arg( async def test_worker_no_concurrency_dependency_in_function(docket: Docket): """Test _can_start_task with function that has no concurrency dependency.""" - from docket.execution import Execution - from datetime import datetime, timezone async def task_without_concurrency_dependency(): await asyncio.sleep(0.001) @@ -1278,8 +1278,6 @@ async def task_without_concurrency_dependency(): async def test_worker_no_concurrency_dependency_in_release(docket: Docket): """Test _release_concurrency_slot with function that has no concurrency dependency.""" - from docket.execution import Execution - from datetime import datetime, timezone async def task_without_concurrency_dependency(): await asyncio.sleep(0.001) @@ -1304,8 +1302,6 @@ async def task_without_concurrency_dependency(): async def test_worker_missing_concurrency_argument_in_release(docket: Docket): """Test _release_concurrency_slot when concurrency argument is missing.""" - from docket.execution import Execution - from datetime import datetime, timezone async def task_with_missing_arg( concurrency: ConcurrencyLimit = ConcurrencyLimit( @@ -1334,8 +1330,6 @@ async def task_with_missing_arg( async def test_worker_concurrency_missing_argument_in_can_start(docket: Docket): """Test _can_start_task with missing concurrency argument during execution.""" - from docket.execution import Execution - from datetime import datetime, timezone async def task_with_missing_concurrency_arg( concurrency: ConcurrencyLimit = ConcurrencyLimit( @@ -1384,7 +1378,6 @@ async def task_that_will_fail(): task_failed = False # Mock resolved_dependencies to fail before setting dependencies - from unittest.mock import patch, AsyncMock await docket.add(task_that_will_fail)() @@ -1504,3 +1497,237 @@ async def test_rapid_replace_operations( # Should only execute the last replacement the_task.assert_awaited_once_with("arg4", b="b4") assert the_task.await_count == 1 + + +async def test_wrongtype_error_with_legacy_known_task_key( + docket: Docket, + worker: Worker, + the_task: AsyncMock, + now: Callable[[], datetime], + caplog: pytest.LogCaptureFixture, +) -> None: + """Test graceful handling when known task keys exist as strings from legacy implementations. + + Regression test for issue where worker scheduler would get WRONGTYPE errors when trying to + HSET on known task keys that existed as string values from older docket versions. + + The original error occurred when: + 1. A legacy docket created known task keys as simple string values (timestamps) + 2. The new scheduler tried to HSET stream_message_id on these keys + 3. Redis threw WRONGTYPE error because you can't HSET on a string key + 4. This caused scheduler loop failures in production + + This test reproduces that scenario by manually setting up the legacy state, + then verifies the new code handles it gracefully without errors. + """ + key = f"legacy-task:{uuid4()}" + + # Simulate legacy behavior: create the known task key as a string + # This is what older versions of docket would have done + async with docket.redis() as redis: + known_task_key = docket.known_task_key(key) + when = now() + timedelta(seconds=1) + + # Set up legacy state: known key as string, task in queue with parked data + await redis.set(known_task_key, str(when.timestamp())) + await redis.zadd(docket.queue_key, {key: when.timestamp()}) + + await redis.hset( # type: ignore + docket.parked_task_key(key), + mapping={ + "key": key, + "when": when.isoformat(), + "function": "trace", + "args": cloudpickle.dumps(["legacy task test"]), # type: ignore[arg-type] + "kwargs": cloudpickle.dumps({}), # type: ignore[arg-type] + "attempt": "1", + }, + ) + + # Capture logs to ensure no errors occur and see task execution + with caplog.at_level(logging.INFO): + await worker.run_until_finished() + + # Should not have any ERROR logs now that the issue is fixed + error_logs = [record for record in caplog.records if record.levelname == "ERROR"] + assert len(error_logs) == 0, ( + f"Expected no error logs, but got: {[r.message for r in error_logs]}" + ) + + # The task should execute successfully + # Since we used trace, we should see an INFO log with the message + info_logs = [record for record in caplog.records if record.levelname == "INFO"] + trace_logs = [ + record for record in info_logs if "legacy task test" in record.message + ] + assert len(trace_logs) > 0, ( + f"Expected to see trace log with 'legacy task test', got: {[r.message for r in info_logs]}" + ) + + +async def count_redis_keys_by_type(redis: Redis, prefix: str) -> dict[str, int]: + """Count Redis keys by type for a given prefix.""" + pattern = f"{prefix}*" + keys: Iterable[str] = await redis.keys(pattern) # type: ignore + counts: dict[str, int] = {} + + for key in keys: + key_type = await redis.type(key) + key_type_str = ( + key_type.decode() if isinstance(key_type, bytes) else str(key_type) + ) + counts[key_type_str] = counts.get(key_type_str, 0) + 1 + + return counts + + +class KeyCountChecker: + """Helper to verify Redis key counts remain consistent across operations.""" + + def __init__(self, docket: Docket, redis: Redis) -> None: + self.docket = docket + self.redis = redis + self.baseline_counts: dict[str, int] = {} + + async def capture_baseline(self) -> None: + """Capture baseline key counts after worker priming.""" + self.baseline_counts = await count_redis_keys_by_type( + self.redis, self.docket.name + ) + print(f"Baseline key counts: {self.baseline_counts}") + + async def verify_keys_increased(self, operation: str) -> None: + """Verify that key counts increased after scheduling operation.""" + current_counts = await count_redis_keys_by_type(self.redis, self.docket.name) + print(f"After {operation} key counts: {current_counts}") + + total_current = sum(current_counts.values()) + total_baseline = sum(self.baseline_counts.values()) + assert total_current > total_baseline, ( + f"Expected more keys after {operation}, but got {total_current} vs {total_baseline}" + ) + + async def verify_keys_returned_to_baseline(self, operation: str) -> None: + """Verify that key counts returned to baseline after operation completion.""" + final_counts = await count_redis_keys_by_type(self.redis, self.docket.name) + print(f"Final key counts: {final_counts}") + + # Check each key type matches baseline + all_key_types = set(self.baseline_counts.keys()) | set(final_counts.keys()) + for key_type in all_key_types: + baseline_count = self.baseline_counts.get(key_type, 0) + final_count = final_counts.get(key_type, 0) + assert final_count == baseline_count, ( + f"Memory leak detected after {operation}: {key_type} keys not cleaned up properly. " + f"Baseline: {baseline_count}, Final: {final_count}" + ) + + +async def test_redis_key_cleanup_successful_task( + docket: Docket, worker: Worker +) -> None: + """Test that Redis keys are properly cleaned up after successful task execution. + + This test systematically counts Redis keys before and after task operations to detect + memory leaks where keys are not properly cleaned up. + """ + # Prime the worker (run once with no tasks to establish baseline) + await worker.run_until_finished() + + # Create and register a simple task + task_executed = False + + async def successful_task(): + nonlocal task_executed + task_executed = True + await asyncio.sleep(0.01) # Small delay to ensure proper execution flow + + docket.register(successful_task) + + async with docket.redis() as redis: + checker = KeyCountChecker(docket, redis) + await checker.capture_baseline() + + # Schedule the task + await docket.add(successful_task)() + await checker.verify_keys_increased("scheduling") + + # Execute the task + await worker.run_until_finished() + + # Verify task executed successfully + assert task_executed, "Task should have executed successfully" + + # Verify cleanup + await checker.verify_keys_returned_to_baseline("successful task execution") + + +async def test_redis_key_cleanup_failed_task(docket: Docket, worker: Worker) -> None: + """Test that Redis keys are properly cleaned up after failed task execution.""" + # Prime the worker + await worker.run_until_finished() + + # Create a task that will fail + task_attempted = False + + async def failing_task(): + nonlocal task_attempted + task_attempted = True + raise ValueError("Intentional test failure") + + docket.register(failing_task) + + async with docket.redis() as redis: + checker = KeyCountChecker(docket, redis) + await checker.capture_baseline() + + # Schedule the task + await docket.add(failing_task)() + await checker.verify_keys_increased("scheduling") + + # Execute the task (should fail) + await worker.run_until_finished() + + # Verify task was attempted + assert task_attempted, "Task should have been attempted" + + # Verify cleanup despite failure + await checker.verify_keys_returned_to_baseline("failed task execution") + + +async def test_redis_key_cleanup_cancelled_task(docket: Docket, worker: Worker) -> None: + """Test that Redis keys are properly cleaned up after task cancellation.""" + # Prime the worker + await worker.run_until_finished() + + # Create a task that won't be executed + task_executed = False + + async def task_to_cancel(): + nonlocal task_executed + task_executed = True # pragma: no cover + + docket.register(task_to_cancel) + + async with docket.redis() as redis: + checker = KeyCountChecker(docket, redis) + await checker.capture_baseline() + + # Schedule the task for future execution + future_time = datetime.now(timezone.utc) + timedelta(seconds=10) + execution = await docket.add(task_to_cancel, future_time)() + await checker.verify_keys_increased("scheduling") + + # Cancel the task + await docket.cancel(execution.key) + + # Run worker to process any cleanup + await worker.run_until_finished() + + # Verify task was not executed + assert not task_executed, ( + "Task should not have been executed after cancellation" + ) + + # Verify cleanup after cancellation + await checker.verify_keys_returned_to_baseline("task cancellation")