Skip to content

feat: Propagate W3C trace context headers from clients #2153

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion llama_stack/distribution/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,18 @@ async def __call__(self, scope, receive, send):
logger.debug(f"No matching endpoint found for path: {path}, falling back to FastAPI")
return await self.app(scope, receive, send)

trace_context = await start_trace(trace_path, {"__location__": "server", "raw_path": path})
trace_attributes = {"__location__": "server", "raw_path": path}

# Extract W3C trace context headers and store as trace attributes
headers = dict(scope.get("headers", []))
traceparent = headers.get(b"traceparent", b"").decode()
if traceparent:
trace_attributes["traceparent"] = traceparent
tracestate = headers.get(b"tracestate", b"").decode()
if tracestate:
trace_attributes["tracestate"] = tracestate

trace_context = await start_trace(trace_path, trace_attributes)

async def send_with_trace_id(message):
if message["type"] == "http.response.start":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.semconv.resource import ResourceAttributes
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator

from llama_stack.apis.telemetry import (
Event,
Expand Down Expand Up @@ -44,6 +45,7 @@
)
from llama_stack.providers.utils.telemetry.dataset_mixin import TelemetryDatasetMixin
from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTraceStore
from llama_stack.providers.utils.telemetry.tracing import ROOT_SPAN_MARKERS

from .config import TelemetryConfig, TelemetrySink

Expand Down Expand Up @@ -206,6 +208,15 @@ def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None:
event.attributes = {}
event.attributes["__ttl__"] = ttl_seconds

# Extract these W3C trace context attributes so they are not written to
# underlying storage, as we just need them to propagate the trace context.
traceparent = event.attributes.pop("traceparent", None)
tracestate = event.attributes.pop("tracestate", None)
if traceparent:
# If we have a traceparent header value, we're not the root span.
for root_attribute in ROOT_SPAN_MARKERS:
event.attributes.pop(root_attribute, None)

if isinstance(event.payload, SpanStartPayload):
# Check if span already exists to prevent duplicates
if span_id in _GLOBAL_STORAGE["active_spans"]:
Expand All @@ -216,8 +227,12 @@ def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None:
parent_span_id = int(event.payload.parent_span_id, 16)
parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id)
context = trace.set_span_in_context(parent_span)
else:
event.attributes["__root_span__"] = "true"
elif traceparent:
carrier = {
"traceparent": traceparent,
"tracestate": tracestate,
}
context = TraceContextTextMapPropagator().extract(carrier=carrier)

span = tracer.start_span(
name=event.payload.name,
Expand Down
5 changes: 4 additions & 1 deletion llama_stack/providers/utils/telemetry/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
INVALID_SPAN_ID = 0x0000000000000000
INVALID_TRACE_ID = 0x00000000000000000000000000000000

ROOT_SPAN_MARKERS = ["__root__", "__root_span__"]


def trace_id_to_str(trace_id: int) -> str:
"""Convenience trace ID formatting method
Expand Down Expand Up @@ -178,7 +180,8 @@ async def start_trace(name: str, attributes: dict[str, Any] = None) -> TraceCont

trace_id = generate_trace_id()
context = TraceContext(BACKGROUND_LOGGER, trace_id)
context.push_span(name, {"__root__": True, **(attributes or {})})
attributes = {marker: True for marker in ROOT_SPAN_MARKERS} | (attributes or {})
context.push_span(name, attributes)

CURRENT_TRACE_CONTEXT.set(context)
return context
Expand Down
Loading