Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/any_agent/callbacks/wrappers/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ def after_tool_execution(*args, **kwargs):
context = callback.after_tool_execution(context, *args, **kwargs)

class _LangChainTracingCallback(BaseCallbackHandler):
# Propagate exceptions from callbacks instead of swallowing them.
# LangChain defaults to raise_error=False which logs warnings but
# continues execution. We need exceptions (especially AgentCancel)
# to propagate so they can be handled by run_async.
raise_error = True

def on_chat_model_start(
self,
serialized: dict[str, Any],
Expand Down
50 changes: 50 additions & 0 deletions src/any_agent/frameworks/any_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,51 @@ def __repr__(self) -> str:
return f"AgentRunError({self._original_exception!r})"


def _unwrap_agent_cancel(exc: BaseException) -> AgentCancel | None:
"""Traverse an exception chain to find an AgentCancel if present.

When callbacks raise AgentCancel subclasses, some frameworks catch and
re-raise them wrapped in their own error types. For example:

- smolagents wraps with AgentGenerationError using `raise ... from e`
- Other frameworks may use similar patterns

Python's exception chaining stores the original exception in __cause__
(explicit: `raise X from Y`) or __context__ (implicit: `raise X` inside
an except block). This function walks that chain to find any AgentCancel.

Note:
This is a defensive catch-all for frameworks that properly chain
exceptions. Some frameworks may swallow exceptions entirely (e.g.,
LangChain's default callback behavior) and require framework-specific
fixes to ensure AgentCancel propagates. See wrapper implementations
for details.

Args:
exc: The exception to inspect.

Returns:
The first AgentCancel found in the exception chain, or None if the
chain contains no AgentCancel instances.

Example:
try:
framework.run() # Raises FrameworkError from AgentCancel
except Exception as e:
if cancel := _unwrap_agent_cancel(e):
# Found the wrapped AgentCancel, re-raise it directly.
raise cancel from e

"""
current: BaseException | None = exc
while current is not None:
if isinstance(current, AgentCancel):
return current
# Check both explicit (raise from) and implicit (raise in except) chaining.
current = current.__cause__ or current.__context__
return None


class AnyAgent(ABC):
"""Base abstract class for all agent implementations.

Expand Down Expand Up @@ -355,6 +400,11 @@ async def run_async(self, prompt: str, **kwargs: Any) -> AgentTrace:
e._trace = trace
raise

# Check if the framework wrapped an AgentCancel in its own error type.
if cancel := _unwrap_agent_cancel(e):
cancel._trace = trace
raise cancel from e

raise AgentRunError(trace, e) from e

async with self._lock:
Expand Down
45 changes: 44 additions & 1 deletion tests/unit/callbacks/wrappers/test_get_wrapper_and_unwrap.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from unittest.mock import MagicMock
from typing import Any
from unittest.mock import AsyncMock, MagicMock

from any_agent import AgentFramework
from any_agent.callbacks.wrappers import _get_wrapper_by_framework
from any_agent.callbacks.wrappers.langchain import _LangChainWrapper


async def test_unwrap_before_wrap(agent_framework: AgentFramework) -> None:
Expand Down Expand Up @@ -30,3 +32,44 @@ async def test_google_instrument_uninstrument() -> None:
assert agent._agent.after_model_callback is None
assert agent._agent.before_tool_callback is None
assert agent._agent.after_tool_callback is None


async def test_langchain_callback_raises_errors() -> None:
"""LangChain callback handler must have raise_error=True to propagate AgentCancel.

By default, LangChain swallows exceptions in callback handlers and only logs
warnings. Setting raise_error=True ensures exceptions (especially AgentCancel
subclasses) propagate so they can be handled by run_async.
"""
agent = MagicMock()
agent._agent = MagicMock()
agent._agent.ainvoke = AsyncMock()
agent.config = MagicMock()
agent.config.callbacks = []

wrapper = _LangChainWrapper()
await wrapper.wrap(agent)

# Call the wrapped ainvoke to trigger callback injection.
captured_kwargs: dict[str, Any] = {}

async def capture_ainvoke(*args: Any, **kwargs: Any) -> MagicMock:
captured_kwargs.update(kwargs)
return MagicMock()

# Replace the mock's original ainvoke to capture the kwargs.
wrapper._original_ainvoke = capture_ainvoke
await agent._agent.ainvoke("test")

# Verify the callback was added with raise_error=True.
assert "config" in captured_kwargs
config = captured_kwargs["config"]
# Config can be a dict or RunnableConfig, handle both.
callbacks = (
config.get("callbacks") if isinstance(config, dict) else config.callbacks
)
assert callbacks is not None
assert len(callbacks) == 1
callback = callbacks[0]
assert hasattr(callback, "raise_error")
assert callback.raise_error is True
69 changes: 69 additions & 0 deletions tests/unit/frameworks/test_agent_cancel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from any_agent import AgentCancel, AgentConfig, AgentFramework, AgentRunError, AnyAgent
from any_agent.callbacks import Callback, Context
from any_agent.frameworks.any_agent import _unwrap_agent_cancel
from any_agent.testing.helpers import DEFAULT_SMALL_MODEL_ID, LLM_IMPORT_PATHS
from any_agent.tracing.agent_trace import AgentTrace

Expand Down Expand Up @@ -148,3 +149,71 @@ async def test_regular_exception_wrapped_in_agent_run_error(self) -> None:
assert str(exc_info.value.original_exception) == "Unexpected error"
assert exc_info.value.trace is not None
assert len(exc_info.value.trace.spans) > 0


class TestUnwrapAgentCancel:
"""Tests for _unwrap_agent_cancel helper function."""

def test_returns_none_for_regular_exception(self) -> None:
"""Returns None when exception chain contains no AgentCancel."""
exc = RuntimeError("regular error")
assert _unwrap_agent_cancel(exc) is None

def test_returns_none_for_chained_regular_exceptions(self) -> None:
"""Returns None when chained exceptions contain no AgentCancel."""
inner = ValueError("inner")
outer = RuntimeError("outer")
outer.__cause__ = inner
assert _unwrap_agent_cancel(outer) is None

def test_finds_direct_agent_cancel(self) -> None:
"""Returns the exception itself if it is an AgentCancel."""
exc = StopAgent("direct")
result = _unwrap_agent_cancel(exc)
assert result is exc

def test_finds_agent_cancel_via_cause(self) -> None:
"""Finds AgentCancel in __cause__ (explicit raise from)."""
cancel = StopAgent("wrapped")
wrapper = RuntimeError("framework error")
wrapper.__cause__ = cancel
result = _unwrap_agent_cancel(wrapper)
assert result is cancel

def test_finds_agent_cancel_via_context(self) -> None:
"""Finds AgentCancel in __context__ (implicit chaining)."""
cancel = StopAgent("wrapped")
wrapper = RuntimeError("framework error")
wrapper.__context__ = cancel
result = _unwrap_agent_cancel(wrapper)
assert result is cancel

def test_finds_deeply_nested_agent_cancel(self) -> None:
"""Finds AgentCancel nested multiple levels deep."""
cancel = StopAgent("deep")
middle = ValueError("middle")
middle.__cause__ = cancel
outer = RuntimeError("outer")
outer.__cause__ = middle
result = _unwrap_agent_cancel(outer)
assert result is cancel

def test_prefers_cause_over_context(self) -> None:
"""When both __cause__ and __context__ exist, follows __cause__ first."""
cause_cancel = StopAgent("from cause")
context_cancel = SpecificStopAgent("from context")
wrapper = RuntimeError("wrapper")
wrapper.__cause__ = cause_cancel
wrapper.__context__ = context_cancel
result = _unwrap_agent_cancel(wrapper)
assert result is cause_cancel

def test_finds_subclass_of_agent_cancel(self) -> None:
"""Finds subclasses of AgentCancel (e.g., SpecificStopAgent)."""
cancel = SpecificStopAgent("specific")
wrapper = RuntimeError("wrapper")
wrapper.__cause__ = cancel
result = _unwrap_agent_cancel(wrapper)
assert result is cancel
assert isinstance(result, StopAgent)
assert isinstance(result, AgentCancel)