Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion apex/services/deep_research/deep_research_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@
class DeepResearchBase(LLMBase):
async def invoke(
self, messages: list[dict[str, str]], body: dict[str, Any] | None = None
) -> tuple[str, list[dict[str, str]]]:
) -> tuple[str, list[dict[str, str]], list[dict[str, Any]]]:
raise NotImplementedError
33 changes: 28 additions & 5 deletions apex/services/deep_research/deep_research_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,10 @@ def _create_research_chain(self) -> RunnableSerializable[dict[str, Any], str]:

async def invoke(
self, messages: list[dict[str, str]], body: dict[str, Any] | None = None
) -> tuple[str, list[dict[str, str]]]: # type: ignore[override]
) -> tuple[str, list[dict[str, str]], list[dict[str, Any]]]: # type: ignore[override]
# Clear tool history for each new invocation
self.tool_history = []
reasoning_traces: list[dict[str, Any]] = []
question = messages[-1]["content"]
documents: list[Document] = []
if body and "documents" in body and body["documents"]:
Expand All @@ -137,11 +138,11 @@ async def invoke(
documents.append(Document(page_content=str(website.content), metadata={"url": website.url}))

if not documents:
return "Could not find any information on the topic.", self.tool_history
return "Could not find any information on the topic.", self.tool_history, reasoning_traces

retriever = await self._create_vector_store(documents)
if not retriever:
return "Could not create a vector store from the documents.", self.tool_history
return "Could not create a vector store from the documents.", self.tool_history, reasoning_traces

compression_retriever = self._create_compression_retriever(retriever)

Expand All @@ -151,8 +152,22 @@ async def invoke(
compressed_docs: list[Document] = await compression_retriever.ainvoke(question)

summary: str = await summary_chain.ainvoke({"context": compressed_docs, "question": question})
reasoning_traces.append(
{
"step": "summary",
"model": getattr(self.summary_model, "model_name", "unknown"),
"output": summary,
}
)

research_report: str = await research_chain.ainvoke({"context": compressed_docs, "question": question})
reasoning_traces.append(
{
"step": "research",
"model": getattr(self.research_model, "model_name", "unknown"),
"output": research_report,
}
)

final_prompt = PromptTemplate(
input_variables=["summary", "research_report", "question"],
Expand All @@ -168,7 +183,14 @@ async def invoke(
final_answer: str = await final_chain.ainvoke(
{"summary": summary, "research_report": research_report, "question": question}
)
return final_answer, self.tool_history
reasoning_traces.append(
{
"step": "final",
"model": getattr(self.final_model, "model_name", "unknown"),
"output": final_answer,
}
)
return final_answer, self.tool_history, reasoning_traces


class _CustomEmbeddings(Embeddings): # type: ignore
Expand Down Expand Up @@ -207,9 +229,10 @@ async def aembed_query(self, text: str) -> list[float]:
# Run the invoke method.
async def main() -> None:
timer_start = time.perf_counter()
result, tool_history = await deep_researcher.invoke(dummy_messages, dummy_body)
result, tool_history, reasoning_traces = await deep_researcher.invoke(dummy_messages, dummy_body)
logger.debug("Answer:", result)
logger.debug("Tool History:", tool_history)
logger.debug("Reasoning Traces:", reasoning_traces)
timer_end = time.perf_counter()
logger.debug(f"Time elapsed: {timer_end - timer_start:.2f}s")

Expand Down
5 changes: 3 additions & 2 deletions apex/services/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def __init__(self, base_url: str, model: str, key: str):

async def invoke(
self, messages: list[dict[str, str]], body: dict[str, Any] | None = None
) -> tuple[str, list[dict[str, str]]]:
) -> tuple[str, list[dict[str, str]], list[dict[str, Any]]]:
headers = {
"Authorization": "Bearer " + self._key,
"Content-Type": "application/json",
Expand All @@ -35,7 +35,8 @@ async def invoke(

data = await response.json()
content = data.get("choices", [{}])[0].get("message", {}).get("content", "")
return str(content), []
# This base LLM does not build multi-step chains; return empty reasoning_traces
return str(content), [], []

def __str__(self) -> str:
return f"{self.__class__.__name__}({self._base_url}, {self._model})"
2 changes: 1 addition & 1 deletion apex/services/llm/llm_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@
class LLMBase:
async def invoke(
self, messages: list[dict[str, str]], body: dict[str, Any] | None = None
) -> tuple[str, list[dict[str, str]]]:
) -> tuple[str, list[dict[str, str]], list[dict[str, Any]]]:
raise NotImplementedError
2 changes: 1 addition & 1 deletion apex/validator/generate_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@ async def generate_query(llm: LLMBase, websearch: WebSearchBase) -> str:
search_website = random.choice(search_results)
search_content = search_website.content
query = QUERY_PROMPT_TEMPLATE.format(context=search_content)
query_response, _ = await llm.invoke([{"role": "user", "content": query}])
query_response, _, _ = await llm.invoke([{"role": "user", "content": query}])
logger.debug(f"Generated query.\nPrompt: '{query}'\nResponse: '{query_response}'")
return query_response
10 changes: 7 additions & 3 deletions apex/validator/generate_reference.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from typing import Any

from loguru import logger

from apex.services.deep_research.deep_research_base import DeepResearchBase


async def generate_reference(llm: DeepResearchBase, query: str) -> tuple[str, list[dict[str, str]]]:
async def generate_reference(
llm: DeepResearchBase, query: str
) -> tuple[str, list[dict[str, str]], list[dict[str, Any]]]:
"""Generate a reference response for the given prompt.

Args:
Expand All @@ -29,6 +33,6 @@ async def generate_reference(llm: DeepResearchBase, query: str) -> tuple[str, li
),
}

response, tool_history = await llm.invoke([system_message, user_message])
response, tool_history, reasoning_traces = await llm.invoke([system_message, user_message])
logger.debug(f"Generated reference.\nPrompt: '{user_message}'\nResponse: '{response}'")
return response, tool_history
return response, tool_history, reasoning_traces
2 changes: 2 additions & 0 deletions apex/validator/logger_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,15 @@ async def log(
reference: str | None = None,
discriminator_results: MinerDiscriminatorResults | None = None,
tool_history: list[dict[str, str]] | None = None,
reasoning_traces: list[dict[str, Any]] | None = None,
) -> None:
"""Log an event to wandb."""
if self.run:
if discriminator_results:
processed_event = self.process_event(discriminator_results.model_dump())
processed_event["reference"] = reference
processed_event["tool_history"] = tool_history
processed_event["reasoning_trace"] = reasoning_traces
self.run.log(processed_event)

def process_event(self, event: Mapping[str, Any]) -> dict[str, Any]:
Expand Down
14 changes: 11 additions & 3 deletions apex/validator/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,15 @@ async def run_single(self, task: QueryTask) -> str:

reference = None
tool_history: list[dict[str, str]] = []
reasoning_traces: list[dict[str, Any]] = []
if random.random() < self.reference_rate:
try:
generator_results = None
ground_truth = 0
logger.debug(f"Generating task reference for query: {query[:20]}..")
reference, tool_history = await generate_reference(llm=self.deep_research, query=query)
reference, tool_history, reasoning_traces = await generate_reference(
llm=self.deep_research, query=query
)
except BaseException as exc:
logger.exception(f"Failed to generate reference: {exc}")

Expand All @@ -100,7 +103,9 @@ async def run_single(self, task: QueryTask) -> str:
if random.random() < self.redundancy_rate:
try:
logger.debug(f"Generating redundant task reference for query: {query[:20]}..")
reference, tool_history = await generate_reference(llm=self.deep_research, query=query)
reference, tool_history, reasoning_traces = await generate_reference(
llm=self.deep_research, query=query
)
except BaseException as exc:
logger.warning(f"Failed to generate redundant reference: {exc}")

Expand All @@ -111,7 +116,10 @@ async def run_single(self, task: QueryTask) -> str:

if self.logger_wandb:
await self.logger_wandb.log(
reference=reference, discriminator_results=discriminator_results, tool_history=tool_history
reference=reference,
discriminator_results=discriminator_results,
tool_history=tool_history,
reasoning_traces=reasoning_traces,
)

if self._debug:
Expand Down
8 changes: 6 additions & 2 deletions tests/services/deep_research/test_deep_research_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,9 @@ async def test_invoke_no_documents_found(deep_research_langchain, mock_websearch

result = await deep_research_langchain.invoke(messages)

assert result == ("Could not find any information on the topic.", deep_research_langchain.tool_history)
assert result[0] == "Could not find any information on the topic."
assert result[1] == deep_research_langchain.tool_history
assert isinstance(result[2], list)


@pytest.mark.asyncio
Expand Down Expand Up @@ -208,4 +210,6 @@ async def test_full_invoke_flow(deep_research_langchain, mock_websearch):
final_chain.ainvoke.assert_called_once_with(
{"summary": summary, "research_report": research_report, "question": question}
)
assert result == (final_answer, deep_research_langchain.tool_history)
assert result[0] == final_answer
assert result[1] == deep_research_langchain.tool_history
assert isinstance(result[2], list)