Skip to content

Commit 32db3d8

Browse files
authored
fix long reply guard without discard reroll (#1206)
1 parent a3e246e commit 32db3d8

2 files changed

Lines changed: 364 additions & 22 deletions

File tree

main_logic/omni_offline_client.py

Lines changed: 111 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1163,10 +1163,22 @@ async def stream_text(self, text: str) -> None:
11631163
fence_triggered = False # 围栏是否已触发
11641164
guard_triggered = False
11651165
discard_reason = None
1166+
length_guard_recovery_text = ""
1167+
length_guard_persisted_prefix = ""
1168+
length_guard_original_tokens = 0
11661169
chunk_usage = None
11671170
prefix_buffer = ""
11681171
prefix_checked = not bool(self._prefix_buffer_size)
11691172

1173+
def _has_unpersisted_recovery_suffix(recovery_text: str) -> bool:
1174+
if not recovery_text:
1175+
return False
1176+
if not length_guard_persisted_prefix:
1177+
return True
1178+
if not recovery_text.startswith(length_guard_persisted_prefix):
1179+
return False
1180+
return bool(recovery_text[len(length_guard_persisted_prefix):].strip())
1181+
11701182
# Tool-aware streaming: ``_astream_with_tools`` runs
11711183
# the multi-turn tool loop inside (executing tools and
11721184
# appending results to ``_conversation_history`` IN
@@ -1187,6 +1199,7 @@ async def stream_text(self, text: str) -> None:
11871199
# 第二次写进 history。``_total`` 不重置——重复检测
11881200
# / token 长度 guard 仍要看完整一轮的实际文本量。
11891201
if getattr(chunk, "tool_round_persisted", False):
1202+
length_guard_persisted_prefix = assistant_message_total
11901203
assistant_message = ""
11911204
# 重置围栏 / prefix buffer:下一段是新的语义
11921205
# 单元(模型基于 tool 结果重新出文本),不应
@@ -1240,23 +1253,48 @@ async def stream_text(self, text: str) -> None:
12401253
break
12411254

12421255
if truncated_content and truncated_content.strip():
1243-
assistant_message += truncated_content
1244-
assistant_message_total += truncated_content
1245-
if self.on_text_delta:
1246-
await self.on_text_delta(truncated_content, is_first_chunk)
1247-
is_first_chunk = False
1248-
1256+
emit_content = truncated_content
12491257
if self.enable_response_guard:
1250-
# 长度 guard 看完整一轮(含 pre-tool)的 token 量,
1251-
# 不能只看 final-segment,否则 tool 轮前长篇大论 +
1252-
# tool 轮后短短一句也能逃过 guard。
1253-
current_length = count_tokens(assistant_message_total)
1258+
# 长度 guard 看完整一轮(含 pre-tool)的 token 量。
1259+
# 必须在 on_text_delta 前裁剪本 chunk,否则 UI/TTS
1260+
# 会先收到超限尾巴,而 history 只保存截断文本。
1261+
candidate_total = assistant_message_total + truncated_content
1262+
current_length = count_tokens(candidate_total)
12541263
if current_length > self.max_response_length:
12551264
guard_triggered = True
12561265
discard_reason = f"length>{self.max_response_length}"
1257-
logger.info(f"OmniOfflineClient: 检测到长回复 ({current_length} tokens),准备重试")
1266+
length_guard_original_tokens = current_length
1267+
logger.info(f"OmniOfflineClient: 检测到长回复 ({current_length} tokens),准备停止生成")
12581268
self._is_responding = False
1259-
break
1269+
emit_content = ""
1270+
if not _is_gibberish_response(candidate_total):
1271+
capped = truncate_to_tokens(
1272+
candidate_total, self.max_response_length,
1273+
)
1274+
candidate_recovery = _truncate_to_last_sentence_end(capped)
1275+
if candidate_recovery:
1276+
if candidate_recovery.startswith(assistant_message_total):
1277+
recovery_suffix = candidate_recovery[len(assistant_message_total):]
1278+
if recovery_suffix.strip():
1279+
emit_content = recovery_suffix
1280+
length_guard_recovery_text = candidate_recovery
1281+
elif (
1282+
assistant_message_total
1283+
and _has_unpersisted_recovery_suffix(assistant_message_total)
1284+
):
1285+
# 已流式发出的前缀无法撤回;保持 history 与
1286+
# UI/TTS 一致,避免可见文本和上下文分叉。
1287+
length_guard_recovery_text = assistant_message_total
1288+
1289+
if emit_content and emit_content.strip():
1290+
assistant_message += emit_content
1291+
assistant_message_total += emit_content
1292+
if self.on_text_delta:
1293+
await self.on_text_delta(emit_content, is_first_chunk)
1294+
is_first_chunk = False
1295+
1296+
if guard_triggered:
1297+
break
12601298
elif content and not content.strip():
12611299
logger.debug(f"OmniOfflineClient: 过滤空白内容 - content_repr: {repr(content)[:100]}")
12621300

@@ -1283,17 +1321,38 @@ async def stream_text(self, text: str) -> None:
12831321
fence_triggered = True
12841322
break
12851323
if flush_text and flush_text.strip():
1286-
assistant_message += flush_text
1287-
assistant_message_total += flush_text
1288-
if self.on_text_delta:
1289-
await self.on_text_delta(flush_text, is_first_chunk)
1290-
is_first_chunk = False
1324+
emit_flush_text = flush_text
12911325
if self.enable_response_guard:
1292-
# 长度 guard 看整轮(含 pre-tool),与上方主累加块对偶
1293-
current_length = count_tokens(assistant_message_total)
1326+
# 长度 guard 看整轮(含 pre-tool),与上方主累加块对偶。
1327+
candidate_total = assistant_message_total + flush_text
1328+
current_length = count_tokens(candidate_total)
12941329
if current_length > self.max_response_length:
12951330
guard_triggered = True
12961331
discard_reason = f"length>{self.max_response_length}"
1332+
length_guard_original_tokens = current_length
1333+
emit_flush_text = ""
1334+
if not _is_gibberish_response(candidate_total):
1335+
capped = truncate_to_tokens(
1336+
candidate_total, self.max_response_length,
1337+
)
1338+
candidate_recovery = _truncate_to_last_sentence_end(capped)
1339+
if candidate_recovery:
1340+
if candidate_recovery.startswith(assistant_message_total):
1341+
recovery_suffix = candidate_recovery[len(assistant_message_total):]
1342+
if recovery_suffix.strip():
1343+
emit_flush_text = recovery_suffix
1344+
length_guard_recovery_text = candidate_recovery
1345+
elif (
1346+
assistant_message_total
1347+
and _has_unpersisted_recovery_suffix(assistant_message_total)
1348+
):
1349+
length_guard_recovery_text = assistant_message_total
1350+
if emit_flush_text and emit_flush_text.strip():
1351+
assistant_message += emit_flush_text
1352+
assistant_message_total += emit_flush_text
1353+
if self.on_text_delta:
1354+
await self.on_text_delta(emit_flush_text, is_first_chunk)
1355+
is_first_chunk = False
12971356

12981357
if guard_triggered:
12991358
guard_attempt += 1
@@ -1305,6 +1364,34 @@ async def stream_text(self, text: str) -> None:
13051364
# / max_attempts 进度条要 1/2 → 2/2 才合理。
13061365
total_attempts = self.max_response_rerolls + 1
13071366

1367+
recovery_text = length_guard_recovery_text
1368+
if discard_reason and "length>" in discard_reason:
1369+
# 长回复若是正常可读文本,直接按已发出的截断文本
1370+
# 收尾,不 reroll,避免 UI/TTS 和 history 分叉。
1371+
if not recovery_text and not _is_gibberish_response(assistant_message_total):
1372+
capped = truncate_to_tokens(
1373+
assistant_message_total, self.max_response_length,
1374+
)
1375+
candidate_recovery = _truncate_to_last_sentence_end(capped)
1376+
if _has_unpersisted_recovery_suffix(candidate_recovery):
1377+
recovery_text = candidate_recovery
1378+
1379+
if recovery_text and _has_unpersisted_recovery_suffix(recovery_text):
1380+
history_recovery_text = assistant_message
1381+
original_tokens = length_guard_original_tokens or count_tokens(assistant_message_total)
1382+
logger.info(
1383+
"OmniOfflineClient: 长回复已流式输出,停止生成并按最后句末入历史 "
1384+
"(原 %d tokens → 截断后 %d tokens)",
1385+
original_tokens, count_tokens(recovery_text),
1386+
)
1387+
if history_recovery_text:
1388+
self._conversation_history.append(AIMessage(content=history_recovery_text))
1389+
await self._check_repetition(recovery_text)
1390+
assistant_message = history_recovery_text
1391+
guard_exhausted = True
1392+
break
1393+
recovery_text = ""
1394+
13081395
if will_retry:
13091396
# 还能 retry:发 will_retry 通知,循环继续。前端
13101397
# 收到 response_discarded(will_retry=True, message=None)
@@ -1339,7 +1426,6 @@ async def stream_text(self, text: str) -> None:
13391426
# max_response_length 再找句末,否则截出来的句末仍
13401427
# 可能在 token 上限之外(比如最后一个句号在 950 token
13411428
# 处但 cap 是 300)。
1342-
recovery_text = ""
13431429
if discard_reason and "length>" in discard_reason:
13441430
# 整轮判定:gibberish / 截断必须看 _total,否则
13451431
# tool 轮 sentinel 把 final-segment 清空之后整段
@@ -1348,13 +1434,16 @@ async def stream_text(self, text: str) -> None:
13481434
capped = truncate_to_tokens(
13491435
assistant_message_total, self.max_response_length,
13501436
)
1351-
recovery_text = _truncate_to_last_sentence_end(capped)
1437+
candidate_recovery = _truncate_to_last_sentence_end(capped)
1438+
if _has_unpersisted_recovery_suffix(candidate_recovery):
1439+
recovery_text = candidate_recovery
13521440

13531441
if recovery_text:
1442+
original_tokens = length_guard_original_tokens or count_tokens(assistant_message_total)
13541443
logger.info(
13551444
"OmniOfflineClient: guard 重试耗尽,截断至最后句末 "
13561445
"(原 %d tokens → 截断后 %d tokens)",
1357-
count_tokens(assistant_message_total), count_tokens(recovery_text),
1446+
original_tokens, count_tokens(recovery_text),
13581447
)
13591448
truncate_msg = json.dumps({
13601449
"code": "RESPONSE_LENGTH_TRUNCATED",

0 commit comments

Comments
 (0)