Skip to content

Commit 1b58460

Browse files
hwchase17baskaryan
andauthored
update keys for chain (#5164)
Co-authored-by: Bagatur <baskaryan@gmail.com>
1 parent aca8cb5 commit 1b58460

File tree

2 files changed

+14
-0
lines changed

2 files changed

+14
-0
lines changed

libs/langchain/langchain/chains/combine_documents/stuff.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,13 @@ def get_default_document_variable_name(cls, values: Dict) -> Dict:
100100
)
101101
return values
102102

103+
@property
104+
def input_keys(self) -> List[str]:
105+
extra_keys = [
106+
k for k in self.llm_chain.input_keys if k != self.document_variable_name
107+
]
108+
return super().input_keys + extra_keys
109+
103110
def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict:
104111
"""Construct inputs from kwargs and docs.
105112

libs/langchain/tests/unit_tests/chains/test_combine_documents.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99
_collapse_docs,
1010
_split_list_of_docs,
1111
)
12+
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
1213
from langchain.docstore.document import Document
1314
from langchain.schema import format_document
15+
from tests.unit_tests.llms.fake_llm import FakeLLM
1416

1517

1618
def _fake_docs_len_func(docs: List[Document]) -> int:
@@ -21,6 +23,11 @@ def _fake_combine_docs_func(docs: List[Document], **kwargs: Any) -> str:
2123
return "".join([d.page_content for d in docs])
2224

2325

26+
def test_multiple_input_keys() -> None:
27+
chain = load_qa_with_sources_chain(FakeLLM(), chain_type="stuff")
28+
assert chain.input_keys == ["input_documents", "question"]
29+
30+
2431
def test__split_list_long_single_doc() -> None:
2532
"""Test splitting of a long single doc."""
2633
docs = [Document(page_content="foo" * 100)]

0 commit comments

Comments
 (0)