Skip to content

Commit 3b782a7

Browse files
authored
Prettier tool calls when streaming (#71)
* Limit the amount of dataframe rows to show when streaming * Always show raw tool call responses when streaming * Show tool call names when streaming
1 parent 458cf7a commit 3b782a7

File tree

2 files changed

+29
-14
lines changed

2 files changed

+29
-14
lines changed

databao/agents/frontend/messages.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,14 @@
1-
from langchain_core.messages import AIMessage, AIMessageChunk, ToolCall
1+
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage, ToolCall, ToolMessage
2+
3+
4+
def get_tool_call(messages: list[BaseMessage], tool_message: ToolMessage) -> ToolCall | None:
5+
"""Returns the tool call which caused the ToolMessage."""
6+
for message in reversed(messages):
7+
if isinstance(message, AIMessage):
8+
for tool_call in message.tool_calls:
9+
if tool_call["id"] == tool_message.tool_call_id:
10+
return tool_call
11+
return None
212

313

414
def get_tool_call_sql(tool_call: ToolCall) -> str | None:

databao/agents/frontend/text_frontend.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pandas as pd
55
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage, BaseMessageChunk, ToolMessage
66

7-
from databao.agents.frontend.messages import get_reasoning_content, get_tool_call_sql
7+
from databao.agents.frontend.messages import get_reasoning_content, get_tool_call, get_tool_call_sql
88

99

1010
class TextStreamFrontend:
@@ -32,9 +32,10 @@ def write(self, text: str) -> None:
3232
self.start()
3333
print(text, end="", flush=True, file=self._writer)
3434

35-
def write_dataframe(self, df: pd.DataFrame) -> None:
36-
self.write(df.to_markdown())
37-
self.write("\n\n")
35+
def write_dataframe(self, df: pd.DataFrame, *, name: str | None = None, max_rows: int = 10) -> None:
36+
rows_to_show = min(max_rows, len(df))
37+
self.write(f"[df: name={name or ''}, showing {rows_to_show} / {len(df)} rows]\n")
38+
self.write(df.head(rows_to_show).to_markdown() + "\n\n")
3839

3940
def write_message_chunk(self, chunk: BaseMessageChunk) -> None:
4041
if not isinstance(chunk, AIMessageChunk):
@@ -50,7 +51,10 @@ def write_message_chunk(self, chunk: BaseMessageChunk) -> None:
5051
# N.B. LangChain sometimes waits for the whole string to complete before yielding chunks
5152
# That's why long "sql" tool calls take some time to show up and then the whole sql is shown in a batch
5253
if not self._is_tool_calling:
53-
self.write("\n\n```\n") # Open code block
54+
self.write("\n\n")
55+
for tool_call_chunk in chunk.tool_call_chunks:
56+
self.write(f"[tool_call: '{tool_call_chunk['name']}']\n")
57+
self.write("```\n") # Open code block
5458
self._is_tool_calling = True
5559
for tool_call_chunk in chunk.tool_call_chunks:
5660
if tool_call_chunk["args"] is not None:
@@ -74,19 +78,20 @@ def write_state_chunk(self, state_chunk: dict[str, Any]) -> None:
7478

7579
for message in new_messages:
7680
if isinstance(message, ToolMessage):
77-
if message.artifact is not None:
78-
if "df" in message.artifact and message.artifact["df"] is not None:
79-
self.write_dataframe(message.artifact["df"])
80-
else:
81-
self.write(f"\n```\n{message.content}\n```\n\n")
82-
else:
83-
self.write(f"\n```\n{message.content}\n```\n\n")
81+
tool_call = get_tool_call(messages, message)
82+
tool_name = tool_call["name"] if tool_call is not None else "unknown"
83+
self.write(f"\n[tool_call_output: '{tool_name}']")
84+
self.write(f"\n```\n{message.text().strip()}\n```\n\n")
85+
if message.artifact is not None and isinstance(message.artifact, dict):
86+
for art_name, art_value in message.artifact.items():
87+
if isinstance(art_value, pd.DataFrame):
88+
self.write_dataframe(art_value, name=art_name)
8489
elif self._pretty_sql and isinstance(message, AIMessage):
8590
# During tool calling we show raw JSON chunks, but for SQL we also want pretty formatting.
8691
for tool_call in message.tool_calls:
8792
sql = get_tool_call_sql(tool_call)
8893
if sql is not None:
89-
self.write(f"\n```sql\n{sql}\n```\n\n")
94+
self.write(f"\n```sql\n{sql.strip()}\n```\n\n")
9095

9196
def write_stream_chunk(self, mode: str, chunk: Any) -> None:
9297
if mode == "messages":

0 commit comments

Comments
 (0)