@@ -60,6 +60,7 @@ def call_llm(prompt, temperature=0.7):
6060from galileo .schema .metrics import LocalMetricConfig
6161from galileo .schema .trace import SPAN_TYPE
6262from galileo .utils import _get_timestamp
63+ from galileo .utils .distributed_tracing import extract_tracing_headers
6364from galileo .utils .logging import is_concludable_span_type , is_textual_span_type
6465from galileo .utils .serialization import EventSerializer , serialize_to_str
6566from galileo .utils .singleton import GalileoLoggerSingleton
@@ -311,7 +312,10 @@ async def async_wrapper(*args, **kwargs) -> Any:
311312 func_args = args ,
312313 func_kwargs = kwargs ,
313314 )
314- self ._prepare_call (span_type , span_params , dataset_record )
315+ if span_params is None :
316+ return await func (* args , ** kwargs )
317+
318+ self ._prepare_call (span_type , span_params , dataset_record , func_args = args , func_kwargs = kwargs )
315319 result = None
316320
317321 try :
@@ -365,7 +369,10 @@ def sync_wrapper(*args, **kwargs) -> Any:
365369 func_args = args ,
366370 func_kwargs = kwargs ,
367371 )
368- self ._prepare_call (span_type , span_params , dataset_record )
372+ if span_params is None :
373+ return func (* args , ** kwargs )
374+
375+ self ._prepare_call (span_type , span_params , dataset_record , func_args = args , func_kwargs = kwargs )
369376 result = None
370377
371378 try :
@@ -553,7 +560,12 @@ def _get_span_param_names(self, span_type: SPAN_TYPE) -> list[str]:
553560 return span_params .get (span_type , common_params )
554561
555562 def _prepare_call (
556- self , span_type : Optional [SPAN_TYPE ], span_params : dict [str , Any ], dataset_record : Optional [DatasetRecord ]
563+ self ,
564+ span_type : Optional [SPAN_TYPE ],
565+ span_params : dict [str , Any ],
566+ dataset_record : Optional [DatasetRecord ],
567+ func_args : tuple = (),
568+ func_kwargs : Optional [dict ] = None ,
557569 ) -> None :
558570 """
559571 Prepare the call for logging by setting up trace and span contexts.
@@ -564,23 +576,45 @@ def _prepare_call(
564576 Type of span to create
565577 span_params
566578 Parameters for the span
579+ dataset_record
580+ Optional dataset record
581+ func_args
582+ Function arguments (used to extract distributed tracing headers)
583+ func_kwargs
584+ Function keyword arguments (used to extract distributed tracing headers)
567585 """
568- client_instance = self .get_logger_instance ()
586+ # Extract distributed tracing headers from function arguments
587+ trace_id , span_id = extract_tracing_headers (func_args = func_args , func_kwargs = func_kwargs )
588+
589+ client_instance = self .get_logger_instance (trace_id = trace_id , span_id = span_id )
569590 _logger .debug (f"client_instance { id (client_instance )} { client_instance } " )
570591
571592 input_ = span_params .get ("input_serialized" , "" )
572593 name = span_params .get ("name" , "" )
573594
574- if not _trace_context .get ():
575- # If the singleton logger has an active trace, use it
576- if client_instance .has_active_trace ():
595+ # If we have trace_id/span_id (distributed tracing in streaming mode), the logger should have loaded an existing trace
596+ # Set the trace context immediately so we don't create a new trace
597+ # In streaming mode, traces are created immediately so we can add spans to them
598+ if trace_id or span_id :
599+ # In streaming mode with distributed tracing, the trace should be in traces[0] after _init_trace() or _init_span()
600+ if client_instance .traces :
601+ # Trace is loaded in traces list - use it!
602+ _trace_context .set (client_instance .traces [0 ])
603+ _logger .debug (f"Set trace context from distributed tracing: trace_id={ client_instance .traces [0 ].id } " )
604+ else :
605+ # This should not happen in streaming mode - if trace_id/span_id was provided, trace should be loaded
606+ raise ValueError (
607+ f"Distributed tracing trace not found in streaming mode (trace_id={ trace_id } , span_id={ span_id } ). "
608+ "The trace should have been loaded during logger initialization."
609+ )
610+ elif not _trace_context .get ():
611+ # Normal mode: no distributed tracing, start a new trace if needed
612+ if client_instance .has_active_trace () and client_instance .traces :
577613 trace = client_instance .traces [- 1 ]
578614 else :
579- # If no trace is available, start a new one
580615 trace = client_instance .start_trace (
581616 input = input_ ,
582617 name = name ,
583- # TODO: add dataset_row_id
584618 dataset_input = dataset_record .input if dataset_record else None ,
585619 dataset_output = dataset_record .output if dataset_record else None ,
586620 dataset_metadata = dataset_record .metadata if dataset_record else None ,
@@ -707,7 +741,10 @@ def _handle_call_result(self, span_type: Optional[SPAN_TYPE], span_params: dict[
707741 span_params ["created_at" ] = created_at
708742 span_params ["duration_ns" ] = 0
709743
710- logger = self .get_logger_instance ()
744+ # Get logger instance - extract trace_id/span_id from context for nested calls
745+ # to ensure we get the same cached logger instance (cache key includes trace_id/span_id)
746+ trace_id , span_id = extract_tracing_headers ()
747+ logger = self .get_logger_instance (trace_id = trace_id , span_id = span_id )
711748
712749 # If the span type is a workflow or agent, conclude it
713750 _logger .debug (f"{ span_type = } { stack = } { span_params = } " )
@@ -829,7 +866,12 @@ async def _wrap_async_generator_result(
829866 self ._handle_call_result (span_type , span_params , output )
830867
831868 def get_logger_instance (
832- self , project : Optional [str ] = None , log_stream : Optional [str ] = None , experiment_id : Optional [str ] = None
869+ self ,
870+ project : Optional [str ] = None ,
871+ log_stream : Optional [str ] = None ,
872+ experiment_id : Optional [str ] = None ,
873+ trace_id : Optional [str ] = None ,
874+ span_id : Optional [str ] = None ,
833875 ) -> GalileoLogger :
834876 """
835877 Get the Galileo Logger instance for the current decorator context.
@@ -840,15 +882,28 @@ def get_logger_instance(
840882 Optional project name to use
841883 log_stream
842884 Optional log stream name to use
885+ experiment_id
886+ Optional experiment ID to use
887+ trace_id
888+ Optional trace ID for distributed tracing (automatically extracted from headers if not provided)
889+ span_id
890+ Optional span ID for distributed tracing (automatically extracted from headers if not provided)
843891
844892 Returns
845893 -------
846894 GalileoLogger instance configured with the specified project and log stream
847895 """
896+ # Get mode from context (defaults to "batch" if not set)
897+ # Mode will be overridden to "streaming" if trace_id/span_id is provided
898+ mode = _mode_context .get () or "batch"
899+
848900 return GalileoLoggerSingleton ().get (
849901 project = project or _project_context .get (),
850902 log_stream = log_stream or _log_stream_context .get (),
851903 experiment_id = experiment_id or _experiment_id_context .get (),
904+ mode = mode ,
905+ trace_id = trace_id ,
906+ span_id = span_id ,
852907 )
853908
854909 def get_current_project (self ) -> Optional [str ]:
@@ -976,6 +1031,7 @@ def init(
9761031 log_stream : Optional [str ] = None ,
9771032 experiment_id : Optional [str ] = None ,
9781033 local_metrics : Optional [list [LocalMetricConfig ]] = None ,
1034+ mode : str = "batch" ,
9791035 ) -> None :
9801036 """
9811037 Initialize the context with a project and log stream. Optionally, it can also be used
@@ -994,15 +1050,19 @@ def init(
9941050 The experiment id. Defaults to None.
9951051 local_metrics
9961052 Local metrics configs to run on the traces/spans before submitting them for ingestion. Defaults to None.
1053+ mode
1054+ The logging mode. Use "streaming" for distributed tracing or real-time logging.
1055+ Use "batch" for batch processing. Defaults to "batch".
9971056 """
9981057 GalileoLoggerSingleton ().reset (project = project , log_stream = log_stream , experiment_id = experiment_id )
9991058 GalileoLoggerSingleton ().get (
1000- project = project , log_stream = log_stream , experiment_id = experiment_id , local_metrics = local_metrics
1059+ project = project , log_stream = log_stream , experiment_id = experiment_id , local_metrics = local_metrics , mode = mode
10011060 )
10021061
10031062 _project_context .set (project )
10041063 _log_stream_context .set (log_stream )
10051064 _experiment_id_context .set (experiment_id )
1065+ _mode_context .set (mode )
10061066 _span_stack_context .set ([])
10071067 _trace_context .set (None )
10081068
@@ -1045,6 +1105,35 @@ def set_session(self, session_id: str) -> None:
10451105 """
10461106 self .get_logger_instance ().set_session (session_id )
10471107
1108+ def get_tracing_headers (self ) -> dict [str , str ]:
1109+ """
1110+ Get current trace and span IDs as headers for distributed tracing.
1111+
1112+ Similar to LangSmith's `get_current_run_tree().to_headers()`, this method
1113+ returns a dictionary of headers that can be passed to HTTP requests to
1114+ propagate distributed tracing context.
1115+
1116+ Returns
1117+ -------
1118+ dict[str, str]
1119+ Dictionary with X-Trace-ID and/or X-Span-ID headers if available
1120+ """
1121+ headers = {}
1122+ trace = self .get_current_trace ()
1123+ span_stack = self .get_current_span_stack ()
1124+
1125+ if trace :
1126+ headers ["X-Trace-ID" ] = str (trace .id )
1127+
1128+ # Get the most recent span (top of stack)
1129+ if span_stack :
1130+ headers ["X-Span-ID" ] = str (span_stack [- 1 ].id )
1131+ elif trace :
1132+ # If no span but we have a trace, use trace ID as span ID
1133+ headers ["X-Span-ID" ] = str (trace .id )
1134+
1135+ return headers
1136+
10481137
10491138galileo_context = GalileoDecorator ()
10501139log = galileo_context .log
0 commit comments