1818import atexit
1919from concurrent .futures import ThreadPoolExecutor
2020import contextvars
21+ import dataclasses
2122from dataclasses import dataclass
2223from dataclasses import field
2324from datetime import datetime
@@ -120,6 +121,8 @@ def _recursive_smart_truncate(obj: Any, max_len: int) -> tuple[Any, bool]:
120121 return obj , False
121122 elif isinstance (obj , dict ):
122123 truncated_any = False
124+ # Use dict comprehension for potentially slightly better performance,
125+ # but explicit loop is fine for clarity given recursive nature.
123126 new_dict = {}
124127 for k , v in obj .items ():
125128 val , trunc = _recursive_smart_truncate (v , max_len )
@@ -130,13 +133,41 @@ def _recursive_smart_truncate(obj: Any, max_len: int) -> tuple[Any, bool]:
130133 elif isinstance (obj , (list , tuple )):
131134 truncated_any = False
132135 new_list = []
136+ # Explicit loop to handle flag propagation
133137 for i in obj :
134138 val , trunc = _recursive_smart_truncate (i , max_len )
135139 if trunc :
136140 truncated_any = True
137141 new_list .append (val )
138142 return type (obj )(new_list ), truncated_any
139- return obj , False
143+ elif dataclasses .is_dataclass (obj ) and not isinstance (obj , type ):
144+ # Convert dataclasses to dicts so they become valid JSON objects
145+ return _recursive_smart_truncate (dataclasses .asdict (obj ), max_len )
146+ elif hasattr (obj , "model_dump" ) and callable (obj .model_dump ):
147+ # Pydantic v2
148+ try :
149+ return _recursive_smart_truncate (obj .model_dump (), max_len )
150+ except Exception :
151+ pass
152+ elif hasattr (obj , "dict" ) and callable (obj .dict ):
153+ # Pydantic v1
154+ try :
155+ return _recursive_smart_truncate (obj .dict (), max_len )
156+ except Exception :
157+ pass
158+ elif hasattr (obj , "to_dict" ) and callable (obj .to_dict ):
159+ # Common pattern for custom objects
160+ try :
161+ return _recursive_smart_truncate (obj .to_dict (), max_len )
162+ except Exception :
163+ pass
164+ elif obj is None or isinstance (obj , (int , float , bool )):
165+ # Basic types are safe
166+ return obj , False
167+
168+ # Fallback for unknown types: Convert to string to ensure JSON validity
169+ # We return string representation of the object, which is a valid JSON string value.
170+ return str (obj ), False
140171
141172
142173# --- PyArrow Helper Functions ---
@@ -352,9 +383,10 @@ class BigQueryLoggerConfig:
352383# ==============================================================================
353384
354385_trace_id_ctx = contextvars .ContextVar ("_bq_analytics_trace_id" , default = None )
355- _span_stack_ctx = contextvars .ContextVar (
356- "_bq_analytics_span_stack " , default = None
386+ _root_agent_name_ctx = contextvars .ContextVar (
387+ "_bq_analytics_root_agent_name " , default = None
357388)
389+ _span_stack_ctx = contextvars .ContextVar ("_bq_analytics_span_stack" , default = ())
358390_span_times_ctx = contextvars .ContextVar (
359391 "_bq_analytics_span_times" , default = None
360392)
@@ -370,7 +402,13 @@ class TraceManager:
370402 def init_trace (callback_context : CallbackContext ) -> None :
371403 if _trace_id_ctx .get () is None :
372404 _trace_id_ctx .set (callback_context .invocation_id )
373- _span_stack_ctx .set ([])
405+ # Extract root agent name from invocation context
406+ try :
407+ root_agent = callback_context ._invocation_context .agent .root_agent
408+ _root_agent_name_ctx .set (root_agent .name )
409+ except (AttributeError , ValueError ):
410+ pass
411+ _span_stack_ctx .set (())
374412 _span_times_ctx .set ({})
375413 _span_first_token_times_ctx .set ({})
376414
@@ -393,39 +431,29 @@ def push_span(
393431 span_id = span_id or str (uuid .uuid4 ())
394432
395433 stack = _span_stack_ctx .get ()
396- if stack is None :
397- # Should have been called by init_trace, but just in case
398- stack = []
399- _span_stack_ctx .set (stack )
400-
401- stack .append (span_id )
402-
403- times = _span_times_ctx .get ()
404- if times is None :
405- times = {}
406- _span_times_ctx .set (times )
407-
408- first_tokens = _span_first_token_times_ctx .get ()
409- if first_tokens is None :
410- first_tokens = {}
411- _span_first_token_times_ctx .set (first_tokens )
434+ new_stack = stack + (span_id ,)
435+ _span_stack_ctx .set (new_stack )
412436
437+ times = dict (_span_times_ctx .get () or {})
413438 times [span_id ] = time .time ()
439+ _span_times_ctx .set (times )
414440 return span_id
415441
416442 @staticmethod
417443 def pop_span () -> tuple [Optional [str ], Optional [int ]]:
418- stack = _span_stack_ctx .get ()
444+ stack = list ( _span_stack_ctx .get () )
419445 if not stack :
420446 return None , None
421447 span_id = stack .pop ()
448+ _span_stack_ctx .set (tuple (stack ))
422449
423- times = _span_times_ctx .get ()
424- start_time = times .pop (span_id , None ) if times else None
450+ times_dict = dict (_span_times_ctx .get () or {})
451+ start_time = times_dict .pop (span_id , None )
452+ _span_times_ctx .set (times_dict )
425453
426- first_tokens = _span_first_token_times_ctx .get ()
427- if first_tokens :
428- first_tokens . pop ( span_id , None )
454+ ft_dict = dict ( _span_first_token_times_ctx .get () or {} )
455+ ft_dict . pop ( span_id , None )
456+ _span_first_token_times_ctx . set ( ft_dict )
429457
430458 duration_ms = int ((time .time () - start_time ) * 1000 ) if start_time else None
431459 return span_id , duration_ms
@@ -442,6 +470,10 @@ def get_current_span_id() -> Optional[str]:
442470 stack = _span_stack_ctx .get ()
443471 return stack [- 1 ] if stack else None
444472
473+ @staticmethod
474+ def get_root_agent_name () -> Optional [str ]:
475+ return _root_agent_name_ctx .get ()
476+
445477 @staticmethod
446478 def get_start_time (span_id : str ) -> Optional [float ]:
447479 times = _span_times_ctx .get ()
@@ -454,13 +486,10 @@ def record_first_token(span_id: str) -> bool:
454486 Returns:
455487 True if this was the first token (newly recorded), False otherwise.
456488 """
457- first_tokens = _span_first_token_times_ctx .get ()
458- if first_tokens is None :
459- first_tokens = {}
460- _span_first_token_times_ctx .set (first_tokens )
461-
489+ first_tokens = dict (_span_first_token_times_ctx .get () or {})
462490 if span_id not in first_tokens :
463491 first_tokens [span_id ] = time .time ()
492+ _span_first_token_times_ctx .set (first_tokens )
464493 return True
465494 return False
466495
@@ -1218,7 +1247,10 @@ def _get_events_schema() -> list[bigquery.SchemaField]:
12181247 mode = "NULLABLE" ,
12191248 description = (
12201249 "A JSON object containing arbitrary key-value pairs for"
1221- " additional event metadata not covered by standard fields."
1250+ " additional event metadata. Includes enrichment fields like"
1251+ " 'root_agent_name' (turn orchestration), 'model' (request"
1252+ " model), 'model_version' (response version), and"
1253+ " 'usage_metadata' (detailed token counts)."
12221254 ),
12231255 ),
12241256 bigquery .SchemaField (
@@ -1420,7 +1452,6 @@ def get_credentials():
14201452 # Use weakref to avoid circular references that prevent garbage collection
14211453 atexit .register (self ._atexit_cleanup , weakref .proxy (self .batch_processor ))
14221454
1423- @staticmethod
14241455 @staticmethod
14251456 def _atexit_cleanup (batch_processor : "BatchProcessor" ) -> None :
14261457 """Clean up batch processor on script exit."""
@@ -1563,7 +1594,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
15631594 async def _ensure_started (self , ** kwargs ) -> None :
15641595 """Ensures that the plugin is started and initialized."""
15651596 if not self ._started :
1566- # Kept original lock name as it was not explicitly changed in the
1597+ # Kept original lock name as it was not explicitly changed.
15671598 if self ._setup_lock is None :
15681599 self ._setup_lock = asyncio .Lock ()
15691600 async with self ._setup_lock :
@@ -1660,6 +1691,28 @@ async def _log_event(
16601691 status = kwargs .pop ("status" , "OK" )
16611692 error_message = kwargs .pop ("error_message" , None )
16621693
1694+ # V2 Metadata Extensions
1695+ model = kwargs .pop ("model" , None )
1696+ model_version = kwargs .pop ("model_version" , None )
1697+ usage_metadata = kwargs .pop ("usage_metadata" , None )
1698+
1699+ # Add new fields to attributes instead of columns
1700+ kwargs ["root_agent_name" ] = TraceManager .get_root_agent_name ()
1701+ if model :
1702+ kwargs ["model" ] = model
1703+ if model_version :
1704+ kwargs ["model_version" ] = model_version
1705+ if usage_metadata :
1706+ # Use smart truncate to handle Pydantic, Dataclasses, and other objects
1707+ usage_dict , _ = _recursive_smart_truncate (
1708+ usage_metadata , self .config .max_content_length
1709+ )
1710+ if isinstance (usage_dict , dict ):
1711+ kwargs ["usage_metadata" ] = usage_dict
1712+ else :
1713+ # Fallback if it couldn't be converted to dict
1714+ kwargs ["usage_metadata" ] = usage_metadata
1715+
16631716 # Serialize remaining kwargs to JSON string for attributes
16641717 try :
16651718 attributes_json = json .dumps (kwargs )
@@ -1822,6 +1875,7 @@ async def before_model_callback(
18221875 "LLM_REQUEST" ,
18231876 callback_context ,
18241877 raw_content = llm_request ,
1878+ model = llm_request .model ,
18251879 ** attributes ,
18261880 )
18271881
@@ -1921,6 +1975,8 @@ async def after_model_callback(
19211975 raw_content = content_str ,
19221976 is_truncated = is_truncated ,
19231977 latency_ms = duration ,
1978+ model_version = llm_response .model_version ,
1979+ usage_metadata = llm_response .usage_metadata ,
19241980 span_id_override = span_id if is_popped else None ,
19251981 parent_span_id_override = parent_span_id
19261982 if is_popped
0 commit comments