Skip to content

Commit 0fada28

Browse files
committed
Use redelivery timeout as a hard limit on task execution
1 parent 63283a2 commit 0fada28

File tree

7 files changed

+1251
-265
lines changed

7 files changed

+1251
-265
lines changed

src/docket/dependencies.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@ async def __aenter__(self) -> Any: ... # pragma: no cover
3939

4040
async def __aexit__(
4141
self,
42-
exc_type: type[BaseException] | None,
43-
exc_value: BaseException | None,
44-
traceback: TracebackType | None,
42+
_exc_type: type[BaseException] | None,
43+
_exc_value: BaseException | None,
44+
_traceback: TracebackType | None,
4545
) -> bool: ... # pragma: no cover
4646

4747

@@ -545,6 +545,7 @@ def __init__(
545545
self.max_concurrent = max_concurrent
546546
self.scope = scope
547547
self._concurrency_key: str | None = None
548+
self._initialized: bool = False
548549

549550
async def __aenter__(self) -> "ConcurrencyLimit":
550551
execution = self.execution.get()
@@ -554,9 +555,13 @@ async def __aenter__(self) -> "ConcurrencyLimit":
554555
try:
555556
argument_value = execution.get_argument(self.argument_name)
556557
except KeyError:
557-
raise ValueError(
558-
f"Argument '{self.argument_name}' not found in task arguments"
558+
# If argument not found, create a bypass limit that doesn't apply concurrency control
559+
limit = ConcurrencyLimit(
560+
self.argument_name, self.max_concurrent, self.scope
559561
)
562+
limit._concurrency_key = None # Special marker for bypassed concurrency
563+
limit._initialized = True # Mark as initialized but bypassed
564+
return limit
560565

561566
# Create a concurrency key for this specific argument value
562567
scope = self.scope or docket.name
@@ -566,17 +571,25 @@ async def __aenter__(self) -> "ConcurrencyLimit":
566571

567572
limit = ConcurrencyLimit(self.argument_name, self.max_concurrent, self.scope)
568573
limit._concurrency_key = self._concurrency_key
574+
limit._initialized = True # Mark as initialized
569575
return limit
570576

571577
@property
572-
def concurrency_key(self) -> str:
573-
"""Redis key used for tracking concurrency for this specific argument value."""
574-
if self._concurrency_key is None:
578+
def concurrency_key(self) -> str | None:
579+
"""Redis key used for tracking concurrency for this specific argument value.
580+
Returns None when concurrency control is bypassed due to missing arguments.
581+
Raises RuntimeError if accessed before initialization."""
582+
if not self._initialized:
575583
raise RuntimeError(
576584
"ConcurrencyLimit not initialized - use within task context"
577585
)
578586
return self._concurrency_key
579587

588+
@property
589+
def is_bypassed(self) -> bool:
590+
"""Returns True if concurrency control is bypassed due to missing arguments."""
591+
return self._initialized and self._concurrency_key is None
592+
580593

581594
D = TypeVar("D", bound=Dependency)
582595

src/docket/worker.py

Lines changed: 39 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,6 @@ class Worker:
8888
scheduling_resolution: timedelta
8989
schedule_automatic_tasks: bool
9090

91-
# Concurrency refresh management
92-
_concurrency_refresh_threshold: timedelta
93-
_concurrency_refresh_interval: timedelta
94-
_concurrency_slots_needing_refresh: dict[str, tuple[str, float]]
95-
_concurrency_refresh_manager_task: asyncio.Task[None] | None
96-
9791
def __init__(
9892
self,
9993
docket: Docket,
@@ -114,20 +108,9 @@ def __init__(
114108
self.scheduling_resolution = scheduling_resolution
115109
self.schedule_automatic_tasks = schedule_automatic_tasks
116110

117-
# Concurrency refresh settings
118-
self._concurrency_refresh_threshold = timedelta(
119-
seconds=self.redelivery_timeout.total_seconds() * 0.5
120-
)
121-
self._concurrency_refresh_interval = timedelta(seconds=30)
122-
self._concurrency_slots_needing_refresh = {}
123-
self._concurrency_refresh_manager_task = None
124-
125111
async def __aenter__(self) -> Self:
126112
self._heartbeat_task = asyncio.create_task(self._heartbeat())
127113
self._execution_counts = {}
128-
self._concurrency_refresh_manager_task = asyncio.create_task(
129-
self._concurrency_refresh_manager()
130-
)
131114
return self
132115

133116
async def __aexit__(
@@ -145,14 +128,6 @@ async def __aexit__(
145128
pass
146129
del self._heartbeat_task
147130

148-
if self._concurrency_refresh_manager_task:
149-
self._concurrency_refresh_manager_task.cancel()
150-
try:
151-
await self._concurrency_refresh_manager_task
152-
except asyncio.CancelledError:
153-
pass
154-
del self._concurrency_refresh_manager_task
155-
156131
def labels(self) -> Mapping[str, str]:
157132
return {
158133
**self.docket.labels(),
@@ -569,8 +544,7 @@ async def _execute(self, execution: Execution) -> None:
569544
concurrency_limit = get_single_dependency_of_type(
570545
dependencies, ConcurrencyLimit
571546
)
572-
concurrency_key = None
573-
if concurrency_limit:
547+
if concurrency_limit and not concurrency_limit.is_bypassed:
574548
async with self.docket.redis() as redis:
575549
# Check if we can acquire a concurrency slot
576550
if not await self._can_start_task(redis, execution):
@@ -589,12 +563,8 @@ async def _execute(self, execution: Execution) -> None:
589563
)
590564
return
591565
else:
592-
# Successfully acquired slot - register for refresh management
593-
concurrency_key = self._get_concurrency_key(execution)
594-
member = f"{self.name}:{execution.key}"
595-
self._concurrency_slots_needing_refresh[
596-
concurrency_key
597-
] = (member, start)
566+
# Successfully acquired slot
567+
pass
598568

599569
# Preemptively reschedule the perpetual task for the future, or clear
600570
# the known task key for this task
@@ -622,17 +592,34 @@ async def _execute(self, execution: Execution) -> None:
622592
],
623593
)
624594

625-
if timeout := get_single_dependency_of_type(dependencies, Timeout):
626-
await self._run_function_with_timeout(
627-
execution, dependencies, timeout
628-
)
595+
# Apply timeout logic - either user's timeout or redelivery timeout
596+
user_timeout = get_single_dependency_of_type(dependencies, Timeout)
597+
if user_timeout:
598+
# If user timeout is longer than redelivery timeout, limit it
599+
if user_timeout.base > self.redelivery_timeout:
600+
# Create a new timeout limited by redelivery timeout
601+
# Remove the user timeout from dependencies to avoid conflicts
602+
limited_dependencies = {
603+
k: v
604+
for k, v in dependencies.items()
605+
if not isinstance(v, Timeout)
606+
}
607+
limited_timeout = Timeout(self.redelivery_timeout)
608+
limited_timeout.start()
609+
await self._run_function_with_timeout(
610+
execution, limited_dependencies, limited_timeout
611+
)
612+
else:
613+
# User timeout is within redelivery timeout, use as-is
614+
await self._run_function_with_timeout(
615+
execution, dependencies, user_timeout
616+
)
629617
else:
630-
await execution.function(
631-
*execution.args,
632-
**{
633-
**execution.kwargs,
634-
**dependencies,
635-
},
618+
# No user timeout - apply redelivery timeout as hard limit
619+
redelivery_timeout = Timeout(self.redelivery_timeout)
620+
redelivery_timeout.start()
621+
await self._run_function_with_timeout(
622+
execution, dependencies, redelivery_timeout
636623
)
637624

638625
duration = log_context["duration"] = time.time() - start
@@ -667,15 +654,11 @@ async def _execute(self, execution: Execution) -> None:
667654
)
668655
finally:
669656
# Release concurrency slot if we acquired one
670-
if "dependencies" in locals():
657+
if dependencies:
671658
concurrency_limit = get_single_dependency_of_type(
672659
dependencies, ConcurrencyLimit
673660
)
674-
if concurrency_limit and concurrency_key:
675-
# Remove from refresh management
676-
self._concurrency_slots_needing_refresh.pop(
677-
concurrency_key, None
678-
)
661+
if concurrency_limit and not concurrency_limit.is_bypassed:
679662
async with self.docket.redis() as redis:
680663
await self._release_concurrency_slot(redis, execution)
681664

@@ -691,7 +674,13 @@ async def _run_function_with_timeout(
691674
) -> None:
692675
task_coro = cast(
693676
Coroutine[None, None, None],
694-
execution.function(*execution.args, **execution.kwargs, **dependencies),
677+
execution.function(
678+
*execution.args,
679+
**{
680+
**execution.kwargs,
681+
**dependencies,
682+
},
683+
),
695684
)
696685
task = asyncio.create_task(task_coro)
697686
try:
@@ -770,21 +759,6 @@ def _startup_log(self) -> None:
770759
def workers_set(self) -> str:
771760
return self.docket.workers_set
772761

773-
@property
774-
def concurrency_refresh_interval(self) -> timedelta:
775-
"""Get the concurrency refresh interval for testing purposes."""
776-
return self._concurrency_refresh_interval
777-
778-
@concurrency_refresh_interval.setter
779-
def concurrency_refresh_interval(self, value: timedelta) -> None:
780-
"""Set the concurrency refresh interval for testing purposes."""
781-
self._concurrency_refresh_interval = value
782-
783-
@property
784-
def concurrency_slots_needing_refresh(self) -> dict[str, tuple[str, float]]:
785-
"""Get the concurrency slots needing refresh for testing purposes."""
786-
return self._concurrency_slots_needing_refresh
787-
788762
def worker_tasks_set(self, worker_name: str) -> str:
789763
return self.docket.worker_tasks_set(worker_name)
790764

@@ -941,63 +915,6 @@ async def _release_concurrency_slot(
941915
# Remove this worker's task from the sorted set
942916
await redis.zrem(concurrency_key, f"{self.name}:{execution.key}") # type: ignore
943917

944-
def _get_concurrency_key(self, execution: Execution) -> str:
945-
"""Get the Redis key for concurrency tracking."""
946-
concurrency_limit = get_single_dependency_parameter_of_type(
947-
execution.function, ConcurrencyLimit
948-
)
949-
if not concurrency_limit:
950-
raise ValueError("No concurrency limit found for execution")
951-
952-
try:
953-
argument_value = execution.get_argument(concurrency_limit.argument_name)
954-
except KeyError:
955-
raise ValueError(
956-
f"Argument '{concurrency_limit.argument_name}' not found in task arguments"
957-
)
958-
959-
scope = concurrency_limit.scope or self.docket.name
960-
return f"{scope}:concurrency:{concurrency_limit.argument_name}:{argument_value}"
961-
962-
async def _concurrency_refresh_manager(self) -> None:
963-
"""Single coroutine that manages all concurrency slot refreshes."""
964-
while True:
965-
try:
966-
await asyncio.sleep(self._concurrency_refresh_interval.total_seconds())
967-
968-
current_time = time.time()
969-
slots_to_refresh = []
970-
971-
# Find slots that need refreshing (running longer than threshold)
972-
for slot_key, (
973-
member,
974-
start_time,
975-
) in self._concurrency_slots_needing_refresh.items():
976-
elapsed = current_time - start_time
977-
if elapsed >= self._concurrency_refresh_threshold.total_seconds():
978-
slots_to_refresh.append((slot_key, member))
979-
980-
# Batch refresh in a single Redis operation
981-
if slots_to_refresh:
982-
async with self.docket.redis() as redis:
983-
async with redis.pipeline() as pipe:
984-
for slot_key, member in slots_to_refresh:
985-
pipe.zadd(slot_key, {member: current_time})
986-
await pipe.execute()
987-
988-
logger.debug(
989-
"Refreshed %d concurrency slots",
990-
len(slots_to_refresh),
991-
extra=self._log_context(),
992-
)
993-
994-
except asyncio.CancelledError:
995-
break
996-
except Exception:
997-
logger.exception(
998-
"Error in concurrency refresh manager", extra=self._log_context()
999-
)
1000-
1001918

1002919
def ms(seconds: float) -> str:
1003920
if seconds < 100:

tests/cli/test_worker.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,14 @@ def test_worker_command_exposes_all_the_options_of_worker():
5555
)
5656

5757

58-
def test_worker_command(
58+
async def test_worker_command(
5959
runner: CliRunner,
6060
docket: Docket,
6161
):
6262
"""Should run a worker until there are no more tasks to process"""
63-
result = runner.invoke(
63+
result = await asyncio.get_running_loop().run_in_executor(
64+
None,
65+
runner.invoke,
6466
app,
6567
[
6668
"worker",
@@ -70,7 +72,6 @@ def test_worker_command(
7072
"--docket",
7173
docket.name,
7274
],
73-
color=True,
7475
)
7576
assert result.exit_code == 0
7677

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def redis_url(redis_port: int, redis_db: int) -> str:
151151

152152

153153
@pytest.fixture
154-
async def docket(redis_url: str, aiolib: str) -> AsyncGenerator[Docket, None]:
154+
async def docket(redis_url: str) -> AsyncGenerator[Docket, None]:
155155
async with Docket(name=f"test-docket-{uuid4()}", url=redis_url) as docket:
156156
yield docket
157157

0 commit comments

Comments
 (0)