Skip to content

Commit 811a76d

Browse files
authored
fix: 优化 Embedding Provider 初始化逻辑,添加缺失提供者警告,确保 Milvus 连接任务按需启动 (#103)
* fix: 优化 Embedding Provider 初始化逻辑,添加缺失提供者警告,确保 Milvus 连接任务按需启动 * fix: 优化异常处理和后台任务创建逻辑,确保在无事件循环时安全启动任务
1 parent bbd4b7f commit 811a76d

1 file changed

Lines changed: 136 additions & 12 deletions

File tree

main.py

Lines changed: 136 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ def __init__(self, context: Context, config: AstrBotConfig):
7171
self._initialized_components = []
7272
self._embedding_provider_ready = False
7373
self._migrated_sessions: set[str] = set() # 用于记录已迁移的会话
74+
self._warned_missing_provider_ids: set[str] = set()
75+
self._post_load_tasks_started = False
76+
self._ensure_milvus_connection_task: asyncio.Task | None = None
7477

7578
logger.info("开始初始化 Mnemosyne 插件...")
7679
# 启动后台异步初始化,但不包括 Embedding Provider 的初始化
@@ -97,7 +100,29 @@ def _initialize_embedding_provider(
97100
# 优先级 1: 从配置指定的 Provider ID 获取
98101
emb_id = self.config.get("embedding_provider_id")
99102
if emb_id:
100-
provider = self.context.get_provider_by_id(emb_id)
103+
# 避免在 AstrBot 尚未完成启动时调用 context.get_provider_by_id 产生多余的 WARN 日志。
104+
provider = None
105+
try:
106+
provider_manager = getattr(self.context, "provider_manager", None)
107+
inst_map = getattr(provider_manager, "inst_map", None)
108+
if isinstance(inst_map, dict):
109+
provider = inst_map.get(emb_id)
110+
except (AttributeError, TypeError) as exc:
111+
# 仅在访问 provider_manager / inst_map 出现属性或类型问题时处理,避免掩盖真实错误。
112+
if silent:
113+
logger.debug(
114+
f"无法从 context 读取 Embedding Provider '{emb_id}'(provider_manager/inst_map 不可用): {exc}"
115+
)
116+
else:
117+
logger.warning(
118+
f"无法从 context 读取 Embedding Provider '{emb_id}'(provider_manager/inst_map 不可用): {exc}"
119+
)
120+
provider = None
121+
122+
# 兼容旧版本:如果无法访问 provider_manager,再回退到官方 API
123+
if provider is None and not hasattr(self.context, "provider_manager"):
124+
provider = self.context.get_provider_by_id(emb_id)
125+
101126
# 安全地检查 provider 是否为 EmbeddingProvider 类型
102127
if provider:
103128
# 检查 provider 是否具有 EmbeddingProvider 的关键方法
@@ -113,6 +138,16 @@ def _initialize_embedding_provider(
113138
logger.warning(
114139
f"获取的 Provider {emb_id} 不是有效的 EmbeddingProvider 类型"
115140
)
141+
else:
142+
if (
143+
not silent
144+
and emb_id not in self._warned_missing_provider_ids
145+
and self._are_providers_initialized()
146+
):
147+
logger.warning(
148+
f"未找到配置的 Embedding Provider: {emb_id},请检查 AstrBot 提供商配置是否已加载/启用。"
149+
)
150+
self._warned_missing_provider_ids.add(emb_id)
116151

117152
# 优先级 2: 使用框架默认的第一个 Embedding Provider
118153
# 使用 context 提供的方法获取所有 embedding providers
@@ -196,6 +231,83 @@ async def _initialize_embedding_provider_async(
196231
)
197232
return False
198233

234+
def _are_providers_initialized(self) -> bool:
235+
"""
236+
判断 AstrBot 的 ProviderManager 是否已完成 Provider 加载。
237+
238+
说明:不要依赖 context.get_provider_by_id 的 WARN 日志来判断是否加载完成。
239+
"""
240+
provider_manager = getattr(self.context, "provider_manager", None)
241+
inst_map = getattr(provider_manager, "inst_map", None)
242+
if isinstance(inst_map, dict) and len(inst_map) > 0:
243+
return True
244+
245+
# 兜底:部分版本可能不暴露 inst_map(或初始化时机不同)
246+
try:
247+
if self.context.get_all_providers():
248+
return True
249+
except (AttributeError, TypeError) as exc:
250+
logger.warning(f"检查 Providers 初始化状态失败(get_all_providers 不可用): {exc}")
251+
return False
252+
253+
def _create_background_task(self, coro: Any, name: str) -> asyncio.Task | None:
254+
"""
255+
安全创建后台任务:确保存在运行中的事件循环,避免在同步/无 loop 场景下直接抛异常。
256+
"""
257+
try:
258+
loop = asyncio.get_running_loop()
259+
except RuntimeError as exc:
260+
logger.warning(f"无法启动后台任务 '{name}': 当前没有运行中的事件循环: {exc}")
261+
return None
262+
263+
# 额外提示:如果不是在 Task 上下文中创建,未来改动更容易定位。
264+
try:
265+
if asyncio.current_task() is None:
266+
logger.debug(f"启动后台任务 '{name}':当前不在 asyncio.Task 上下文中")
267+
except RuntimeError:
268+
# 极少数情况下 current_task() 也可能因 loop 上下文问题抛错,忽略即可。
269+
pass
270+
271+
try:
272+
return loop.create_task(coro, name=name) # type: ignore[arg-type]
273+
except TypeError:
274+
# 兼容:部分运行环境可能不支持 name 参数
275+
return loop.create_task(coro) # type: ignore[arg-type]
276+
277+
def _start_post_load_tasks(self):
278+
"""在 AstrBot 启动完成/Providers 可用后启动需要依赖 Providers 的后台任务。"""
279+
if (
280+
self._post_load_tasks_started
281+
and self._embedding_provider_task
282+
and not self._embedding_provider_task.done()
283+
and self._ensure_milvus_connection_task
284+
and not self._ensure_milvus_connection_task.done()
285+
):
286+
return
287+
288+
self._post_load_tasks_started = True
289+
290+
# 启动 Embedding Provider 后台加载任务(静默模式)
291+
if not self._embedding_provider_task or self._embedding_provider_task.done():
292+
task = self._create_background_task(
293+
self._initialize_embedding_provider_async(max_wait=10.0),
294+
name="mnemosyne.embedding_provider_init",
295+
)
296+
if task:
297+
self._embedding_provider_task = task
298+
299+
# 启动 Milvus 连接后台任务(在 Embedding Provider 加载后执行)
300+
if (
301+
not self._ensure_milvus_connection_task
302+
or self._ensure_milvus_connection_task.done()
303+
):
304+
task = self._create_background_task(
305+
self._ensure_milvus_connection_async(),
306+
name="mnemosyne.ensure_milvus_connection",
307+
)
308+
if task:
309+
self._ensure_milvus_connection_task = task
310+
199311
async def _ensure_milvus_connection_async(self):
200312
"""
201313
在 Embedding Provider 加载完成后,确保 Milvus 连接已建立
@@ -207,10 +319,15 @@ async def _ensure_milvus_connection_async(self):
207319
logger.debug("等待 Embedding Provider 加载完成后再连接 Milvus...")
208320
await self._embedding_provider_task
209321

210-
# 检查 Milvus Manager 是否已初始化
322+
# 等待 Milvus Manager 初始化完成(插件初始化可能仍在进行)
211323
if not self.milvus_manager:
212-
logger.warning("Milvus Manager 未初始化,跳过自动连接")
213-
return
324+
wait_start = time.time()
325+
while not self.milvus_manager and time.time() - wait_start < 10.0:
326+
await asyncio.sleep(0.5)
327+
328+
if not self.milvus_manager:
329+
logger.warning("Milvus Manager 未初始化,跳过自动连接")
330+
return
214331

215332
# 检查是否已连接
216333
if self.milvus_manager.is_connected():
@@ -260,14 +377,11 @@ async def _initialize_plugin_async(self):
260377
# 1. Embedding Provider 采用延迟初始化策略
261378
# 在后台静默尝试加载,但不阻塞插件启动
262379
logger.info("Embedding Provider 采用延迟初始化策略")
263-
264-
# 启动 Embedding Provider 后台加载任务(静默模式)
265-
self._embedding_provider_task = asyncio.create_task(
266-
self._initialize_embedding_provider_async(max_wait=10.0)
267-
)
268-
269-
# 启动 Milvus 连接后台任务(在 Embedding Provider 加载后执行)
270-
asyncio.create_task(self._ensure_milvus_connection_async())
380+
if self._are_providers_initialized():
381+
# 兼容:插件被热重载/晚加载时,Providers 可能已经就绪
382+
self._start_post_load_tasks()
383+
else:
384+
logger.info("等待 AstrBot 启动完成后再加载 Embedding Provider")
271385

272386
# 2. 继续初始化其他组件
273387
try:
@@ -410,6 +524,16 @@ async def _initialize_plugin_async(self):
410524
raise
411525

412526
# --- 事件处理钩子 (调用 memory_operations.py 中的实现) ---
527+
@filter.on_astrbot_loaded()
528+
async def on_astrbot_loaded(self):
529+
"""[事件钩子] AstrBot 初始化完成后再加载依赖 Providers 的组件。"""
530+
try:
531+
logger.info("AstrBot 初始化完成,开始加载 Embedding Provider...")
532+
self._start_post_load_tasks()
533+
except Exception as e:
534+
logger.error(f"处理 on_astrbot_loaded 钩子时发生捕获异常: {e}", exc_info=True)
535+
return
536+
413537
@filter.on_llm_request()
414538
async def query_memory(self, event: AstrMessageEvent, req: ProviderRequest):
415539
"""[事件钩子] 在 LLM 请求前,查询并注入长期记忆。"""

0 commit comments

Comments
 (0)