Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 44 additions & 38 deletions apex/services/deep_research/deep_research_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ async def invoke(
collected_sources: list[dict[str, str]] = []
seen_urls: set[str] = set()

agent_chain = self._build_agent_chain()
agent_chain = await self._build_agent_chain()

while step_index < max_iterations:
logger.debug(f"Starting deep researcher {step_index + 1}/{max_iterations} step")
Expand Down Expand Up @@ -202,17 +202,18 @@ async def invoke(

# Final answer branch
if "final_answer" in parsed:
logger.debug("Early-stopping deep research due to the final answer")
final_answer = str(parsed.get("final_answer", ""))
reasoning_traces.append(
{
"step": f"iteration-{step_index}",
"model": getattr(self.research_model, "model_name", "unknown"),
"thought": thought,
"final_answer": final_answer,
"observation": final_answer,
}
)
return final_answer, self.tool_history, reasoning_traces
logger.debug("Early-stopping deep research due to the final answer")
# return final_answer, self.tool_history, reasoning_traces
break

# Action branch (only websearch supported)
action = parsed.get("action") or {}
Expand Down Expand Up @@ -255,7 +256,7 @@ async def invoke(
"model": getattr(self.research_model, "model_name", "unknown"),
"thought": thought,
"action": {"tool": "websearch", "query": query, "max_results": max_results},
"observation": observation_text[:1000],
"observation": observation_text[:1200],
}
)
continue
Expand Down Expand Up @@ -289,7 +290,7 @@ async def invoke(
"model": getattr(self.research_model, "model_name", "unknown"),
"thought": thought,
"action": {"tool": "python_repl", "code": code[:1000]},
"observation": observation_text[:1000],
"observation": observation_text[:1200],
}
)
continue
Expand All @@ -305,19 +306,9 @@ async def invoke(
)
notes.append("Agent returned an unsupported action. Use the websearch tool or provide final_answer.")

# Fallback: if loop ends without final answer, ask final model to synthesize from notes
# If loop ends without final answer, ask final model to synthesize from notes.
logger.debug("Generating final answer")
final_prompt = PromptTemplate(
input_variables=["question", "notes", "sources"],
template=(
self._FINAL_ANSWER_INST + "Do NOT use JSON, or any other structured data format.\n"
"Question:\n{question}\n\n"
"Notes:\n{notes}\n\n"
"Sources:\n{sources}\n\n"
"Research Report:"
),
)
final_chain = final_prompt | self.final_model | StrOutputParser()
final_chain = await self._build_final_chain()

final_report: str = await self._try_invoke(
final_chain,
Expand All @@ -336,6 +327,19 @@ async def invoke(
)
return final_report, self.tool_history, reasoning_traces

async def _build_final_chain(self) -> RunnableSerializable[dict[str, Any], str]:
final_prompt = PromptTemplate(
input_variables=["question", "notes", "sources"],
template=(
self._FINAL_ANSWER_INST + "Do NOT use JSON, or any other structured data format. Provide \n"
"Question:\n{question}\n\n"
"Notes:\n{notes}\n\n"
"Sources:\n{sources}\n\n"
"Research Report:"
),
)
return final_prompt | self.final_model | StrOutputParser()

def _render_sources(self, collected_sources: list[dict[str, str]], max_items: int = 12) -> str:
if not collected_sources:
return "(none)"
Expand All @@ -352,47 +356,49 @@ def _render_notes(self, notes: list[str], max_items: int = 8) -> str:
clipped = notes[-max_items:]
return "\n".join(f"- {item}" for item in clipped)

def _build_agent_chain(self) -> RunnableSerializable[dict[str, Any], str]:
async def _build_agent_chain(self) -> RunnableSerializable[dict[str, Any], str]:
prompt = PromptTemplate(
input_variables=["question", "notes", "sources"],
template=(
"You are DeepResearcher, a meticulous, tool-using research agent.\n"
"You can use exactly these tools: websearch, python_repl.\n\n"
"Tool: websearch\n"
"- description: Search the web for relevant information.\n"
"- args: keys: 'query' (string), 'max_results' (integer <= 10)\n\n"
" - description: Search the web for relevant information.\n"
" - args: keys: 'query' (string), 'max_results' (integer <= 10)\n\n"
"Tool: python_repl\n"
"- description: A Python shell for executing Python commands.\n"
"- note: Print values to see output, e.g., `print(...)`.\n"
"- args: keys: 'code' (string: valid python command).\n\n"
" - description: A Python shell for executing Python commands.\n"
" - note: Print values to see output, e.g., `print(...)`.\n"
" - args: keys: 'code' (string: valid python command).\n\n"
"Follow an iterative think-act-observe loop. "
"Prefer rich internal reasoning over issuing many tool calls.\n"
"Spend time thinking: produce substantial, explicit reasoning in each 'thought'.\n"
"Avoid giving a final answer too early. Aim for at least 6 detailed thoughts before finalizing,\n"
"unless the question is truly trivial. "
"If no tool use is needed in a step, still provide a reflective 'thought'\n"
"that evaluates evidence, identifies gaps, and plans the next step.\n\n"
"Always respond in strict JSON. Use one of the two schemas:\n\n"
"1) Action step (JSON keys shown with dot-paths):\n"
"- thought: string\n"
"- action.tool: 'websearch' | 'python_repl'\n"
"- action.input: for websearch -> {{query: string, max_results: integer}}\n"
"- action.input: for python_repl -> {{code: string}}\n\n"
"2) Final answer step:\n"
"- thought: string\n"
"- final_answer: string (use plain text for final answer, not a JSON)\n\n"
"Always respond in strict JSON for deep research steps (do not use JSON for final answer). "
"Use one of the two schemas:\n\n"
"1. Action step (JSON keys shown with dot-paths):\n"
" - thought: string\n"
" - action.tool: 'websearch' | 'python_repl'\n"
" - action.input: for websearch -> {{query: string, max_results: integer}}\n"
" - action.input: for python_repl -> {{code: string}}\n\n"
"2. Final answer step:\n"
" - thought: string\n"
" - final_answer: string\n\n"
"In every step, make 'thought' a detailed paragraph (120-200 words) that:\n"
"- Summarizes what is known and unknown so far\n"
"- Justifies the chosen next action or decision not to act\n"
"- Evaluates evidence quality and cites source numbers when applicable\n"
"- Identifies risks, uncertainties, and alternative hypotheses\n\n"
" - Summarizes what is known and unknown so far\n"
" - Justifies the chosen next action or decision not to act\n"
" - Evaluates evidence quality and cites source numbers when applicable\n"
" - Identifies risks, uncertainties, and alternative hypotheses\n\n"
"Respond with JSON only during deep research steps, "
"final answer must be always in a plain text formatted as a research report, with sections:\n"
"Executive Summary, Key Findings, Evidence, Limitations, Conclusion.\n"
"Use inline numeric citations like [1], [2] that refer to Sources.\n"
"Include a final section titled 'Sources' listing the numbered citations.\n\n"
"Question:\n{question}\n\n"
"Notes and observations so far:\n{notes}\n\n"
"Sources (use these for citations):\n{sources}\n\n"
"Respond with JSON always, except for final_anwer (use plain text)."
),
)
return prompt | self.research_model | StrOutputParser()
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "apex"
version = "3.0.5"
version = "3.0.6"
description = "Bittensor Subnet 1: Apex"
readme = "README.md"
requires-python = "~=3.11"
Expand Down
81 changes: 62 additions & 19 deletions tests/services/deep_research/test_deep_research_langchain.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
Expand Down Expand Up @@ -76,17 +77,26 @@ async def test_invoke_with_documents_in_body(deep_research_langchain, mock_webse
body = {"documents": [{"page_content": "doc1"}, {"page_content": "doc2"}]}

with (
patch("apex.services.deep_research.deep_research_langchain.PromptTemplate") as mock_prompt_template,
patch("apex.services.deep_research.deep_research_langchain.StrOutputParser"),
patch(
"apex.services.deep_research.deep_research_langchain.DeepResearchLangchain._build_agent_chain"
) as mock_build_agent_chain,
patch(
"apex.services.deep_research.deep_research_langchain.DeepResearchLangchain._build_final_chain"
) as mock_build_final_chain,
):
agent_chain = AsyncMock()
agent_chain.ainvoke.return_value = '{"thought": "enough info", "final_answer": "final_report"}'
mock_prompt_template.return_value.__or__.return_value.__or__.return_value = agent_chain
return_value = json.dumps({"thought": "enough info", "final_answer": "final_report"})
agent_chain.ainvoke.return_value = return_value
mock_build_agent_chain.return_value = agent_chain

final_chain_mock = AsyncMock()
final_chain_mock.ainvoke.return_value = return_value
mock_build_final_chain.return_value = final_chain_mock

result = await deep_research_langchain.invoke(messages, body)

mock_websearch.search.assert_not_called()
assert result[0] == "final_report"
assert result[0] == return_value


@pytest.mark.asyncio
Expand All @@ -96,8 +106,12 @@ async def test_invoke_with_websearch(deep_research_langchain, mock_websearch):
mock_websearch.search.return_value = [MagicMock(content="web_doc", url="http://a.com", title="A")]

with (
patch("apex.services.deep_research.deep_research_langchain.PromptTemplate") as mock_prompt_template,
patch("apex.services.deep_research.deep_research_langchain.StrOutputParser"),
patch(
"apex.services.deep_research.deep_research_langchain.DeepResearchLangchain._build_agent_chain"
) as mock_build_agent_chain,
patch(
"apex.services.deep_research.deep_research_langchain.DeepResearchLangchain._build_final_chain"
) as mock_build_final_chain,
):
agent_chain = AsyncMock()
agent_chain.ainvoke.side_effect = [
Expand All @@ -107,7 +121,11 @@ async def test_invoke_with_websearch(deep_research_langchain, mock_websearch):
),
'{"thought": "done", "final_answer": "final_answer"}',
]
mock_prompt_template.return_value.__or__.return_value.__or__.return_value = agent_chain
mock_build_agent_chain.return_value = agent_chain

final_chain_mock = AsyncMock()
final_chain_mock.ainvoke.return_value = "final_answer"
mock_build_final_chain.return_value = final_chain_mock

result = await deep_research_langchain.invoke(messages)

Expand All @@ -121,17 +139,26 @@ async def test_invoke_no_websearch_needed_final_answer(deep_research_langchain,
messages = [{"role": "user", "content": "test question"}]

with (
patch("apex.services.deep_research.deep_research_langchain.PromptTemplate") as mock_prompt_template,
patch("apex.services.deep_research.deep_research_langchain.StrOutputParser"),
patch(
"apex.services.deep_research.deep_research_langchain.DeepResearchLangchain._build_agent_chain"
) as mock_build_agent_chain,
patch(
"apex.services.deep_research.deep_research_langchain.DeepResearchLangchain._build_final_chain"
) as mock_build_final_chain,
):
agent_chain = AsyncMock()
agent_chain.ainvoke.return_value = '{"thought": "clear", "final_answer": "final_report"}'
mock_prompt_template.return_value.__or__.return_value.__or__.return_value = agent_chain
return_value = json.dumps({"thought": "enough info", "final_answer": "final_report"})
agent_chain.ainvoke.return_value = return_value
mock_build_agent_chain.return_value = agent_chain

final_chain_mock = AsyncMock()
final_chain_mock.ainvoke.return_value = return_value
mock_build_final_chain.return_value = final_chain_mock

result = await deep_research_langchain.invoke(messages)

mock_websearch.search.assert_not_called()
assert result[0] == "final_report"
assert result[0] == return_value


@pytest.mark.asyncio
Expand All @@ -149,8 +176,12 @@ async def test_full_invoke_flow_with_multiple_actions(deep_research_langchain, m
]

with (
patch("apex.services.deep_research.deep_research_langchain.PromptTemplate") as mock_prompt_template,
patch("apex.services.deep_research.deep_research_langchain.StrOutputParser"),
patch(
"apex.services.deep_research.deep_research_langchain.DeepResearchLangchain._build_agent_chain"
) as mock_build_agent_chain,
patch(
"apex.services.deep_research.deep_research_langchain.DeepResearchLangchain._build_final_chain"
) as mock_build_final_chain,
):
agent_chain = AsyncMock()
agent_chain.ainvoke.side_effect = [
Expand All @@ -164,7 +195,11 @@ async def test_full_invoke_flow_with_multiple_actions(deep_research_langchain, m
),
'{"thought": "complete", "final_answer": "final_report"}',
]
mock_prompt_template.return_value.__or__.return_value.__or__.return_value = agent_chain
mock_build_agent_chain.return_value = agent_chain

final_chain_mock = AsyncMock()
final_chain_mock.ainvoke.return_value = "final_report"
mock_build_final_chain.return_value = final_chain_mock

result = await deep_research_langchain.invoke(messages)

Expand All @@ -186,15 +221,23 @@ async def test_full_invoke_flow_with_multiple_actions(deep_research_langchain, m
async def test_invoke_with_python_repl(deep_research_langchain):
"""Agent chooses python_repl then produces final answer."""
with (
patch("apex.services.deep_research.deep_research_langchain.PromptTemplate") as mock_prompt_template,
patch("apex.services.deep_research.deep_research_langchain.StrOutputParser"),
patch(
"apex.services.deep_research.deep_research_langchain.DeepResearchLangchain._build_agent_chain"
) as mock_build_agent_chain,
patch(
"apex.services.deep_research.deep_research_langchain.DeepResearchLangchain._build_final_chain"
) as mock_build_final_chain,
):
agent_chain = AsyncMock()
agent_chain.ainvoke.side_effect = [
('{"thought": "compute needed", "action": {"tool": "python_repl", "input": {"code": "print(1+1)"}}}'),
'{"thought": "done", "final_answer": "final_answer"}',
]
mock_prompt_template.return_value.__or__.return_value.__or__.return_value = agent_chain
mock_build_agent_chain.return_value = agent_chain

final_chain_mock = AsyncMock()
final_chain_mock.ainvoke.return_value = "final_answer"
mock_build_final_chain.return_value = final_chain_mock

result = await deep_research_langchain.invoke([{"role": "user", "content": "q"}])

Expand Down