Skip to content

Commit 7754f0d

Browse files
authored
Merge branch 'main' into exectuion-expansion
2 parents 282cfce + afb074d commit 7754f0d

File tree

4 files changed

+268
-43
lines changed

4 files changed

+268
-43
lines changed

pyproject.toml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,11 @@ dependencies = [
4040
dev = [
4141
"codespell>=2.4.1",
4242
"docker>=7.1.0",
43-
# Using fork until https://github.com/cunla/fakeredis-py/pull/427 is merged
44-
# This fixes xpending_range to return all 4 required fields (message_id, consumer,
45-
# time_since_delivered, times_delivered) instead of just 2, matching Redis behavior
46-
"fakeredis[lua] @ git+https://github.com/zzstoatzz/fakeredis-py.git@fix-xpending-range-fields",
43+
# Using specific commit until version > 2.32.0 is released to PyPI
44+
# This includes the fix for xpending_range to return all 4 required fields
45+
# Once released, we can use "fakeredis[lua]>=2.33.0" and move to main dependencies
46+
# See: https://github.com/cunla/fakeredis-py/pull/427
47+
"fakeredis[lua] @ git+https://github.com/cunla/fakeredis-py.git@ad50a0de8d6dce554fb629ec284bc4ccbc6a7f12",
4748
"ipython>=8.0.0",
4849
"mypy>=1.14.1",
4950
"opentelemetry-distro>=0.51b0",

src/docket/dependencies.py

Lines changed: 46 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -725,38 +725,50 @@ def __init__(self, parameter: str, error: Exception) -> None:
725725
async def resolved_dependencies(
726726
worker: "Worker", execution: Execution
727727
) -> AsyncGenerator[dict[str, Any], None]:
728-
# Set context variables once at the beginning
729-
Dependency.docket.set(worker.docket)
730-
Dependency.worker.set(worker)
731-
Dependency.execution.set(execution)
732-
733-
_Depends.cache.set({})
734-
735-
async with AsyncExitStack() as stack:
736-
_Depends.stack.set(stack)
737-
738-
arguments: dict[str, Any] = {}
739-
740-
parameters = get_dependency_parameters(execution.function)
741-
for parameter, dependency in parameters.items():
742-
kwargs = execution.kwargs
743-
if parameter in kwargs:
744-
arguments[parameter] = kwargs[parameter]
745-
continue
746-
747-
# Special case for TaskArguments, they are "magical" and infer the parameter
748-
# they refer to from the parameter name (unless otherwise specified). At
749-
# the top-level task function call, it doesn't make sense to specify one
750-
# _without_ a parameter name, so we'll call that a failed dependency.
751-
if isinstance(dependency, _TaskArgument) and not dependency.parameter:
752-
arguments[parameter] = FailedDependency(
753-
parameter, ValueError("No parameter name specified")
754-
)
755-
continue
756-
728+
# Capture tokens for all contextvar sets to ensure proper cleanup
729+
docket_token = Dependency.docket.set(worker.docket)
730+
worker_token = Dependency.worker.set(worker)
731+
execution_token = Dependency.execution.set(execution)
732+
cache_token = _Depends.cache.set({})
733+
734+
try:
735+
async with AsyncExitStack() as stack:
736+
stack_token = _Depends.stack.set(stack)
757737
try:
758-
arguments[parameter] = await stack.enter_async_context(dependency)
759-
except Exception as error:
760-
arguments[parameter] = FailedDependency(parameter, error)
761-
762-
yield arguments
738+
arguments: dict[str, Any] = {}
739+
740+
parameters = get_dependency_parameters(execution.function)
741+
for parameter, dependency in parameters.items():
742+
kwargs = execution.kwargs
743+
if parameter in kwargs:
744+
arguments[parameter] = kwargs[parameter]
745+
continue
746+
747+
# Special case for TaskArguments, they are "magical" and infer the parameter
748+
# they refer to from the parameter name (unless otherwise specified). At
749+
# the top-level task function call, it doesn't make sense to specify one
750+
# _without_ a parameter name, so we'll call that a failed dependency.
751+
if (
752+
isinstance(dependency, _TaskArgument)
753+
and not dependency.parameter
754+
):
755+
arguments[parameter] = FailedDependency(
756+
parameter, ValueError("No parameter name specified")
757+
)
758+
continue
759+
760+
try:
761+
arguments[parameter] = await stack.enter_async_context(
762+
dependency
763+
)
764+
except Exception as error:
765+
arguments[parameter] = FailedDependency(parameter, error)
766+
767+
yield arguments
768+
finally:
769+
_Depends.stack.reset(stack_token)
770+
finally:
771+
_Depends.cache.reset(cache_token)
772+
Dependency.execution.reset(execution_token)
773+
Dependency.worker.reset(worker_token)
774+
Dependency.docket.reset(docket_token)

tests/test_dependencies.py

Lines changed: 214 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,20 @@
11
import logging
2-
from contextlib import contextmanager
2+
from contextlib import asynccontextmanager, contextmanager
33
from datetime import datetime, timedelta, timezone
44

55
import pytest
66

77
from docket import CurrentDocket, CurrentWorker, Docket, Worker
8-
from docket.dependencies import Depends, ExponentialRetry, Retry, TaskArgument
8+
from docket.dependencies import (
9+
Depends,
10+
Dependency,
11+
ExponentialRetry,
12+
Retry,
13+
TaskArgument,
14+
_Depends, # type: ignore[attr-defined]
15+
resolved_dependencies,
16+
)
17+
from docket.execution import Execution
918

1019

1120
async def test_dependencies_may_be_duplicated(docket: Docket, worker: Worker):
@@ -449,3 +458,206 @@ async def dependent_task(result: int = Depends(sync_adder)):
449458
await worker.run_until_finished()
450459

451460
assert called
461+
462+
463+
async def test_contextvar_isolation_between_tasks(docket: Docket, worker: Worker):
464+
"""Contextvars should be isolated between sequential task executions"""
465+
executions_seen: list[tuple[str, Execution]] = []
466+
467+
async def first_task(a: str):
468+
# Capture the execution context during first task
469+
execution = Dependency.execution.get()
470+
executions_seen.append(("first", execution))
471+
assert a == "first"
472+
473+
async def second_task(b: str):
474+
# Capture the execution context during second task
475+
execution = Dependency.execution.get()
476+
executions_seen.append(("second", execution))
477+
assert b == "second"
478+
479+
await docket.add(first_task)(a="first")
480+
await docket.add(second_task)(b="second")
481+
await worker.run_until_finished()
482+
483+
# Verify we captured both executions
484+
assert len(executions_seen) == 2
485+
486+
# Find first and second executions (order may vary)
487+
executions_by_name = {name: exec for name, exec in executions_seen}
488+
assert set(executions_by_name.keys()) == {"first", "second"}
489+
490+
# Verify the executions are different and have correct kwargs
491+
first_execution = executions_by_name["first"]
492+
second_execution = executions_by_name["second"]
493+
assert first_execution is not second_execution
494+
assert first_execution.kwargs["a"] == "first"
495+
assert second_execution.kwargs["b"] == "second"
496+
497+
498+
async def test_contextvar_cleanup_after_task(docket: Docket, worker: Worker):
499+
"""Contextvars should be reset after task execution completes"""
500+
captured_stack = None
501+
captured_cache = None
502+
503+
async def capture_task():
504+
nonlocal captured_stack, captured_cache
505+
# Capture references during task execution
506+
captured_stack = _Depends.stack.get()
507+
captured_cache = _Depends.cache.get()
508+
509+
await docket.add(capture_task)()
510+
await worker.run_until_finished()
511+
512+
# After the task completes, the contextvars should be reset
513+
# Attempting to get them should raise LookupError
514+
with pytest.raises(LookupError):
515+
_Depends.stack.get()
516+
517+
with pytest.raises(LookupError):
518+
_Depends.cache.get()
519+
520+
with pytest.raises(LookupError):
521+
Dependency.execution.get()
522+
523+
with pytest.raises(LookupError):
524+
Dependency.worker.get()
525+
526+
with pytest.raises(LookupError):
527+
Dependency.docket.get()
528+
529+
530+
async def test_dependency_cache_isolated_between_tasks(docket: Docket, worker: Worker):
531+
"""Dependency cache should be fresh for each task, not reused"""
532+
call_counts = {"task1": 0, "task2": 0}
533+
534+
def dependency_for_task1() -> str:
535+
call_counts["task1"] += 1
536+
return f"task1-call-{call_counts['task1']}"
537+
538+
def dependency_for_task2() -> str:
539+
call_counts["task2"] += 1
540+
return f"task2-call-{call_counts['task2']}"
541+
542+
async def first_task(val: str = Depends(dependency_for_task1)):
543+
assert val == "task1-call-1"
544+
545+
async def second_task(val: str = Depends(dependency_for_task2)):
546+
assert val == "task2-call-1"
547+
548+
# Run tasks sequentially
549+
await docket.add(first_task)()
550+
await worker.run_until_finished()
551+
552+
await docket.add(second_task)()
553+
await worker.run_until_finished()
554+
555+
# Each dependency should have been called once (no cache leakage between tasks)
556+
assert call_counts["task1"] == 1
557+
assert call_counts["task2"] == 1
558+
559+
560+
async def test_async_exit_stack_cleanup(docket: Docket, worker: Worker):
561+
"""AsyncExitStack should be properly cleaned up after task execution"""
562+
cleanup_called: list[str] = []
563+
564+
@asynccontextmanager
565+
async def tracked_resource():
566+
try:
567+
yield "resource"
568+
finally:
569+
cleanup_called.append("cleaned")
570+
571+
async def task_with_context(res: str = Depends(tracked_resource)):
572+
assert res == "resource"
573+
assert len(cleanup_called) == 0 # Not cleaned up yet
574+
575+
await docket.add(task_with_context)()
576+
await worker.run_until_finished()
577+
578+
# After task completes, cleanup should have been called
579+
assert cleanup_called == ["cleaned"]
580+
581+
582+
async def test_contextvar_reset_on_reentrant_call(docket: Docket, worker: Worker):
583+
"""Contextvars should be properly reset on reentrant calls to resolved_dependencies"""
584+
585+
# Create two mock executions
586+
async def task1(): ...
587+
588+
async def task2(): ...
589+
590+
execution1 = Execution(
591+
key="task1-key",
592+
function=task1,
593+
args=(),
594+
kwargs={},
595+
attempt=1,
596+
when=datetime.now(timezone.utc),
597+
)
598+
599+
execution2 = Execution(
600+
key="task2-key",
601+
function=task2,
602+
args=(),
603+
kwargs={},
604+
attempt=1,
605+
when=datetime.now(timezone.utc),
606+
)
607+
608+
# Capture contextvars from first call
609+
captured_exec1 = None
610+
captured_stack1 = None
611+
612+
async with resolved_dependencies(worker, execution1):
613+
captured_exec1 = Dependency.execution.get()
614+
captured_stack1 = _Depends.stack.get()
615+
assert captured_exec1 is execution1
616+
617+
# After exiting, contextvars should be reset (raise LookupError)
618+
with pytest.raises(LookupError):
619+
Dependency.execution.get()
620+
621+
# Now make a second call - should not see values from first call
622+
async with resolved_dependencies(worker, execution2):
623+
captured_exec2 = Dependency.execution.get()
624+
captured_stack2 = _Depends.stack.get()
625+
assert captured_exec2 is execution2
626+
assert captured_exec2 is not captured_exec1
627+
# Stacks should be different objects
628+
assert captured_stack2 is not captured_stack1
629+
630+
631+
async def test_contextvar_not_leaked_to_caller(docket: Docket):
632+
"""Verify contextvars don't leak outside resolved_dependencies context"""
633+
# Before calling resolved_dependencies, contextvars should not be set
634+
with pytest.raises(LookupError):
635+
Dependency.execution.get()
636+
637+
async def dummy_task(): ...
638+
639+
execution = Execution(
640+
key="test-key",
641+
function=dummy_task,
642+
args=(),
643+
kwargs={},
644+
attempt=1,
645+
when=datetime.now(timezone.utc),
646+
)
647+
648+
async with Docket("test-contextvar-leak", url="memory://leak-test") as test_docket:
649+
async with Worker(test_docket) as test_worker:
650+
# Use resolved_dependencies
651+
async with resolved_dependencies(test_worker, execution):
652+
# Inside context, we should be able to get values
653+
assert Dependency.execution.get() is execution
654+
655+
# After exiting context, contextvars should be cleaned up
656+
with pytest.raises(LookupError):
657+
Dependency.execution.get()
658+
659+
with pytest.raises(LookupError):
660+
_Depends.stack.get()
661+
662+
with pytest.raises(LookupError): # pragma: no branch
663+
_Depends.cache.get()

uv.lock

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)