Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions graphrag_agent/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,25 @@ class AgentState(TypedDict):
# 编译图
self.graph = workflow.compile(checkpointer=self.memory)

@staticmethod
def _find_last_human_message(messages):
"""
从 messages 列表中查找最后一条 HumanMessage。

使用类型匹配代替硬编码索引 messages[-3],
避免在多工具调用或多轮对话场景下取到错误内容。

参数:
messages: 消息列表

返回:
HumanMessage 或 None
"""
for msg in reversed(messages):
if isinstance(msg, HumanMessage):
return msg
return None

async def _stream_process(self, inputs: Dict[str, Any], config: Dict[str, Any]) -> AsyncGenerator[str, None]:
"""
执行流式处理的默认实现
Expand Down
4 changes: 2 additions & 2 deletions graphrag_agent/agents/deep_research_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,9 @@ def _generate_node(self, state):
messages = state["messages"]

# 安全地获取问题和检索结果
human_msg = self._find_last_human_message(messages)
question = human_msg.content if human_msg else "未找到问题"
try:
# 原始问题在倒数第三个消息
question = messages[-3].content if len(messages) >= 3 else "未找到问题"
# 检索结果在最后一个消息
retrieval_result = messages[-1].content if messages[-1] else "未找到相关信息"
except Exception as e:
Expand Down
20 changes: 13 additions & 7 deletions graphrag_agent/agents/graph_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,14 @@ def _grade_documents(self, state) -> str:
return "reduce"

# 获取问题和文档内容
human_msg = self._find_last_human_message(messages)
if not human_msg:
print("文档评分出错: 未找到用户问题")
return "generate"
question = human_msg.content
try:
question = messages[-3].content
docs = messages[-1].content
except Exception as e:
# 如果出错,默认为 generate 模式
print(f"文档评分出错: {e}")
return "generate"

Expand All @@ -143,8 +146,8 @@ def _grade_documents(self, state) -> str:

# 从问题中提取关键词
keywords = []
if hasattr(messages[-3], 'additional_kwargs') and messages[-3].additional_kwargs:
kw_data = messages[-3].additional_kwargs.get("keywords", {})
if human_msg and hasattr(human_msg, 'additional_kwargs') and human_msg.additional_kwargs:
kw_data = human_msg.additional_kwargs.get("keywords", {})
if isinstance(kw_data, dict):
keywords = kw_data.get("low_level", []) + kw_data.get("high_level", [])

Expand All @@ -171,7 +174,8 @@ def _grade_documents(self, state) -> str:
def _generate_node(self, state):
"""生成回答节点逻辑"""
messages = state["messages"]
question = messages[-3].content
human_msg = self._find_last_human_message(messages)
question = human_msg.content if human_msg else "未找到问题"
docs = messages[-1].content

# 首先尝试全局缓存
Expand Down Expand Up @@ -221,7 +225,8 @@ def _generate_node(self, state):
def _reduce_node(self, state):
"""处理全局搜索的Reduce节点逻辑"""
messages = state["messages"]
question = messages[-3].content
human_msg = self._find_last_human_message(messages)
question = human_msg.content if human_msg else "未找到问题"
docs = messages[-1].content

# 检查缓存
Expand Down Expand Up @@ -259,7 +264,8 @@ async def _generate_node_stream(self, state):

# 安全获取问题和文档内容
try:
question = messages[-3].content if len(messages) >= 3 else "未找到问题"
human_msg = self._find_last_human_message(messages)
question = human_msg.content if human_msg else "未找到问题"
docs = messages[-1].content if messages[-1] else "未找到相关信息"
except Exception as e:
yield f"**获取问题或文档时出错**: {str(e)}"
Expand Down
12 changes: 4 additions & 8 deletions graphrag_agent/agents/hybrid_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,8 @@ def _generate_node(self, state):
messages = state["messages"]

# 安全地获取问题内容
try:
question = messages[-3].content if len(messages) >= 3 else "未找到问题"
except Exception:
question = "无法获取问题"
human_msg = self._find_last_human_message(messages)
question = human_msg.content if human_msg else "未找到问题"

# 安全地获取文档内容
try:
Expand Down Expand Up @@ -143,10 +141,8 @@ async def _generate_node_stream(self, state):
messages = state["messages"]

# 安全地获取问题内容
try:
question = messages[-3].content if len(messages) >= 3 else "未找到问题"
except Exception:
question = "无法获取问题"
human_msg = self._find_last_human_message(messages)
question = human_msg.content if human_msg else "未找到问题"

# 安全地获取文档内容
try:
Expand Down
6 changes: 2 additions & 4 deletions graphrag_agent/agents/naive_rag_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,8 @@ def _generate_node(self, state):
messages = state["messages"]

# 安全地获取问题和检索结果
try:
question = messages[-3].content if len(messages) >= 3 else "未找到问题"
except Exception:
question = "无法获取问题"
human_msg = self._find_last_human_message(messages)
question = human_msg.content if human_msg else "未找到问题"

try:
docs = messages[-1].content if messages[-1] else "未找到相关信息"
Expand Down