Skip to content

Commit 2ce92a7

Browse files
Auto-merge PR #181 (refactor/storm) into integration for testing
2 parents f408289 + 9db29d2 commit 2ce92a7

File tree

3 files changed

+33
-16
lines changed

3 files changed

+33
-16
lines changed

akd/agents/storm/nodes.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,13 +146,14 @@ async def write_sections(
146146
Dict: Updated research state.
147147
"""
148148
outline = state["outline"]
149-
sections = await section_writer(
149+
sections, references = await section_writer(
150150
outline=outline,
151151
sections=outline.sections,
152152
topic=state["topic"],
153153
long_context_llm=long_context_llm,
154154
retriever=retriever,
155155
)
156+
state["references"] = references
156157
return {
157158
**state,
158159
"sections": sections,

akd/agents/storm/storm.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ class StormOutputSchema(OutputSchema):
6767
...,
6868
description="Outline of the article",
6969
)
70+
sections: list = Field(
71+
...,
72+
description="List of sections of the article",
73+
)
7074

7175

7276
class StormAgentConfig(BaseAgentConfig):
@@ -232,6 +236,7 @@ async def get_response_async(
232236
references=article_state["references"],
233237
search_results=article_state["search_results"],
234238
outline=article_state["outline"],
239+
sections=article_state["sections"],
235240
)
236241

237242
async def _arun(self, params: StormInputSchema, **kwargs) -> StormOutputSchema:

akd/agents/storm/tools.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from functools import partial
21
from typing import Dict, List
32

43
from langchain_core.output_parsers import StrOutputParser
4+
from langchain_core.runnables import chain as as_runnable
55
from langchain_core.vectorstores import VectorStoreRetriever
66
from langchain_openai import ChatOpenAI
77

@@ -118,13 +118,13 @@ async def retrieve(inputs: Dict, retriever: VectorStoreRetriever) -> Dict:
118118
Dict: Original input dictionary augmented with a formatted string of retrieved documents.
119119
"""
120120
docs = await retriever.ainvoke(inputs["topic"] + ": " + inputs["section"])
121-
formatted = "\n".join(
122-
[
123-
f'<Document href="{doc.metadata["source"]}"/>\n{doc.page_content}\n</Document>'
124-
for doc in docs
125-
],
126-
)
127-
return {"docs": formatted, **inputs}
121+
references = {}
122+
formatted_docs = ""
123+
for doc in docs:
124+
formatted_docs += f'<Document href="{doc.metadata["source"]}"/>\n{doc.page_content}\n</Document>'
125+
references.update({doc.metadata["source"]: doc.page_content})
126+
127+
return {"docs": formatted_docs, "references": references, **inputs}
128128

129129

130130
async def section_writer(
@@ -147,12 +147,17 @@ async def section_writer(
147147
Returns:
148148
List[ArticleSection]: A list of generated article sections.
149149
"""
150-
section_writer = (
151-
partial(retrieve, retriever=retriever)
152-
| SECTION_WRITER_PROMPT
153-
| long_context_llm.with_structured_output(ArticleSection)
154-
)
155-
sections = await section_writer.abatch(
150+
151+
@as_runnable
152+
async def section_writer(inputs: Dict) -> Dict:
153+
retrieved_data = await retrieve(inputs, retriever)
154+
section = await (
155+
SECTION_WRITER_PROMPT
156+
| long_context_llm.with_structured_output(ArticleSection)
157+
).ainvoke({**retrieved_data})
158+
return {"section": section, "references": retrieved_data["references"]}
159+
160+
output = await section_writer.abatch(
156161
[
157162
{
158163
"outline": outline.as_str,
@@ -162,7 +167,13 @@ async def section_writer(
162167
for section in sections
163168
],
164169
)
165-
return sections
170+
171+
sections, references = [], {}
172+
for out in output:
173+
sections.append(out["section"])
174+
references.update(out["references"])
175+
176+
return sections, references
166177

167178

168179
async def writer(topic: str, draft: str, long_context_llm: ChatOpenAI) -> str:

0 commit comments

Comments
 (0)