Skip to content

Commit 8576442

Browse files
authored
feat: render markdown when streaming (#124)
1 parent d75a87c commit 8576442

2 files changed

Lines changed: 53 additions & 8 deletions

File tree

src/deepset_mcp/benchmark/runner/repl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ async def streaming_callback(
4343
return await manager(chunk)
4444

4545
# Run the agent
46-
typer.secho("\n🤖 Agent", fg=typer.colors.BLUE, nl=False)
46+
typer.secho("\n🤖 Agent\n\n", fg=typer.colors.BLUE, nl=False)
4747
agent_output = await agent.run_async(messages=history, streaming_callback=streaming_callback)
4848

4949
# The streaming callback handles printing the final text output.

src/deepset_mcp/benchmark/runner/streaming.py

Lines changed: 52 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
from haystack.dataclasses.streaming_chunk import StreamingChunk
1111
from rich.console import Console
12+
from rich.live import Live
13+
from rich.markdown import Markdown
1214

1315

1416
class StreamingCallbackManager:
@@ -22,6 +24,9 @@ def __init__(self) -> None:
2224
"""Initialize the streaming callback."""
2325
self.console = Console()
2426
self.active_tools: dict[int, dict[str, Any]] = {}
27+
self.accumulated_text = ""
28+
self.live_display: Live | None = None
29+
self.text_started = False
2530

2631
async def __call__(self, chunk: StreamingChunk) -> None:
2732
"""Process each streaming chunk asynchronously."""
@@ -34,7 +39,8 @@ async def _handle_chunk(self, chunk: StreamingChunk) -> None:
3439
# 1. Handle text streaming (like "I'll help you troubleshoot...")
3540
if self._is_text_delta(meta):
3641
text = meta["delta"]["text"]
37-
self.console.print(text, end="")
42+
self.accumulated_text += text
43+
await self._render_markdown_optimistic()
3844

3945
# 2. Handle tool call start (like list_pipelines, get_pipeline)
4046
elif self._is_tool_start(meta):
@@ -52,10 +58,36 @@ async def _handle_chunk(self, chunk: StreamingChunk) -> None:
5258
elif self._is_message_delta(meta):
5359
await self._handle_message_delta(meta)
5460

55-
# 6. Handle finish events
56-
elif self._is_finish_event(meta):
61+
if self._is_finish_event(meta):
5762
await self._handle_finish_event(meta)
5863

64+
async def _render_markdown_optimistic(self) -> None:
65+
"""Render accumulated text as markdown optimistically."""
66+
if not self.accumulated_text.strip():
67+
return
68+
69+
try:
70+
# Attempt to render as markdown
71+
markdown = Markdown(self.accumulated_text)
72+
73+
# Start live display if not already started
74+
if not self.live_display:
75+
self.live_display = Live(markdown, console=self.console, refresh_per_second=10)
76+
self.live_display.start()
77+
self.text_started = True
78+
else:
79+
# Update the live display
80+
self.live_display.update(markdown)
81+
82+
except Exception:
83+
# Fallback to plain text if markdown parsing fails
84+
if not self.live_display:
85+
self.live_display = Live(self.accumulated_text, console=self.console, refresh_per_second=10)
86+
self.live_display.start()
87+
self.text_started = True
88+
else:
89+
self.live_display.update(self.accumulated_text)
90+
5991
def _is_text_delta(self, meta: dict[str, Any]) -> bool:
6092
"""Check if this is a text streaming chunk."""
6193
return meta.get("type") == "content_block_delta" and meta.get("delta", {}).get("type") == "text_delta"
@@ -78,7 +110,7 @@ def _is_message_delta(self, meta: dict[str, Any]) -> bool:
78110

79111
def _is_finish_event(self, meta: dict[str, Any]) -> bool:
80112
"""Check if this is a finish event."""
81-
return "finish_reason" in meta
113+
return "stop_reason" in meta.get("delta", {})
82114

83115
async def _handle_tool_start(self, meta: dict[str, Any]) -> None:
84116
"""Handle the start of a tool call."""
@@ -87,6 +119,11 @@ async def _handle_tool_start(self, meta: dict[str, Any]) -> None:
87119
tool_id = content_block["id"]
88120
index = meta["index"]
89121

122+
# Stop live display if active
123+
if self.live_display:
124+
self.live_display.stop()
125+
self.live_display = None
126+
90127
# Store tool state
91128
self.active_tools[index] = {
92129
"name": tool_name,
@@ -96,7 +133,7 @@ async def _handle_tool_start(self, meta: dict[str, Any]) -> None:
96133
"args_displayed": False,
97134
}
98135

99-
# Display tool call header
136+
# Display tool call header (text accumulation continues after tools)
100137
self.console.print() # New line
101138
self.console.print("┌─ 🔧 Tool Call", style="bold cyan")
102139
self.console.print(f"│ Name: {tool_name}", style="cyan")
@@ -307,9 +344,17 @@ async def _handle_message_delta(self, meta: dict[str, Any]) -> None:
307344

308345
async def _handle_finish_event(self, meta: dict[str, Any]) -> None:
309346
"""Handle finish events."""
310-
finish_reason = meta.get("finish_reason")
311-
347+
finish_reason = meta.get("delta", {}).get("stop_reason")
312348
if finish_reason == "tool_call_results":
313349
# Clean up after tool calls
314350
self.active_tools.clear()
315351
self.console.print() # Extra line after tools
352+
elif finish_reason == "end_turn":
353+
# Stop live display and reset for next interaction
354+
if self.live_display:
355+
self.live_display.stop()
356+
self.live_display = None
357+
# Ensure cursor is on a new line for the next prompt
358+
self.console.print()
359+
self.accumulated_text = ""
360+
self.text_started = False

0 commit comments

Comments
 (0)