Skip to content

Commit c7db4ce

Browse files
committed
修改对话记录添加bug
1 parent 4ab1ce5 commit c7db4ce

File tree

1 file changed

+30
-29
lines changed

1 file changed

+30
-29
lines changed

services/dify_service.py

Lines changed: 30 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,11 @@ async def exec_query(self, res):
7676
# 构建请求参数
7777
dify_service_url, body_params, headers = self._build_request(chat_id, cleaned_query, app_key, qa_type)
7878

79+
# 收集流式输出结果
80+
t02_answer_data = []
81+
# 收集业务数据流式输出结果
82+
t04_answer_data = {}
83+
7984
async with aiohttp.ClientSession(read_bufsize=1024 * 16) as session:
8085
async with session.post(
8186
dify_service_url,
@@ -114,43 +119,28 @@ async def exec_query(self, res):
114119
if data_type == DataTypeEnum.ANSWER.value[0]:
115120
await self.send_message(
116121
res,
117-
qa_context,
118122
answer,
119123
{"data": {"messageType": "begin"}, "dataType": data_type},
120-
qa_type,
121-
conversation_id,
122-
message_id,
123-
task_id,
124124
)
125125
elif event_list[1] == "1":
126126
# 输出结束
127127
data_type = event_list[2]
128128
if data_type == DataTypeEnum.ANSWER.value[0]:
129129
await self.send_message(
130130
res,
131-
qa_context,
132131
answer,
133132
{"data": {"messageType": "end"}, "dataType": data_type},
134-
qa_type,
135-
conversation_id,
136-
message_id,
137-
task_id,
138133
)
139134

140135
# 输出业务数据
141136
elif bus_data and data_type == DataTypeEnum.BUS_DATA.value[0]:
142137
res_data = process(json.loads(bus_data)["data"])
143-
# logging.info(f"chart_data: {res_data}")
144138
await self.send_message(
145139
res,
146-
qa_context,
147140
answer,
148141
{"data": res_data, "dataType": data_type},
149-
qa_type,
150-
conversation_id,
151-
message_id,
152-
task_id,
153142
)
143+
t04_answer_data = {"data": res_data, "dataType": data_type}
154144

155145
data_type = ""
156146

@@ -159,15 +149,12 @@ async def exec_query(self, res):
159149
if data_type == DataTypeEnum.ANSWER.value[0]:
160150
await self.send_message(
161151
res,
162-
qa_context,
163152
answer,
164153
{"data": {"messageType": "continue", "content": answer}, "dataType": data_type},
165-
qa_type,
166-
conversation_id,
167-
message_id,
168-
task_id,
169154
)
170155

156+
t02_answer_data.append(answer)
157+
171158
# 这里设置业务数据
172159
if data_type == DataTypeEnum.BUS_DATA.value[0]:
173160
bus_data = answer
@@ -188,14 +175,30 @@ async def exec_query(self, res):
188175
+ "\n\n"
189176
)
190177

178+
elif DiFyCodeEnum.MESSAGE_END.value[0] == event_name:
179+
t02_message_json = {
180+
"data": {"messageType": "continue", "content": "".join(t02_answer_data)},
181+
"dataType": DataTypeEnum.ANSWER.value[0],
182+
}
183+
print(t02_message_json)
184+
185+
if t02_message_json:
186+
await self._save_message(t02_message_json, qa_context, conversation_id, message_id, task_id, qa_type)
187+
if t04_answer_data:
188+
await self._save_message(t04_answer_data, qa_context, conversation_id, message_id, task_id, qa_type)
189+
190+
t02_answer_data = []
191+
t04_answer_data = {}
192+
191193
except Exception as e:
192194
logging.error(f"Error during get_answer: {e}")
193195
traceback.print_exception(e)
194196
return {"error": str(e)} # 返回错误信息作为字典
195197
finally:
196198
await self.res_end(res)
197199

198-
async def handle_think_tag(self, answer):
200+
@staticmethod
201+
async def handle_think_tag(answer):
199202
"""
200203
处理<think>标签内的内容
201204
:param answer
@@ -205,10 +208,10 @@ async def handle_think_tag(self, answer):
205208

206209
return think_content, remaining_content
207210

208-
async def save_message(self, response, message, qa_context, conversation_id, message_id, task_id, qa_type):
211+
@staticmethod
212+
async def _save_message(message, qa_context, conversation_id, message_id, task_id, qa_type):
209213
"""
210214
保存消息记录并发送SSE数据
211-
:param response:
212215
:param message:
213216
:param qa_context:
214217
:param conversation_id:
@@ -226,9 +229,8 @@ async def save_message(self, response, message, qa_context, conversation_id, mes
226229
await add_question_record(
227230
qa_context.token, conversation_id, message_id, task_id, qa_context.chat_id, qa_context.question, "", message, qa_type
228231
)
229-
await response.write("data:" + json.dumps(message, ensure_ascii=False) + "\n\n")
230232

231-
async def send_message(self, response, qa_context, answer, message, qa_type, conversation_id, message_id, task_id):
233+
async def send_message(self, response, answer, message):
232234
"""
233235
SSE 格式发送数据,每一行以 data: 开头
234236
"""
@@ -241,10 +243,9 @@ async def send_message(self, response, qa_context, answer, message, qa_type, con
241243
"data": {"messageType": "continue", "content": "> " + think_content.replace("\n", "") + "\n\n" + remaining_content},
242244
"dataType": "t02",
243245
}
244-
await self.save_message(response, message, qa_context, conversation_id, message_id, task_id, qa_type)
245-
246+
await response.write("data:" + json.dumps(message, ensure_ascii=False) + "\n\n")
246247
else:
247-
await self.save_message(response, message, qa_context, conversation_id, message_id, task_id, qa_type)
248+
await response.write("data:" + json.dumps(message, ensure_ascii=False) + "\n\n")
248249

249250
@staticmethod
250251
async def res_begin(res, chat_id):

0 commit comments

Comments
 (0)