Skip to content

Commit 1a960be

Browse files
authored
Preserve Exception in ToolOutput (run-llama#20231)
1 parent 60d102d commit 1a960be

File tree

5 files changed

+41
-2
lines changed

5 files changed

+41
-2
lines changed

llama-index-core/llama_index/core/agent/workflow/base_agent.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,7 @@ async def _call_tool(
301301
raw_input=tool_input,
302302
raw_output=str(e),
303303
is_error=True,
304+
exception=e,
304305
)
305306

306307
return tool_output

llama-index-core/llama_index/core/agent/workflow/multi_agent_workflow.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,7 @@ async def _call_tool(
314314
raw_input=tool_input,
315315
raw_output=str(e),
316316
is_error=True,
317+
exception=e,
317318
)
318319

319320
return tool_output

llama-index-core/llama_index/core/tools/calling.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def call_tool(tool: BaseTool, arguments: dict) -> ToolOutput:
2929
raw_input=arguments,
3030
raw_output=str(e),
3131
is_error=True,
32+
exception=e,
3233
)
3334

3435

@@ -55,6 +56,7 @@ async def acall_tool(tool: BaseTool, arguments: dict) -> ToolOutput:
5556
raw_input=arguments,
5657
raw_output=str(e),
5758
is_error=True,
59+
exception=e,
5860
)
5961

6062

llama-index-core/llama_index/core/tools/types.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
if TYPE_CHECKING:
1111
from llama_index.core.bridge.langchain import StructuredTool, Tool
1212
from deprecated import deprecated
13-
from llama_index.core.bridge.pydantic import BaseModel
13+
from llama_index.core.bridge.pydantic import BaseModel, PrivateAttr
1414

1515

1616
class DefaultToolFnSchema(BaseModel):
@@ -99,6 +99,8 @@ class ToolOutput(BaseModel):
9999
raw_output: Any
100100
is_error: bool = False
101101

102+
_exception: Optional[Exception] = PrivateAttr(default=None)
103+
102104
def __init__(
103105
self,
104106
tool_name: str,
@@ -107,6 +109,7 @@ def __init__(
107109
raw_input: Optional[Dict[str, Any]] = None,
108110
raw_output: Optional[Any] = None,
109111
is_error: bool = False,
112+
exception: Optional[Exception] = None,
110113
):
111114
if content and blocks:
112115
raise ValueError("Cannot provide both content and blocks.")
@@ -125,6 +128,8 @@ def __init__(
125128
is_error=is_error,
126129
)
127130

131+
self._exception = exception
132+
128133
@property
129134
def content(self) -> str:
130135
"""Get the content of the tool output."""
@@ -137,6 +142,11 @@ def content(self, content: str) -> None:
137142
"""Set the content of the tool output."""
138143
self.blocks = [TextBlock(text=content)]
139144

145+
@property
146+
def exception(self) -> Optional[Exception]:
147+
"""Get the exception of the tool output."""
148+
return self._exception
149+
140150
def __str__(self) -> str:
141151
"""String."""
142152
return self.content

llama-index-core/tests/agent/workflow/test_function_call.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
)
1010
from llama_index.core.llms import ChatMessage
1111
from llama_index.core.memory import BaseMemory
12-
from llama_index.core.tools import ToolOutput
12+
from llama_index.core.tools import FunctionTool, ToolOutput
1313
from llama_index.core.workflow.context import Context
1414
from llama_index.core.workflow.events import StopEvent
1515

@@ -432,6 +432,31 @@ async def test_aggregate_tool_results_boolean_logic_verification():
432432
)
433433

434434

435+
@pytest.mark.asyncio
436+
async def test_call_tool_with_exception(mock_context, test_agent):
437+
"""
438+
Test that when a tool raises an exception, _call_tool catches it
439+
and returns a ToolOutput with is_error=True and the exception.
440+
"""
441+
442+
# Arrange
443+
def error_function(x: int) -> str:
444+
raise ValueError("This is a test error")
445+
446+
error_tool = FunctionTool.from_defaults(error_function)
447+
tool_input = {"x": 1}
448+
449+
# Act
450+
tool_output = await test_agent._call_tool(mock_context, error_tool, tool_input)
451+
452+
# Assert
453+
assert tool_output.is_error is True
454+
assert isinstance(tool_output.exception, ValueError)
455+
assert str(tool_output.exception) == "This is a test error"
456+
assert tool_output.tool_name == "error_function"
457+
assert tool_output.raw_input == tool_input
458+
459+
435460
if __name__ == "__main__":
436461
# Run the tests
437462
pytest.main([__file__, "-v"])

0 commit comments

Comments
 (0)