Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from dataclasses import replace
from typing import Any

from haystack import Document, component, default_from_dict, default_to_dict
Expand Down Expand Up @@ -195,10 +196,11 @@ def run(self, documents: list[Document]) -> dict[str, list[Document] | dict[str,
self.embedding_type,
)

new_documents = []
for doc, embeddings in zip(documents, all_embeddings, strict=True):
doc.embedding = embeddings
new_documents.append(replace(doc, embedding=embeddings))

return {"documents": documents, "meta": metadata}
return {"documents": new_documents, "meta": metadata}

@component.output_types(documents=list[Document], meta=dict[str, Any])
async def run_async(self, documents: list[Document]) -> dict[str, list[Document] | dict[str, Any]]:
Expand Down Expand Up @@ -228,7 +230,8 @@ async def run_async(self, documents: list[Document]) -> dict[str, list[Document]
embedding_type=self.embedding_type,
)

new_documents = []
for doc, embeddings in zip(documents, all_embeddings, strict=True):
doc.embedding = embeddings
new_documents.append(replace(doc, embedding=embeddings))

return {"documents": documents, "meta": metadata}
return {"documents": new_documents, "meta": metadata}
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from dataclasses import replace
from typing import Any

from haystack import Document, component, default_from_dict, default_to_dict, logging
Expand Down Expand Up @@ -162,6 +163,5 @@ def run(self, query: str, documents: list[Document], top_k: int | None = None) -
sorted_docs = []
for idx, score in zip(indices, scores, strict=True):
doc = documents[idx]
doc.score = score
sorted_docs.append(documents[idx])
sorted_docs.append(replace(doc, score=score))
return {"documents": sorted_docs}
45 changes: 45 additions & 0 deletions integrations/cohere/tests/test_document_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,51 @@ async def test_run_async(self, mock_get_response):
assert doc_with_embedding.meta == doc.meta
assert doc_with_embedding.embedding == embedding

@patch("haystack_integrations.components.embedders.cohere.document_embedder.get_response")
def test_run_does_not_modify_original_documents(self, mock_get_response):
embedder = CohereDocumentEmbedder(api_key=Secret.from_token("test-api-key"))

embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
mock_get_response.return_value = (embeddings, {"api_version": "1.0"})

docs = [
Document(content="I love cheese", meta={"topic": "Cuisine"}),
Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}),
]

result = embedder.run(docs)

# Check that the original documents are not modified
for doc in docs:
assert doc.embedding is None

# Check that the returned documents have embeddings
for doc_with_embedding, embedding in zip(result["documents"], embeddings, strict=True):
assert doc_with_embedding.embedding == embedding

@pytest.mark.asyncio
@patch("haystack_integrations.components.embedders.cohere.document_embedder.get_async_response")
async def test_run_async_does_not_modify_original_documents(self, mock_get_response):
embedder = CohereDocumentEmbedder(api_key=Secret.from_token("test-api-key"))

embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
mock_get_response.return_value = (embeddings, {"api_version": "1.0"})

docs = [
Document(content="I love cheese", meta={"topic": "Cuisine"}),
Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}),
]

result = await embedder.run_async(docs)

# Check that the original documents are not modified
for doc in docs:
assert doc.embedding is None

# Check that the returned documents have embeddings
for doc_with_embedding, embedding in zip(result["documents"], embeddings, strict=True):
assert doc_with_embedding.embedding == embedding

@pytest.mark.skipif(
not os.environ.get("COHERE_API_KEY", None) and not os.environ.get("CO_API_KEY", None),
reason="Export an env var called COHERE_API_KEY/CO_API_KEY containing the Cohere API key to run this test.",
Expand Down
21 changes: 21 additions & 0 deletions integrations/cohere/tests/test_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,27 @@ def test_run_topk_set_in_init(self, monkeypatch, mock_ranker_response): # noqa:
Document(id="efgh", content="doc2", score=0.95),
]

def test_run_does_not_modify_original_documents(self, monkeypatch, mock_ranker_response): # noqa: ARG002
monkeypatch.setenv("CO_API_KEY", "test-api-key")
ranker = CohereRanker(top_k=2)
query = "test"
documents = [
Document(id="abcd", content="doc1"),
Document(id="efgh", content="doc2"),
Document(id="ijkl", content="doc3"),
]

ranker_results = ranker.run(query, documents)

# Check that the original documents are not modified
for doc in documents:
assert doc.score is None

# Check that the returned documents have scores
reranked_docs = ranker_results["documents"]
for doc in reranked_docs:
assert doc.score is not None

@pytest.mark.skipif(
not os.environ.get("COHERE_API_KEY", None) and not os.environ.get("CO_API_KEY", None),
reason="Export an env var called COHERE_API_KEY/CO_API_KEY containing the Cohere API key to run this test.",
Expand Down