Skip to content

Commit a4116a6

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: Enhance TraceManager async safety, enrich BigQuery plugin logging, and fix serialization
* **Async Safety:** Improved TraceManager context variable handling to ensure correct context isolation in concurrent asynchronous operations. This was achieved by using immutable tuples for the span stack and making copies of context dictionaries before modification. * **Enhanced Logging:** The BigQueryAgentAnalyticsPlugin now captures richer metadata, including: * Root agent name (via a new context variable). * LLM model name and version. * Usage metadata from LLM requests and responses. * **Serialization Fix:** Updated BigQueryAgentAnalyticsPlugin to prevent JSON serialization errors when logging custom objects (e.g., Dataclasses). These are now automatically converted to dictionaries or string representations to ensure successful insertion into BigQuery. PiperOrigin-RevId: 855415320
1 parent 2592f01 commit a4116a6

File tree

2 files changed

+338
-38
lines changed

2 files changed

+338
-38
lines changed

src/google/adk/plugins/bigquery_agent_analytics_plugin.py

Lines changed: 90 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import atexit
1919
from concurrent.futures import ThreadPoolExecutor
2020
import contextvars
21+
import dataclasses
2122
from dataclasses import dataclass
2223
from dataclasses import field
2324
from datetime import datetime
@@ -120,6 +121,8 @@ def _recursive_smart_truncate(obj: Any, max_len: int) -> tuple[Any, bool]:
120121
return obj, False
121122
elif isinstance(obj, dict):
122123
truncated_any = False
124+
# Use dict comprehension for potentially slightly better performance,
125+
# but explicit loop is fine for clarity given recursive nature.
123126
new_dict = {}
124127
for k, v in obj.items():
125128
val, trunc = _recursive_smart_truncate(v, max_len)
@@ -130,13 +133,41 @@ def _recursive_smart_truncate(obj: Any, max_len: int) -> tuple[Any, bool]:
130133
elif isinstance(obj, (list, tuple)):
131134
truncated_any = False
132135
new_list = []
136+
# Explicit loop to handle flag propagation
133137
for i in obj:
134138
val, trunc = _recursive_smart_truncate(i, max_len)
135139
if trunc:
136140
truncated_any = True
137141
new_list.append(val)
138142
return type(obj)(new_list), truncated_any
139-
return obj, False
143+
elif dataclasses.is_dataclass(obj) and not isinstance(obj, type):
144+
# Convert dataclasses to dicts so they become valid JSON objects
145+
return _recursive_smart_truncate(dataclasses.asdict(obj), max_len)
146+
elif hasattr(obj, "model_dump") and callable(obj.model_dump):
147+
# Pydantic v2
148+
try:
149+
return _recursive_smart_truncate(obj.model_dump(), max_len)
150+
except Exception:
151+
pass
152+
elif hasattr(obj, "dict") and callable(obj.dict):
153+
# Pydantic v1
154+
try:
155+
return _recursive_smart_truncate(obj.dict(), max_len)
156+
except Exception:
157+
pass
158+
elif hasattr(obj, "to_dict") and callable(obj.to_dict):
159+
# Common pattern for custom objects
160+
try:
161+
return _recursive_smart_truncate(obj.to_dict(), max_len)
162+
except Exception:
163+
pass
164+
elif obj is None or isinstance(obj, (int, float, bool)):
165+
# Basic types are safe
166+
return obj, False
167+
168+
# Fallback for unknown types: Convert to string to ensure JSON validity
169+
# We return string representation of the object, which is a valid JSON string value.
170+
return str(obj), False
140171

141172

142173
# --- PyArrow Helper Functions ---
@@ -352,9 +383,10 @@ class BigQueryLoggerConfig:
352383
# ==============================================================================
353384

354385
_trace_id_ctx = contextvars.ContextVar("_bq_analytics_trace_id", default=None)
355-
_span_stack_ctx = contextvars.ContextVar(
356-
"_bq_analytics_span_stack", default=None
386+
_root_agent_name_ctx = contextvars.ContextVar(
387+
"_bq_analytics_root_agent_name", default=None
357388
)
389+
_span_stack_ctx = contextvars.ContextVar("_bq_analytics_span_stack", default=())
358390
_span_times_ctx = contextvars.ContextVar(
359391
"_bq_analytics_span_times", default=None
360392
)
@@ -370,7 +402,13 @@ class TraceManager:
370402
def init_trace(callback_context: CallbackContext) -> None:
371403
if _trace_id_ctx.get() is None:
372404
_trace_id_ctx.set(callback_context.invocation_id)
373-
_span_stack_ctx.set([])
405+
# Extract root agent name from invocation context
406+
try:
407+
root_agent = callback_context._invocation_context.agent.root_agent
408+
_root_agent_name_ctx.set(root_agent.name)
409+
except (AttributeError, ValueError):
410+
pass
411+
_span_stack_ctx.set(())
374412
_span_times_ctx.set({})
375413
_span_first_token_times_ctx.set({})
376414

@@ -393,39 +431,29 @@ def push_span(
393431
span_id = span_id or str(uuid.uuid4())
394432

395433
stack = _span_stack_ctx.get()
396-
if stack is None:
397-
# Should have been called by init_trace, but just in case
398-
stack = []
399-
_span_stack_ctx.set(stack)
400-
401-
stack.append(span_id)
402-
403-
times = _span_times_ctx.get()
404-
if times is None:
405-
times = {}
406-
_span_times_ctx.set(times)
407-
408-
first_tokens = _span_first_token_times_ctx.get()
409-
if first_tokens is None:
410-
first_tokens = {}
411-
_span_first_token_times_ctx.set(first_tokens)
434+
new_stack = stack + (span_id,)
435+
_span_stack_ctx.set(new_stack)
412436

437+
times = dict(_span_times_ctx.get() or {})
413438
times[span_id] = time.time()
439+
_span_times_ctx.set(times)
414440
return span_id
415441

416442
@staticmethod
417443
def pop_span() -> tuple[Optional[str], Optional[int]]:
418-
stack = _span_stack_ctx.get()
444+
stack = list(_span_stack_ctx.get())
419445
if not stack:
420446
return None, None
421447
span_id = stack.pop()
448+
_span_stack_ctx.set(tuple(stack))
422449

423-
times = _span_times_ctx.get()
424-
start_time = times.pop(span_id, None) if times else None
450+
times_dict = dict(_span_times_ctx.get() or {})
451+
start_time = times_dict.pop(span_id, None)
452+
_span_times_ctx.set(times_dict)
425453

426-
first_tokens = _span_first_token_times_ctx.get()
427-
if first_tokens:
428-
first_tokens.pop(span_id, None)
454+
ft_dict = dict(_span_first_token_times_ctx.get() or {})
455+
ft_dict.pop(span_id, None)
456+
_span_first_token_times_ctx.set(ft_dict)
429457

430458
duration_ms = int((time.time() - start_time) * 1000) if start_time else None
431459
return span_id, duration_ms
@@ -442,6 +470,10 @@ def get_current_span_id() -> Optional[str]:
442470
stack = _span_stack_ctx.get()
443471
return stack[-1] if stack else None
444472

473+
@staticmethod
474+
def get_root_agent_name() -> Optional[str]:
475+
return _root_agent_name_ctx.get()
476+
445477
@staticmethod
446478
def get_start_time(span_id: str) -> Optional[float]:
447479
times = _span_times_ctx.get()
@@ -454,13 +486,10 @@ def record_first_token(span_id: str) -> bool:
454486
Returns:
455487
True if this was the first token (newly recorded), False otherwise.
456488
"""
457-
first_tokens = _span_first_token_times_ctx.get()
458-
if first_tokens is None:
459-
first_tokens = {}
460-
_span_first_token_times_ctx.set(first_tokens)
461-
489+
first_tokens = dict(_span_first_token_times_ctx.get() or {})
462490
if span_id not in first_tokens:
463491
first_tokens[span_id] = time.time()
492+
_span_first_token_times_ctx.set(first_tokens)
464493
return True
465494
return False
466495

@@ -1218,7 +1247,10 @@ def _get_events_schema() -> list[bigquery.SchemaField]:
12181247
mode="NULLABLE",
12191248
description=(
12201249
"A JSON object containing arbitrary key-value pairs for"
1221-
" additional event metadata not covered by standard fields."
1250+
" additional event metadata. Includes enrichment fields like"
1251+
" 'root_agent_name' (turn orchestration), 'model' (request"
1252+
" model), 'model_version' (response version), and"
1253+
" 'usage_metadata' (detailed token counts)."
12221254
),
12231255
),
12241256
bigquery.SchemaField(
@@ -1420,7 +1452,6 @@ def get_credentials():
14201452
# Use weakref to avoid circular references that prevent garbage collection
14211453
atexit.register(self._atexit_cleanup, weakref.proxy(self.batch_processor))
14221454

1423-
@staticmethod
14241455
@staticmethod
14251456
def _atexit_cleanup(batch_processor: "BatchProcessor") -> None:
14261457
"""Clean up batch processor on script exit."""
@@ -1563,7 +1594,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
15631594
async def _ensure_started(self, **kwargs) -> None:
15641595
"""Ensures that the plugin is started and initialized."""
15651596
if not self._started:
1566-
# Kept original lock name as it was not explicitly changed in the
1597+
# Kept original lock name as it was not explicitly changed.
15671598
if self._setup_lock is None:
15681599
self._setup_lock = asyncio.Lock()
15691600
async with self._setup_lock:
@@ -1660,6 +1691,28 @@ async def _log_event(
16601691
status = kwargs.pop("status", "OK")
16611692
error_message = kwargs.pop("error_message", None)
16621693

1694+
# V2 Metadata Extensions
1695+
model = kwargs.pop("model", None)
1696+
model_version = kwargs.pop("model_version", None)
1697+
usage_metadata = kwargs.pop("usage_metadata", None)
1698+
1699+
# Add new fields to attributes instead of columns
1700+
kwargs["root_agent_name"] = TraceManager.get_root_agent_name()
1701+
if model:
1702+
kwargs["model"] = model
1703+
if model_version:
1704+
kwargs["model_version"] = model_version
1705+
if usage_metadata:
1706+
# Use smart truncate to handle Pydantic, Dataclasses, and other objects
1707+
usage_dict, _ = _recursive_smart_truncate(
1708+
usage_metadata, self.config.max_content_length
1709+
)
1710+
if isinstance(usage_dict, dict):
1711+
kwargs["usage_metadata"] = usage_dict
1712+
else:
1713+
# Fallback if it couldn't be converted to dict
1714+
kwargs["usage_metadata"] = usage_metadata
1715+
16631716
# Serialize remaining kwargs to JSON string for attributes
16641717
try:
16651718
attributes_json = json.dumps(kwargs)
@@ -1822,6 +1875,7 @@ async def before_model_callback(
18221875
"LLM_REQUEST",
18231876
callback_context,
18241877
raw_content=llm_request,
1878+
model=llm_request.model,
18251879
**attributes,
18261880
)
18271881

@@ -1921,6 +1975,8 @@ async def after_model_callback(
19211975
raw_content=content_str,
19221976
is_truncated=is_truncated,
19231977
latency_ms=duration,
1978+
model_version=llm_response.model_version,
1979+
usage_metadata=llm_response.usage_metadata,
19241980
span_id_override=span_id if is_popped else None,
19251981
parent_span_id_override=parent_span_id
19261982
if is_popped

0 commit comments

Comments
 (0)