Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions src/A_memorix/core/runtime/sdk_memory_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1182,17 +1182,36 @@ async def summarize_chat_stream(
) -> Dict[str, Any]:
await self.initialize()
assert self.summary_importer
success, detail = await self.summary_importer.import_from_stream(
import_result = await self.summary_importer.import_from_stream(
stream_id=str(chat_id or "").strip(),
context_length=context_length,
include_personality=include_personality,
time_end=time_end,
metadata=metadata,
)
success = bool(getattr(import_result, "success", False))
detail = str(getattr(import_result, "detail", "") or "")
paragraph_hash = str(getattr(import_result, "paragraph_hash", "") or "").strip()
source = (
str(getattr(import_result, "source", "") or "").strip()
Comment on lines +1192 to +1196
or self._build_source("chat_summary", chat_id, [])
)
stored_ids: List[str] = []
episode_pending_ids: List[str] = []
if success:
await self.rebuild_episodes_for_sources([self._build_source("chat_summary", chat_id, [])])
if not paragraph_hash:
raise RuntimeError("聊天摘要导入成功但未返回 paragraph_hash,无法执行 Episode 增量入队")
assert self.metadata_store is not None
self.metadata_store.enqueue_episode_pending(paragraph_hash, source=source)
stored_ids.append(paragraph_hash)
episode_pending_ids.append(paragraph_hash)
self._persist()
return {"success": bool(success), "detail": detail}
payload = {"success": success, "detail": detail}
if stored_ids:
payload["stored_ids"] = stored_ids
if episode_pending_ids:
payload["episode_pending_ids"] = episode_pending_ids
return payload

async def ingest_summary(
self,
Expand Down
55 changes: 39 additions & 16 deletions src/A_memorix/core/utils/summary_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,19 @@
导入到 A_memorix 的存储组件中。
"""

from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from dataclasses import dataclass
from typing import Any, Dict, Iterator, List, Optional, Tuple

import json
import re
import time
import traceback

from src.common.logger import get_logger
from src.services import llm_service as llm_api
from src.services import message_service as message_api
from src.config.config import config_manager, global_config
from src.config.model_configs import TaskConfig
from src.services import llm_service as llm_api
from src.services import message_service as message_api

from ..storage import (
KnowledgeType,
Expand All @@ -37,6 +37,23 @@

logger = get_logger("A_Memorix.SummaryImporter")


@dataclass(frozen=True)
class SummaryImportResult:
"""聊天摘要导入结果。

保留二元组解包兼容:旧调用方仍可使用 ``success, detail = result``。
"""

success: bool
detail: str
paragraph_hash: str = ""
source: str = ""

def __iter__(self) -> Iterator[bool | str]:
yield self.success
yield self.detail

# 默认总结提示词模版
SUMMARY_PROMPT_TEMPLATE = """
你是 {bot_name}。{personality_context}
Expand Down Expand Up @@ -414,7 +431,7 @@ async def import_from_stream(
include_personality: Optional[bool] = None,
time_end: Optional[float] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> Tuple[bool, str]:
) -> SummaryImportResult:
"""
从指定的聊天流中提取记录并执行总结导入

Expand All @@ -425,12 +442,12 @@ async def import_from_stream(
time_end: 用于截取聊天记录的时间上界(闭区间)

Returns:
Tuple[bool, str]: (是否成功, 结果消息)
SummaryImportResult: 导入结果,包含本次新增摘要段落 hash。
"""
try:
self_check_ok, self_check_msg = await self._ensure_runtime_self_check()
if not self_check_ok:
return False, f"导入前自检失败: {self_check_msg}"
return SummaryImportResult(False, f"导入前自检失败: {self_check_msg}")

# 1. 获取配置
if context_length is None:
Expand All @@ -450,7 +467,7 @@ async def import_from_stream(
)

if not messages:
return False, "未找到有效的聊天记录进行总结"
return SummaryImportResult(False, "未找到有效的聊天记录进行总结")

# 转换为可读文本
chat_history_text = message_api.build_readable_messages(messages)
Expand Down Expand Up @@ -478,7 +495,7 @@ async def import_from_stream(

resolved_model = self._resolve_summary_model_config()
if resolved_model is None:
return False, "未找到可用的总结模型配置"
return SummaryImportResult(False, "未找到可用的总结模型配置")
task_name_to_use, model_config_to_use = resolved_model

logger.info(f"正在为流 {stream_id} 执行总结,消息条数: {len(messages)}")
Expand All @@ -498,16 +515,16 @@ async def import_from_stream(
response = str(result.completion.response or "")

if not success or not response:
return False, "LLM 生成总结失败"
return SummaryImportResult(False, "LLM 生成总结失败")

# 5. 解析结果
data = self._parse_llm_response(response)
if not data or "summary" not in data:
return False, "解析 LLM 响应失败或总结为空"
return SummaryImportResult(False, "解析 LLM 响应失败或总结为空")

summary_text = str(data["summary"] or "").strip()
if not summary_text:
return False, "解析 LLM 响应失败或总结为空"
return SummaryImportResult(False, "解析 LLM 响应失败或总结为空")
entities = _normalize_entity_items(data.get("entities"))
relations = _normalize_relation_items(data.get("relations"))
msg_times = [timestamp for msg in messages if (timestamp := _message_timestamp(msg)) is not None]
Expand All @@ -521,7 +538,7 @@ async def import_from_stream(
}

# 6. 执行导入
await self._execute_import(
paragraph_hash = await self._execute_import(
summary_text,
entities,
relations,
Expand All @@ -540,11 +557,16 @@ async def import_from_stream(
f"📌 提取实体: {len(entities)}\n"
f"🔗 提取关系: {len(relations)}"
)
return True, result_msg
return SummaryImportResult(
True,
result_msg,
paragraph_hash=paragraph_hash,
source=f"chat_summary:{stream_id}",
)

except Exception as e:
logger.error(f"总结导入过程中出错: {e}\n{traceback.format_exc()}")
return False, f"错误: {str(e)}"
return SummaryImportResult(False, f"错误: {str(e)}")

async def _ensure_runtime_self_check(self) -> Tuple[bool, str]:
plugin_instance = self.plugin_config.get("plugin_instance") if isinstance(self.plugin_config, dict) else None
Expand Down Expand Up @@ -595,7 +617,7 @@ async def _execute_import(
stream_id: str,
time_meta: Optional[Dict[str, Any]] = None,
metadata: Optional[Dict[str, Any]] = None,
):
) -> str:
"""将数据写入存储"""
# 获取默认知识类型
type_str = self.plugin_config.get("summarization", {}).get("default_knowledge_type", "narrative")
Expand Down Expand Up @@ -675,3 +697,4 @@ async def _execute_import(
pass

logger.info(f"总结导入完成: hash={hash_value[:8]}")
return hash_value
Loading