88"""
99
1010from dataclasses import dataclass , field
11+ import json
1112from typing import (
1213 Any ,
1314 AsyncIterator ,
3031import pydash
3132
3233from ..utils .helper import merge , MergeOptions
34+ from ..utils .reasoning import (
35+ get_reasoning_content ,
36+ is_thinking_enabled_from_env ,
37+ )
3338from .model import (
3439 AgentEvent ,
3540 AgentRequest ,
@@ -60,6 +65,14 @@ class TextState:
6065 message_id : str = field (default_factory = lambda : str (uuid .uuid4 ()))
6166
6267
68+ @dataclass
69+ class ReasoningState :
70+ started : bool = False
71+ message_started : bool = False
72+ phase_id : str = field (default_factory = lambda : str (uuid .uuid4 ()))
73+ message_id : str = field (default_factory = lambda : str (uuid .uuid4 ()))
74+
75+
6376@dataclass
6477class ToolCallState :
6578 name : str = ""
@@ -72,6 +85,7 @@ class ToolCallState:
7285@dataclass
7386class StreamStateMachine :
7487 text : TextState = field (default_factory = TextState )
88+ reasoning : ReasoningState = field (default_factory = ReasoningState )
7589 tool_call_states : Dict [str , ToolCallState ] = field (default_factory = dict )
7690 tool_result_chunks : Dict [str , List [str ]] = field (default_factory = dict )
7791 run_errored : bool = False
@@ -121,6 +135,43 @@ def cache_tool_result_chunk(self, tool_id: str, delta: str) -> None:
121135 def pop_tool_result_chunks (self , tool_id : str ) -> str :
122136 return "" .join (self .tool_result_chunks .pop (tool_id , []))
123137
138+ def ensure_reasoning_started (self ) -> Iterator [str ]:
139+ if not self .reasoning .started :
140+ yield _encode_reasoning_event (
141+ "REASONING_START" ,
142+ messageId = self .reasoning .phase_id ,
143+ )
144+ self .reasoning .started = True
145+ if not self .reasoning .message_started :
146+ yield _encode_reasoning_event (
147+ "REASONING_MESSAGE_START" ,
148+ messageId = self .reasoning .message_id ,
149+ role = "reasoning" ,
150+ )
151+ self .reasoning .message_started = True
152+
153+ def end_reasoning_if_open (self ) -> Iterator [str ]:
154+ if self .reasoning .message_started :
155+ yield _encode_reasoning_event (
156+ "REASONING_MESSAGE_END" ,
157+ messageId = self .reasoning .message_id ,
158+ )
159+ self .reasoning .message_started = False
160+ if self .reasoning .started :
161+ yield _encode_reasoning_event (
162+ "REASONING_END" ,
163+ messageId = self .reasoning .phase_id ,
164+ )
165+ self .reasoning = ReasoningState ()
166+
167+
168+ def _encode_reasoning_event (event_type : str , ** payload : Any ) -> str :
169+ return (
170+ "data: "
171+ + json .dumps ({"type" : event_type , ** payload }, ensure_ascii = False )
172+ + "\n \n "
173+ )
174+
124175
125176class AGUIProtocolHandler (BaseProtocolHandler ):
126177 """AG-UI 协议处理器
@@ -376,6 +427,10 @@ async def _format_stream(
376427 if state .run_errored :
377428 return
378429
430+ # 结束未结束的 reasoning 消息
431+ for sse_data in state .end_reasoning_if_open ():
432+ yield sse_data
433+
379434 # 结束所有未结束的工具调用
380435 for sse_data in state .end_all_tools (self ._encoder ):
381436 yield sse_data
@@ -399,8 +454,6 @@ def _process_event_with_boundaries(
399454 state : StreamStateMachine ,
400455 ) -> Iterator [str ]:
401456 """处理事件并注入边界事件"""
402- import json
403-
404457 from ag_ui .core import CustomEvent as AguiCustomEvent
405458 from ag_ui .core import (
406459 RunErrorEvent ,
@@ -413,6 +466,8 @@ def _process_event_with_boundaries(
413466 ToolCallStartEvent ,
414467 )
415468
469+ thinking_enabled = is_thinking_enabled_from_env ()
470+
416471 # RAW 事件直接透传
417472 if event .event == EventType .RAW :
418473 raw_data = event .data .get ("raw" , "" )
@@ -422,9 +477,46 @@ def _process_event_with_boundaries(
422477 yield raw_data
423478 return
424479
480+ if event .event == EventType .REASONING :
481+ if thinking_enabled :
482+ reasoning_content = (
483+ event .data .get ("delta" )
484+ or get_reasoning_content (event .data )
485+ or ""
486+ )
487+ if reasoning_content :
488+ for sse_data in state .end_text_if_open (self ._encoder ):
489+ yield sse_data
490+ for sse_data in state .end_all_tools (self ._encoder ):
491+ yield sse_data
492+ for sse_data in state .ensure_reasoning_started ():
493+ yield sse_data
494+ yield _encode_reasoning_event (
495+ "REASONING_MESSAGE_CONTENT" ,
496+ messageId = state .reasoning .message_id ,
497+ delta = reasoning_content ,
498+ )
499+ return
500+
425501 # TEXT 事件:在首个 TEXT 前注入 TEXT_MESSAGE_START
426502 # AG-UI 协议要求:发送 TEXT_MESSAGE_START 前必须先结束所有未结束的 TOOL_CALL
427503 if event .event == EventType .TEXT :
504+ addition = self ._strip_reasoning_from_addition (
505+ event .addition , thinking_enabled
506+ )
507+ addition_reasoning = get_reasoning_content (event .addition or {})
508+ if thinking_enabled and addition_reasoning :
509+ for sse_data in state .ensure_reasoning_started ():
510+ yield sse_data
511+ yield _encode_reasoning_event (
512+ "REASONING_MESSAGE_CONTENT" ,
513+ messageId = state .reasoning .message_id ,
514+ delta = addition_reasoning ,
515+ )
516+
517+ for sse_data in state .end_reasoning_if_open ():
518+ yield sse_data
519+
428520 for sse_data in state .end_all_tools (self ._encoder ):
429521 yield sse_data
430522
@@ -435,13 +527,13 @@ def _process_event_with_boundaries(
435527 message_id = state .text .message_id ,
436528 delta = event .data .get ("delta" , "" ),
437529 )
438- if event . addition :
530+ if addition :
439531 event_dict = agui_event .model_dump (
440532 by_alias = True , exclude_none = True
441533 )
442534 event_dict = self ._apply_addition (
443535 event_dict ,
444- event . addition ,
536+ addition ,
445537 event .addition_merge_options ,
446538 )
447539 json_str = json .dumps (event_dict , ensure_ascii = False )
@@ -455,6 +547,9 @@ def _process_event_with_boundaries(
455547 tool_id = event .data .get ("id" , "" )
456548 tool_name = event .data .get ("name" , "" )
457549
550+ for sse_data in state .end_reasoning_if_open ():
551+ yield sse_data
552+
458553 for sse_data in state .end_text_if_open (self ._encoder ):
459554 yield sse_data
460555
@@ -491,6 +586,9 @@ def _process_event_with_boundaries(
491586 tool_name = event .data .get ("name" , "" )
492587 tool_args = event .data .get ("args" , "" )
493588
589+ for sse_data in state .end_reasoning_if_open ():
590+ yield sse_data
591+
494592 for sse_data in state .end_text_if_open (self ._encoder ):
495593 yield sse_data
496594
@@ -541,6 +639,9 @@ def _process_event_with_boundaries(
541639 timeout = event .data .get ("timeout" )
542640 schema = event .data .get ("schema" )
543641
642+ for sse_data in state .end_reasoning_if_open ():
643+ yield sse_data
644+
544645 for sse_data in state .end_text_if_open (self ._encoder ):
545646 yield sse_data
546647
@@ -601,6 +702,9 @@ def _process_event_with_boundaries(
601702 tool_id = event .data .get ("id" , "" )
602703 tool_name = event .data .get ("name" , "" )
603704
705+ for sse_data in state .end_reasoning_if_open ():
706+ yield sse_data
707+
604708 for sse_data in state .end_text_if_open (self ._encoder ):
605709 yield sse_data
606710
@@ -767,6 +871,29 @@ def _apply_addition(
767871
768872 return merge (event_data , addition , ** (merge_options or {}))
769873
874+ def _strip_reasoning_from_addition (
875+ self ,
876+ addition : Optional [Dict [str , Any ]],
877+ thinking_enabled : bool ,
878+ ) -> Optional [Dict [str , Any ]]:
879+ if not addition :
880+ return addition
881+
882+ stripped = dict (addition )
883+ stripped .pop ("reasoning_content" , None )
884+ additional_kwargs = stripped .get ("additional_kwargs" )
885+ if isinstance (additional_kwargs , dict ):
886+ additional_kwargs = dict (additional_kwargs )
887+ additional_kwargs .pop ("reasoning_content" , None )
888+ if additional_kwargs :
889+ stripped ["additional_kwargs" ] = additional_kwargs
890+ else :
891+ stripped .pop ("additional_kwargs" , None )
892+
893+ if not thinking_enabled :
894+ return stripped
895+ return stripped or None
896+
770897 async def _error_stream (self , message : str ) -> AsyncIterator [str ]:
771898 """生成错误事件流
772899
0 commit comments