From 9523c5d3e60027a52de994549463fc3cda7d6e37 Mon Sep 17 00:00:00 2001 From: huamingjie Date: Wed, 11 Mar 2026 22:30:31 +0800 Subject: [PATCH 1/8] feat(kg): add dual-mode knowledge graph support (Zep Cloud + Graphiti) - Add kg_adapter for dual-mode knowledge graph (cloud/local) - Support switching between Zep Cloud and local Graphiti + Neo4j - Improve entity extraction and report agent robustness - Add test_kg_adapter.py with unit tests Co-Authored-By: Claude Opus 4.6 --- .env.example | 20 +- backend/app/api/graph.py | 16 +- backend/app/api/simulation.py | 16 +- backend/app/config.py | 35 +- backend/app/services/graph_builder.py | 242 ++++--- backend/app/services/kg_adapter.py | 674 ++++++++++++++++++ .../app/services/oasis_profile_generator.py | 31 +- backend/app/services/report_agent.py | 35 +- backend/app/services/zep_entity_reader.py | 152 ++-- .../app/services/zep_graph_memory_updater.py | 61 +- backend/app/services/zep_tools.py | 270 +++++-- backend/pyproject.toml | 18 +- backend/requirements.txt | 9 +- backend/tests/test_kg_adapter.py | 152 ++++ backend/uv.lock | 42 +- 15 files changed, 1465 insertions(+), 308 deletions(-) create mode 100644 backend/app/services/kg_adapter.py create mode 100644 backend/tests/test_kg_adapter.py diff --git a/.env.example b/.env.example index 78a3b72c..c7a906d6 100644 --- a/.env.example +++ b/.env.example @@ -5,10 +5,28 @@ LLM_API_KEY=your_api_key_here LLM_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 LLM_MODEL_NAME=qwen-plus -# ===== ZEP记忆图谱配置 ===== +# 嵌入模型配置(用于 Graphiti local 模式,可独立配置) +# 如果不配置,则默认使用 LLM 的 API +# EMBEDDING_API_KEY=your_embedding_api_key # 可选 +# EMBEDDING_BASE_URL=your_embedding_base_url # 可选 +# EMBEDDING_MODEL=text-embedding-3-small +# EMBEDDING_DIM=1536 + +# ===== 知识图谱配置 ===== +# 模式选择: "cloud" (Zep Cloud) 或 "local" (Graphiti + Neo4j) +KNOWLEDGE_GRAPH_MODE=cloud + +# Zep Cloud 配置 (KNOWLEDGE_GRAPH_MODE=cloud 时需要) # 每月免费额度即可支撑简单使用:https://app.getzep.com/ ZEP_API_KEY=your_zep_api_key_here +# Graphiti / Neo4j 配置 (KNOWLEDGE_GRAPH_MODE=local 时需要) +NEO4J_URI=bolt://localhost:7687 +NEO4J_USER=neo4j +NEO4J_PASSWORD=your_neo4j_password +# 嵌入向量 API Key(可选,用 LLM_API_KEY 即可,会自动使用同一 base_url) +# OPENAI_API_KEY=your_openai_key_for_embedding + # ===== 加速 LLM 配置(可选)===== # 注意如果不使用加速配置,env文件中就不要出现下面的配置项 LLM_BOOST_API_KEY=your_api_key_here diff --git a/backend/app/api/graph.py b/backend/app/api/graph.py index 12ff1ba2..2f94c53a 100644 --- a/backend/app/api/graph.py +++ b/backend/app/api/graph.py @@ -282,9 +282,9 @@ def build_graph(): try: logger.info("=== 开始构建图谱 ===") - # 检查配置 + # 检查配置 (cloud 模式需要 Zep Cloud) errors = [] - if not Config.ZEP_API_KEY: + if Config.KNOWLEDGE_GRAPH_MODE == 'cloud' and not Config.ZEP_API_KEY: errors.append("ZEP_API_KEY未配置") if errors: logger.error(f"配置错误: {errors}") @@ -567,13 +567,13 @@ def get_graph_data(graph_id: str): 获取图谱数据(节点和边) """ try: - if not Config.ZEP_API_KEY: + if Config.KNOWLEDGE_GRAPH_MODE == 'cloud' and not Config.ZEP_API_KEY: return jsonify({ "success": False, "error": "ZEP_API_KEY未配置" }), 500 - - builder = GraphBuilderService(api_key=Config.ZEP_API_KEY) + + builder = GraphBuilderService() graph_data = builder.get_graph_data(graph_id) return jsonify({ @@ -595,13 +595,13 @@ def delete_graph(graph_id: str): 删除Zep图谱 """ try: - if not Config.ZEP_API_KEY: + if Config.KNOWLEDGE_GRAPH_MODE == 'cloud' and not Config.ZEP_API_KEY: return jsonify({ "success": False, "error": "ZEP_API_KEY未配置" }), 500 - - builder = GraphBuilderService(api_key=Config.ZEP_API_KEY) + + builder = GraphBuilderService() builder.delete_graph(graph_id) return jsonify({ diff --git a/backend/app/api/simulation.py b/backend/app/api/simulation.py index 3a0f6816..102a8428 100644 --- a/backend/app/api/simulation.py +++ b/backend/app/api/simulation.py @@ -56,18 +56,18 @@ def get_graph_entities(graph_id: str): enrich: 是否获取相关边信息(默认true) """ try: - if not Config.ZEP_API_KEY: + if Config.KNOWLEDGE_GRAPH_MODE == 'cloud' and not Config.ZEP_API_KEY: return jsonify({ "success": False, "error": "ZEP_API_KEY未配置" }), 500 - + entity_types_str = request.args.get('entity_types', '') entity_types = [t.strip() for t in entity_types_str.split(',') if t.strip()] if entity_types_str else None enrich = request.args.get('enrich', 'true').lower() == 'true' - + logger.info(f"获取图谱实体: graph_id={graph_id}, entity_types={entity_types}, enrich={enrich}") - + reader = ZepEntityReader() result = reader.filter_defined_entities( graph_id=graph_id, @@ -93,12 +93,12 @@ def get_graph_entities(graph_id: str): def get_entity_detail(graph_id: str, entity_uuid: str): """获取单个实体的详细信息""" try: - if not Config.ZEP_API_KEY: + if Config.KNOWLEDGE_GRAPH_MODE == 'cloud' and not Config.ZEP_API_KEY: return jsonify({ "success": False, "error": "ZEP_API_KEY未配置" }), 500 - + reader = ZepEntityReader() entity = reader.get_entity_with_context(graph_id, entity_uuid) @@ -126,12 +126,12 @@ def get_entity_detail(graph_id: str, entity_uuid: str): def get_entities_by_type(graph_id: str, entity_type: str): """获取指定类型的所有实体""" try: - if not Config.ZEP_API_KEY: + if Config.KNOWLEDGE_GRAPH_MODE == 'cloud' and not Config.ZEP_API_KEY: return jsonify({ "success": False, "error": "ZEP_API_KEY未配置" }), 500 - + enrich = request.args.get('enrich', 'true').lower() == 'true' reader = ZepEntityReader() diff --git a/backend/app/config.py b/backend/app/config.py index 953dfa50..05883b15 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -31,9 +31,27 @@ class Config: LLM_API_KEY = os.environ.get('LLM_API_KEY') LLM_BASE_URL = os.environ.get('LLM_BASE_URL', 'https://api.openai.com/v1') LLM_MODEL_NAME = os.environ.get('LLM_MODEL_NAME', 'gpt-4o-mini') + + # 嵌入模型配置(用于 Graphiti local 模式,可独立配置) + EMBEDDING_API_KEY = os.environ.get('EMBEDDING_API_KEY') # 可选,默认使用 LLM_API_KEY + EMBEDDING_BASE_URL = os.environ.get('EMBEDDING_BASE_URL') # 可选,默认使用 LLM_BASE_URL + EMBEDDING_MODEL = os.environ.get('EMBEDDING_MODEL', 'text-embedding-3-small') + EMBEDDING_DIM = int(os.environ.get('EMBEDDING_DIM', '1536')) - # Zep配置 + # 知识图谱模式配置 + # cloud: 使用 Zep Cloud (默认) + # local: 使用 Graphiti + Neo4j (本地部署) + KNOWLEDGE_GRAPH_MODE = os.environ.get('KNOWLEDGE_GRAPH_MODE', 'cloud') + + # Zep Cloud 配置 (KNOWLEDGE_GRAPH_MODE=cloud 时需要) ZEP_API_KEY = os.environ.get('ZEP_API_KEY') + + # Graphiti / Neo4j 配置 (KNOWLEDGE_GRAPH_MODE=local 时需要) + NEO4J_URI = os.environ.get('NEO4J_URI', 'bolt://localhost:7687') + NEO4J_USER = os.environ.get('NEO4J_USER', 'neo4j') + NEO4J_PASSWORD = os.environ.get('NEO4J_PASSWORD') + # OpenAI API 用于嵌入向量 (Graphiti 模式需要) + OPENAI_API_KEY = os.environ.get('OPENAI_API_KEY') # 文件上传配置 MAX_CONTENT_LENGTH = 50 * 1024 * 1024 # 50MB @@ -69,7 +87,18 @@ def validate(cls): errors = [] if not cls.LLM_API_KEY: errors.append("LLM_API_KEY 未配置") - if not cls.ZEP_API_KEY: - errors.append("ZEP_API_KEY 未配置") + + # 根据模式验证对应的配置 + if cls.KNOWLEDGE_GRAPH_MODE == 'cloud': + if not cls.ZEP_API_KEY: + errors.append("ZEP_API_KEY 未配置 (当前模式: cloud)") + elif cls.KNOWLEDGE_GRAPH_MODE == 'local': + if not cls.NEO4J_PASSWORD: + errors.append("NEO4J_PASSWORD 未配置 (当前模式: local)") + if not cls.LLM_API_KEY and not cls.OPENAI_API_KEY: + errors.append("LLM_API_KEY 或 OPENAI_API_KEY 未配置 (当前模式: local,用于嵌入向量)") + else: + errors.append(f"未知的 KNOWLEDGE_GRAPH_MODE: {cls.KNOWLEDGE_GRAPH_MODE}") + return errors diff --git a/backend/app/services/graph_builder.py b/backend/app/services/graph_builder.py index 0e0444bf..7c7a6ca6 100644 --- a/backend/app/services/graph_builder.py +++ b/backend/app/services/graph_builder.py @@ -1,6 +1,7 @@ """ 图谱构建服务 -接口2:使用Zep API构建Standalone Graph +接口2:使用知识图谱API构建图谱 +支持 Zep Cloud 和 Graphiti (本地) 两种模式 """ import os @@ -10,14 +11,19 @@ from typing import Dict, Any, List, Optional, Callable from dataclasses import dataclass -from zep_cloud.client import Zep -from zep_cloud import EpisodeData, EntityEdgeSourceTarget - from ..config import Config from ..models.task import TaskManager, TaskStatus -from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges +from .kg_adapter import get_knowledge_graph_adapter from .text_processor import TextProcessor +# 保留原有的导入,用于动态类生成(兼容模式) +try: + from zep_cloud import EpisodeData, EntityEdgeSourceTarget + from zep_cloud.external_clients.ontology import EntityModel, EntityText, EdgeModel + ZEP_CLOUD_AVAILABLE = True +except ImportError: + ZEP_CLOUD_AVAILABLE = False + @dataclass class GraphInfo: @@ -39,15 +45,14 @@ def to_dict(self) -> Dict[str, Any]: class GraphBuilderService: """ 图谱构建服务 - 负责调用Zep API构建知识图谱 + 负责调用知识图谱 API 构建图谱 + 支持 Zep Cloud 和 Graphiti 两种模式 """ - + def __init__(self, api_key: Optional[str] = None): - self.api_key = api_key or Config.ZEP_API_KEY - if not self.api_key: - raise ValueError("ZEP_API_KEY 未配置") - - self.client = Zep(api_key=self.api_key) + self.api_key = api_key # 保留参数兼容性 + # 使用适配器 + self.kg = get_knowledge_graph_adapter() self.task_manager = TaskManager() def build_graph_async( @@ -185,15 +190,14 @@ def _build_graph_worker( self.task_manager.fail_task(task_id, error_msg) def create_graph(self, name: str) -> str: - """创建Zep图谱(公开方法)""" + """创建图谱(公开方法)""" graph_id = f"mirofish_{uuid.uuid4().hex[:16]}" - - self.client.graph.create( + + self.kg.create_graph( graph_id=graph_id, name=name, - description="MiroFish Social Simulation Graph" ) - + return graph_id def set_ontology(self, graph_id: str, ontology: Dict[str, Any]): @@ -277,13 +281,14 @@ def safe_attr_name(attr_name: str) -> str: if source_targets: edge_definitions[name] = (edge_class, source_targets) - # 调用Zep API设置本体 + # 调用图谱API设置本体 if entity_types or edge_definitions: - self.client.graph.set_ontology( - graph_ids=[graph_id], - entities=entity_types if entity_types else None, - edges=edge_definitions if edge_definitions else None, - ) + # 封装为 ontology 格式 + ontology = { + "entities": entity_types if entity_types else None, + "edges": edge_definitions if edge_definitions else None, + } + self.kg.set_ontology(graph_id, ontology) def add_text_batches( self, @@ -308,34 +313,29 @@ def add_text_batches( progress ) - # 构建episode数据 - episodes = [ - EpisodeData(data=chunk, type="text") - for chunk in batch_chunks - ] - - # 发送到Zep + # 发送到图谱 try: - batch_result = self.client.graph.add_batch( + # 使用适配器的批量添加方法 + batch_result = self.kg.add_episodes_batch( graph_id=graph_id, - episodes=episodes + texts=batch_chunks ) - + # 收集返回的 episode uuid if batch_result and isinstance(batch_result, list): for ep in batch_result: ep_uuid = getattr(ep, 'uuid_', None) or getattr(ep, 'uuid', None) if ep_uuid: episode_uuids.append(ep_uuid) - + # 避免请求过快 time.sleep(1) - + except Exception as e: if progress_callback: progress_callback(f"批次 {batch_num} 发送失败: {str(e)}", 0) raise - + return episode_uuids def _wait_for_episodes( @@ -370,13 +370,13 @@ def _wait_for_episodes( # 检查每个 episode 的处理状态 for ep_uuid in list(pending_episodes): try: - episode = self.client.graph.episode.get(uuid_=ep_uuid) + episode = self.kg.get_episode(ep_uuid) is_processed = getattr(episode, 'processed', False) - + if is_processed: pending_episodes.remove(ep_uuid) completed_count += 1 - + except Exception as e: # 忽略单个查询错误,继续 pass @@ -396,17 +396,18 @@ def _wait_for_episodes( def _get_graph_info(self, graph_id: str) -> GraphInfo: """获取图谱信息""" - # 获取节点(分页) - nodes = fetch_all_nodes(self.client, graph_id) + # 获取节点(使用适配器) + nodes = self.kg.get_nodes(graph_id, limit=2000) - # 获取边(分页) - edges = fetch_all_edges(self.client, graph_id) + # 获取边(使用适配器) + edges = self.kg.get_edges(graph_id, limit=2000) # 统计实体类型 entity_types = set() for node in nodes: - if node.labels: - for label in node.labels: + labels = node.labels if hasattr(node, 'labels') else node.get('labels', []) + if labels: + for label in labels: if label not in ["Entity", "Node"]: entity_types.add(label) @@ -420,72 +421,113 @@ def _get_graph_info(self, graph_id: str) -> GraphInfo: def get_graph_data(self, graph_id: str) -> Dict[str, Any]: """ 获取完整图谱数据(包含详细信息) - + Args: graph_id: 图谱ID - + Returns: 包含nodes和edges的字典,包括时间信息、属性等详细数据 """ - nodes = fetch_all_nodes(self.client, graph_id) - edges = fetch_all_edges(self.client, graph_id) + # 使用适配器获取节点和边 + nodes = self.kg.get_nodes(graph_id, limit=2000) + edges = self.kg.get_edges(graph_id, limit=2000) - # 创建节点映射用于获取节点名称 + # 创建节点映射用于获取节点名称(兼容对象和字典两种格式) node_map = {} for node in nodes: - node_map[node.uuid_] = node.name or "" - + if isinstance(node, dict): + node_map[node.get('uuid_', '')] = node.get('name', '') or "" + else: + node_map[getattr(node, 'uuid_', '')] = getattr(node, 'name', '') or "" + nodes_data = [] for node in nodes: - # 获取创建时间 - created_at = getattr(node, 'created_at', None) - if created_at: - created_at = str(created_at) - - nodes_data.append({ - "uuid": node.uuid_, - "name": node.name, - "labels": node.labels or [], - "summary": node.summary or "", - "attributes": node.attributes or {}, - "created_at": created_at, - }) - + # 兼容对象和字典两种格式 + if isinstance(node, dict): + created_at = node.get('created_at') + if created_at: + created_at = str(created_at) + nodes_data.append({ + "uuid": node.get('uuid_', ''), + "name": node.get('name', ''), + "labels": node.get('labels', []), + "summary": node.get('summary', ''), + "attributes": node.get('attributes', {}), + "created_at": created_at, + }) + else: + created_at = getattr(node, 'created_at', None) + if created_at: + created_at = str(created_at) + nodes_data.append({ + "uuid": getattr(node, 'uuid_', ''), + "name": getattr(node, 'name', ''), + "labels": getattr(node, 'labels', []), + "summary": getattr(node, 'summary', ''), + "attributes": getattr(node, 'attributes', {}), + "created_at": created_at, + }) + edges_data = [] for edge in edges: - # 获取时间信息 - created_at = getattr(edge, 'created_at', None) - valid_at = getattr(edge, 'valid_at', None) - invalid_at = getattr(edge, 'invalid_at', None) - expired_at = getattr(edge, 'expired_at', None) - - # 获取 episodes - episodes = getattr(edge, 'episodes', None) or getattr(edge, 'episode_ids', None) - if episodes and not isinstance(episodes, list): - episodes = [str(episodes)] - elif episodes: - episodes = [str(e) for e in episodes] - - # 获取 fact_type - fact_type = getattr(edge, 'fact_type', None) or edge.name or "" - - edges_data.append({ - "uuid": edge.uuid_, - "name": edge.name or "", - "fact": edge.fact or "", - "fact_type": fact_type, - "source_node_uuid": edge.source_node_uuid, - "target_node_uuid": edge.target_node_uuid, - "source_node_name": node_map.get(edge.source_node_uuid, ""), - "target_node_name": node_map.get(edge.target_node_uuid, ""), - "attributes": edge.attributes or {}, - "created_at": str(created_at) if created_at else None, - "valid_at": str(valid_at) if valid_at else None, - "invalid_at": str(invalid_at) if invalid_at else None, - "expired_at": str(expired_at) if expired_at else None, - "episodes": episodes or [], - }) - + # 兼容对象和字典两种格式 + if isinstance(edge, dict): + created_at = edge.get('created_at') + valid_at = edge.get('valid_at') + invalid_at = edge.get('invalid_at') + expired_at = edge.get('expired_at') + episodes = edge.get('episodes', []) + fact_type = edge.get('fact_type', '') or edge.get('name', '') + edges_data.append({ + "uuid": edge.get('uuid_', ''), + "name": edge.get('name', ''), + "fact": edge.get('fact', ''), + "fact_type": fact_type, + "source_node_uuid": edge.get('source_node_uuid', ''), + "target_node_uuid": edge.get('target_node_uuid', ''), + "source_node_name": node_map.get(edge.get('source_node_uuid', ''), ''), + "target_node_name": node_map.get(edge.get('target_node_uuid', ''), ''), + "attributes": edge.get('attributes', {}), + "created_at": str(created_at) if created_at else None, + "valid_at": str(valid_at) if valid_at else None, + "invalid_at": str(invalid_at) if invalid_at else None, + "expired_at": str(expired_at) if expired_at else None, + "episodes": episodes if isinstance(episodes, list) else [], + }) + else: + # 获取时间信息 + created_at = getattr(edge, 'created_at', None) + valid_at = getattr(edge, 'valid_at', None) + invalid_at = getattr(edge, 'invalid_at', None) + expired_at = getattr(edge, 'expired_at', None) + + # 获取 episodes + episodes = getattr(edge, 'episodes', None) or getattr(edge, 'episode_ids', None) + if episodes and not isinstance(episodes, list): + episodes = [str(episodes)] + elif episodes: + episodes = [str(e) for e in episodes] + + # 获取 fact_type + fact_type = getattr(edge, 'fact_type', None) or getattr(edge, 'name', '') or "" + + edges_data.append({ + "uuid": getattr(edge, 'uuid_', ''), + "name": getattr(edge, 'name', ''), + "fact": getattr(edge, 'fact', ''), + "fact_type": fact_type, + "source_node_uuid": getattr(edge, 'source_node_uuid', ''), + "target_node_uuid": getattr(edge, 'target_node_uuid', ''), + "source_node_name": node_map.get(getattr(edge, 'source_node_uuid', ''), ''), + "target_node_name": node_map.get(getattr(edge, 'target_node_uuid', ''), ''), + "attributes": getattr(edge, 'attributes', {}), + "created_at": str(created_at) if created_at else None, + "valid_at": str(valid_at) if valid_at else None, + "invalid_at": str(invalid_at) if invalid_at else None, + "expired_at": str(expired_at) if expired_at else None, + "episodes": episodes or [], + }) + return { "graph_id": graph_id, "nodes": nodes_data, @@ -496,5 +538,5 @@ def get_graph_data(self, graph_id: str) -> Dict[str, Any]: def delete_graph(self, graph_id: str): """删除图谱""" - self.client.graph.delete(graph_id=graph_id) + self.kg.delete(graph_id=graph_id) diff --git a/backend/app/services/kg_adapter.py b/backend/app/services/kg_adapter.py new file mode 100644 index 00000000..b8941316 --- /dev/null +++ b/backend/app/services/kg_adapter.py @@ -0,0 +1,674 @@ +""" +知识图谱适配器 +支持 Zep Cloud 和 Graphiti (本地) 两种模式 + +使用方式: + from app.services.kg_adapter import get_knowledge_graph_adapter + + kg = get_knowledge_graph_adapter() + kg.add_episode(graph_id="xxx", text="hello") + kg.search(graph_id="xxx", query="hello") +""" + +from abc import ABC, abstractmethod +from typing import Dict, Any, List, Optional +import logging + +from ..config import Config + +logger = logging.getLogger(__name__) + + +class KnowledgeGraphAdapter(ABC): + """知识图谱适配器抽象基类""" + + @abstractmethod + def create_graph(self, graph_id: str, name: str = None) -> Any: + """创建图谱""" + pass + + @abstractmethod + def add_episode(self, graph_id: str, text: str, **kwargs) -> Any: + """添加单条内容""" + pass + + @abstractmethod + def add_episodes_batch(self, graph_id: str, texts: List[str]) -> List[Any]: + """批量添加内容""" + pass + + @abstractmethod + def get_episode(self, episode_uuid: str) -> Any: + """获取单个 episode""" + pass + + @abstractmethod + def search(self, graph_id: str, query: str, limit: int = 10) -> List[Dict]: + """搜索""" + pass + + @abstractmethod + def get_nodes(self, graph_id: str, limit: int = 100, cursor: str = None) -> List[Any]: + """获取节点""" + pass + + @abstractmethod + def get_node(self, node_uuid: str) -> Any: + """获取单个节点""" + pass + + @abstractmethod + def get_node_edges(self, node_uuid: str) -> List[Dict]: + """获取节点的所有边""" + pass + + @abstractmethod + def get_edges(self, graph_id: str, limit: int = 100, cursor: str = None) -> List[Any]: + """获取边""" + pass + + @abstractmethod + def delete(self, graph_id: str) -> bool: + """删除图谱""" + pass + + @abstractmethod + def set_ontology(self, graph_id: str, ontology: Dict) -> bool: + """设置本体""" + pass + + @abstractmethod + def get_graph_info(self, graph_id: str) -> Dict: + """获取图谱信息""" + pass + + +class ZepCloudAdapter(KnowledgeGraphAdapter): + """Zep Cloud 适配器""" + + def __init__(self, api_key: str = None): + from zep_cloud.client import Zep + self.api_key = api_key or Config.ZEP_API_KEY + if not self.api_key: + raise ValueError("ZEP_API_KEY 未配置") + self.client = Zep(api_key=self.api_key) + logger.info("ZepCloudAdapter 初始化完成") + + def create_graph(self, graph_id: str, name: str = None) -> Any: + return self.client.graph.create(graph_id=graph_id, name=name or graph_id) + + def add_episode(self, graph_id: str, text: str, **kwargs) -> Any: + return self.client.graph.add(graph_id=graph_id, type="text", data=text) + + def add_episodes_batch(self, graph_id: str, texts: List[str]) -> List[Any]: + from zep_cloud.types import EpisodeData + episodes = [EpisodeData(data=t, type="text") for t in texts] + return self.client.graph.add_batch(episodes=episodes, graph_id=graph_id) + + def get_episode(self, episode_uuid: str) -> Any: + return self.client.graph.episode.get(uuid_=episode_uuid) + + def search(self, graph_id: str, query: str, limit: int = 10) -> List[Dict]: + result = self.client.graph.search(graph_id=graph_id, query=query, limit=limit) + return [r.model_dump() for r in result.results] if hasattr(result, 'results') else [] + + def get_nodes(self, graph_id: str, limit: int = 100, cursor: str = None) -> List[Any]: + kwargs = {"limit": limit} + if cursor: + kwargs["uuid_cursor"] = cursor + return self.client.graph.node.get_by_graph_id(graph_id=graph_id, **kwargs) + + def get_node(self, node_uuid: str) -> Any: + return self.client.graph.node.get(uuid_=node_uuid) + + def get_node_edges(self, node_uuid: str) -> List[Dict]: + edges = self.client.graph.node.get_entity_edges(node_uuid=node_uuid) + return [e.model_dump() if hasattr(e, 'model_dump') else e for e in edges] + + def get_edges(self, graph_id: str, limit: int = 100, cursor: str = None) -> List[Any]: + kwargs = {"limit": limit} + if cursor: + kwargs["uuid_cursor"] = cursor + return self.client.graph.edge.get_by_graph_id(graph_id=graph_id, **kwargs) + + def delete(self, graph_id: str) -> bool: + self.client.graph.delete(graph_id=graph_id) + return True + + def set_ontology(self, graph_id: str, ontology: Dict) -> bool: + entities = ontology.get('entities', {}) + edges = ontology.get('edges', {}) + self.client.graph.set_ontology( + entities=entities, + edges=edges, + graph_ids=[graph_id] + ) + return True + + def get_graph_info(self, graph_id: str) -> Dict: + # Zep Cloud 没有直接的图谱信息 API,返回基本信息 + return {"graph_id": graph_id} + + +class GraphitiAdapter(KnowledgeGraphAdapter): + """Graphiti 适配器 - 本地部署""" + + def __init__(self): + import os + import asyncio + from graphiti_core import Graphiti + from graphiti_core.llm_client.config import LLMConfig + from graphiti_core.llm_client.openai_client import OpenAIClient + from graphiti_core.embedder import OpenAIEmbedder, OpenAIEmbedderConfig + + if not all([Config.NEO4J_URI, Config.NEO4J_USER, Config.NEO4J_PASSWORD]): + raise ValueError("Neo4j 配置不完整,请检查 NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD") + + # 获取 API Key(优先使用独立配置,其次使用 LLM 配置) + api_key = Config.LLM_API_KEY or Config.OPENAI_API_KEY + llm_base_url = Config.LLM_BASE_URL + + # 嵌入模型独立配置 + embedding_api_key = Config.EMBEDDING_API_KEY or api_key + embedding_base_url = Config.EMBEDDING_BASE_URL or llm_base_url + + if not api_key: + raise ValueError("请配置 LLM_API_KEY") + + # 设置环境变量(Graphiti 内部组件会读取) + os.environ['OPENAI_API_KEY'] = api_key + os.environ['OPENAI_BASE_URL'] = llm_base_url + + # 配置 LLM 客户端(支持 OpenAI 兼容 API) + llm_config = LLMConfig( + api_key=api_key, + base_url=llm_base_url, + model=Config.LLM_MODEL_NAME, + small_model=Config.LLM_MODEL_NAME, # 使用相同模型 + ) + llm_client = OpenAIClient(config=llm_config) + + # 配置 Embedder 客户端(可独立配置) + embedder_config = OpenAIEmbedderConfig( + api_key=embedding_api_key, + base_url=embedding_base_url, + embedding_model=Config.EMBEDDING_MODEL, + embedding_dim=Config.EMBEDDING_DIM, + ) + embedder_client = OpenAIEmbedder(config=embedder_config) + + self.client = Graphiti( + uri=Config.NEO4J_URI, + user=Config.NEO4J_USER, + password=Config.NEO4J_PASSWORD, + llm_client=llm_client, + embedder=embedder_client, + cross_encoder=None, # 禁用 reranker,需要时可配置 + ) + # graph_id 到 group 的映射(Graphiti 使用 group 区分不同的图) + self._graph_id_to_group: Dict[str, str] = {} + + # 使用同步驱动避免 asyncio 事件循环冲突 + from neo4j import GraphDatabase + self._sync_driver = GraphDatabase.driver( + Config.NEO4J_URI, + auth=(Config.NEO4J_USER, Config.NEO4J_PASSWORD) + ) + + # 初始化数据库索引 + try: + import asyncio + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(self.client.build_indices_and_constraints()) + loop.close() + logger.info("Graphiti 数据库索引初始化完成") + except Exception as e: + logger.warning(f"数据库索引初始化警告: {e}") + + logger.info("GraphitiAdapter 初始化完成") + + def _run_async(self, coro): + """同步调用异步方法的包装器,使用持久化事件循环""" + import asyncio + + # 创建持久化的事件循环(线程级别) + if not hasattr(self, '_loop') or self._loop.is_closed(): + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) + + result = self._loop.run_until_complete(coro) + return result + + def _get_group(self, graph_id: str) -> str: + """获取或创建 group""" + if graph_id not in self._graph_id_to_group: + self._graph_id_to_group[graph_id] = graph_id + return self._graph_id_to_group[graph_id] + + def create_graph(self, graph_id: str, name: str = None) -> Any: + # Graphiti 不需要预创建图,通过 group 区分 + self._graph_id_to_group[graph_id] = graph_id + + # 创建 Group 节点 + with self._sync_driver.session() as session: + session.run(""" + MERGE (g:Group {name: $name}) + SET g.created_at = datetime() + """, name=graph_id) + + logger.info(f"Graphiti: 标记图谱 {graph_id}") + return {"status": "ok", "graph_id": graph_id} + + def add_episode(self, graph_id: str, text: str, **kwargs) -> Any: + """使用同步驱动添加 episode""" + import uuid + from datetime import datetime, timezone + + group = self._get_group(graph_id) + episode_uuid = str(uuid.uuid4()) + now = datetime.now(timezone.utc) + + # 直接使用同步驱动创建 episode + with self._sync_driver.session() as session: + query = """ + CREATE (e:Episodic { + uuid: $uuid, + name: $name, + content: $content, + created_at: $created_at, + valid_at: $valid_at, + group_id: $group_id, + source: 'text', + episode_type: 'text' + }) + RETURN e + """ + result = session.run( + query, + uuid=episode_uuid, + name=f"episode_{now.strftime('%Y%m%d%H%M%S')}", + content=text, + created_at=now, + valid_at=now, + group_id=group + ) + record = result.single() + return {"uuid": episode_uuid, "name": record["e"]["name"]} if record else None + + def add_episodes_batch(self, graph_id: str, texts: List[str], batch_size: int = 10) -> List[Any]: + """批量添加内容,使用同步驱动,并提取实体""" + results = [] + group = self._get_group(graph_id) + + # 直接添加 episodes + for i, text in enumerate(texts): + result = self.add_episode(graph_id, text) + results.append(result) + if (i + 1) % batch_size == 0: + logger.info(f"已添加 {i + 1}/{len(texts)} 条内容") + + # 提取实体 + self._extract_entities_from_texts(graph_id, texts) + + logger.info(f"图谱实体全部构建完成: {graph_id}, 共 {len(results)} 条") + return results + + def _extract_entities_from_texts(self, graph_id: str, texts: List[str]): + """从文本中提取实体并存储到 Neo4j""" + import uuid + + # 合并所有文本进行实体提取 + combined_text = "\n\n".join(texts) + + # 使用 LLM 提取实体 + entities_json = self._call_llm_for_entities(combined_text) + + if not entities_json: + logger.warning("未能从文本中提取到实体") + return + + # 解析实体 + try: + import json + import re + + # 尝试直接解析 + try: + entities_data = json.loads(entities_json) + except json.JSONDecodeError: + # 尝试提取 JSON 数组 + match = re.search(r'\[.*\]', entities_json, re.DOTALL) + if match: + entities_data = json.loads(match.group()) + else: + raise ValueError("No JSON array found") + + if not entities_data or not isinstance(entities_data, list): + logger.warning("未提取到实体数据") + return + + logger.info(f"提取到 {len(entities_data)} 个实体") + except Exception as e: + logger.error(f"解析实体数据失败: {e}, 内容: {entities_json[:300]}") + return + + # 存储实体到 Neo4j + with self._sync_driver.session() as session: + for entity in entities_data: + entity_name = entity.get("name") + entity_type = entity.get("type", "Entity") + description = entity.get("description", "") + relationships = entity.get("relationships", []) + + if not entity_name: + continue + + # 创建实体节点并关联到 Group + entity_uuid = str(uuid.uuid4()) + query = """ + MERGE (g:Group {name: $group_id}) + MERGE (e:Entity {name: $name, group_id: $group_id}) + SET e.uuid = $uuid, + e.summary = $summary, + e.created_at = datetime(), + e.entity_type = $type + MERGE (e)-[:MEMBER_OF]->(g) + RETURN e + """ + session.run( + query, + name=entity_name, + uuid=entity_uuid, + summary=description, + type=entity_type, + group_id=graph_id + ) + + # 创建关系 + for rel in relationships: + target = rel.get("target") + rel_type = rel.get("type", "RELATED_TO") + fact = rel.get("fact", "") + + if target: + rel_query = """ + MATCH (e1:Entity {name: $source, group_id: $group_id}) + MATCH (e2:Entity {name: $target, group_id: $group_id}) + MERGE (e1)-[r:RELATED {fact: $fact, fact_type: $type}]->(e2) + SET r.created_at = datetime() + """ + session.run( + rel_query, + source=entity_name, + target=target, + group_id=graph_id, + fact=fact, + type=rel_type + ) + + logger.info(f"实体提取完成: {len(entities_data)} 个实体") + + def _call_llm_for_entities(self, text: str) -> str: + """调用 LLM 提取实体""" + from openai import OpenAI + + client = OpenAI( + api_key=Config.LLM_API_KEY, + base_url=Config.LLM_BASE_URL + ) + + prompt = f"""从以下文本中提取实体和关系。 + +重要:直接返回JSON数组,不要任何markdown格式,不要```标记。 + +要求返回格式: +[ + {{ + "name": "实体名", + "type": "实体类型", + "description": "描述", + "relationships": [ + {{"target": "目标实体", "type": "关系类型", "fact": "事实描述"}} + ] + }} +] + +文本内容: +{text[:3000]} + +直接返回JSON数组:""" + + try: + response = client.chat.completions.create( + model=Config.LLM_MODEL_NAME, + messages=[ + {"role": "system", "content": "你是一个实体关系提取助手。"}, + {"role": "user", "content": prompt} + ], + temperature=0.1, + max_tokens=2000 + ) + # 清理 JSON(去除 markdown 代码块) + content = response.choices[0].message.content + content = content.strip() + # 去除 ```json 和 ``` 标记 + if content.startswith("```"): + content = content.split("\n", 1)[1] if "\n" in content else content + if content.endswith("```"): + content = content.rsplit("```", 1)[0] + return content.strip() + except Exception as e: + logger.error(f"LLM 实体提取失败: {e}") + return "[]" + + def get_episode(self, episode_uuid: str) -> Any: + """使用同步驱动获取 episode""" + with self._sync_driver.session() as session: + query = """ + MATCH (e:Episodic {uuid: $uuid}) + RETURN e.content as content, e.created_at as created_at, + e.valid_at as valid_at, e.uuid as uuid, + e.name as name, e.group_id as group_id + """ + result = session.run(query, uuid=episode_uuid) + record = result.single() + if record: + return dict(record) + return None + + def search(self, graph_id: str, query: str, limit: int = 10) -> List[Dict]: + """使用同步驱动搜索(简单实现:搜索 episodes 内容)""" + group = self._get_group(graph_id) + with self._sync_driver.session() as session: + # 简单的文本搜索:匹配 episodes 内容 + query_cypher = """ + MATCH (e:Episodic {group_id: $group}) + WHERE e.content CONTAINS $search_text + RETURN e.content as content, e.uuid as uuid, e.name as name + LIMIT $limit + """ + result = session.run( + query_cypher, + group=group, + search_text=query, + limit=limit + ) + return [{"content": r["content"], "uuid": r["uuid"], "name": r["name"]} for r in result] + + def get_nodes(self, graph_id: str, limit: int = 100, cursor: str = None) -> List[Any]: + """通过同步驱动查询实体节点""" + with self._sync_driver.session() as session: + query = """ + MATCH (e:Entity)-[:MEMBER_OF]->(g:Group {name: $group}) + RETURN e.uuid as uuid_, e.name as name, labels(e) as labels, + e.summary as summary, e.created_at as created_at, + e.entity_type as entity_type + LIMIT $limit + """ + result = session.run(query, group=graph_id, limit=limit) + nodes = [dict(record) for record in result] + # 转换格式以兼容前端,将 entity_type 放入 attributes + for node in nodes: + if 'attributes' not in node: + node['attributes'] = {} + if node.get('entity_type'): + node['attributes']['entity_type'] = node['entity_type'] + return nodes + + def get_node(self, node_uuid: str) -> Any: + """通过同步驱动获取单个节点""" + with self._sync_driver.session() as session: + query = """ + MATCH (e:Entity {uuid: $uuid}) + RETURN e.uuid as uuid_, e.name as name, labels(e) as labels, + e.summary as summary, e.created_at as created_at + """ + result = session.run(query, uuid=node_uuid) + record = result.single() + if record: + node = dict(record) + if 'attributes' not in node: + node['attributes'] = {} + return node + return None + + def get_node_edges(self, node_uuid: str) -> List[Dict]: + """通过同步驱动获取节点的所有边""" + with self._sync_driver.session() as session: + query = """ + MATCH (e1:Entity {uuid: $uuid})-[r]-(e2:Entity) + RETURN r.uuid as uuid_, type(r) as name, r.fact as fact, + r.fact_type as fact_type, + e1.uuid as source_node_uuid, e2.uuid as target_node_uuid, + e1.name as source_node_name, e2.name as target_node_name, + r.created_at as created_at + """ + result = session.run(query, uuid=node_uuid) + edges = [dict(record) for record in result] + return edges + + def get_edges(self, graph_id: str, limit: int = 100, cursor: str = None) -> List[Any]: + """通过同步驱动查询边""" + with self._sync_driver.session() as session: + query = """ + MATCH (e1:Entity)-[r]-(e2:Entity) + WHERE e1.group_id = $group OR e2.group_id = $group + RETURN r.uuid as uuid_, type(r) as name, r.fact as fact, + r.fact_type as fact_type, + e1.uuid as source_node_uuid, e2.uuid as target_node_uuid, + e1.name as source_node_name, e2.name as target_node_name, + r.created_at as created_at, r.valid_at as valid_at, + r.invalid_at as invalid_at, r.expired_at as expired_at + LIMIT $limit + """ + result = session.run(query, group=graph_id, limit=limit) + edges = [dict(record) for record in result] + # 兼容前端格式 + for edge in edges: + if 'attributes' not in edge: + edge['attributes'] = {} + if 'episodes' not in edge: + edge['episodes'] = [] + return edges + + def delete(self, graph_id: str) -> bool: + """使用同步驱动删除图谱""" + with self._sync_driver.session() as session: + # 删除关联边 + session.run(""" + MATCH (e1:Entity)-[r]-(e2:Entity) + WHERE e1.group = $group OR e2.group = $group + DELETE r + """, group=graph_id) + # 删除实体节点 + session.run(""" + MATCH (e:Entity)-[:MEMBER_OF]->(g:Group {name: $group}) + DELETE e + """, group=graph_id) + # 删除组节点 + session.run(""" + MATCH (g:Group {name: $group}) + DELETE g + """, group=graph_id) + + if graph_id in self._graph_id_to_group: + del self._graph_id_to_group[graph_id] + + logger.info(f"Graphiti: 删除图谱 {graph_id}") + return True + + def set_ontology(self, graph_id: str, ontology: Dict) -> bool: + # Graphiti 通过 Pydantic 模型定义实体类型 + # 这里简化处理:仅记录 ontology 配置 + logger.info(f"Graphiti: 设置本体 {graph_id}, ontology types: {list(ontology.keys())}") + # 实际使用时需要动态创建实体类 + return True + + def get_graph_info(self, graph_id: str) -> Dict: + """使用同步驱动获取图谱信息""" + with self._sync_driver.session() as session: + # 统计节点数量 + node_result = session.run(""" + MATCH (e:Entity)-[:MEMBER_OF]->(g:Group {name: $group}) + RETURN count(e) as count + """, group=graph_id) + node_count = node_result.single()["count"] if node_result.single() else 0 + + # 统计边数量 + edge_result = session.run(""" + MATCH (e1:Entity)-[r]-(e2:Entity) + WHERE e1.group = $group OR e2.group = $group + RETURN count(r) as count + """, group=graph_id) + edge_count = edge_result.single()["count"] if edge_result.single() else 0 + + return { + "graph_id": graph_id, + "node_count": node_count, + "edge_count": edge_count, + } + + def _result_to_dict(self, result) -> Dict: + if hasattr(result, 'model_dump'): + return result.model_dump() + elif hasattr(result, 'dict'): + return result.dict() + return {} + + +# 全局缓存 +_adapter_cache: Optional[KnowledgeGraphAdapter] = None + + +def get_knowledge_graph_adapter(force_refresh: bool = False) -> KnowledgeGraphAdapter: + """ + 获取知识图谱适配器实例 + + Args: + force_refresh: 是否强制刷新缓存 + + Returns: + KnowledgeGraphAdapter: 适配器实例 + """ + global _adapter_cache + + if _adapter_cache is not None and not force_refresh: + return _adapter_cache + + mode = Config.KNOWLEDGE_GRAPH_MODE + + if mode == 'local': + _adapter_cache = GraphitiAdapter() + elif mode == 'cloud': + _adapter_cache = ZepCloudAdapter() + else: + raise ValueError(f"未知的 KNOWLEDGE_GRAPH_MODE: {mode}") + + return _adapter_cache + + +def reset_adapter(): + """重置适配器缓存""" + global _adapter_cache + _adapter_cache = None diff --git a/backend/app/services/oasis_profile_generator.py b/backend/app/services/oasis_profile_generator.py index 57836c53..692e139f 100644 --- a/backend/app/services/oasis_profile_generator.py +++ b/backend/app/services/oasis_profile_generator.py @@ -1,9 +1,10 @@ """ OASIS Agent Profile生成器 -将Zep图谱中的实体转换为OASIS模拟平台所需的Agent Profile格式 +将图谱中的实体转换为OASIS模拟平台所需的Agent Profile格式 +支持 Zep Cloud 和 Graphiti 两种模式 优化改进: -1. 调用Zep检索功能二次丰富节点信息 +1. 调用图谱检索功能二次丰富节点信息 2. 优化提示词生成非常详细的人设 3. 区分个人实体和抽象群体实体 """ @@ -16,11 +17,11 @@ from datetime import datetime from openai import OpenAI -from zep_cloud.client import Zep from ..config import Config from ..utils.logger import get_logger from .zep_entity_reader import EntityNode, ZepEntityReader +from .kg_adapter import get_knowledge_graph_adapter logger = get_logger('mirofish.oasis_profile') @@ -196,17 +197,15 @@ def __init__( api_key=self.api_key, base_url=self.base_url ) - - # Zep客户端用于检索丰富上下文 - self.zep_api_key = zep_api_key or Config.ZEP_API_KEY - self.zep_client = None + + # 图谱客户端用于检索丰富上下文 self.graph_id = graph_id - - if self.zep_api_key: - try: - self.zep_client = Zep(api_key=self.zep_api_key) - except Exception as e: - logger.warning(f"Zep客户端初始化失败: {e}") + # 使用适配器 + try: + self.kg = get_knowledge_graph_adapter() + except Exception as e: + logger.warning(f"图谱客户端初始化失败: {e}") + self.kg = None def generate_profile_from_entity( self, @@ -297,7 +296,7 @@ def _search_zep_for_entity(self, entity: EntityNode) -> Dict[str, Any]: """ import concurrent.futures - if not self.zep_client: + if not self.kg: return {"facts": [], "node_summaries": [], "context": ""} entity_name = entity.name @@ -323,7 +322,7 @@ def search_edges(): for attempt in range(max_retries): try: - return self.zep_client.graph.search( + return self.kg.search( query=comprehensive_query, graph_id=self.graph_id, limit=30, @@ -348,7 +347,7 @@ def search_nodes(): for attempt in range(max_retries): try: - return self.zep_client.graph.search( + return self.kg.search( query=comprehensive_query, graph_id=self.graph_id, limit=20, diff --git a/backend/app/services/report_agent.py b/backend/app/services/report_agent.py index 02ca5bdc..442b02a4 100644 --- a/backend/app/services/report_agent.py +++ b/backend/app/services/report_agent.py @@ -1114,14 +1114,37 @@ def _is_valid_tool_call(self, data: dict) -> bool: """校验解析出的 JSON 是否是合法的工具调用""" # 支持 {"name": ..., "parameters": ...} 和 {"tool": ..., "params": ...} 两种键名 tool_name = data.get("name") or data.get("tool") - if tool_name and tool_name in self.VALID_TOOL_NAMES: - # 统一键名为 name / parameters - if "tool" in data: - data["name"] = data.pop("tool") - if "params" in data and "parameters" not in data: - data["parameters"] = data.pop("params") + if not tool_name: + return False + + # 精确匹配 + if tool_name in self.VALID_TOOL_NAMES: + self._normalize_tool_call(data) return True + + # 容错匹配:处理常见格式错误 + # interviewagents -> interview_agents + # quicksearch -> quick_search + # panoramasearch -> panorama_search + # insightforge -> insight_forge + normalized = tool_name.replace("search", "_search").replace("forge", "_forge") + normalized = normalized.replace("agents", "_agents") + + # 尝试匹配 + for valid_name in self.VALID_TOOL_NAMES: + if normalized == valid_name or tool_name == valid_name: + data["name"] = valid_name # 修正为正确名称 + self._normalize_tool_call(data) + return True + return False + + def _normalize_tool_call(self, data: dict): + """统一键名格式""" + if "tool" in data: + data["name"] = data.pop("tool") + if "params" in data and "parameters" not in data: + data["parameters"] = data.pop("params") def _get_tools_description(self) -> str: """生成工具描述文本""" diff --git a/backend/app/services/zep_entity_reader.py b/backend/app/services/zep_entity_reader.py index 71661be4..40578c78 100644 --- a/backend/app/services/zep_entity_reader.py +++ b/backend/app/services/zep_entity_reader.py @@ -1,17 +1,16 @@ """ -Zep实体读取与过滤服务 -从Zep图谱中读取节点,筛选出符合预定义实体类型的节点 +图谱实体读取与过滤服务 +从图谱中读取节点,筛选出符合预定义实体类型的节点 +支持 Zep Cloud 和 Graphiti 两种模式 """ import time from typing import Dict, Any, List, Optional, Set, Callable, TypeVar from dataclasses import dataclass, field -from zep_cloud.client import Zep - from ..config import Config from ..utils.logger import get_logger -from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges +from .kg_adapter import get_knowledge_graph_adapter logger = get_logger('mirofish.zep_entity_reader') @@ -70,20 +69,18 @@ def to_dict(self) -> Dict[str, Any]: class ZepEntityReader: """ - Zep实体读取与过滤服务 - + 图谱实体读取与过滤服务 + 主要功能: - 1. 从Zep图谱读取所有节点 + 1. 从图谱读取所有节点 2. 筛选出符合预定义实体类型的节点(Labels不只是Entity的节点) 3. 获取每个实体的相关边和关联节点信息 """ - + def __init__(self, api_key: Optional[str] = None): - self.api_key = api_key or Config.ZEP_API_KEY - if not self.api_key: - raise ValueError("ZEP_API_KEY 未配置") - - self.client = Zep(api_key=self.api_key) + self.api_key = api_key # 保留参数兼容性 + # 使用适配器 + self.kg = get_knowledge_graph_adapter() def _call_with_retry( self, @@ -136,15 +133,25 @@ def get_all_nodes(self, graph_id: str) -> List[Dict[str, Any]]: """ logger.info(f"获取图谱 {graph_id} 的所有节点...") - nodes = fetch_all_nodes(self.client, graph_id) + # 使用适配器获取节点 + nodes = self.kg.get_nodes(graph_id, limit=2000) nodes_data = [] for node in nodes: - nodes_data.append({ - "uuid": getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''), - "name": node.name or "", - "labels": node.labels or [], - "summary": node.summary or "", + if isinstance(node, dict): + nodes_data.append({ + "uuid": node.get('uuid_', '') or node.get('uuid', ''), + "name": node.get('name', ''), + "labels": node.get('labels', []), + "summary": node.get('summary', ''), + "attributes": node.get('attributes', {}), + }) + else: + nodes_data.append({ + "uuid": getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''), + "name": node.name or "", + "labels": node.labels or [], + "summary": node.summary or "", "attributes": node.attributes or {}, }) @@ -163,18 +170,29 @@ def get_all_edges(self, graph_id: str) -> List[Dict[str, Any]]: """ logger.info(f"获取图谱 {graph_id} 的所有边...") - edges = fetch_all_edges(self.client, graph_id) + # 使用适配器获取边 + edges = self.kg.get_edges(graph_id, limit=2000) edges_data = [] for edge in edges: - edges_data.append({ - "uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''), - "name": edge.name or "", - "fact": edge.fact or "", - "source_node_uuid": edge.source_node_uuid, - "target_node_uuid": edge.target_node_uuid, - "attributes": edge.attributes or {}, - }) + if isinstance(edge, dict): + edges_data.append({ + "uuid": edge.get('uuid_', '') or edge.get('uuid', ''), + "name": edge.get('name', ''), + "fact": edge.get('fact', ''), + "source_node_uuid": edge.get('source_node_uuid', ''), + "target_node_uuid": edge.get('target_node_uuid', ''), + "attributes": edge.get('attributes', {}), + }) + else: + edges_data.append({ + "uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''), + "name": edge.name or "", + "fact": edge.fact or "", + "source_node_uuid": edge.source_node_uuid, + "target_node_uuid": edge.target_node_uuid, + "attributes": edge.attributes or {}, + }) logger.info(f"共获取 {len(edges_data)} 条边") return edges_data @@ -190,23 +208,33 @@ def get_node_edges(self, node_uuid: str) -> List[Dict[str, Any]]: 边列表 """ try: - # 使用重试机制调用Zep API + # 使用重试机制调用图谱API edges = self._call_with_retry( - func=lambda: self.client.graph.node.get_entity_edges(node_uuid=node_uuid), + func=lambda: self.kg.get_node_edges(node_uuid), operation_name=f"获取节点边(node={node_uuid[:8]}...)" ) - + edges_data = [] for edge in edges: - edges_data.append({ - "uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''), - "name": edge.name or "", - "fact": edge.fact or "", - "source_node_uuid": edge.source_node_uuid, - "target_node_uuid": edge.target_node_uuid, - "attributes": edge.attributes or {}, - }) - + if isinstance(edge, dict): + edges_data.append({ + "uuid": edge.get('uuid_', '') or edge.get('uuid', ''), + "name": edge.get('name', ''), + "fact": edge.get('fact', ''), + "source_node_uuid": edge.get('source_node_uuid', ''), + "target_node_uuid": edge.get('target_node_uuid', ''), + "attributes": edge.get('attributes', {}), + }) + else: + edges_data.append({ + "uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''), + "name": edge.name or "", + "fact": edge.fact or "", + "source_node_uuid": edge.source_node_uuid, + "target_node_uuid": edge.target_node_uuid, + "attributes": edge.attributes or {}, + }) + return edges_data except Exception as e: logger.warning(f"获取节点 {node_uuid} 的边失败: {str(e)}") @@ -251,23 +279,27 @@ def filter_defined_entities( for node in all_nodes: labels = node.get("labels", []) - - # 筛选逻辑:Labels必须包含除"Entity"和"Node"之外的标签 - custom_labels = [l for l in labels if l not in ["Entity", "Node"]] - - if not custom_labels: - # 只有默认标签,跳过 + + # 获取实体类型(优先从属性获取,其次从标签获取) + entity_type = None + if node.get("attributes"): + entity_type = node["attributes"].get("entity_type") + + if not entity_type: + # 筛选逻辑:Labels必须包含除"Entity"和"Node"之外的标签 + custom_labels = [l for l in labels if l not in ["Entity", "Node"]] + if custom_labels: + entity_type = custom_labels[0] + + if not entity_type: + # 没有实体类型,跳过 continue - + # 如果指定了预定义类型,检查是否匹配 if defined_entity_types: - matching_labels = [l for l in custom_labels if l in defined_entity_types] - if not matching_labels: + if entity_type not in defined_entity_types: continue - entity_type = matching_labels[0] - else: - entity_type = custom_labels[0] - + entity_types_found.add(entity_type) # 创建实体节点对象 @@ -341,27 +373,27 @@ def get_entity_with_context( Args: graph_id: 图谱ID entity_uuid: 实体UUID - + Returns: EntityNode或None """ try: # 使用重试机制获取节点 node = self._call_with_retry( - func=lambda: self.client.graph.node.get(uuid_=entity_uuid), + func=lambda: self.kg.get_node(entity_uuid), operation_name=f"获取节点详情(uuid={entity_uuid[:8]}...)" ) - + if not node: return None - + # 获取节点的边 edges = self.get_node_edges(entity_uuid) - + # 获取所有节点用于关联查找 all_nodes = self.get_all_nodes(graph_id) node_map = {n["uuid"]: n for n in all_nodes} - + # 处理相关边和节点 related_edges = [] related_node_uuids = set() diff --git a/backend/app/services/zep_graph_memory_updater.py b/backend/app/services/zep_graph_memory_updater.py index a8f3cecd..116f6d74 100644 --- a/backend/app/services/zep_graph_memory_updater.py +++ b/backend/app/services/zep_graph_memory_updater.py @@ -1,6 +1,7 @@ """ -Zep图谱记忆更新服务 -将模拟中的Agent活动动态更新到Zep图谱中 +图谱记忆更新服务 +将模拟中的Agent活动动态更新到图谱中 +支持 Zep Cloud 和 Graphiti 两种模式 """ import os @@ -12,10 +13,9 @@ from datetime import datetime from queue import Queue, Empty -from zep_cloud.client import Zep - from ..config import Config from ..utils.logger import get_logger +from .kg_adapter import get_knowledge_graph_adapter logger = get_logger('mirofish.zep_graph_memory_updater') @@ -200,49 +200,47 @@ def _describe_generic(self) -> str: class ZepGraphMemoryUpdater: """ - Zep图谱记忆更新器 - - 监控模拟的actions日志文件,将新的agent活动实时更新到Zep图谱中。 - 按平台分组,每累积BATCH_SIZE条活动后批量发送到Zep。 - - 所有有意义的行为都会被更新到Zep,action_args中会包含完整的上下文信息: + 图谱记忆更新器 + + 监控模拟的actions日志文件,将新的agent活动实时更新到图谱中。 + 按平台分组,每累积BATCH_SIZE条活动后批量发送。 + + 所有有意义的行为都会被更新到图谱,action_args中会包含完整的上下文信息: - 点赞/踩的帖子原文 - 转发/引用的帖子原文 - 关注/屏蔽的用户名 - 点赞/踩的评论原文 """ - + # 批量发送大小(每个平台累积多少条后发送) BATCH_SIZE = 5 - + # 平台名称映射(用于控制台显示) PLATFORM_DISPLAY_NAMES = { 'twitter': '世界1', 'reddit': '世界2', } - + # 发送间隔(秒),避免请求过快 SEND_INTERVAL = 0.5 - + # 重试配置 MAX_RETRIES = 3 RETRY_DELAY = 2 # 秒 - + def __init__(self, graph_id: str, api_key: Optional[str] = None): """ 初始化更新器 - + Args: - graph_id: Zep图谱ID - api_key: Zep API Key(可选,默认从配置读取) + graph_id: 图谱ID + api_key: 保留参数(兼容旧代码) """ self.graph_id = graph_id - self.api_key = api_key or Config.ZEP_API_KEY - - if not self.api_key: - raise ValueError("ZEP_API_KEY未配置") - - self.client = Zep(api_key=self.api_key) + self.api_key = api_key # 保留参数兼容性 + + # 使用适配器 + self.kg = get_knowledge_graph_adapter() # 活动队列 self._activity_queue: Queue = Queue() @@ -401,29 +399,28 @@ def _send_batch_activities(self, activities: List[AgentActivity], platform: str) # 将多条活动合并为一条文本,用换行分隔 episode_texts = [activity.to_episode_text() for activity in activities] combined_text = "\n".join(episode_texts) - + # 带重试的发送 for attempt in range(self.MAX_RETRIES): try: - self.client.graph.add( + self.kg.add_episode( graph_id=self.graph_id, - type="text", - data=combined_text + text=combined_text ) - + self._total_sent += 1 self._total_items_sent += len(activities) display_name = self._get_platform_display_name(platform) logger.info(f"成功批量发送 {len(activities)} 条{display_name}活动到图谱 {self.graph_id}") logger.debug(f"批量内容预览: {combined_text[:200]}...") return - + except Exception as e: if attempt < self.MAX_RETRIES - 1: - logger.warning(f"批量发送到Zep失败 (尝试 {attempt + 1}/{self.MAX_RETRIES}): {e}") + logger.warning(f"批量发送到图谱失败 (尝试 {attempt + 1}/{self.MAX_RETRIES}): {e}") time.sleep(self.RETRY_DELAY * (attempt + 1)) else: - logger.error(f"批量发送到Zep失败,已重试{self.MAX_RETRIES}次: {e}") + logger.error(f"批量发送到图谱失败,已重试{self.MAX_RETRIES}次: {e}") self._failed_count += 1 def _flush_remaining(self): diff --git a/backend/app/services/zep_tools.py b/backend/app/services/zep_tools.py index 384cf540..201ce9cb 100644 --- a/backend/app/services/zep_tools.py +++ b/backend/app/services/zep_tools.py @@ -1,6 +1,7 @@ """ -Zep检索工具服务 +图谱检索工具服务 封装图谱搜索、节点读取、边查询等工具,供Report Agent使用 +支持 Zep Cloud 和 Graphiti 两种模式 核心检索工具(优化后): 1. InsightForge(深度洞察检索)- 最强大的混合检索,自动生成子问题并多维度检索 @@ -13,12 +14,10 @@ from typing import Dict, Any, List, Optional from dataclasses import dataclass, field -from zep_cloud.client import Zep - from ..config import Config from ..utils.logger import get_logger from ..utils.llm_client import LLMClient -from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges +from .kg_adapter import get_knowledge_graph_adapter logger = get_logger('mirofish.zep_tools') @@ -61,7 +60,24 @@ class NodeInfo: labels: List[str] summary: str attributes: Dict[str, Any] - + + def __init__(self, uuid: str = "", name: str = "", labels: List[str] = None, + summary: str = "", attributes: Dict[str, Any] = None): + # 如果传入的是 dict,转换为对象属性 + if isinstance(uuid, dict): + d = uuid + self.uuid = d.get('uuid_') or d.get('uuid', '') + self.name = d.get('name', '') + self.labels = d.get('labels', []) + self.summary = d.get('summary', '') + self.attributes = d.get('attributes', {}) + else: + self.uuid = uuid or '' + self.name = name or '' + self.labels = labels or [] + self.summary = summary or '' + self.attributes = attributes or {} + def to_dict(self) -> Dict[str, Any]: return { "uuid": self.uuid, @@ -70,7 +86,7 @@ def to_dict(self) -> Dict[str, Any]: "summary": self.summary, "attributes": self.attributes } - + def to_text(self) -> str: """转换为文本格式""" entity_type = next((l for l in self.labels if l not in ["Entity", "Node"]), "未知类型") @@ -92,7 +108,39 @@ class EdgeInfo: valid_at: Optional[str] = None invalid_at: Optional[str] = None expired_at: Optional[str] = None - + + def __init__(self, uuid: str = "", name: str = "", fact: str = "", + source_node_uuid: str = "", target_node_uuid: str = "", + source_node_name: str = None, target_node_name: str = None, + created_at: str = None, valid_at: str = None, + invalid_at: str = None, expired_at: str = None): + # 如果传入的是 dict,转换为对象属性 + if isinstance(uuid, dict): + d = uuid + self.uuid = d.get('uuid_') or d.get('uuid', '') + self.name = d.get('name', 'RELATED') + self.fact = d.get('fact', '') + self.source_node_uuid = d.get('source_node_uuid', '') + self.target_node_uuid = d.get('target_node_uuid', '') + self.source_node_name = d.get('source_node_name') + self.target_node_name = d.get('target_node_name') + self.created_at = d.get('created_at') + self.valid_at = d.get('valid_at') + self.invalid_at = d.get('invalid_at') + self.expired_at = d.get('expired_at') + else: + self.uuid = uuid or '' + self.name = name or 'RELATED' + self.fact = fact or '' + self.source_node_uuid = source_node_uuid or '' + self.target_node_uuid = target_node_uuid or '' + self.source_node_name = source_node_name + self.target_node_name = target_node_name + self.created_at = created_at + self.valid_at = valid_at + self.invalid_at = invalid_at + self.expired_at = expired_at + def to_dict(self) -> Dict[str, Any]: return { "uuid": self.uuid, @@ -422,11 +470,9 @@ class ZepToolsService: RETRY_DELAY = 2.0 def __init__(self, api_key: Optional[str] = None, llm_client: Optional[LLMClient] = None): - self.api_key = api_key or Config.ZEP_API_KEY - if not self.api_key: - raise ValueError("ZEP_API_KEY 未配置") - - self.client = Zep(api_key=self.api_key) + self.api_key = api_key # 保留参数兼容性 + # 使用适配器 + self.kg = get_knowledge_graph_adapter() # LLM客户端用于InsightForge生成子问题 self._llm_client = llm_client logger.info("ZepToolsService 初始化完成") @@ -485,15 +531,13 @@ def search_graph( """ logger.info(f"图谱搜索: graph_id={graph_id}, query={query[:50]}...") - # 尝试使用Zep Cloud Search API + # 尝试使用图谱搜索 API try: search_results = self._call_with_retry( - func=lambda: self.client.graph.search( + func=lambda: self.kg.search( graph_id=graph_id, query=query, limit=limit, - scope=scope, - reranker="cross_encoder" ), operation_name=f"图谱搜索(graph={graph_id})" ) @@ -501,33 +545,62 @@ def search_graph( facts = [] edges = [] nodes = [] - - # 解析边搜索结果 - if hasattr(search_results, 'edges') and search_results.edges: - for edge in search_results.edges: - if hasattr(edge, 'fact') and edge.fact: - facts.append(edge.fact) - edges.append({ - "uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''), - "name": getattr(edge, 'name', ''), - "fact": getattr(edge, 'fact', ''), - "source_node_uuid": getattr(edge, 'source_node_uuid', ''), - "target_node_uuid": getattr(edge, 'target_node_uuid', ''), - }) - - # 解析节点搜索结果 - if hasattr(search_results, 'nodes') and search_results.nodes: - for node in search_results.nodes: - nodes.append({ - "uuid": getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''), - "name": getattr(node, 'name', ''), - "labels": getattr(node, 'labels', []), - "summary": getattr(node, 'summary', ''), - }) - # 节点摘要也算作事实 - if hasattr(node, 'summary') and node.summary: - facts.append(f"[{node.name}]: {node.summary}") - + + # 解析搜索结果(兼容对象和字典格式) + # 适配器返回的是 List[Dict],需要处理 + if isinstance(search_results, list): + for result in search_results: + # 判断是边还是节点 + if isinstance(result, dict): + if result.get('source_node_uuid') and result.get('target_node_uuid'): + # 边 + fact = result.get('fact', '') + if fact: + facts.append(fact) + edges.append({ + "uuid": result.get('uuid_', '') or result.get('uuid', ''), + "name": result.get('name', ''), + "fact": fact, + "source_node_uuid": result.get('source_node_uuid', ''), + "target_node_uuid": result.get('target_node_uuid', ''), + }) + else: + # 节点 + name = result.get('name', '') + summary = result.get('summary', '') + nodes.append({ + "uuid": result.get('uuid_', '') or result.get('uuid', ''), + "name": name, + "labels": result.get('labels', []), + "summary": summary, + }) + if summary: + facts.append(f"[{name}]: {summary}") + else: + # 原始对象格式(保留兼容性) + if hasattr(search_results, 'edges') and search_results.edges: + for edge in search_results.edges: + if hasattr(edge, 'fact') and edge.fact: + facts.append(edge.fact) + edges.append({ + "uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''), + "name": getattr(edge, 'name', ''), + "fact": getattr(edge, 'fact', ''), + "source_node_uuid": getattr(edge, 'source_node_uuid', ''), + "target_node_uuid": getattr(edge, 'target_node_uuid', ''), + }) + + if hasattr(search_results, 'nodes') and search_results.nodes: + for node in search_results.nodes: + nodes.append({ + "uuid": getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''), + "name": getattr(node, 'name', ''), + "labels": getattr(node, 'labels', []), + "summary": getattr(node, 'summary', ''), + }) + if hasattr(node, 'summary') and node.summary: + facts.append(f"[{node.name}]: {node.summary}") + logger.info(f"搜索完成: 找到 {len(facts)} 条相关事实") return SearchResult( @@ -659,7 +732,38 @@ def get_all_nodes(self, graph_id: str) -> List[NodeInfo]: """ logger.info(f"获取图谱 {graph_id} 的所有节点...") - nodes = fetch_all_nodes(self.client, graph_id) + # 使用适配器分页获取所有节点 + # Zep Cloud API 使用 uuid_cursor 分页,但响应不返回 cursor + # 通过返回数量判断是否有更多:< limit 则说明是最后一页 + nodes = [] + cursor = None + max_pages = 100 # 最多获取 100 页,防止无限循环 + page_count = 0 + + while page_count < max_pages: + page = self.kg.get_nodes(graph_id, limit=100, cursor=cursor) + if not page: + break + # 将 dict 转换为 NodeInfo 对象 + for item in page: + if isinstance(item, dict): + nodes.append(NodeInfo(item)) + else: + nodes.append(item) + page_count += 1 + + # 如果返回数量 < limit,说明是最后一页 + if len(page) < 100: + break + + # 尝试获取下一页 - Zep Cloud 使用 uuid_cursor 参数 + # 由于 API 不返回 next_cursor,我们需要用最后一条的 uuid 作为 cursor + last_item = page[-1] + cursor = getattr(last_item, 'uuid_', None) or getattr(last_item, 'uuid', None) + if not cursor: + break + + logger.info(f"分页获取完成,共 {page_count} 页,{len(nodes)} 个节点") result = [] for node in nodes: @@ -688,7 +792,33 @@ def get_all_edges(self, graph_id: str, include_temporal: bool = True) -> List[Ed """ logger.info(f"获取图谱 {graph_id} 的所有边...") - edges = fetch_all_edges(self.client, graph_id) + # 使用适配器分页获取所有边 + edges = [] + cursor = None + max_pages = 100 + page_count = 0 + + while page_count < max_pages: + page = self.kg.get_edges(graph_id, limit=100, cursor=cursor) + if not page: + break + # 将 dict 转换为 EdgeInfo 对象 + for item in page: + if isinstance(item, dict): + edges.append(EdgeInfo(item)) + else: + edges.append(item) + page_count += 1 + + if len(page) < 100: + break + + last_item = page[-1] + cursor = getattr(last_item, 'uuid_', None) or (last_item.get('uuid_') if isinstance(last_item, dict) else None) or (last_item.get('uuid') if isinstance(last_item, dict) else None) + if not cursor: + break + + logger.info(f"分页获取完成,共 {page_count} 页,{len(edges)} 条边") result = [] for edge in edges: @@ -724,23 +854,33 @@ def get_node_detail(self, node_uuid: str) -> Optional[NodeInfo]: 节点信息或None """ logger.info(f"获取节点详情: {node_uuid[:8]}...") - + try: node = self._call_with_retry( - func=lambda: self.client.graph.node.get(uuid_=node_uuid), + func=lambda: self.kg.get_node(node_uuid), operation_name=f"获取节点详情(uuid={node_uuid[:8]}...)" ) - + if not node: return None - - return NodeInfo( - uuid=getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''), - name=node.name or "", - labels=node.labels or [], - summary=node.summary or "", - attributes=node.attributes or {} - ) + + # 兼容对象和字典格式 + if isinstance(node, dict): + return NodeInfo( + uuid=node.get('uuid_', '') or node.get('uuid', ''), + name=node.get('name', ''), + labels=node.get('labels', []), + summary=node.get('summary', ''), + attributes=node.get('attributes', {}) + ) + else: + return NodeInfo( + uuid=getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''), + name=node.name or "", + labels=node.labels or [], + summary=node.summary or "", + attributes=node.attributes or {} + ) except Exception as e: logger.error(f"获取节点详情失败: {str(e)}") return None @@ -863,22 +1003,24 @@ def get_graph_statistics(self, graph_id: str) -> Dict[str, Any]: 统计信息 """ logger.info(f"获取图谱 {graph_id} 的统计信息...") - + nodes = self.get_all_nodes(graph_id) edges = self.get_all_edges(graph_id) - - # 统计实体类型分布 + + # 统计实体类型分布(兼容 dict 和对象) entity_types = {} for node in nodes: - for label in node.labels: + labels = node.labels if hasattr(node, 'labels') else node.get('labels', []) + for label in labels: if label not in ["Entity", "Node"]: entity_types[label] = entity_types.get(label, 0) + 1 - - # 统计关系类型分布 + + # 统计关系类型分布(兼容 dict 和对象) relation_types = {} for edge in edges: - relation_types[edge.name] = relation_types.get(edge.name, 0) + 1 - + edge_name = edge.name if hasattr(edge, 'name') else edge.get('name', 'RELATED') + relation_types[edge_name] = relation_types.get(edge_name, 0) + 1 + return { "graph_id": graph_id, "total_nodes": len(nodes), diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 4f5361d5..83d2dc6d 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -12,23 +12,27 @@ dependencies = [ # 核心框架 "flask>=3.0.0", "flask-cors>=6.0.0", - + # LLM 相关 "openai>=1.0.0", - - # Zep Cloud - "zep-cloud==3.13.0", - + + # 知识图谱 - 根据 KNOWLEDGE_GRAPH_MODE 选择使用 + # cloud 模式: zep-cloud + # local 模式: graphiti-core + neo4j + "zep-cloud>=3.13.0", + "graphiti-core>=0.5.0", + "neo4j>=5.0.0", + # OASIS 社交媒体模拟 "camel-oasis==0.2.5", "camel-ai==0.2.78", - + # 文件处理 "PyMuPDF>=1.24.0", # 编码检测(支持非UTF-8编码的文本文件) "charset-normalizer>=3.0.0", "chardet>=5.0.0", - + # 工具库 "python-dotenv>=1.0.0", "pydantic>=2.0.0", diff --git a/backend/requirements.txt b/backend/requirements.txt index 4f146296..e23821b2 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -13,8 +13,13 @@ flask-cors>=6.0.0 # OpenAI SDK(统一使用 OpenAI 格式调用 LLM) openai>=1.0.0 -# ============= Zep Cloud ============= -zep-cloud==3.13.0 +# ============= 知识图谱 ============= +# 根据 KNOWLEDGE_GRAPH_MODE 选择使用 +# cloud 模式: zep-cloud +# local 模式: graphiti-core + neo4j +zep-cloud>=3.13.0 +graphiti-core>=0.5.0 +neo4j>=5.0.0 # ============= OASIS 社交媒体模拟 ============= # OASIS 社交模拟框架 diff --git a/backend/tests/test_kg_adapter.py b/backend/tests/test_kg_adapter.py new file mode 100644 index 00000000..132705f4 --- /dev/null +++ b/backend/tests/test_kg_adapter.py @@ -0,0 +1,152 @@ +""" +Knowledge Graph Adapter Unit Tests + +Tests the kg_adapter module API signatures and configuration. +Run with: uv run pytest tests/test_kg_adapter.py -v +""" +import pytest +from unittest.mock import Mock, patch +import os + + +class TestZepCloudAdapterAPI: + """Test ZepCloudAdapter API calls match Zep Cloud SDK""" + + def test_create_graph_signature(self): + """Test create_graph accepts graph_id and name""" + from app.services.kg_adapter import ZepCloudAdapter + import inspect + + sig = inspect.signature(ZepCloudAdapter.create_graph) + params = list(sig.parameters.keys()) + assert 'self' in params + assert 'graph_id' in params + assert 'name' in params + + def test_add_episode_signature(self): + """Test add_episode accepts graph_id and text""" + from app.services.kg_adapter import ZepCloudAdapter + import inspect + + sig = inspect.signature(ZepCloudAdapter.add_episode) + params = list(sig.parameters.keys()) + assert 'graph_id' in params + assert 'text' in params + + def test_add_episodes_batch_signature(self): + """Test add_episodes_batch accepts graph_id and texts""" + from app.services.kg_adapter import ZepCloudAdapter + import inspect + + sig = inspect.signature(ZepCloudAdapter.add_episodes_batch) + params = list(sig.parameters.keys()) + assert 'graph_id' in params + assert 'texts' in params + + def test_set_ontology_signature(self): + """Test set_ontology accepts graph_id and ontology""" + from app.services.kg_adapter import ZepCloudAdapter + import inspect + + sig = inspect.signature(ZepCloudAdapter.set_ontology) + params = list(sig.parameters.keys()) + assert 'graph_id' in params + assert 'ontology' in params + + def test_search_signature(self): + """Test search accepts graph_id, query and limit""" + from app.services.kg_adapter import ZepCloudAdapter + import inspect + + sig = inspect.signature(ZepCloudAdapter.search) + params = list(sig.parameters.keys()) + assert 'graph_id' in params + assert 'query' in params + assert 'limit' in params + + def test_get_nodes_signature(self): + """Test get_nodes accepts graph_id, limit and cursor""" + from app.services.kg_adapter import ZepCloudAdapter + import inspect + + sig = inspect.signature(ZepCloudAdapter.get_nodes) + params = list(sig.parameters.keys()) + assert 'graph_id' in params + assert 'limit' in params + assert 'cursor' in params + + def test_get_edges_signature(self): + """Test get_edges accepts graph_id, limit and cursor""" + from app.services.kg_adapter import ZepCloudAdapter + import inspect + + sig = inspect.signature(ZepCloudAdapter.get_edges) + params = list(sig.parameters.keys()) + assert 'graph_id' in params + assert 'limit' in params + assert 'cursor' in params + + +class TestGraphitiAdapterAPI: + """Test GraphitiAdapter API signatures""" + + def test_create_graph_signature(self): + """Test create_graph accepts graph_id and name""" + from app.services.kg_adapter import GraphitiAdapter + import inspect + + sig = inspect.signature(GraphitiAdapter.create_graph) + params = list(sig.parameters.keys()) + assert 'graph_id' in params + + def test_add_episode_signature(self): + """Test add_episode accepts graph_id and text""" + from app.services.kg_adapter import GraphitiAdapter + import inspect + + sig = inspect.signature(GraphitiAdapter.add_episode) + params = list(sig.parameters.keys()) + assert 'graph_id' in params + assert 'text' in params + + def test_add_episodes_batch_signature(self): + """Test add_episodes_batch accepts graph_id and texts""" + from app.services.kg_adapter import GraphitiAdapter + import inspect + + sig = inspect.signature(GraphitiAdapter.add_episodes_batch) + params = list(sig.parameters.keys()) + assert 'graph_id' in params + assert 'texts' in params + + def test_search_signature(self): + """Test search accepts graph_id, query and limit""" + from app.services.kg_adapter import GraphitiAdapter + import inspect + + sig = inspect.signature(GraphitiAdapter.search) + params = list(sig.parameters.keys()) + assert 'graph_id' in params + assert 'query' in params + assert 'limit' in params + + +class TestAdapterFactory: + """Test adapter factory function""" + + def test_factory_returns_adapter(self): + """Test factory returns an adapter""" + from app.services.kg_adapter import get_knowledge_graph_adapter + + adapter = get_knowledge_graph_adapter() + assert adapter is not None + + def test_cloud_mode(self): + """Test cloud mode is default""" + from app.config import Config + + assert Config.KNOWLEDGE_GRAPH_MODE == 'cloud' + + +if __name__ == '__main__': + pytest.main([__file__, '-v']) diff --git a/backend/uv.lock b/backend/uv.lock index f1ce4b60..cb34fa5c 100644 --- a/backend/uv.lock +++ b/backend/uv.lock @@ -475,6 +475,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/07/6c/aa3f2f849e01cb6a001cd8554a88d4c77c5c1a31c95bdf1cf9301e6d9ef4/defusedxml-0.7.1-py2.py3-none-any.whl", hash = "sha256:a352e7e428770286cc899e2542b6cdaedb2b4953ff269a210103ec58f6198a61", size = 25604, upload-time = "2021-03-08T10:59:24.45Z" }, ] +[[package]] +name = "diskcache" +version = "5.6.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/3f/21/1c1ffc1a039ddcc459db43cc108658f32c57d271d7289a2794e401d0fdb6/diskcache-5.6.3.tar.gz", hash = "sha256:2c3a3fa2743d8535d832ec61c2054a1641f41775aa7c556758a109941e33e4fc", size = 67916, upload-time = "2023-08-31T06:12:00.316Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3f/27/4570e78fc0bf5ea0ca45eb1de3818a23787af9b390c0b0a0033a1b8236f9/diskcache-5.6.3-py3-none-any.whl", hash = "sha256:5e31b2d5fbad117cc363ebaf6b689474db18a1f6438bc82358b024abd4c2ca19", size = 45550, upload-time = "2023-08-31T06:11:58.822Z" }, +] + [[package]] name = "distlib" version = "0.4.0" @@ -592,6 +601,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/51/c7/b64cae5dba3a1b138d7123ec36bb5ccd39d39939f18454407e5468f4763f/fsspec-2025.12.0-py3-none-any.whl", hash = "sha256:8bf1fe301b7d8acfa6e8571e3b1c3d158f909666642431cc78a1b7b4dbc5ec5b", size = 201422, upload-time = "2025-12-03T15:23:41.434Z" }, ] +[[package]] +name = "graphiti-core" +version = "0.11.6" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "diskcache" }, + { name = "neo4j" }, + { name = "numpy" }, + { name = "openai" }, + { name = "pydantic" }, + { name = "python-dotenv" }, + { name = "tenacity" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/30/94/3f84400e5f02ea8e9dc79784202de4173cbc16f4b3ad1bd4302da888e4d8/graphiti_core-0.11.6.tar.gz", hash = "sha256:31d26621834d7d4b8865059ab749feb18af15937b59c69598a640a5dfabea331", size = 71928, upload-time = "2025-05-15T17:58:02.304Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ac/2e/c8f22f01585bf173d1c82f6d4615511aebc75aeda764c69aa394446fa93c/graphiti_core-0.11.6-py3-none-any.whl", hash = "sha256:6ec4807a884f5ea88b942d0c8b7bcd2e107c7358ab4f98ef2a2092c229929707", size = 111001, upload-time = "2025-05-15T17:58:00.542Z" }, +] + [[package]] name = "h11" version = "0.16.0" @@ -1248,6 +1275,8 @@ dependencies = [ { name = "charset-normalizer" }, { name = "flask" }, { name = "flask-cors" }, + { name = "graphiti-core" }, + { name = "neo4j" }, { name = "openai" }, { name = "pydantic" }, { name = "pymupdf" }, @@ -1276,6 +1305,8 @@ requires-dist = [ { name = "charset-normalizer", specifier = ">=3.0.0" }, { name = "flask", specifier = ">=3.0.0" }, { name = "flask-cors", specifier = ">=6.0.0" }, + { name = "graphiti-core", specifier = ">=0.5.0" }, + { name = "neo4j", specifier = ">=5.0.0" }, { name = "openai", specifier = ">=1.0.0" }, { name = "pipreqs", marker = "extra == 'dev'", specifier = ">=0.5.0" }, { name = "pydantic", specifier = ">=2.0.0" }, @@ -1283,7 +1314,7 @@ requires-dist = [ { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0.0" }, { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.23.0" }, { name = "python-dotenv", specifier = ">=1.0.0" }, - { name = "zep-cloud", specifier = "==3.13.0" }, + { name = "zep-cloud", specifier = ">=3.13.0" }, ] provides-extras = ["dev"] @@ -2987,6 +3018,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/40/44/4a5f08c96eb108af5cb50b41f76142f0afa346dfa99d5296fe7202a11854/tabulate-0.9.0-py3-none-any.whl", hash = "sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f", size = 35252, upload-time = "2022-10-06T17:21:44.262Z" }, ] +[[package]] +name = "tenacity" +version = "9.1.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/47/c6/ee486fd809e357697ee8a44d3d69222b344920433d3b6666ccd9b374630c/tenacity-9.1.4.tar.gz", hash = "sha256:adb31d4c263f2bd041081ab33b498309a57c77f9acf2db65aadf0898179cf93a", size = 49413, upload-time = "2026-02-07T10:45:33.841Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d7/c1/eb8f9debc45d3b7918a32ab756658a0904732f75e555402972246b0b8e71/tenacity-9.1.4-py3-none-any.whl", hash = "sha256:6095a360c919085f28c6527de529e76a06ad89b23659fa881ae0649b867a9d55", size = 28926, upload-time = "2026-02-07T10:45:32.24Z" }, +] + [[package]] name = "texttable" version = "1.7.0" From 5e05e7e290d1245696a3c23964f44efc9d2eae95 Mon Sep 17 00:00:00 2001 From: huamingjie Date: Wed, 11 Mar 2026 23:27:00 +0800 Subject: [PATCH 2/8] feat(report): change tool call format from JSON to XML - Update prompt template to use XML format: value - Rewrite _parse_tool_calls to support new XML format - Add fallback support for old JSON format - Update response cleaning regex to handle new format Co-Authored-By: Claude Opus 4.6 --- backend/app/services/report_agent.py | 53 +++++++++++++++++++++------- 1 file changed, 40 insertions(+), 13 deletions(-) diff --git a/backend/app/services/report_agent.py b/backend/app/services/report_agent.py index 442b02a4..d06c0e35 100644 --- a/backend/app/services/report_agent.py +++ b/backend/app/services/report_agent.py @@ -721,8 +721,8 @@ def to_dict(self) -> Dict[str, Any]: 选项A - 调用工具: 输出你的思考,然后用以下格式调用一个工具: - -{{"name": "工具名称", "parameters": {{"参数名": "参数值"}}}} + + 参数值 系统会执行工具并把结果返回给你。你不需要也不能自己编写工具返回结果。 @@ -844,8 +844,8 @@ def to_dict(self) -> Dict[str, Any]: {tools_description} 【工具调用格式】 - -{{"name": "工具名称", "parameters": {{"参数名": "参数值"}}}} + + 参数值 【回答风格】 @@ -1068,25 +1068,52 @@ def _parse_tool_calls(self, response: str) -> List[Dict[str, Any]]: 从LLM响应中解析工具调用 支持的格式(按优先级): - 1. {"name": "tool_name", "parameters": {...}} - 2. 裸 JSON(响应整体或单行就是一个工具调用 JSON) + 1. XML格式(标准格式): value + 2. JSON格式(兜底): {"name": "tool_name", "parameters": {...}} + 3. 裸 JSON(兜底): 响应整体或单行就是一个工具调用 JSON """ tool_calls = [] - # 格式1: XML风格(标准格式) - xml_pattern = r'\s*(\{.*?\})\s*' + # 格式1: XML格式(标准格式)- 优先匹配 + # + # 参数值 + # + xml_pattern = r']*>(.*?)' for match in re.finditer(xml_pattern, response, re.DOTALL): + tool_name = match.group(1) + params_content = match.group(2) + + # 提取所有 value 标签 + param_pattern = r'([^<]*)' + params = {} + for param_match in re.finditer(param_pattern, params_content): + param_name = param_match.group(1) + param_value = param_match.group(2) + params[param_name] = param_value + + if tool_name: + tool_calls.append({ + "name": tool_name, + "parameters": params + }) + + if tool_calls: + return tool_calls + + # 格式2: JSON格式(旧格式兼容)- {"name": ...} + json_xml_pattern = r'\s*(\{.*?\})\s*' + for match in re.finditer(json_xml_pattern, response, re.DOTALL): try: call_data = json.loads(match.group(1)) - tool_calls.append(call_data) + if self._is_valid_tool_call(call_data): + tool_calls.append(call_data) except json.JSONDecodeError: pass if tool_calls: return tool_calls - # 格式2: 兜底 - LLM 直接输出裸 JSON(没包 标签) - # 只在格式1未匹配时尝试,避免误匹配正文中的 JSON + # 格式3: 兜底 - LLM 直接输出裸 JSON(没包 标签) stripped = response.strip() if stripped.startswith('{') and stripped.endswith('}'): try: @@ -1857,7 +1884,7 @@ def chat( if not tool_calls: # 没有工具调用,直接返回响应 - clean_response = re.sub(r'.*?', '', response, flags=re.DOTALL) + clean_response = re.sub(r']*>.*?', '', response, flags=re.DOTALL) clean_response = re.sub(r'\[TOOL_CALL\].*?\)', '', clean_response) return { @@ -1893,7 +1920,7 @@ def chat( ) # 清理响应 - clean_response = re.sub(r'.*?', '', final_response, flags=re.DOTALL) + clean_response = re.sub(r']*>.*?', '', final_response, flags=re.DOTALL) clean_response = re.sub(r'\[TOOL_CALL\].*?\)', '', clean_response) return { From 48ff6593aa688325fe346c802da7e9112c862d99 Mon Sep 17 00:00:00 2001 From: huamingjie Date: Wed, 11 Mar 2026 23:30:57 +0800 Subject: [PATCH 3/8] fix(report): add title style constraints to PLAN_SYSTEM_PROMPT - Add title requirements section with clear guidelines: - Titles should be concise and directly reflect user needs - Avoid academic or abstract expressions - Add concrete examples of good vs bad titles Co-Authored-By: Claude Opus 4.6 --- backend/app/services/report_agent.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/backend/app/services/report_agent.py b/backend/app/services/report_agent.py index d06c0e35..7b646eb4 100644 --- a/backend/app/services/report_agent.py +++ b/backend/app/services/report_agent.py @@ -573,6 +573,16 @@ def to_dict(self) -> Dict[str, Any]: - 内容要精炼,聚焦于核心预测发现 - 章节结构由你根据预测结果自主设计 +【标题要求】 +- 报告标题应简洁明了,直接体现用户的预测需求 +- 章节标题应具体、开门见山,直接概括该章节的核心预测发现 +- 避免使用过于学术化、抽象或晦涩的表达 +- 避免使用"深度分析"、"全面解读"等泛泛的套话 +- 示例: + ✅ "2025年新能源汽车销量预测报告" + ✅ "大学生就业趋势分析" + ❌ "新能源汽车行业发展趋势深度研究与战略分析" + 请输出JSON格式的报告大纲,格式如下: { "title": "报告标题", From dada9e64e618043010b053503bad4ad138c4574c Mon Sep 17 00:00:00 2001 From: huamingjie Date: Thu, 12 Mar 2026 09:05:55 +0800 Subject: [PATCH 4/8] Revert "fix(report): add title style constraints to PLAN_SYSTEM_PROMPT" This reverts commit 48ff6593aa688325fe346c802da7e9112c862d99. --- backend/app/services/report_agent.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/backend/app/services/report_agent.py b/backend/app/services/report_agent.py index 7b646eb4..d06c0e35 100644 --- a/backend/app/services/report_agent.py +++ b/backend/app/services/report_agent.py @@ -573,16 +573,6 @@ def to_dict(self) -> Dict[str, Any]: - 内容要精炼,聚焦于核心预测发现 - 章节结构由你根据预测结果自主设计 -【标题要求】 -- 报告标题应简洁明了,直接体现用户的预测需求 -- 章节标题应具体、开门见山,直接概括该章节的核心预测发现 -- 避免使用过于学术化、抽象或晦涩的表达 -- 避免使用"深度分析"、"全面解读"等泛泛的套话 -- 示例: - ✅ "2025年新能源汽车销量预测报告" - ✅ "大学生就业趋势分析" - ❌ "新能源汽车行业发展趋势深度研究与战略分析" - 请输出JSON格式的报告大纲,格式如下: { "title": "报告标题", From cdf4fcb49c74ff3c65c1adf763d4ec9259df26f7 Mon Sep 17 00:00:00 2001 From: huamingjie Date: Thu, 12 Mar 2026 09:35:56 +0800 Subject: [PATCH 5/8] feat(report): add regenerate report button and title style constraints MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add style and theme constraints to PLAN_SYSTEM_PROMPT to ensure report titles are specific and relevant to prediction conditions - Add "重新分析" (regenerate) button next to "进入深度互动" button in report page - Button only shows when report generation is complete (isComplete = true) - Reuses existing /api/report/generate endpoint with force_regenerate=true Co-Authored-By: Claude Opus 4.6 --- backend/app/services/report_agent.py | 17 ++++++ frontend/src/components/Step4Report.vue | 74 +++++++++++++++++++++---- 2 files changed, 80 insertions(+), 11 deletions(-) diff --git a/backend/app/services/report_agent.py b/backend/app/services/report_agent.py index d06c0e35..cc78585d 100644 --- a/backend/app/services/report_agent.py +++ b/backend/app/services/report_agent.py @@ -567,6 +567,23 @@ def to_dict(self) -> Dict[str, Any]: - ❌ 不是对现实世界现状的分析 - ❌ 不是泛泛而谈的舆情综述 +【风格与主题约束】(重要!必须遵守) +- 报告标题必须紧密围绕「模拟需求」展开,直接体现预测的具体条件和结果 +- 章节标题要简洁明了,让人一眼能看出该章节讨论什么预测内容 +- 避免使用过于晦涩、抽象、诗意化的表达(如"未来的迷雾""命运交响") +- 避免泛泛而谈,每个章节都应该能回答"在什么条件下,会发生什么" +- 报告是给人看的预测分析,不是文学作品 + +✅ 好的标题示例: + - 「AI助手普及后用户交互行为预测」 + - 「2025年新能源汽车市场渗透率预测」 + - 「远程办公趋势下城市房价走势分析」 + +❌ 差的标题示例: + - 「未来的迷雾」 + - 「智能时代的命运交响」 + - 「探索未知的边界」 + 【章节数量限制】 - 最少2个章节,最多5个章节 - 不需要子章节,每个章节直接撰写完整内容 diff --git a/frontend/src/components/Step4Report.vue b/frontend/src/components/Step4Report.vue index 22f2bdcf..d11e81d1 100644 --- a/frontend/src/components/Step4Report.vue +++ b/frontend/src/components/Step4Report.vue @@ -127,14 +127,24 @@ - - + +
+ + +
@@ -392,7 +402,7 @@