@@ -76,6 +76,11 @@ async def exec_query(self, res):
7676 # 构建请求参数
7777 dify_service_url , body_params , headers = self ._build_request (chat_id , cleaned_query , app_key , qa_type )
7878
79+ # 收集流式输出结果
80+ t02_answer_data = []
81+ # 收集业务数据流式输出结果
82+ t04_answer_data = {}
83+
7984 async with aiohttp .ClientSession (read_bufsize = 1024 * 16 ) as session :
8085 async with session .post (
8186 dify_service_url ,
@@ -114,43 +119,28 @@ async def exec_query(self, res):
114119 if data_type == DataTypeEnum .ANSWER .value [0 ]:
115120 await self .send_message (
116121 res ,
117- qa_context ,
118122 answer ,
119123 {"data" : {"messageType" : "begin" }, "dataType" : data_type },
120- qa_type ,
121- conversation_id ,
122- message_id ,
123- task_id ,
124124 )
125125 elif event_list [1 ] == "1" :
126126 # 输出结束
127127 data_type = event_list [2 ]
128128 if data_type == DataTypeEnum .ANSWER .value [0 ]:
129129 await self .send_message (
130130 res ,
131- qa_context ,
132131 answer ,
133132 {"data" : {"messageType" : "end" }, "dataType" : data_type },
134- qa_type ,
135- conversation_id ,
136- message_id ,
137- task_id ,
138133 )
139134
140135 # 输出业务数据
141136 elif bus_data and data_type == DataTypeEnum .BUS_DATA .value [0 ]:
142137 res_data = process (json .loads (bus_data )["data" ])
143- # logging.info(f"chart_data: {res_data}")
144138 await self .send_message (
145139 res ,
146- qa_context ,
147140 answer ,
148141 {"data" : res_data , "dataType" : data_type },
149- qa_type ,
150- conversation_id ,
151- message_id ,
152- task_id ,
153142 )
143+ t04_answer_data = {"data" : res_data , "dataType" : data_type }
154144
155145 data_type = ""
156146
@@ -159,15 +149,12 @@ async def exec_query(self, res):
159149 if data_type == DataTypeEnum .ANSWER .value [0 ]:
160150 await self .send_message (
161151 res ,
162- qa_context ,
163152 answer ,
164153 {"data" : {"messageType" : "continue" , "content" : answer }, "dataType" : data_type },
165- qa_type ,
166- conversation_id ,
167- message_id ,
168- task_id ,
169154 )
170155
156+ t02_answer_data .append (answer )
157+
171158 # 这里设置业务数据
172159 if data_type == DataTypeEnum .BUS_DATA .value [0 ]:
173160 bus_data = answer
@@ -188,14 +175,30 @@ async def exec_query(self, res):
188175 + "\n \n "
189176 )
190177
178+ elif DiFyCodeEnum .MESSAGE_END .value [0 ] == event_name :
179+ t02_message_json = {
180+ "data" : {"messageType" : "continue" , "content" : "" .join (t02_answer_data )},
181+ "dataType" : DataTypeEnum .ANSWER .value [0 ],
182+ }
183+ print (t02_message_json )
184+
185+ if t02_message_json :
186+ await self ._save_message (t02_message_json , qa_context , conversation_id , message_id , task_id , qa_type )
187+ if t04_answer_data :
188+ await self ._save_message (t04_answer_data , qa_context , conversation_id , message_id , task_id , qa_type )
189+
190+ t02_answer_data = []
191+ t04_answer_data = {}
192+
191193 except Exception as e :
192194 logging .error (f"Error during get_answer: { e } " )
193195 traceback .print_exception (e )
194196 return {"error" : str (e )} # 返回错误信息作为字典
195197 finally :
196198 await self .res_end (res )
197199
198- async def handle_think_tag (self , answer ):
200+ @staticmethod
201+ async def handle_think_tag (answer ):
199202 """
200203 处理<think>标签内的内容
201204 :param answer
@@ -205,10 +208,10 @@ async def handle_think_tag(self, answer):
205208
206209 return think_content , remaining_content
207210
208- async def save_message (self , response , message , qa_context , conversation_id , message_id , task_id , qa_type ):
211+ @staticmethod
212+ async def _save_message (message , qa_context , conversation_id , message_id , task_id , qa_type ):
209213 """
210214 保存消息记录并发送SSE数据
211- :param response:
212215 :param message:
213216 :param qa_context:
214217 :param conversation_id:
@@ -226,9 +229,8 @@ async def save_message(self, response, message, qa_context, conversation_id, mes
226229 await add_question_record (
227230 qa_context .token , conversation_id , message_id , task_id , qa_context .chat_id , qa_context .question , "" , message , qa_type
228231 )
229- await response .write ("data:" + json .dumps (message , ensure_ascii = False ) + "\n \n " )
230232
231- async def send_message (self , response , qa_context , answer , message , qa_type , conversation_id , message_id , task_id ):
233+ async def send_message (self , response , answer , message ):
232234 """
233235 SSE 格式发送数据,每一行以 data: 开头
234236 """
@@ -241,10 +243,9 @@ async def send_message(self, response, qa_context, answer, message, qa_type, con
241243 "data" : {"messageType" : "continue" , "content" : "> " + think_content .replace ("\n " , "" ) + "\n \n " + remaining_content },
242244 "dataType" : "t02" ,
243245 }
244- await self .save_message (response , message , qa_context , conversation_id , message_id , task_id , qa_type )
245-
246+ await response .write ("data:" + json .dumps (message , ensure_ascii = False ) + "\n \n " )
246247 else :
247- await self . save_message ( response , message , qa_context , conversation_id , message_id , task_id , qa_type )
248+ await response . write ( "data:" + json . dumps ( message , ensure_ascii = False ) + " \n \n " )
248249
249250 @staticmethod
250251 async def res_begin (res , chat_id ):
0 commit comments