Skip to content

Commit 797c20c

Browse files
authored
feat: Adding filters param to MostSimilarDocumentsPipeline run and run_batch (#3301)
* Adding filters param to MostSimilarDocumentsPipeline run and run_batch * Adding index param to MostSimilarDocumentsPipeline run and run_batch * Adding index param documentation to MostSimilarDocumentsPipeline run and run_batch * Updated index param documentation to MostSimilarDocumentsPipeline run and run_batch. Updated type: ignore in run_batch * Adding filters param to MostSimilarDocumentsPipeline run and run_batch * Adding index param to MostSimilarDocumentsPipeline run and run_batch * Adding index param documentation to MostSimilarDocumentsPipeline run and run_batch * Updated index param documentation to MostSimilarDocumentsPipeline run and run_batch. Updated type: ignore in run_batch
1 parent b84a6b1 commit 797c20c

File tree

2 files changed

+85
-5
lines changed

2 files changed

+85
-5
lines changed

haystack/pipelines/standard_pipelines.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -717,27 +717,43 @@ def __init__(self, document_store: BaseDocumentStore):
717717
self.pipeline.add_node(component=document_store, name="DocumentStore", inputs=["Query"])
718718
self.document_store = document_store
719719

720-
def run(self, document_ids: List[str], top_k: int = 5):
720+
def run(
721+
self,
722+
document_ids: List[str],
723+
filters: Optional[Dict[str, Union[Dict, List, str, int, float, bool]]] = None,
724+
top_k: int = 5,
725+
index: Optional[str] = None,
726+
):
721727
"""
722728
:param document_ids: document ids
729+
:param filters: Optional filters to narrow down the search space to documents whose metadata fulfill certain conditions
723730
:param top_k: How many documents id to return against single document
731+
:param index: Optionally specify the name of index to query the document from. If None, the DocumentStore's default index (self.index) will be used.
724732
"""
725733
similar_documents: list = []
726734
self.document_store.return_embedding = True # type: ignore
727735

728-
for document in self.document_store.get_documents_by_id(ids=document_ids):
736+
for document in self.document_store.get_documents_by_id(ids=document_ids, index=index):
729737
similar_documents.append(
730738
self.document_store.query_by_embedding(
731-
query_emb=document.embedding, return_embedding=False, top_k=top_k
739+
query_emb=document.embedding, filters=filters, return_embedding=False, top_k=top_k, index=index
732740
)
733741
)
734742

735743
self.document_store.return_embedding = False # type: ignore
736744
return similar_documents
737745

738-
def run_batch(self, document_ids: List[str], top_k: int = 5): # type: ignore
746+
def run_batch( # type: ignore
747+
self,
748+
document_ids: List[str],
749+
filters: Optional[Dict[str, Union[Dict, List, str, int, float, bool]]] = None,
750+
top_k: int = 5,
751+
index: Optional[str] = None,
752+
):
739753
"""
740754
:param document_ids: document ids
755+
:param filters: Optional filters to narrow down the search space to documents whose metadata fulfill certain conditions
741756
:param top_k: How many documents id to return against single document
757+
:param index: Optionally specify the name of index to query the document from. If None, the DocumentStore's default index (self.index) will be used.
742758
"""
743-
return self.run(document_ids=document_ids, top_k=top_k)
759+
return self.run(document_ids=document_ids, filters=filters, top_k=top_k, index=index)

test/pipelines/test_standard_pipelines.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,39 @@ def test_most_similar_documents_pipeline(retriever, document_store):
200200
assert isinstance(document.content, str)
201201

202202

203+
@pytest.mark.parametrize(
204+
"retriever,document_store", [("embedding", "milvus1"), ("embedding", "elasticsearch")], indirect=True
205+
)
206+
def test_most_similar_documents_pipeline_with_filters(retriever, document_store):
207+
documents = [
208+
{"id": "a", "content": "Sample text for document-1", "meta": {"source": "wiki1"}},
209+
{"id": "b", "content": "Sample text for document-2", "meta": {"source": "wiki2"}},
210+
{"content": "Sample text for document-3", "meta": {"source": "wiki3"}},
211+
{"content": "Sample text for document-4", "meta": {"source": "wiki4"}},
212+
{"content": "Sample text for document-5", "meta": {"source": "wiki5"}},
213+
]
214+
215+
document_store.write_documents(documents)
216+
document_store.update_embeddings(retriever)
217+
218+
docs_id: list = ["a", "b"]
219+
filters = {"source": ["wiki3", "wiki4", "wiki5"]}
220+
pipeline = MostSimilarDocumentsPipeline(document_store=document_store)
221+
list_of_documents = pipeline.run(document_ids=docs_id, filters=filters)
222+
223+
assert len(list_of_documents[0]) > 1
224+
assert isinstance(list_of_documents, list)
225+
assert len(list_of_documents) == len(docs_id)
226+
227+
for another_list in list_of_documents:
228+
assert isinstance(another_list, list)
229+
for document in another_list:
230+
assert isinstance(document, Document)
231+
assert isinstance(document.id, str)
232+
assert isinstance(document.content, str)
233+
assert document.meta["source"] in ["wiki3", "wiki4", "wiki5"]
234+
235+
203236
@pytest.mark.parametrize("retriever,document_store", [("embedding", "memory")], indirect=True)
204237
def test_most_similar_documents_pipeline_batch(retriever, document_store):
205238
documents = [
@@ -229,6 +262,37 @@ def test_most_similar_documents_pipeline_batch(retriever, document_store):
229262
assert isinstance(document.content, str)
230263

231264

265+
@pytest.mark.parametrize("retriever,document_store", [("embedding", "memory")], indirect=True)
266+
def test_most_similar_documents_pipeline_with_filters_batch(retriever, document_store):
267+
documents = [
268+
{"id": "a", "content": "Sample text for document-1", "meta": {"source": "wiki1"}},
269+
{"id": "b", "content": "Sample text for document-2", "meta": {"source": "wiki2"}},
270+
{"content": "Sample text for document-3", "meta": {"source": "wiki3"}},
271+
{"content": "Sample text for document-4", "meta": {"source": "wiki4"}},
272+
{"content": "Sample text for document-5", "meta": {"source": "wiki5"}},
273+
]
274+
275+
document_store.write_documents(documents)
276+
document_store.update_embeddings(retriever)
277+
278+
docs_id: list = ["a", "b"]
279+
filters = {"source": ["wiki3", "wiki4", "wiki5"]}
280+
pipeline = MostSimilarDocumentsPipeline(document_store=document_store)
281+
list_of_documents = pipeline.run_batch(document_ids=docs_id, filters=filters)
282+
283+
assert len(list_of_documents[0]) > 1
284+
assert isinstance(list_of_documents, list)
285+
assert len(list_of_documents) == len(docs_id)
286+
287+
for another_list in list_of_documents:
288+
assert isinstance(another_list, list)
289+
for document in another_list:
290+
assert isinstance(document, Document)
291+
assert isinstance(document.id, str)
292+
assert isinstance(document.content, str)
293+
assert document.meta["source"] in ["wiki3", "wiki4", "wiki5"]
294+
295+
232296
@pytest.mark.integration
233297
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
234298
def test_most_similar_documents_pipeline_save(tmpdir, document_store_with_docs):

0 commit comments

Comments
 (0)