1- from functools import partial
21from typing import Dict , List
32
43from langchain_core .output_parsers import StrOutputParser
4+ from langchain_core .runnables import chain as as_runnable
55from langchain_core .vectorstores import VectorStoreRetriever
66from langchain_openai import ChatOpenAI
77
@@ -118,13 +118,13 @@ async def retrieve(inputs: Dict, retriever: VectorStoreRetriever) -> Dict:
118118 Dict: Original input dictionary augmented with a formatted string of retrieved documents.
119119 """
120120 docs = await retriever .ainvoke (inputs ["topic" ] + ": " + inputs ["section" ])
121- formatted = " \n " . join (
122- [
123- f'<Document href=" { doc . metadata [ "source" ] } "/> \n { doc . page_content } \n </Document>'
124- for doc in docs
125- ],
126- )
127- return {"docs" : formatted , ** inputs }
121+ references = {}
122+ formatted_docs = ""
123+ for doc in docs :
124+ formatted_docs += f'<Document href=" { doc . metadata [ "source" ] } "/> \n { doc . page_content } \n </Document>'
125+ references . update ({ doc . metadata [ "source" ]: doc . page_content })
126+
127+ return {"docs" : formatted_docs , "references" : references , ** inputs }
128128
129129
130130async def section_writer (
@@ -147,12 +147,17 @@ async def section_writer(
147147 Returns:
148148 List[ArticleSection]: A list of generated article sections.
149149 """
150- section_writer = (
151- partial (retrieve , retriever = retriever )
152- | SECTION_WRITER_PROMPT
153- | long_context_llm .with_structured_output (ArticleSection )
154- )
155- sections = await section_writer .abatch (
150+
151+ @as_runnable
152+ async def section_writer (inputs : Dict ) -> Dict :
153+ retrieved_data = await retrieve (inputs , retriever )
154+ section = await (
155+ SECTION_WRITER_PROMPT
156+ | long_context_llm .with_structured_output (ArticleSection )
157+ ).ainvoke ({** retrieved_data })
158+ return {"section" : section , "references" : retrieved_data ["references" ]}
159+
160+ output = await section_writer .abatch (
156161 [
157162 {
158163 "outline" : outline .as_str ,
@@ -162,7 +167,13 @@ async def section_writer(
162167 for section in sections
163168 ],
164169 )
165- return sections
170+
171+ sections , references = [], {}
172+ for out in output :
173+ sections .append (out ["section" ])
174+ references .update (out ["references" ])
175+
176+ return sections , references
166177
167178
168179async def writer (topic : str , draft : str , long_context_llm : ChatOpenAI ) -> str :
0 commit comments