Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions src/docket/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)

from opentelemetry import trace
from opentelemetry.trace import Tracer
from opentelemetry.trace import Status, StatusCode, Tracer
from redis.asyncio import Redis
from redis.exceptions import ConnectionError, LockError

Expand Down Expand Up @@ -531,7 +531,7 @@ async def _execute(self, execution: Execution) -> None:
"code.function.name": execution.function.__name__,
},
links=execution.incoming_span_links(),
):
) as span:
try:
async with resolved_dependencies(self, execution) as dependencies:
# Preemptively reschedule the perpetual task for the future, or clear
Expand Down Expand Up @@ -576,6 +576,8 @@ async def _execute(self, execution: Execution) -> None:
duration = log_context["duration"] = time.time() - start
TASKS_SUCCEEDED.add(1, counter_labels)

span.set_status(Status(StatusCode.OK))

rescheduled = await self._perpetuate_if_requested(
execution, dependencies, timedelta(seconds=duration)
)
Expand All @@ -584,10 +586,13 @@ async def _execute(self, execution: Execution) -> None:
logger.info(
"%s [%s] %s", arrow, ms(duration), call, extra=log_context
)
except Exception:
except Exception as e:
duration = log_context["duration"] = time.time() - start
TASKS_FAILED.add(1, counter_labels)

span.record_exception(e)
span.set_status(Status(StatusCode.ERROR, str(e)))

retried = await self._retry_if_requested(execution, dependencies)
if not retried:
retried = await self._perpetuate_if_requested(
Expand Down
92 changes: 92 additions & 0 deletions tests/test_instrumentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from opentelemetry.metrics import Counter, Histogram, UpDownCounter
from opentelemetry.metrics import _Gauge as Gauge
from opentelemetry.sdk.trace import Span, TracerProvider
from opentelemetry.trace import StatusCode

from docket import Docket, Worker
from docket.dependencies import Retry
Expand Down Expand Up @@ -98,6 +99,97 @@ async def the_task():
assert link.context.span_id == originating_span.context.span_id


async def test_failed_task_span_has_error_status(docket: Docket, worker: Worker):
"""When a task fails, its span should have ERROR status."""
captured: list[Span] = []

async def the_failing_task():
span = trace.get_current_span()
assert isinstance(span, Span)
captured.append(span)
raise ValueError("Task failed")

await docket.add(the_failing_task)()
await worker.run_until_finished()

assert len(captured) == 1
(task_span,) = captured

assert isinstance(task_span, Span)
assert task_span.status is not None
assert task_span.status.status_code == StatusCode.ERROR
assert task_span.status.description is not None
assert "Task failed" in task_span.status.description


async def test_retried_task_spans_have_error_status(docket: Docket, worker: Worker):
"""When a task fails and is retried, each failed attempt's span should have ERROR status."""
captured: list[Span] = []
attempt_count = 0

async def the_retrying_task(retry: Retry = Retry(attempts=3)):
nonlocal attempt_count
attempt_count += 1
span = trace.get_current_span()
assert isinstance(span, Span)
captured.append(span)

if attempt_count < 3:
raise ValueError(f"Attempt {attempt_count} failed")
# Third attempt succeeds

await docket.add(the_retrying_task)()
await worker.run_until_finished()

assert len(captured) == 3

# First two attempts should have ERROR status
for i in range(2):
span = captured[i]
assert isinstance(span, Span)
assert span.status is not None
assert span.status.status_code == StatusCode.ERROR
assert span.status.description is not None
assert f"Attempt {i + 1} failed" in span.status.description

# Third attempt should have OK status (or no status set, which is treated as OK)
success_span = captured[2]
assert isinstance(success_span, Span)
assert (
success_span.status is None or success_span.status.status_code == StatusCode.OK
)


async def test_infinitely_retrying_task_spans_have_error_status(
docket: Docket, worker: Worker
):
"""When a task with infinite retries fails, each attempt's span should have ERROR status."""
captured: list[Span] = []
attempt_count = 0

async def the_infinite_retry_task(retry: Retry = Retry(attempts=None)):
nonlocal attempt_count
attempt_count += 1
span = trace.get_current_span()
assert isinstance(span, Span)
captured.append(span)
raise ValueError(f"Attempt {attempt_count} failed")

execution = await docket.add(the_infinite_retry_task)()

# Run worker for only 3 task executions of this specific task
await worker.run_at_most({execution.key: 3})

# All captured spans should have ERROR status
assert len(captured) == 3
for i, span in enumerate(captured):
assert isinstance(span, Span)
assert span.status is not None
assert span.status.status_code == StatusCode.ERROR
assert span.status.description is not None
assert f"Attempt {i + 1} failed" in span.status.description


async def test_message_getter_returns_none_for_missing_key():
"""Should return None when a key is not present in the message."""

Expand Down