-
Notifications
You must be signed in to change notification settings - Fork 603
Expand file tree
/
Copy pathcallback_handler.py
More file actions
363 lines (309 loc) · 14.5 KB
/
callback_handler.py
File metadata and controls
363 lines (309 loc) · 14.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import logging
import threading
import time
from collections.abc import Callable
from typing import Any
from nat.builder.context import Context
from nat.builder.framework_enum import LLMFrameworkEnum
from nat.data_models.intermediate_step import IntermediateStepPayload
from nat.data_models.intermediate_step import IntermediateStepType
from nat.data_models.intermediate_step import StreamEventData
from nat.data_models.intermediate_step import ToolErrorData
from nat.data_models.intermediate_step import TraceMetadata
from nat.data_models.intermediate_step import UsageInfo
from nat.data_models.profiler_callback import BaseProfilerCallback
from nat.data_models.token_usage import TokenUsageBaseModel
logger = logging.getLogger(__name__)
class ADKProfilerHandler(BaseProfilerCallback):
"""
A callback manager/handler for Google ADK that intercepts calls to:
- Tools
- LLMs
to collect usage statistics (tokens, inputs, outputs, time intervals, etc.)
and store them in the usage_stats queue for subsequent analysis.
"""
_instance: "ADKProfilerHandler | None" = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
self._lock = threading.Lock()
self.last_call_ts = 0.0
self.step_manager = Context.get().intermediate_step_manager
# Original references to Google ADK Tool and LLM methods (for uninstrumenting if needed)
self._original_tool_call = None
self._original_llm_call = None
self._instrumented = False
def instrument(self) -> None:
"""
Monkey-patch the relevant Google ADK methods with usage-stat collection logic.
Assumes the 'google-adk' library is installed.
"""
if self._instrumented:
logger.debug("ADKProfilerHandler already instrumented; skipping.")
return
try:
import litellm
except Exception as _e:
logger.exception("litellm import failed; skipping instrumentation")
return
try:
from google.adk.tools.function_tool import FunctionTool
except Exception as _e:
logger.exception("ADK import failed; skipping instrumentation")
return
# Save the originals
self._original_tool_call = FunctionTool.run_async
self._original_llm_call = litellm.acompletion
FunctionTool.run_async = self._tool_use_monkey_patch()
litellm.acompletion = self._llm_call_monkey_patch()
logger.debug("ADKProfilerHandler instrumentation applied successfully.")
self._instrumented = True
def uninstrument(self) -> None:
""" Restore the original Google ADK methods.
Add an explicit unpatch to avoid side-effects across tests/process lifetime.
"""
try:
import litellm
from google.adk.tools.function_tool import FunctionTool
if self._original_tool_call is not None:
FunctionTool.run_async = self._original_tool_call
self._original_tool_call = None
if self._original_llm_call is not None:
litellm.acompletion = self._original_llm_call
self._original_llm_call = None
self._instrumented = False
self.last_call_ts = 0.0
logger.debug("ADKProfilerHandler uninstrumented successfully.")
except Exception as _e:
logger.exception("Failed to uninstrument ADKProfilerHandler")
def ensure_last_call_ts_initialized(self) -> float:
""" Ensure that last_call_ts is initialized to avoid issues in async calls. """
if self.last_call_ts == 0.0:
with self._lock:
# Now that we have the lock, double-check
if self.last_call_ts == 0.0:
self.last_call_ts = time.time()
return self.last_call_ts
def _tool_use_monkey_patch(self) -> Callable[..., Any]:
"""
Returns a function that wraps calls to BaseTool.run_async with usage-logging.
"""
original_func = self._original_tool_call
async def wrapped_tool_use(base_tool_instance, *args, **kwargs) -> Any:
"""
Replicates _tool_use_wrapper logic without wrapt: collects usage stats,
calls the original, and captures output stats.
Args:
base_tool_instance (FunctionTool): The instance of the tool being called.
*args: Positional arguments to the tool.
**kwargs: Keyword arguments to the tool.
Returns:
Any: The result of the tool execution.
"""
self.ensure_last_call_ts_initialized()
now = time.time()
tool_name = ""
try:
tool_name = base_tool_instance.name
except Exception as _e:
logger.exception("Error getting tool name")
tool_name = ""
try:
# Pre-call usage event - safely extract kwargs args if present
kwargs_args = (kwargs.get("args", {}) if isinstance(kwargs.get("args"), dict) else {})
stats = IntermediateStepPayload(
event_type=IntermediateStepType.TOOL_START,
framework=LLMFrameworkEnum.ADK,
name=tool_name,
data=StreamEventData(),
metadata=TraceMetadata(tool_inputs={
"args": args, "kwargs": dict(kwargs_args)
}),
usage_info=UsageInfo(token_usage=TokenUsageBaseModel()),
)
# Store the UUID to ensure the END event uses the same ID
step_uuid = stats.UUID
self.step_manager.push_intermediate_step(stats)
with self._lock:
self.last_call_ts = now
# Call the original _use(...)
if original_func is None:
raise RuntimeError(
"Original tool function is None - instrumentation may not have been set up correctly")
result = await original_func(base_tool_instance, *args, **kwargs)
now = time.time()
# Post-call usage stats - safely extract kwargs args if present
kwargs_args = (kwargs.get("args", {}) if isinstance(kwargs.get("args"), dict) else {})
usage_stat = IntermediateStepPayload(
event_type=IntermediateStepType.TOOL_END,
span_event_timestamp=now,
framework=LLMFrameworkEnum.ADK,
name=tool_name,
data=StreamEventData(
input={
"args": args, "kwargs": dict(kwargs_args)
},
output=str(result),
),
metadata=TraceMetadata(tool_outputs={"result": str(result)}),
usage_info=UsageInfo(token_usage=TokenUsageBaseModel()),
UUID=step_uuid, # Use the same UUID as the START event
)
self.step_manager.push_intermediate_step(usage_stat)
return result
except Exception as e:
logger.error("BaseTool error: %s", e)
kwargs_args = (kwargs.get("args", {}) if isinstance(kwargs.get("args"), dict) else {})
tool_error: ToolErrorData = ToolErrorData(
content=f"{type(e).__name__}: {e!s}",
error_type=type(e).__name__,
error_message=str(e),
)
self.step_manager.push_intermediate_step(
IntermediateStepPayload(
event_type=IntermediateStepType.TOOL_END,
span_event_timestamp=time.time(),
framework=LLMFrameworkEnum.ADK,
name=tool_name,
data=StreamEventData(
input={
"args": args, "kwargs": dict(kwargs_args)
},
output=tool_error,
),
usage_info=UsageInfo(token_usage=TokenUsageBaseModel()),
UUID=step_uuid,
))
raise
return wrapped_tool_use
def _llm_call_monkey_patch(self) -> Callable[..., Any]:
"""
Returns a function that wraps calls to litellm.acompletion(...) with usage-logging.
Returns:
Callable[..., Any]: The wrapped function.
"""
original_func = self._original_llm_call
async def wrapped_llm_call(*args, **kwargs) -> Any:
"""
Replicates _llm_call_wrapper logic without wrapt: collects usage stats,
calls the original, and captures output stats.
Args:
*args: Positional arguments to the LLM call.
**kwargs: Keyword arguments to the LLM call.
Returns:
Any: The result of the LLM call.
"""
self.ensure_last_call_ts_initialized()
now = time.time()
with self._lock:
seconds_between_calls = int(now - self.last_call_ts)
model_name = kwargs.get("model")
if not model_name and args:
first = args[0]
if isinstance(first, str):
model_name = first
model_name = model_name or ""
model_input = []
try:
for message in kwargs.get("messages", []):
content = message.get("content", "")
if isinstance(content, list):
for part in content:
if isinstance(part, dict):
model_input.append(str(part.get("text", ""))) # text parts
else:
model_input.append(str(part))
else:
model_input.append("" if content is None else str(content))
except Exception as _e:
logger.exception("Error getting model input")
model_input = "".join(model_input)
# Record the start event
input_stats = IntermediateStepPayload(
event_type=IntermediateStepType.LLM_START,
framework=LLMFrameworkEnum.ADK,
name=model_name,
data=StreamEventData(input=model_input, payload=kwargs.get("messages", [])),
metadata=TraceMetadata(chat_inputs=copy.deepcopy(kwargs.get("messages", []))),
usage_info=UsageInfo(
token_usage=TokenUsageBaseModel(),
num_llm_calls=1,
seconds_between_calls=seconds_between_calls,
),
)
# Store the UUID to ensure the END event uses the same ID
step_uuid = input_stats.UUID
self.step_manager.push_intermediate_step(input_stats)
# Call the original litellm.acompletion(...)
if original_func is None:
raise RuntimeError("Original LLM function is None - instrumentation may not have been set up correctly")
output = await original_func(*args, **kwargs)
choice_dump = None
model_output = []
try:
for choice in output.choices:
if not choice_dump:
choice_dump = choice.model_dump() if hasattr(
choice, "model_dump") else getattr(choice, "__dict__", {}) or {}
msg = choice.message
model_output.append(msg.content or "")
except Exception as _e:
logger.exception("Error getting model output")
model_output = "".join(model_output)
now = time.time()
# Record the end event
# Prepare safe metadata and usage
chat_resp: dict[str, Any] = {}
try:
if getattr(output, "choices", []):
first_choice = output.choices[0]
chat_resp = first_choice.model_dump() if hasattr(
first_choice, "model_dump") else getattr(first_choice, "__dict__", {}) or {}
except Exception as _e:
logger.exception("Error preparing chat_responses")
usage_payload: dict[str, Any] = {}
try:
usage_obj = getattr(output, "usage", None) or (getattr(output, "model_extra", {}) or {}).get("usage")
if usage_obj:
if hasattr(usage_obj, "model_dump"):
usage_payload = usage_obj.model_dump()
elif isinstance(usage_obj, dict):
usage_payload = usage_obj
except Exception as _e:
logger.exception("Error preparing token usage")
output_stats = IntermediateStepPayload(
event_type=IntermediateStepType.LLM_END,
span_event_timestamp=now,
framework=LLMFrameworkEnum.ADK,
name=model_name,
data=StreamEventData(input=model_input, output=model_output, payload=choice_dump),
metadata=TraceMetadata(chat_responses=chat_resp),
usage_info=UsageInfo(
token_usage=TokenUsageBaseModel(**usage_payload),
num_llm_calls=1,
seconds_between_calls=seconds_between_calls,
),
UUID=step_uuid, # Use the same UUID as the START event
)
self.step_manager.push_intermediate_step(output_stats)
with self._lock:
self.last_call_ts = now
return output
return wrapped_llm_call