|
17 | 17 | from __future__ import annotations |
18 | 18 |
|
19 | 19 | import asyncio |
| 20 | +from concurrent.futures import ThreadPoolExecutor |
20 | 21 | import copy |
| 22 | +import functools |
21 | 23 | import inspect |
22 | 24 | import logging |
23 | 25 | import threading |
|
53 | 55 |
|
54 | 56 | logger = logging.getLogger('google_adk.' + __name__) |
55 | 57 |
|
| 58 | +# Global thread pool executors for running tools in background threads. |
| 59 | +# This prevents blocking tools from blocking the event loop in Live API mode. |
| 60 | +# Key is max_workers, value is the executor. |
| 61 | +_TOOL_THREAD_POOLS: dict[int, ThreadPoolExecutor] = {} |
| 62 | +_TOOL_THREAD_POOL_LOCK = threading.Lock() |
| 63 | + |
| 64 | + |
| 65 | +def _get_tool_thread_pool(max_workers: int = 4) -> ThreadPoolExecutor: |
| 66 | + """Gets or creates a thread pool executor for tool execution. |
| 67 | +
|
| 68 | + Args: |
| 69 | + max_workers: Maximum number of worker threads in the pool. |
| 70 | +
|
| 71 | + Returns: |
| 72 | + A ThreadPoolExecutor with the specified max_workers. |
| 73 | + """ |
| 74 | + if max_workers not in _TOOL_THREAD_POOLS: |
| 75 | + with _TOOL_THREAD_POOL_LOCK: |
| 76 | + if max_workers not in _TOOL_THREAD_POOLS: |
| 77 | + _TOOL_THREAD_POOLS[max_workers] = ThreadPoolExecutor( |
| 78 | + max_workers=max_workers, thread_name_prefix='adk_tool_executor' |
| 79 | + ) |
| 80 | + return _TOOL_THREAD_POOLS[max_workers] |
| 81 | + |
| 82 | + |
| 83 | +def _is_sync_tool(tool: BaseTool) -> bool: |
| 84 | + """Checks if a tool's underlying function is synchronous.""" |
| 85 | + if not hasattr(tool, 'func'): |
| 86 | + return False |
| 87 | + func = tool.func |
| 88 | + return not ( |
| 89 | + inspect.iscoroutinefunction(func) |
| 90 | + or inspect.isasyncgenfunction(func) |
| 91 | + or ( |
| 92 | + hasattr(func, '__call__') |
| 93 | + and inspect.iscoroutinefunction(func.__call__) |
| 94 | + ) |
| 95 | + ) |
| 96 | + |
| 97 | + |
| 98 | +async def _call_tool_in_thread_pool( |
| 99 | + tool: BaseTool, |
| 100 | + args: dict[str, Any], |
| 101 | + tool_context: ToolContext, |
| 102 | + max_workers: int = 4, |
| 103 | +) -> Any: |
| 104 | + """Runs a tool in a thread pool to avoid blocking the event loop. |
| 105 | +
|
| 106 | + For sync tools, this runs the tool's function directly in a background thread. |
| 107 | + For async tools, this creates a new event loop in the background thread and |
| 108 | + runs the async function there. This helps catch blocking I/O (like time.sleep, |
| 109 | + network calls, file I/O) that was mistakenly used inside async functions. |
| 110 | +
|
| 111 | + Note: Due to Python's GIL, this does NOT help with pure Python CPU-bound code. |
| 112 | + Thread pool only helps when the GIL is released (blocking I/O, C extensions). |
| 113 | +
|
| 114 | + Args: |
| 115 | + tool: The tool to execute. |
| 116 | + args: Arguments to pass to the tool. |
| 117 | + tool_context: The tool context. |
| 118 | + max_workers: Maximum number of worker threads in the pool. |
| 119 | +
|
| 120 | + Returns: |
| 121 | + The result of running the tool. |
| 122 | + """ |
| 123 | + from ...tools.function_tool import FunctionTool |
| 124 | + |
| 125 | + loop = asyncio.get_running_loop() |
| 126 | + executor = _get_tool_thread_pool(max_workers) |
| 127 | + |
| 128 | + if _is_sync_tool(tool): |
| 129 | + # For sync FunctionTool, call the underlying function directly |
| 130 | + def run_sync_tool(): |
| 131 | + if isinstance(tool, FunctionTool): |
| 132 | + args_to_call = tool._preprocess_args(args) |
| 133 | + signature = inspect.signature(tool.func) |
| 134 | + valid_params = {param for param in signature.parameters} |
| 135 | + if 'tool_context' in valid_params: |
| 136 | + args_to_call['tool_context'] = tool_context |
| 137 | + args_to_call = { |
| 138 | + k: v for k, v in args_to_call.items() if k in valid_params |
| 139 | + } |
| 140 | + return tool.func(**args_to_call) |
| 141 | + else: |
| 142 | + # For other sync tool types, we can't easily run them in thread pool |
| 143 | + return None |
| 144 | + |
| 145 | + result = await loop.run_in_executor(executor, run_sync_tool) |
| 146 | + if result is not None: |
| 147 | + return result |
| 148 | + else: |
| 149 | + # For async tools, run them in a new event loop in a background thread. |
| 150 | + # This helps when async functions contain blocking I/O (common user mistake) |
| 151 | + # that would otherwise block the main event loop. |
| 152 | + def run_async_tool_in_new_loop(): |
| 153 | + # Create a new event loop for this thread |
| 154 | + return asyncio.run(tool.run_async(args=args, tool_context=tool_context)) |
| 155 | + |
| 156 | + return await loop.run_in_executor(executor, run_async_tool_in_new_loop) |
| 157 | + |
| 158 | + # Fall back to normal async execution for non-FunctionTool sync tools |
| 159 | + return await tool.run_async(args=args, tool_context=tool_context) |
| 160 | + |
56 | 161 |
|
57 | 162 | def generate_client_function_call_id() -> str: |
58 | 163 | return f'{AF_FUNCTION_CALL_ID_PREFIX}{uuid.uuid4()}' |
@@ -706,9 +811,19 @@ async def run_tool_and_update_queue(tool, function_args, tool_context): |
706 | 811 | ) |
707 | 812 | } |
708 | 813 | else: |
709 | | - function_response = await __call_tool_async( |
710 | | - tool, args=function_args, tool_context=tool_context |
711 | | - ) |
| 814 | + # Check if we should run tools in thread pool to avoid blocking event loop |
| 815 | + thread_pool_config = invocation_context.run_config.tool_thread_pool_config |
| 816 | + if thread_pool_config is not None: |
| 817 | + function_response = await _call_tool_in_thread_pool( |
| 818 | + tool, |
| 819 | + args=function_args, |
| 820 | + tool_context=tool_context, |
| 821 | + max_workers=thread_pool_config.max_workers, |
| 822 | + ) |
| 823 | + else: |
| 824 | + function_response = await __call_tool_async( |
| 825 | + tool, args=function_args, tool_context=tool_context |
| 826 | + ) |
712 | 827 | return function_response |
713 | 828 |
|
714 | 829 |
|
|
0 commit comments