Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def get_top_k_mmr_embeddings(
A mmr_threshold of 1 will check similarity the query and ignore previous results.

"""
threshold = mmr_threshold or 0.5
threshold = 0.5 if mmr_threshold is None else mmr_threshold
similarity_fn = similarity_fn or default_similarity_fn

if embedding_ids is None or embedding_ids == []:
Expand Down
12 changes: 12 additions & 0 deletions llama-index-core/tests/indices/query/test_embedding_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,15 @@ def test_get_top_k_mmr_embeddings() -> None:
result_similarities_no_mmr, result_similarities
):
assert np.isclose(result_no_mmr, result_with_mmr, atol=0.00001)


def test_get_top_k_mmr_embeddings_respects_zero_threshold() -> None:
"""Test that an explicit zero threshold does not use the MMR default."""
query_embedding = [1.0, 0.0]
embeddings = [[1.0, 0.0], [0.9, 0.0], [0.0, 1.0]]

_, result_ids = get_top_k_mmr_embeddings(
query_embedding, embeddings, mmr_threshold=0
)

assert result_ids == [0, 2, 1]