11import asyncio
22import logging
3+ import time
34from contextlib import asynccontextmanager
5+ from contextvars import ContextVar
46from datetime import datetime , timedelta , timezone
5- from typing import AsyncGenerator , Callable
7+ from typing import AsyncGenerator , Callable , Iterable
68from unittest .mock import AsyncMock , patch
79from uuid import uuid4
810
11+ import cloudpickle # type: ignore[import]
912import pytest
1013from redis .asyncio import Redis
1114from redis .exceptions import ConnectionError
1821 Perpetual ,
1922 Worker ,
2023)
24+ from docket .dependencies import Timeout
25+ from docket .execution import Execution
2126from docket .tasks import standard_tasks
2227from docket .worker import ms
2328
@@ -175,7 +180,6 @@ async def task_that_sometimes_fails(
175180 nonlocal failure_count
176181
177182 # Record when this task runs
178- import time
179183
180184 task_executions .append ((customer_id , time .time ()))
181185
@@ -556,7 +560,6 @@ async def perpetual_task(
556560
557561async def test_worker_concurrency_limits_task_queuing_behavior (docket : Docket ):
558562 """Test that concurrency limits control task execution properly"""
559- from contextvars import ContextVar
560563
561564 # Use contextvar for reliable tracking across async execution
562565 execution_log : ContextVar [list [tuple [str , int ]]] = ContextVar ("execution_log" )
@@ -1172,7 +1175,6 @@ async def edge_case_task(
11721175
11731176async def test_worker_timeout_exceeds_redelivery_timeout (docket : Docket ):
11741177 """Test worker handles user timeout longer than redelivery timeout."""
1175- from docket .dependencies import Timeout
11761178
11771179 task_executed = False
11781180
@@ -1251,8 +1253,6 @@ async def task_missing_concurrency_arg(
12511253
12521254async def test_worker_no_concurrency_dependency_in_function (docket : Docket ):
12531255 """Test _can_start_task with function that has no concurrency dependency."""
1254- from docket .execution import Execution
1255- from datetime import datetime , timezone
12561256
12571257 async def task_without_concurrency_dependency ():
12581258 await asyncio .sleep (0.001 )
@@ -1278,8 +1278,6 @@ async def task_without_concurrency_dependency():
12781278
12791279async def test_worker_no_concurrency_dependency_in_release (docket : Docket ):
12801280 """Test _release_concurrency_slot with function that has no concurrency dependency."""
1281- from docket .execution import Execution
1282- from datetime import datetime , timezone
12831281
12841282 async def task_without_concurrency_dependency ():
12851283 await asyncio .sleep (0.001 )
@@ -1304,8 +1302,6 @@ async def task_without_concurrency_dependency():
13041302
13051303async def test_worker_missing_concurrency_argument_in_release (docket : Docket ):
13061304 """Test _release_concurrency_slot when concurrency argument is missing."""
1307- from docket .execution import Execution
1308- from datetime import datetime , timezone
13091305
13101306 async def task_with_missing_arg (
13111307 concurrency : ConcurrencyLimit = ConcurrencyLimit (
@@ -1334,8 +1330,6 @@ async def task_with_missing_arg(
13341330
13351331async def test_worker_concurrency_missing_argument_in_can_start (docket : Docket ):
13361332 """Test _can_start_task with missing concurrency argument during execution."""
1337- from docket .execution import Execution
1338- from datetime import datetime , timezone
13391333
13401334 async def task_with_missing_concurrency_arg (
13411335 concurrency : ConcurrencyLimit = ConcurrencyLimit (
@@ -1384,7 +1378,6 @@ async def task_that_will_fail():
13841378 task_failed = False
13851379
13861380 # Mock resolved_dependencies to fail before setting dependencies
1387- from unittest .mock import patch , AsyncMock
13881381
13891382 await docket .add (task_that_will_fail )()
13901383
@@ -1504,3 +1497,237 @@ async def test_rapid_replace_operations(
15041497 # Should only execute the last replacement
15051498 the_task .assert_awaited_once_with ("arg4" , b = "b4" )
15061499 assert the_task .await_count == 1
1500+
1501+
1502+ async def test_wrongtype_error_with_legacy_known_task_key (
1503+ docket : Docket ,
1504+ worker : Worker ,
1505+ the_task : AsyncMock ,
1506+ now : Callable [[], datetime ],
1507+ caplog : pytest .LogCaptureFixture ,
1508+ ) -> None :
1509+ """Test graceful handling when known task keys exist as strings from legacy implementations.
1510+
1511+ Regression test for issue where worker scheduler would get WRONGTYPE errors when trying to
1512+ HSET on known task keys that existed as string values from older docket versions.
1513+
1514+ The original error occurred when:
1515+ 1. A legacy docket created known task keys as simple string values (timestamps)
1516+ 2. The new scheduler tried to HSET stream_message_id on these keys
1517+ 3. Redis threw WRONGTYPE error because you can't HSET on a string key
1518+ 4. This caused scheduler loop failures in production
1519+
1520+ This test reproduces that scenario by manually setting up the legacy state,
1521+ then verifies the new code handles it gracefully without errors.
1522+ """
1523+ key = f"legacy-task:{ uuid4 ()} "
1524+
1525+ # Simulate legacy behavior: create the known task key as a string
1526+ # This is what older versions of docket would have done
1527+ async with docket .redis () as redis :
1528+ known_task_key = docket .known_task_key (key )
1529+ when = now () + timedelta (seconds = 1 )
1530+
1531+ # Set up legacy state: known key as string, task in queue with parked data
1532+ await redis .set (known_task_key , str (when .timestamp ()))
1533+ await redis .zadd (docket .queue_key , {key : when .timestamp ()})
1534+
1535+ await redis .hset ( # type: ignore
1536+ docket .parked_task_key (key ),
1537+ mapping = {
1538+ "key" : key ,
1539+ "when" : when .isoformat (),
1540+ "function" : "trace" ,
1541+ "args" : cloudpickle .dumps (["legacy task test" ]), # type: ignore[arg-type]
1542+ "kwargs" : cloudpickle .dumps ({}), # type: ignore[arg-type]
1543+ "attempt" : "1" ,
1544+ },
1545+ )
1546+
1547+ # Capture logs to ensure no errors occur and see task execution
1548+ with caplog .at_level (logging .INFO ):
1549+ await worker .run_until_finished ()
1550+
1551+ # Should not have any ERROR logs now that the issue is fixed
1552+ error_logs = [record for record in caplog .records if record .levelname == "ERROR" ]
1553+ assert len (error_logs ) == 0 , (
1554+ f"Expected no error logs, but got: { [r .message for r in error_logs ]} "
1555+ )
1556+
1557+ # The task should execute successfully
1558+ # Since we used trace, we should see an INFO log with the message
1559+ info_logs = [record for record in caplog .records if record .levelname == "INFO" ]
1560+ trace_logs = [
1561+ record for record in info_logs if "legacy task test" in record .message
1562+ ]
1563+ assert len (trace_logs ) > 0 , (
1564+ f"Expected to see trace log with 'legacy task test', got: { [r .message for r in info_logs ]} "
1565+ )
1566+
1567+
1568+ async def count_redis_keys_by_type (redis : Redis , prefix : str ) -> dict [str , int ]:
1569+ """Count Redis keys by type for a given prefix."""
1570+ pattern = f"{ prefix } *"
1571+ keys : Iterable [str ] = await redis .keys (pattern ) # type: ignore
1572+ counts : dict [str , int ] = {}
1573+
1574+ for key in keys :
1575+ key_type = await redis .type (key )
1576+ key_type_str = (
1577+ key_type .decode () if isinstance (key_type , bytes ) else str (key_type )
1578+ )
1579+ counts [key_type_str ] = counts .get (key_type_str , 0 ) + 1
1580+
1581+ return counts
1582+
1583+
1584+ class KeyCountChecker :
1585+ """Helper to verify Redis key counts remain consistent across operations."""
1586+
1587+ def __init__ (self , docket : Docket , redis : Redis ) -> None :
1588+ self .docket = docket
1589+ self .redis = redis
1590+ self .baseline_counts : dict [str , int ] = {}
1591+
1592+ async def capture_baseline (self ) -> None :
1593+ """Capture baseline key counts after worker priming."""
1594+ self .baseline_counts = await count_redis_keys_by_type (
1595+ self .redis , self .docket .name
1596+ )
1597+ print (f"Baseline key counts: { self .baseline_counts } " )
1598+
1599+ async def verify_keys_increased (self , operation : str ) -> None :
1600+ """Verify that key counts increased after scheduling operation."""
1601+ current_counts = await count_redis_keys_by_type (self .redis , self .docket .name )
1602+ print (f"After { operation } key counts: { current_counts } " )
1603+
1604+ total_current = sum (current_counts .values ())
1605+ total_baseline = sum (self .baseline_counts .values ())
1606+ assert total_current > total_baseline , (
1607+ f"Expected more keys after { operation } , but got { total_current } vs { total_baseline } "
1608+ )
1609+
1610+ async def verify_keys_returned_to_baseline (self , operation : str ) -> None :
1611+ """Verify that key counts returned to baseline after operation completion."""
1612+ final_counts = await count_redis_keys_by_type (self .redis , self .docket .name )
1613+ print (f"Final key counts: { final_counts } " )
1614+
1615+ # Check each key type matches baseline
1616+ all_key_types = set (self .baseline_counts .keys ()) | set (final_counts .keys ())
1617+ for key_type in all_key_types :
1618+ baseline_count = self .baseline_counts .get (key_type , 0 )
1619+ final_count = final_counts .get (key_type , 0 )
1620+ assert final_count == baseline_count , (
1621+ f"Memory leak detected after { operation } : { key_type } keys not cleaned up properly. "
1622+ f"Baseline: { baseline_count } , Final: { final_count } "
1623+ )
1624+
1625+
1626+ async def test_redis_key_cleanup_successful_task (
1627+ docket : Docket , worker : Worker
1628+ ) -> None :
1629+ """Test that Redis keys are properly cleaned up after successful task execution.
1630+
1631+ This test systematically counts Redis keys before and after task operations to detect
1632+ memory leaks where keys are not properly cleaned up.
1633+ """
1634+ # Prime the worker (run once with no tasks to establish baseline)
1635+ await worker .run_until_finished ()
1636+
1637+ # Create and register a simple task
1638+ task_executed = False
1639+
1640+ async def successful_task ():
1641+ nonlocal task_executed
1642+ task_executed = True
1643+ await asyncio .sleep (0.01 ) # Small delay to ensure proper execution flow
1644+
1645+ docket .register (successful_task )
1646+
1647+ async with docket .redis () as redis :
1648+ checker = KeyCountChecker (docket , redis )
1649+ await checker .capture_baseline ()
1650+
1651+ # Schedule the task
1652+ await docket .add (successful_task )()
1653+ await checker .verify_keys_increased ("scheduling" )
1654+
1655+ # Execute the task
1656+ await worker .run_until_finished ()
1657+
1658+ # Verify task executed successfully
1659+ assert task_executed , "Task should have executed successfully"
1660+
1661+ # Verify cleanup
1662+ await checker .verify_keys_returned_to_baseline ("successful task execution" )
1663+
1664+
1665+ async def test_redis_key_cleanup_failed_task (docket : Docket , worker : Worker ) -> None :
1666+ """Test that Redis keys are properly cleaned up after failed task execution."""
1667+ # Prime the worker
1668+ await worker .run_until_finished ()
1669+
1670+ # Create a task that will fail
1671+ task_attempted = False
1672+
1673+ async def failing_task ():
1674+ nonlocal task_attempted
1675+ task_attempted = True
1676+ raise ValueError ("Intentional test failure" )
1677+
1678+ docket .register (failing_task )
1679+
1680+ async with docket .redis () as redis :
1681+ checker = KeyCountChecker (docket , redis )
1682+ await checker .capture_baseline ()
1683+
1684+ # Schedule the task
1685+ await docket .add (failing_task )()
1686+ await checker .verify_keys_increased ("scheduling" )
1687+
1688+ # Execute the task (should fail)
1689+ await worker .run_until_finished ()
1690+
1691+ # Verify task was attempted
1692+ assert task_attempted , "Task should have been attempted"
1693+
1694+ # Verify cleanup despite failure
1695+ await checker .verify_keys_returned_to_baseline ("failed task execution" )
1696+
1697+
1698+ async def test_redis_key_cleanup_cancelled_task (docket : Docket , worker : Worker ) -> None :
1699+ """Test that Redis keys are properly cleaned up after task cancellation."""
1700+ # Prime the worker
1701+ await worker .run_until_finished ()
1702+
1703+ # Create a task that won't be executed
1704+ task_executed = False
1705+
1706+ async def task_to_cancel ():
1707+ nonlocal task_executed
1708+ task_executed = True # pragma: no cover
1709+
1710+ docket .register (task_to_cancel )
1711+
1712+ async with docket .redis () as redis :
1713+ checker = KeyCountChecker (docket , redis )
1714+ await checker .capture_baseline ()
1715+
1716+ # Schedule the task for future execution
1717+ future_time = datetime .now (timezone .utc ) + timedelta (seconds = 10 )
1718+ execution = await docket .add (task_to_cancel , future_time )()
1719+ await checker .verify_keys_increased ("scheduling" )
1720+
1721+ # Cancel the task
1722+ await docket .cancel (execution .key )
1723+
1724+ # Run worker to process any cleanup
1725+ await worker .run_until_finished ()
1726+
1727+ # Verify task was not executed
1728+ assert not task_executed , (
1729+ "Task should not have been executed after cancellation"
1730+ )
1731+
1732+ # Verify cleanup after cancellation
1733+ await checker .verify_keys_returned_to_baseline ("task cancellation" )
0 commit comments