Skip to content

Commit c2b8187

Browse files
chrisguidryclaude
andauthored
Fix WRONGTYPE error and add memory leak detection tests (#153)
## Summary - Fixes a Redis WRONGTYPE error that occurred when the worker scheduler tried to HSET on known task keys that existed as strings from legacy implementations - Adds systematic Redis key counting tests that can detect memory leaks by comparing key counts before and after task operations - Improves code organization by hoisting imports to top of test file ## What was the problem? The error occurred because we were redundantly storing stream message IDs in known task keys - this wasn't needed for the cancellation functionality to work. When legacy docket versions created known task keys as simple string values, the new scheduler would fail trying to HSET on those string keys. ## What changed? - Remove redundant HSET of stream_message_id in worker scheduler Lua script - Add memory leak detection tests with KeyCountChecker helper class that covers successful tasks, failed tasks, and cancelled tasks - Add regression test for the original WRONGTYPE error scenario - Hoist all imports to top of test file per code style guidelines ## Test plan The new tests systematically verify that Redis keys are properly cleaned up in all scenarios and can detect memory leaks by counting keys before and after operations. We validated the tests work by temporarily introducing bugs and confirming they catch the issues. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-authored-by: Claude <noreply@anthropic.com>
1 parent 9dc405c commit c2b8187

File tree

2 files changed

+241
-17
lines changed

2 files changed

+241
-17
lines changed

src/docket/worker.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -406,17 +406,14 @@ async def _scheduler_loop(
406406
task[task_data[j]] = task_data[j+1]
407407
end
408408
409-
local message_id = redis.call('XADD', KEYS[2], '*',
409+
redis.call('XADD', KEYS[2], '*',
410410
'key', task['key'],
411411
'when', task['when'],
412412
'function', task['function'],
413413
'args', task['args'],
414414
'kwargs', task['kwargs'],
415415
'attempt', task['attempt']
416416
)
417-
-- Store the message ID in the known task key
418-
local known_key = ARGV[2] .. ":known:" .. key
419-
redis.call('HSET', known_key, 'stream_message_id', message_id)
420417
redis.call('DEL', hash_key)
421418
due_work = due_work + 1
422419
end

tests/test_worker.py

Lines changed: 240 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
import asyncio
22
import logging
3+
import time
34
from contextlib import asynccontextmanager
5+
from contextvars import ContextVar
46
from datetime import datetime, timedelta, timezone
5-
from typing import AsyncGenerator, Callable
7+
from typing import AsyncGenerator, Callable, Iterable
68
from unittest.mock import AsyncMock, patch
79
from uuid import uuid4
810

11+
import cloudpickle # type: ignore[import]
912
import pytest
1013
from redis.asyncio import Redis
1114
from redis.exceptions import ConnectionError
@@ -18,6 +21,8 @@
1821
Perpetual,
1922
Worker,
2023
)
24+
from docket.dependencies import Timeout
25+
from docket.execution import Execution
2126
from docket.tasks import standard_tasks
2227
from docket.worker import ms
2328

@@ -175,7 +180,6 @@ async def task_that_sometimes_fails(
175180
nonlocal failure_count
176181

177182
# Record when this task runs
178-
import time
179183

180184
task_executions.append((customer_id, time.time()))
181185

@@ -556,7 +560,6 @@ async def perpetual_task(
556560

557561
async def test_worker_concurrency_limits_task_queuing_behavior(docket: Docket):
558562
"""Test that concurrency limits control task execution properly"""
559-
from contextvars import ContextVar
560563

561564
# Use contextvar for reliable tracking across async execution
562565
execution_log: ContextVar[list[tuple[str, int]]] = ContextVar("execution_log")
@@ -1172,7 +1175,6 @@ async def edge_case_task(
11721175

11731176
async def test_worker_timeout_exceeds_redelivery_timeout(docket: Docket):
11741177
"""Test worker handles user timeout longer than redelivery timeout."""
1175-
from docket.dependencies import Timeout
11761178

11771179
task_executed = False
11781180

@@ -1251,8 +1253,6 @@ async def task_missing_concurrency_arg(
12511253

12521254
async def test_worker_no_concurrency_dependency_in_function(docket: Docket):
12531255
"""Test _can_start_task with function that has no concurrency dependency."""
1254-
from docket.execution import Execution
1255-
from datetime import datetime, timezone
12561256

12571257
async def task_without_concurrency_dependency():
12581258
await asyncio.sleep(0.001)
@@ -1278,8 +1278,6 @@ async def task_without_concurrency_dependency():
12781278

12791279
async def test_worker_no_concurrency_dependency_in_release(docket: Docket):
12801280
"""Test _release_concurrency_slot with function that has no concurrency dependency."""
1281-
from docket.execution import Execution
1282-
from datetime import datetime, timezone
12831281

12841282
async def task_without_concurrency_dependency():
12851283
await asyncio.sleep(0.001)
@@ -1304,8 +1302,6 @@ async def task_without_concurrency_dependency():
13041302

13051303
async def test_worker_missing_concurrency_argument_in_release(docket: Docket):
13061304
"""Test _release_concurrency_slot when concurrency argument is missing."""
1307-
from docket.execution import Execution
1308-
from datetime import datetime, timezone
13091305

13101306
async def task_with_missing_arg(
13111307
concurrency: ConcurrencyLimit = ConcurrencyLimit(
@@ -1334,8 +1330,6 @@ async def task_with_missing_arg(
13341330

13351331
async def test_worker_concurrency_missing_argument_in_can_start(docket: Docket):
13361332
"""Test _can_start_task with missing concurrency argument during execution."""
1337-
from docket.execution import Execution
1338-
from datetime import datetime, timezone
13391333

13401334
async def task_with_missing_concurrency_arg(
13411335
concurrency: ConcurrencyLimit = ConcurrencyLimit(
@@ -1384,7 +1378,6 @@ async def task_that_will_fail():
13841378
task_failed = False
13851379

13861380
# Mock resolved_dependencies to fail before setting dependencies
1387-
from unittest.mock import patch, AsyncMock
13881381

13891382
await docket.add(task_that_will_fail)()
13901383

@@ -1504,3 +1497,237 @@ async def test_rapid_replace_operations(
15041497
# Should only execute the last replacement
15051498
the_task.assert_awaited_once_with("arg4", b="b4")
15061499
assert the_task.await_count == 1
1500+
1501+
1502+
async def test_wrongtype_error_with_legacy_known_task_key(
1503+
docket: Docket,
1504+
worker: Worker,
1505+
the_task: AsyncMock,
1506+
now: Callable[[], datetime],
1507+
caplog: pytest.LogCaptureFixture,
1508+
) -> None:
1509+
"""Test graceful handling when known task keys exist as strings from legacy implementations.
1510+
1511+
Regression test for issue where worker scheduler would get WRONGTYPE errors when trying to
1512+
HSET on known task keys that existed as string values from older docket versions.
1513+
1514+
The original error occurred when:
1515+
1. A legacy docket created known task keys as simple string values (timestamps)
1516+
2. The new scheduler tried to HSET stream_message_id on these keys
1517+
3. Redis threw WRONGTYPE error because you can't HSET on a string key
1518+
4. This caused scheduler loop failures in production
1519+
1520+
This test reproduces that scenario by manually setting up the legacy state,
1521+
then verifies the new code handles it gracefully without errors.
1522+
"""
1523+
key = f"legacy-task:{uuid4()}"
1524+
1525+
# Simulate legacy behavior: create the known task key as a string
1526+
# This is what older versions of docket would have done
1527+
async with docket.redis() as redis:
1528+
known_task_key = docket.known_task_key(key)
1529+
when = now() + timedelta(seconds=1)
1530+
1531+
# Set up legacy state: known key as string, task in queue with parked data
1532+
await redis.set(known_task_key, str(when.timestamp()))
1533+
await redis.zadd(docket.queue_key, {key: when.timestamp()})
1534+
1535+
await redis.hset( # type: ignore
1536+
docket.parked_task_key(key),
1537+
mapping={
1538+
"key": key,
1539+
"when": when.isoformat(),
1540+
"function": "trace",
1541+
"args": cloudpickle.dumps(["legacy task test"]), # type: ignore[arg-type]
1542+
"kwargs": cloudpickle.dumps({}), # type: ignore[arg-type]
1543+
"attempt": "1",
1544+
},
1545+
)
1546+
1547+
# Capture logs to ensure no errors occur and see task execution
1548+
with caplog.at_level(logging.INFO):
1549+
await worker.run_until_finished()
1550+
1551+
# Should not have any ERROR logs now that the issue is fixed
1552+
error_logs = [record for record in caplog.records if record.levelname == "ERROR"]
1553+
assert len(error_logs) == 0, (
1554+
f"Expected no error logs, but got: {[r.message for r in error_logs]}"
1555+
)
1556+
1557+
# The task should execute successfully
1558+
# Since we used trace, we should see an INFO log with the message
1559+
info_logs = [record for record in caplog.records if record.levelname == "INFO"]
1560+
trace_logs = [
1561+
record for record in info_logs if "legacy task test" in record.message
1562+
]
1563+
assert len(trace_logs) > 0, (
1564+
f"Expected to see trace log with 'legacy task test', got: {[r.message for r in info_logs]}"
1565+
)
1566+
1567+
1568+
async def count_redis_keys_by_type(redis: Redis, prefix: str) -> dict[str, int]:
1569+
"""Count Redis keys by type for a given prefix."""
1570+
pattern = f"{prefix}*"
1571+
keys: Iterable[str] = await redis.keys(pattern) # type: ignore
1572+
counts: dict[str, int] = {}
1573+
1574+
for key in keys:
1575+
key_type = await redis.type(key)
1576+
key_type_str = (
1577+
key_type.decode() if isinstance(key_type, bytes) else str(key_type)
1578+
)
1579+
counts[key_type_str] = counts.get(key_type_str, 0) + 1
1580+
1581+
return counts
1582+
1583+
1584+
class KeyCountChecker:
1585+
"""Helper to verify Redis key counts remain consistent across operations."""
1586+
1587+
def __init__(self, docket: Docket, redis: Redis) -> None:
1588+
self.docket = docket
1589+
self.redis = redis
1590+
self.baseline_counts: dict[str, int] = {}
1591+
1592+
async def capture_baseline(self) -> None:
1593+
"""Capture baseline key counts after worker priming."""
1594+
self.baseline_counts = await count_redis_keys_by_type(
1595+
self.redis, self.docket.name
1596+
)
1597+
print(f"Baseline key counts: {self.baseline_counts}")
1598+
1599+
async def verify_keys_increased(self, operation: str) -> None:
1600+
"""Verify that key counts increased after scheduling operation."""
1601+
current_counts = await count_redis_keys_by_type(self.redis, self.docket.name)
1602+
print(f"After {operation} key counts: {current_counts}")
1603+
1604+
total_current = sum(current_counts.values())
1605+
total_baseline = sum(self.baseline_counts.values())
1606+
assert total_current > total_baseline, (
1607+
f"Expected more keys after {operation}, but got {total_current} vs {total_baseline}"
1608+
)
1609+
1610+
async def verify_keys_returned_to_baseline(self, operation: str) -> None:
1611+
"""Verify that key counts returned to baseline after operation completion."""
1612+
final_counts = await count_redis_keys_by_type(self.redis, self.docket.name)
1613+
print(f"Final key counts: {final_counts}")
1614+
1615+
# Check each key type matches baseline
1616+
all_key_types = set(self.baseline_counts.keys()) | set(final_counts.keys())
1617+
for key_type in all_key_types:
1618+
baseline_count = self.baseline_counts.get(key_type, 0)
1619+
final_count = final_counts.get(key_type, 0)
1620+
assert final_count == baseline_count, (
1621+
f"Memory leak detected after {operation}: {key_type} keys not cleaned up properly. "
1622+
f"Baseline: {baseline_count}, Final: {final_count}"
1623+
)
1624+
1625+
1626+
async def test_redis_key_cleanup_successful_task(
1627+
docket: Docket, worker: Worker
1628+
) -> None:
1629+
"""Test that Redis keys are properly cleaned up after successful task execution.
1630+
1631+
This test systematically counts Redis keys before and after task operations to detect
1632+
memory leaks where keys are not properly cleaned up.
1633+
"""
1634+
# Prime the worker (run once with no tasks to establish baseline)
1635+
await worker.run_until_finished()
1636+
1637+
# Create and register a simple task
1638+
task_executed = False
1639+
1640+
async def successful_task():
1641+
nonlocal task_executed
1642+
task_executed = True
1643+
await asyncio.sleep(0.01) # Small delay to ensure proper execution flow
1644+
1645+
docket.register(successful_task)
1646+
1647+
async with docket.redis() as redis:
1648+
checker = KeyCountChecker(docket, redis)
1649+
await checker.capture_baseline()
1650+
1651+
# Schedule the task
1652+
await docket.add(successful_task)()
1653+
await checker.verify_keys_increased("scheduling")
1654+
1655+
# Execute the task
1656+
await worker.run_until_finished()
1657+
1658+
# Verify task executed successfully
1659+
assert task_executed, "Task should have executed successfully"
1660+
1661+
# Verify cleanup
1662+
await checker.verify_keys_returned_to_baseline("successful task execution")
1663+
1664+
1665+
async def test_redis_key_cleanup_failed_task(docket: Docket, worker: Worker) -> None:
1666+
"""Test that Redis keys are properly cleaned up after failed task execution."""
1667+
# Prime the worker
1668+
await worker.run_until_finished()
1669+
1670+
# Create a task that will fail
1671+
task_attempted = False
1672+
1673+
async def failing_task():
1674+
nonlocal task_attempted
1675+
task_attempted = True
1676+
raise ValueError("Intentional test failure")
1677+
1678+
docket.register(failing_task)
1679+
1680+
async with docket.redis() as redis:
1681+
checker = KeyCountChecker(docket, redis)
1682+
await checker.capture_baseline()
1683+
1684+
# Schedule the task
1685+
await docket.add(failing_task)()
1686+
await checker.verify_keys_increased("scheduling")
1687+
1688+
# Execute the task (should fail)
1689+
await worker.run_until_finished()
1690+
1691+
# Verify task was attempted
1692+
assert task_attempted, "Task should have been attempted"
1693+
1694+
# Verify cleanup despite failure
1695+
await checker.verify_keys_returned_to_baseline("failed task execution")
1696+
1697+
1698+
async def test_redis_key_cleanup_cancelled_task(docket: Docket, worker: Worker) -> None:
1699+
"""Test that Redis keys are properly cleaned up after task cancellation."""
1700+
# Prime the worker
1701+
await worker.run_until_finished()
1702+
1703+
# Create a task that won't be executed
1704+
task_executed = False
1705+
1706+
async def task_to_cancel():
1707+
nonlocal task_executed
1708+
task_executed = True # pragma: no cover
1709+
1710+
docket.register(task_to_cancel)
1711+
1712+
async with docket.redis() as redis:
1713+
checker = KeyCountChecker(docket, redis)
1714+
await checker.capture_baseline()
1715+
1716+
# Schedule the task for future execution
1717+
future_time = datetime.now(timezone.utc) + timedelta(seconds=10)
1718+
execution = await docket.add(task_to_cancel, future_time)()
1719+
await checker.verify_keys_increased("scheduling")
1720+
1721+
# Cancel the task
1722+
await docket.cancel(execution.key)
1723+
1724+
# Run worker to process any cleanup
1725+
await worker.run_until_finished()
1726+
1727+
# Verify task was not executed
1728+
assert not task_executed, (
1729+
"Task should not have been executed after cancellation"
1730+
)
1731+
1732+
# Verify cleanup after cancellation
1733+
await checker.verify_keys_returned_to_baseline("task cancellation")

0 commit comments

Comments
 (0)