Skip to content

Commit 8928f4b

Browse files
committed
feat: Adding distributed tracing to decorator
1 parent 6952abf commit 8928f4b

File tree

6 files changed

+735
-29
lines changed

6 files changed

+735
-29
lines changed

src/galileo/decorator.py

Lines changed: 101 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def call_llm(prompt, temperature=0.7):
6060
from galileo.schema.metrics import LocalMetricConfig
6161
from galileo.schema.trace import SPAN_TYPE
6262
from galileo.utils import _get_timestamp
63+
from galileo.utils.distributed_tracing import extract_tracing_headers
6364
from galileo.utils.logging import is_concludable_span_type, is_textual_span_type
6465
from galileo.utils.serialization import EventSerializer, serialize_to_str
6566
from galileo.utils.singleton import GalileoLoggerSingleton
@@ -311,7 +312,10 @@ async def async_wrapper(*args, **kwargs) -> Any:
311312
func_args=args,
312313
func_kwargs=kwargs,
313314
)
314-
self._prepare_call(span_type, span_params, dataset_record)
315+
if span_params is None:
316+
return await func(*args, **kwargs)
317+
318+
self._prepare_call(span_type, span_params, dataset_record, func_args=args, func_kwargs=kwargs)
315319
result = None
316320

317321
try:
@@ -365,7 +369,10 @@ def sync_wrapper(*args, **kwargs) -> Any:
365369
func_args=args,
366370
func_kwargs=kwargs,
367371
)
368-
self._prepare_call(span_type, span_params, dataset_record)
372+
if span_params is None:
373+
return func(*args, **kwargs)
374+
375+
self._prepare_call(span_type, span_params, dataset_record, func_args=args, func_kwargs=kwargs)
369376
result = None
370377

371378
try:
@@ -553,7 +560,12 @@ def _get_span_param_names(self, span_type: SPAN_TYPE) -> list[str]:
553560
return span_params.get(span_type, common_params)
554561

555562
def _prepare_call(
556-
self, span_type: Optional[SPAN_TYPE], span_params: dict[str, Any], dataset_record: Optional[DatasetRecord]
563+
self,
564+
span_type: Optional[SPAN_TYPE],
565+
span_params: dict[str, Any],
566+
dataset_record: Optional[DatasetRecord],
567+
func_args: tuple = (),
568+
func_kwargs: Optional[dict] = None,
557569
) -> None:
558570
"""
559571
Prepare the call for logging by setting up trace and span contexts.
@@ -564,23 +576,45 @@ def _prepare_call(
564576
Type of span to create
565577
span_params
566578
Parameters for the span
579+
dataset_record
580+
Optional dataset record
581+
func_args
582+
Function arguments (used to extract distributed tracing headers)
583+
func_kwargs
584+
Function keyword arguments (used to extract distributed tracing headers)
567585
"""
568-
client_instance = self.get_logger_instance()
586+
# Extract distributed tracing headers from function arguments
587+
trace_id, span_id = extract_tracing_headers(func_args=func_args, func_kwargs=func_kwargs)
588+
589+
client_instance = self.get_logger_instance(trace_id=trace_id, span_id=span_id)
569590
_logger.debug(f"client_instance {id(client_instance)} {client_instance}")
570591

571592
input_ = span_params.get("input_serialized", "")
572593
name = span_params.get("name", "")
573594

574-
if not _trace_context.get():
575-
# If the singleton logger has an active trace, use it
576-
if client_instance.has_active_trace():
595+
# If we have trace_id/span_id (distributed tracing in streaming mode), the logger should have loaded an existing trace
596+
# Set the trace context immediately so we don't create a new trace
597+
# In streaming mode, traces are created immediately so we can add spans to them
598+
if trace_id or span_id:
599+
# In streaming mode with distributed tracing, the trace should be in traces[0] after _init_trace() or _init_span()
600+
if client_instance.traces:
601+
# Trace is loaded in traces list - use it!
602+
_trace_context.set(client_instance.traces[0])
603+
_logger.debug(f"Set trace context from distributed tracing: trace_id={client_instance.traces[0].id}")
604+
else:
605+
# This should not happen in streaming mode - if trace_id/span_id was provided, trace should be loaded
606+
raise ValueError(
607+
f"Distributed tracing trace not found in streaming mode (trace_id={trace_id}, span_id={span_id}). "
608+
"The trace should have been loaded during logger initialization."
609+
)
610+
elif not _trace_context.get():
611+
# Normal mode: no distributed tracing, start a new trace if needed
612+
if client_instance.has_active_trace() and client_instance.traces:
577613
trace = client_instance.traces[-1]
578614
else:
579-
# If no trace is available, start a new one
580615
trace = client_instance.start_trace(
581616
input=input_,
582617
name=name,
583-
# TODO: add dataset_row_id
584618
dataset_input=dataset_record.input if dataset_record else None,
585619
dataset_output=dataset_record.output if dataset_record else None,
586620
dataset_metadata=dataset_record.metadata if dataset_record else None,
@@ -707,7 +741,10 @@ def _handle_call_result(self, span_type: Optional[SPAN_TYPE], span_params: dict[
707741
span_params["created_at"] = created_at
708742
span_params["duration_ns"] = 0
709743

710-
logger = self.get_logger_instance()
744+
# Get logger instance - extract trace_id/span_id from context for nested calls
745+
# to ensure we get the same cached logger instance (cache key includes trace_id/span_id)
746+
trace_id, span_id = extract_tracing_headers()
747+
logger = self.get_logger_instance(trace_id=trace_id, span_id=span_id)
711748

712749
# If the span type is a workflow or agent, conclude it
713750
_logger.debug(f"{span_type=} {stack=} {span_params=}")
@@ -829,7 +866,12 @@ async def _wrap_async_generator_result(
829866
self._handle_call_result(span_type, span_params, output)
830867

831868
def get_logger_instance(
832-
self, project: Optional[str] = None, log_stream: Optional[str] = None, experiment_id: Optional[str] = None
869+
self,
870+
project: Optional[str] = None,
871+
log_stream: Optional[str] = None,
872+
experiment_id: Optional[str] = None,
873+
trace_id: Optional[str] = None,
874+
span_id: Optional[str] = None,
833875
) -> GalileoLogger:
834876
"""
835877
Get the Galileo Logger instance for the current decorator context.
@@ -840,15 +882,28 @@ def get_logger_instance(
840882
Optional project name to use
841883
log_stream
842884
Optional log stream name to use
885+
experiment_id
886+
Optional experiment ID to use
887+
trace_id
888+
Optional trace ID for distributed tracing (automatically extracted from headers if not provided)
889+
span_id
890+
Optional span ID for distributed tracing (automatically extracted from headers if not provided)
843891
844892
Returns
845893
-------
846894
GalileoLogger instance configured with the specified project and log stream
847895
"""
896+
# Get mode from context (defaults to "batch" if not set)
897+
# Mode will be overridden to "streaming" if trace_id/span_id is provided
898+
mode = _mode_context.get() or "batch"
899+
848900
return GalileoLoggerSingleton().get(
849901
project=project or _project_context.get(),
850902
log_stream=log_stream or _log_stream_context.get(),
851903
experiment_id=experiment_id or _experiment_id_context.get(),
904+
mode=mode,
905+
trace_id=trace_id,
906+
span_id=span_id,
852907
)
853908

854909
def get_current_project(self) -> Optional[str]:
@@ -976,6 +1031,7 @@ def init(
9761031
log_stream: Optional[str] = None,
9771032
experiment_id: Optional[str] = None,
9781033
local_metrics: Optional[list[LocalMetricConfig]] = None,
1034+
mode: str = "batch",
9791035
) -> None:
9801036
"""
9811037
Initialize the context with a project and log stream. Optionally, it can also be used
@@ -994,15 +1050,19 @@ def init(
9941050
The experiment id. Defaults to None.
9951051
local_metrics
9961052
Local metrics configs to run on the traces/spans before submitting them for ingestion. Defaults to None.
1053+
mode
1054+
The logging mode. Use "streaming" for distributed tracing or real-time logging.
1055+
Use "batch" for batch processing. Defaults to "batch".
9971056
"""
9981057
GalileoLoggerSingleton().reset(project=project, log_stream=log_stream, experiment_id=experiment_id)
9991058
GalileoLoggerSingleton().get(
1000-
project=project, log_stream=log_stream, experiment_id=experiment_id, local_metrics=local_metrics
1059+
project=project, log_stream=log_stream, experiment_id=experiment_id, local_metrics=local_metrics, mode=mode
10011060
)
10021061

10031062
_project_context.set(project)
10041063
_log_stream_context.set(log_stream)
10051064
_experiment_id_context.set(experiment_id)
1065+
_mode_context.set(mode)
10061066
_span_stack_context.set([])
10071067
_trace_context.set(None)
10081068

@@ -1045,6 +1105,35 @@ def set_session(self, session_id: str) -> None:
10451105
"""
10461106
self.get_logger_instance().set_session(session_id)
10471107

1108+
def get_tracing_headers(self) -> dict[str, str]:
1109+
"""
1110+
Get current trace and span IDs as headers for distributed tracing.
1111+
1112+
Similar to LangSmith's `get_current_run_tree().to_headers()`, this method
1113+
returns a dictionary of headers that can be passed to HTTP requests to
1114+
propagate distributed tracing context.
1115+
1116+
Returns
1117+
-------
1118+
dict[str, str]
1119+
Dictionary with X-Trace-ID and/or X-Span-ID headers if available
1120+
"""
1121+
headers = {}
1122+
trace = self.get_current_trace()
1123+
span_stack = self.get_current_span_stack()
1124+
1125+
if trace:
1126+
headers["X-Trace-ID"] = str(trace.id)
1127+
1128+
# Get the most recent span (top of stack)
1129+
if span_stack:
1130+
headers["X-Span-ID"] = str(span_stack[-1].id)
1131+
elif trace:
1132+
# If no span but we have a trace, use trace ID as span ID
1133+
headers["X-Span-ID"] = str(trace.id)
1134+
1135+
return headers
1136+
10481137

10491138
galileo_context = GalileoDecorator()
10501139
log = galileo_context.log

src/galileo/middleware/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
"""
2+
Galileo middleware for web frameworks.
3+
4+
This module provides middleware for automatically extracting distributed tracing
5+
headers from HTTP requests and making them available to the @log decorator.
6+
"""
7+
8+
from galileo.middleware.tracing import TracingMiddleware
9+
10+
__all__ = ["TracingMiddleware"]

src/galileo/middleware/tracing.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
"""
2+
Tracing middleware for FastAPI/Starlette applications.
3+
4+
This middleware automatically extracts distributed tracing headers from incoming
5+
HTTP requests and stores them in a context variable, making them available to
6+
the @log decorator throughout the request lifecycle.
7+
8+
How it works:
9+
1. The `dispatch` method intercepts incoming HTTP requests
10+
2. Extracts X-Trace-ID and X-Span-ID headers from the request
11+
3. Stores them in ContextVar (thread-local context variables)
12+
4. The @log decorator (via extract_tracing_headers in distributed_tracing.py)
13+
reads these context variables to automatically configure distributed tracing
14+
15+
This middleware is only for ASGI frameworks (FastAPI/Starlette). For Flask (WSGI),
16+
users can manually pass request objects to decorated functions, and the decorator
17+
will extract headers from the request object directly.
18+
"""
19+
20+
import logging
21+
from collections.abc import Awaitable, Callable
22+
from contextvars import ContextVar
23+
from typing import Optional
24+
25+
from starlette.middleware.base import BaseHTTPMiddleware
26+
from starlette.requests import Request
27+
from starlette.responses import Response
28+
29+
_logger = logging.getLogger(__name__)
30+
31+
# Context variables to store trace and span IDs
32+
_trace_id_context: ContextVar[Optional[str]] = ContextVar("trace_id_context", default=None)
33+
_span_id_context: ContextVar[Optional[str]] = ContextVar("span_id_context", default=None)
34+
35+
36+
def get_trace_id() -> Optional[str]:
37+
"""Get the current trace ID from context."""
38+
return _trace_id_context.get()
39+
40+
41+
def get_span_id() -> Optional[str]:
42+
"""Get the current span ID from context."""
43+
return _span_id_context.get()
44+
45+
46+
class TracingMiddleware(BaseHTTPMiddleware):
47+
"""
48+
Middleware that extracts distributed tracing headers from HTTP requests.
49+
50+
This middleware automatically extracts X-Trace-ID and X-Span-ID headers
51+
from incoming requests and stores them in context variables. The @log decorator
52+
can then read these values to automatically configure distributed tracing.
53+
54+
Usage:
55+
from fastapi import FastAPI
56+
from galileo.middleware import TracingMiddleware
57+
58+
app = FastAPI()
59+
app.add_middleware(TracingMiddleware)
60+
"""
61+
62+
async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
63+
"""
64+
Extract tracing headers from incoming request and store in context.
65+
66+
This method:
67+
1. Extracts X-Trace-ID and X-Span-ID headers from the HTTP request
68+
2. Stores them in ContextVar (context variables) that are automatically
69+
available throughout the async request lifecycle
70+
3. The @log decorator (via extract_tracing_headers in distributed_tracing.py)
71+
reads these context variables using get_trace_id() and get_span_id()
72+
73+
The context variables are thread-local and async-safe, so they work correctly
74+
with FastAPI/Starlette's async request handling.
75+
76+
Parameters
77+
----------
78+
request
79+
The incoming HTTP request
80+
call_next
81+
The next middleware or route handler in the chain
82+
83+
Returns
84+
-------
85+
Response
86+
The HTTP response from the next handler
87+
"""
88+
# Extract X-Trace-ID and X-Span-ID headers (case-insensitive)
89+
trace_id = request.headers.get("x-trace-id") or request.headers.get("X-Trace-ID")
90+
span_id = request.headers.get("x-span-id") or request.headers.get("X-Span-ID")
91+
92+
# Store in context variables for @log decorator to use
93+
# These context variables are automatically available to extract_tracing_headers()
94+
# via get_trace_id() and get_span_id() throughout the request lifecycle
95+
if trace_id:
96+
_trace_id_context.set(trace_id)
97+
if span_id:
98+
_span_id_context.set(span_id)
99+
100+
# Call the next middleware/route handler
101+
return await call_next(request)

0 commit comments

Comments
 (0)