Skip to content

Commit 85ec7a9

Browse files
feat: Enhance Reply chain handling for Record components (#8527)
* Enhance Reply chain handling for Record components Added processing for Record components within Reply chains, including WAV conversion and STT functionality. * Refactor STT processing for Record components * Add STT record function for voice-to-text processing * Update stage.py * Update stage.py * Update stage.py
1 parent 9a648eb commit 85ec7a9

1 file changed

Lines changed: 66 additions & 22 deletions

File tree

  • astrbot/core/pipeline/preprocess_stage

astrbot/core/pipeline/preprocess_stage/stage.py

Lines changed: 66 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from collections.abc import AsyncGenerator
55

66
from astrbot.core import logger
7-
from astrbot.core.message.components import Image, Plain, Record
7+
from astrbot.core.message.components import Image, Plain, Record, Reply
88
from astrbot.core.platform.astr_message_event import AstrMessageEvent
99
from astrbot.core.utils.media_utils import ensure_wav
1010

@@ -80,6 +80,24 @@ async def process(
8080
except Exception as e:
8181
logger.warning(f"Voice processing failed: {e}")
8282

83+
# Also process Record components inside Reply chains (wav conversion)
84+
for component in event.get_messages():
85+
if isinstance(component, Reply) and component.chain:
86+
for idx, reply_comp in enumerate(component.chain):
87+
if isinstance(reply_comp, Record):
88+
try:
89+
original_path = await reply_comp.convert_to_file_path()
90+
record_path = await ensure_wav(original_path)
91+
if record_path != original_path:
92+
event.track_temporary_local_file(record_path)
93+
reply_comp.file = record_path
94+
reply_comp.path = record_path
95+
component.chain[idx] = reply_comp
96+
except Exception as e:
97+
logger.warning(
98+
f"Voice processing in reply chain failed: {e}"
99+
)
100+
83101
# STT
84102
if self.stt_settings.get("enable", False):
85103
# TODO: 独立
@@ -90,27 +108,53 @@ async def process(
90108
f"会话 {event.unified_msg_origin} 未配置语音转文本模型。",
91109
)
92110
return
111+
112+
async def _stt_record(record_comp: Record, is_reply: bool = False):
113+
"""对单个 Record 组件执行语音转文本,成功返回 Plain,失败返回 None。"""
114+
prefix = "引用消息" if is_reply else ""
115+
try:
116+
path = await record_comp.convert_to_file_path()
117+
except Exception as e:
118+
logger.warning(f"获取{prefix}语音路径失败: {e}")
119+
return None
120+
121+
retry = 5
122+
for i in range(retry):
123+
try:
124+
result = await stt_provider.get_text(audio_url=path)
125+
if result:
126+
suffix = "(引用消息)" if is_reply else ""
127+
logger.info(f"语音转文本{suffix}结果: " + result)
128+
return Plain(result)
129+
break
130+
except FileNotFoundError:
131+
# napcat workaround: file may not be ready immediately
132+
logger.debug(f"文件尚未就绪 ({path}),重试 {i + 1}/{retry}")
133+
await asyncio.sleep(0.5)
134+
continue
135+
except BaseException as e:
136+
logger.error(traceback.format_exc())
137+
suffix = "(引用消息)" if is_reply else ""
138+
logger.error(f"语音转文本{suffix}失败: {e}")
139+
break
140+
return None
141+
93142
message_chain = event.get_messages()
94143
for idx, component in enumerate(message_chain):
95144
if isinstance(component, Record):
96-
path = await component.convert_to_file_path()
97-
retry = 5
98-
for i in range(retry):
99-
try:
100-
result = await stt_provider.get_text(audio_url=path)
101-
if result:
102-
logger.info("语音转文本结果: " + result)
103-
message_chain[idx] = Plain(result)
104-
event.message_str += result
105-
event.message_obj.message_str += result
106-
break
107-
except FileNotFoundError as e:
108-
# napcat workaround
109-
logger.warning(e)
110-
logger.warning(f"重试中: {i + 1}/{retry}")
111-
await asyncio.sleep(0.5)
112-
continue
113-
except BaseException as e:
114-
logger.error(traceback.format_exc())
115-
logger.error(f"语音转文本失败: {e}")
116-
break
145+
plain_comp = await _stt_record(component)
146+
if plain_comp:
147+
message_chain[idx] = plain_comp
148+
event.message_str += plain_comp.text
149+
event.message_obj.message_str += plain_comp.text
150+
151+
# Also STT for Record components inside Reply chains
152+
for component in event.get_messages():
153+
if isinstance(component, Reply) and component.chain:
154+
for idx, reply_comp in enumerate(component.chain):
155+
if isinstance(reply_comp, Record):
156+
plain_comp = await _stt_record(reply_comp, is_reply=True)
157+
if plain_comp:
158+
component.chain[idx] = plain_comp
159+
event.message_str += plain_comp.text
160+
event.message_obj.message_str += plain_comp.text

0 commit comments

Comments
 (0)