|
17 | 17 | import inspect |
18 | 18 | import logging |
19 | 19 | import textwrap |
| 20 | +import threading |
20 | 21 | import warnings |
21 | 22 | from concurrent.futures import ThreadPoolExecutor |
22 | 23 | from inspect import Parameter, getsource, signature |
|
37 | 38 | # Shared thread pool for running sync tools without blocking the event loop |
38 | 39 | _SYNC_TOOL_EXECUTOR = ThreadPoolExecutor(max_workers=64) |
39 | 40 |
|
| 41 | +# Persistent event loop to avoid httpx connection pool issues |
| 42 | +_PERSISTENT_LOOP: Optional[asyncio.AbstractEventLoop] = None |
| 43 | +_PERSISTENT_LOOP_LOCK = threading.Lock() |
| 44 | + |
40 | 45 |
|
41 | 46 | def _remove_a_key(d: Dict, remove_key: Any) -> None: |
42 | 47 | r"""Remove a key from a dictionary recursively.""" |
@@ -482,37 +487,102 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: |
482 | 487 | if self.synthesize_output: |
483 | 488 | result = self.synthesize_execution_output(args, kwargs) |
484 | 489 | return result |
485 | | - else: |
486 | | - # Pass the extracted arguments to the indicated function |
| 490 | + |
| 491 | + # Call the function first |
| 492 | + try: |
| 493 | + result = self.func(*args, **kwargs) |
| 494 | + except Exception as e: |
| 495 | + parts = [] |
| 496 | + if args: |
| 497 | + parts.append(f"args={args}") |
| 498 | + if kwargs: |
| 499 | + parts.append(f"kwargs={kwargs}") |
| 500 | + args_str = ", ".join(parts) if parts else "no arguments" |
| 501 | + raise ValueError( |
| 502 | + f"Execution of function {self.func.__name__} failed with " |
| 503 | + f"{args_str}. Error: {e}" |
| 504 | + ) |
| 505 | + |
| 506 | + # Handle coroutine result (from async function or sync wrapper |
| 507 | + # returning coroutine) |
| 508 | + if inspect.iscoroutine(result): |
| 509 | + # Check if there's already a running event loop |
487 | 510 | try: |
488 | | - result = self.func(*args, **kwargs) |
489 | | - return result |
490 | | - except Exception as e: |
491 | | - parts = [] |
492 | | - if args: |
493 | | - parts.append(f"args={args}") |
494 | | - if kwargs: |
495 | | - parts.append(f"kwargs={kwargs}") |
496 | | - args_str = ", ".join(parts) if parts else "no arguments" |
497 | | - raise ValueError( |
498 | | - f"Execution of function {self.func.__name__} failed with " |
499 | | - f"{args_str}. Error: {e}" |
| 511 | + asyncio.get_running_loop() |
| 512 | + has_running_loop = True |
| 513 | + except RuntimeError: |
| 514 | + has_running_loop = False |
| 515 | + |
| 516 | + if has_running_loop: |
| 517 | + # Already in an async context |
| 518 | + warnings.warn( |
| 519 | + f"Async tool '{self.func.__name__}' is being called " |
| 520 | + f"synchronously within an async context. Consider using " |
| 521 | + f"'await tool.async_call()' or 'await agent.astep()' for " |
| 522 | + f"better performance.", |
| 523 | + RuntimeWarning, |
| 524 | + stacklevel=2, |
| 525 | + ) |
| 526 | + # Must run in separate thread to avoid blocking current loop |
| 527 | + future = _SYNC_TOOL_EXECUTOR.submit( |
| 528 | + self._run_async_in_persistent_loop, result |
| 529 | + ) |
| 530 | + return future.result() |
| 531 | + else: |
| 532 | + warnings.warn( |
| 533 | + f"Async tool '{self.func.__name__}' is being called " |
| 534 | + f"synchronously. Consider using 'await tool.async_call()' " |
| 535 | + f"or 'await agent.astep()' for better performance.", |
| 536 | + RuntimeWarning, |
| 537 | + stacklevel=2, |
500 | 538 | ) |
| 539 | + return self._run_async_in_persistent_loop(result) |
| 540 | + |
| 541 | + return result |
| 542 | + |
| 543 | + @staticmethod |
| 544 | + def _run_async_in_persistent_loop(coro): |
| 545 | + r"""Run coroutine in persistent loop to preserve httpx connections.""" |
| 546 | + global _PERSISTENT_LOOP |
| 547 | + with _PERSISTENT_LOOP_LOCK: |
| 548 | + need_new_loop = ( |
| 549 | + _PERSISTENT_LOOP is None |
| 550 | + or _PERSISTENT_LOOP.is_closed() |
| 551 | + or not _PERSISTENT_LOOP.is_running() |
| 552 | + ) |
| 553 | + if need_new_loop: |
| 554 | + _PERSISTENT_LOOP = asyncio.new_event_loop() |
| 555 | + t = threading.Thread( |
| 556 | + target=_PERSISTENT_LOOP.run_forever, daemon=True |
| 557 | + ) |
| 558 | + t.start() |
| 559 | + while not _PERSISTENT_LOOP.is_running(): |
| 560 | + pass # Wait for loop to start |
| 561 | + future = asyncio.run_coroutine_threadsafe(coro, _PERSISTENT_LOOP) |
| 562 | + return future.result() |
501 | 563 |
|
502 | 564 | async def async_call(self, *args: Any, **kwargs: Any) -> Any: |
503 | 565 | if self.synthesize_output: |
504 | 566 | result = self.synthesize_execution_output(args, kwargs) |
505 | 567 | return result |
506 | | - if self.is_async: |
| 568 | + |
| 569 | + # Check if the function itself (not unwrapped) is a coroutine function |
| 570 | + if inspect.iscoroutinefunction(self.func): |
507 | 571 | return await self.func(*args, **kwargs) |
508 | | - else: |
509 | | - # Run sync function in executor to avoid blocking event loop |
510 | | - # Use functools.partial to properly capture args/kwargs |
511 | | - loop = asyncio.get_running_loop() |
512 | | - return await loop.run_in_executor( |
513 | | - _SYNC_TOOL_EXECUTOR, |
514 | | - functools.partial(self.func, *args, **kwargs), |
515 | | - ) |
| 572 | + |
| 573 | + # For sync functions (including sync wrappers around async functions), |
| 574 | + # run in executor to avoid blocking |
| 575 | + loop = asyncio.get_running_loop() |
| 576 | + result = await loop.run_in_executor( |
| 577 | + _SYNC_TOOL_EXECUTOR, |
| 578 | + functools.partial(self.func, *args, **kwargs), |
| 579 | + ) |
| 580 | + |
| 581 | + # If the sync wrapper returned a coroutine, await it |
| 582 | + if inspect.iscoroutine(result): |
| 583 | + return await result |
| 584 | + |
| 585 | + return result |
516 | 586 |
|
517 | 587 | @property |
518 | 588 | def is_async(self) -> bool: |
|
0 commit comments