99
1010from haystack .dataclasses .streaming_chunk import StreamingChunk
1111from rich .console import Console
12+ from rich .live import Live
13+ from rich .markdown import Markdown
1214
1315
1416class StreamingCallbackManager :
@@ -22,6 +24,9 @@ def __init__(self) -> None:
2224 """Initialize the streaming callback."""
2325 self .console = Console ()
2426 self .active_tools : dict [int , dict [str , Any ]] = {}
27+ self .accumulated_text = ""
28+ self .live_display : Live | None = None
29+ self .text_started = False
2530
2631 async def __call__ (self , chunk : StreamingChunk ) -> None :
2732 """Process each streaming chunk asynchronously."""
@@ -34,7 +39,8 @@ async def _handle_chunk(self, chunk: StreamingChunk) -> None:
3439 # 1. Handle text streaming (like "I'll help you troubleshoot...")
3540 if self ._is_text_delta (meta ):
3641 text = meta ["delta" ]["text" ]
37- self .console .print (text , end = "" )
42+ self .accumulated_text += text
43+ await self ._render_markdown_optimistic ()
3844
3945 # 2. Handle tool call start (like list_pipelines, get_pipeline)
4046 elif self ._is_tool_start (meta ):
@@ -52,10 +58,36 @@ async def _handle_chunk(self, chunk: StreamingChunk) -> None:
5258 elif self ._is_message_delta (meta ):
5359 await self ._handle_message_delta (meta )
5460
55- # 6. Handle finish events
56- elif self ._is_finish_event (meta ):
61+ if self ._is_finish_event (meta ):
5762 await self ._handle_finish_event (meta )
5863
64+ async def _render_markdown_optimistic (self ) -> None :
65+ """Render accumulated text as markdown optimistically."""
66+ if not self .accumulated_text .strip ():
67+ return
68+
69+ try :
70+ # Attempt to render as markdown
71+ markdown = Markdown (self .accumulated_text )
72+
73+ # Start live display if not already started
74+ if not self .live_display :
75+ self .live_display = Live (markdown , console = self .console , refresh_per_second = 10 )
76+ self .live_display .start ()
77+ self .text_started = True
78+ else :
79+ # Update the live display
80+ self .live_display .update (markdown )
81+
82+ except Exception :
83+ # Fallback to plain text if markdown parsing fails
84+ if not self .live_display :
85+ self .live_display = Live (self .accumulated_text , console = self .console , refresh_per_second = 10 )
86+ self .live_display .start ()
87+ self .text_started = True
88+ else :
89+ self .live_display .update (self .accumulated_text )
90+
5991 def _is_text_delta (self , meta : dict [str , Any ]) -> bool :
6092 """Check if this is a text streaming chunk."""
6193 return meta .get ("type" ) == "content_block_delta" and meta .get ("delta" , {}).get ("type" ) == "text_delta"
@@ -78,7 +110,7 @@ def _is_message_delta(self, meta: dict[str, Any]) -> bool:
78110
79111 def _is_finish_event (self , meta : dict [str , Any ]) -> bool :
80112 """Check if this is a finish event."""
81- return "finish_reason " in meta
113+ return "stop_reason " in meta . get ( "delta" , {})
82114
83115 async def _handle_tool_start (self , meta : dict [str , Any ]) -> None :
84116 """Handle the start of a tool call."""
@@ -87,6 +119,11 @@ async def _handle_tool_start(self, meta: dict[str, Any]) -> None:
87119 tool_id = content_block ["id" ]
88120 index = meta ["index" ]
89121
122+ # Stop live display if active
123+ if self .live_display :
124+ self .live_display .stop ()
125+ self .live_display = None
126+
90127 # Store tool state
91128 self .active_tools [index ] = {
92129 "name" : tool_name ,
@@ -96,7 +133,7 @@ async def _handle_tool_start(self, meta: dict[str, Any]) -> None:
96133 "args_displayed" : False ,
97134 }
98135
99- # Display tool call header
136+ # Display tool call header (text accumulation continues after tools)
100137 self .console .print () # New line
101138 self .console .print ("┌─ 🔧 Tool Call" , style = "bold cyan" )
102139 self .console .print (f"│ Name: { tool_name } " , style = "cyan" )
@@ -307,9 +344,17 @@ async def _handle_message_delta(self, meta: dict[str, Any]) -> None:
307344
308345 async def _handle_finish_event (self , meta : dict [str , Any ]) -> None :
309346 """Handle finish events."""
310- finish_reason = meta .get ("finish_reason" )
311-
347+ finish_reason = meta .get ("delta" , {}).get ("stop_reason" )
312348 if finish_reason == "tool_call_results" :
313349 # Clean up after tool calls
314350 self .active_tools .clear ()
315351 self .console .print () # Extra line after tools
352+ elif finish_reason == "end_turn" :
353+ # Stop live display and reset for next interaction
354+ if self .live_display :
355+ self .live_display .stop ()
356+ self .live_display = None
357+ # Ensure cursor is on a new line for the next prompt
358+ self .console .print ()
359+ self .accumulated_text = ""
360+ self .text_started = False
0 commit comments