Skip to content

Commit 137685a

Browse files
committed
Allow empty strings for reranking query and improve test cases
1 parent 9ec1461 commit 137685a

4 files changed

Lines changed: 59 additions & 38 deletions

File tree

examples/sap_hanavector.ipynb

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@
164164
},
165165
{
166166
"cell_type": "code",
167-
"execution_count": null,
167+
"execution_count": 4,
168168
"metadata": {},
169169
"outputs": [
170170
{
@@ -944,20 +944,26 @@
944944
"name": "stdout",
945945
"output_type": "stream",
946946
"text": [
947-
"Reranked documents:\n",
947+
"Compressed documents:\n",
948+
"--------------------------------------------------------------------------------\n",
949+
"Content: Deep learning applications\n",
950+
"Relevance score: 0.2188\n",
948951
"--------------------------------------------------------------------------------\n",
949952
"Content: Advanced machine learning techniques\n",
950-
"Relevance score: 0.0748353898525238\n",
953+
"Relevance score: 0.0748\n",
951954
"--------------------------------------------------------------------------------\n",
952955
"Content: Introduction to neural networks\n",
953-
"Relevance score: 0.007919907569885254\n",
956+
"Relevance score: 0.0079\n",
957+
"--------------------------------------------------------------------------------\n",
958+
"Content: Natural language processing techniques\n",
959+
"Relevance score: 0.0040\n",
954960
"--------------------------------------------------------------------------------\n",
955-
"Content: Python programming basics\n",
956-
"Relevance score: 0.002642408013343811\n",
961+
"Content: Reinforcement learning strategies\n",
962+
"Relevance score: 0.0036\n",
957963
"\n",
958964
"Top 2 reranked results:\n",
959965
" [1] Score: 0.4351 - Advanced machine learning techniques\n",
960-
" [2] Score: 0.0262 - Introduction to neural networks\n"
966+
" [3] Score: 0.0342 - Deep learning applications\n"
961967
]
962968
}
963969
],
@@ -975,17 +981,20 @@
975981
" Document(page_content=\"Python programming basics\"),\n",
976982
" Document(page_content=\"Advanced machine learning techniques\"),\n",
977983
" Document(page_content=\"Introduction to neural networks\"),\n",
984+
" Document(page_content=\"Deep learning applications\"),\n",
985+
" Document(page_content=\"Reinforcement learning strategies\"),\n",
986+
" Document(page_content=\"Natural language processing techniques\"),\n",
978987
"]\n",
979988
"\n",
980-
"# Rerank documents based on a query using compress_documents (returns top 5)\n",
989+
"# Rerank documents based on a query using compress_documents (returns top min(5, len(documents)))\n",
981990
"# Reranking scores will be added to the metadata of each document under the key \"relevance_score\"\n",
982991
"reranked_docs = reranker.compress_documents(documents, query=\"AI and deep learning\")\n",
983992
"\n",
984-
"print(\"Reranked documents:\")\n",
993+
"print(\"Compressed documents:\")\n",
985994
"for doc in reranked_docs:\n",
986995
" print(\"-\" * 80)\n",
987996
" print(f\"Content: {doc.page_content}\")\n",
988-
" print(f\"Relevance score: {doc.metadata.get('relevance_score')}\")\n",
997+
" print(f\"Relevance score: {doc.metadata.get('relevance_score'):.4f}\")\n",
989998
"\n",
990999
"# Or use the rerank method for more control over top_n\n",
9911000
"results = reranker.rerank(documents, query=\"machine learning\", top_n=2)\n",

langchain_hana/vectorstores/hana_db.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -840,8 +840,8 @@ def from_texts( # type: ignore[no-untyped-def, override]
840840
def _validate_rerank_config(rerank_config: dict) -> None:
841841
if not isinstance(rerank_config, dict):
842842
raise ValueError("rerank_config must be a dictionary")
843-
if "query" not in rerank_config or not isinstance(rerank_config["query"], str) or not rerank_config["query"]:
844-
raise ValueError("rerank_config must contain 'query' and it must be a non-empty string")
843+
if "query" not in rerank_config or not isinstance(rerank_config["query"], str):
844+
raise ValueError("rerank_config must contain 'query' and it must be a string")
845845
if "top_n" in rerank_config and (not isinstance(rerank_config["top_n"], int) or rerank_config["top_n"] <= 0):
846846
raise ValueError("rerank_config 'top_n' must be a positive integer")
847847
if "rank_fields" in rerank_config:
@@ -924,7 +924,7 @@ def similarity_search_with_score(
924924
else:
925925
if rerank_config:
926926
rerank_config_copy = {**rerank_config}
927-
if not rerank_config.get("query"):
927+
if rerank_config.get("query") is None:
928928
rerank_config_copy["query"] = query # Use the original query if no specific rerank query is provided
929929
else:
930930
rerank_config_copy = None
@@ -1007,7 +1007,7 @@ def similarity_search_with_score_and_vector_by_query(
10071007

10081008
if rerank_config:
10091009
rerank_config_copy = {**rerank_config}
1010-
if not rerank_config.get("query"):
1010+
if rerank_config.get("query") is None:
10111011
rerank_config_copy["query"] = query # Use the original query if no specific rerank query is provided
10121012
else:
10131013
rerank_config_copy = None

tests/integration_tests/test_hana_db.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -129,16 +129,22 @@ def build_rerank_config(base_config: dict[str, Any] | None, top_n: int, query: s
129129

130130

131131
@pytest.fixture(params=[
132-
({"query": 5}, "rerank_config must contain 'query' and it must be a non-empty string"),
132+
({"query": 5}, "rerank_config must contain 'query' and it must be a string"),
133133
({"top_n": "not_an_int"}, "rerank_config 'top_n' must be a positive integer"),
134134
({"model_id": 5}, "rerank_config 'model_id' must be a non-empty string"),
135+
({"model_id": ""}, "rerank_config 'model_id' must be a non-empty string"),
135136
({"rank_fields": "not_a_list"}, "rerank_config 'rank_fields' must be a list of strings"),
136137
({"rank_fields": [1, 2, 3]}, "rerank_config 'rank_fields' must be a list of strings")
137-
], ids=["query_not_str", "top_n_not_int", "model_id_not_str", "rank_fields_not_list", "rank_fields_not_str"])
138+
], ids=["query_not_str", "top_n_not_int", "model_id_not_str", "model_id_empty_str", "rank_fields_not_list", "rank_fields_not_str"])
138139
def invalid_rerank_config_with_error_message(request):
139140
return request.param
140141

141142

143+
@pytest.fixture
144+
def invalid_rerank_config_non_existent_model_id():
145+
return {"query": "test_query", "model_id": "non_existing_model"}
146+
147+
142148
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
143149
def test_hanavector_non_existing_table(table_name_with_cleanup) -> None:
144150
"""Test end to end construction and search."""
@@ -365,9 +371,9 @@ def test_hanavector_similarity_search_simple_invalid_rerank_config(vectorDB, inv
365371

366372

367373
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
368-
def test_hanavector_similarity_search_simple_invalid_rerank_model_id(vectorDB) -> None:
374+
def test_hanavector_similarity_search_simple_invalid_rerank_model_id(vectorDB, invalid_rerank_config_non_existent_model_id) -> None:
369375
with pytest.raises(dbapi.Error):
370-
vectorDB.similarity_search(HanaTestConstants.TEXTS[0], 1, rerank_config={"model_id": "non_existing_model"})
376+
vectorDB.similarity_search(HanaTestConstants.TEXTS[0], 1, rerank_config=invalid_rerank_config_non_existent_model_id)
371377

372378

373379
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
@@ -408,9 +414,9 @@ def test_hanavector_similarity_search_by_vector_simple_invalid_rerank_config(vec
408414

409415

410416
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
411-
def test_hanavector_similarity_search_by_vector_simple_invalid_rerank_model_id(vectorDB) -> None:
417+
def test_hanavector_similarity_search_by_vector_simple_invalid_rerank_model_id(vectorDB, invalid_rerank_config_non_existent_model_id) -> None:
412418
with pytest.raises(dbapi.Error):
413-
vectorDB.similarity_search_by_vector(embedding.embed_query(HanaTestConstants.TEXTS[0]), 1, rerank_config={"query": HanaTestConstants.TEXTS[0], "model_id": "non_existing_model"})
419+
vectorDB.similarity_search_by_vector(embedding.embed_query(HanaTestConstants.TEXTS[0]), 1, rerank_config=invalid_rerank_config_non_existent_model_id)
414420

415421

416422
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
@@ -456,7 +462,7 @@ def test_hanavector_similarity_search_simple_euclidean_distance_invalid_rerank_c
456462

457463

458464
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
459-
def test_hanavector_similarity_search_simple_euclidean_distance_invalid_rerank_model_id(table_name_with_cleanup) -> None:
465+
def test_hanavector_similarity_search_simple_euclidean_distance_invalid_rerank_model_id(table_name_with_cleanup, invalid_rerank_config_non_existent_model_id) -> None:
460466
table_name = table_name_with_cleanup
461467

462468
# Check if table is created
@@ -469,7 +475,7 @@ def test_hanavector_similarity_search_simple_euclidean_distance_invalid_rerank_m
469475
)
470476

471477
with pytest.raises(dbapi.Error):
472-
vectorDB.similarity_search(HanaTestConstants.TEXTS[0], 1, rerank_config={"model_id": "non_existing_model"})
478+
vectorDB.similarity_search(HanaTestConstants.TEXTS[0], 1, rerank_config=invalid_rerank_config_non_existent_model_id)
473479

474480

475481
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
@@ -502,9 +508,9 @@ def test_hanavector_similarity_search_with_metadata_invalid_rerank_config(
502508

503509

504510
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
505-
def test_hanavector_similarity_search_with_metadata_invalid_rerank_model_id(vectorDB) -> None:
511+
def test_hanavector_similarity_search_with_metadata_invalid_rerank_model_id(vectorDB, invalid_rerank_config_non_existent_model_id) -> None:
506512
with pytest.raises(dbapi.Error):
507-
vectorDB.similarity_search(HanaTestConstants.TEXTS[0], 3, rerank_config={"model_id": "non_existing_model"})
513+
vectorDB.similarity_search(HanaTestConstants.TEXTS[0], 3, rerank_config=invalid_rerank_config_non_existent_model_id)
508514

509515

510516
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
@@ -550,9 +556,9 @@ def test_hanavector_similarity_search_with_metadata_filter_invalid_rerank_config
550556

551557

552558
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
553-
def test_hanavector_similarity_search_with_metadata_filter_invalid_rerank_model_id(vectorDB) -> None:
559+
def test_hanavector_similarity_search_with_metadata_filter_invalid_rerank_model_id(vectorDB, invalid_rerank_config_non_existent_model_id) -> None:
554560
with pytest.raises(dbapi.Error):
555-
vectorDB.similarity_search(HanaTestConstants.TEXTS[0], 3, filter={"start": 100}, rerank_config={"model_id": "non_existing_model"})
561+
vectorDB.similarity_search(HanaTestConstants.TEXTS[0], 3, filter={"start": 100}, rerank_config=invalid_rerank_config_non_existent_model_id)
556562

557563

558564
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
@@ -583,9 +589,9 @@ def test_hanavector_similarity_search_with_metadata_filter_string_invalid_rerank
583589

584590

585591
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
586-
def test_hanavector_similarity_search_with_metadata_filter_string_invalid_rerank_model_id(vectorDB) -> None:
592+
def test_hanavector_similarity_search_with_metadata_filter_string_invalid_rerank_model_id(vectorDB, invalid_rerank_config_non_existent_model_id) -> None:
587593
with pytest.raises(dbapi.Error):
588-
vectorDB.similarity_search(HanaTestConstants.TEXTS[0], 3, filter={"quality": "bad"}, rerank_config={"model_id": "non_existing_model"})
594+
vectorDB.similarity_search(HanaTestConstants.TEXTS[0], 3, filter={"quality": "bad"}, rerank_config=invalid_rerank_config_non_existent_model_id)
589595

590596

591597
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
@@ -616,9 +622,9 @@ def test_hanavector_similarity_search_with_metadata_filter_bool_invalid_rerank_c
616622

617623

618624
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
619-
def test_hanavector_similarity_search_with_metadata_filter_bool_invalid_rerank_model_id(vectorDB) -> None:
625+
def test_hanavector_similarity_search_with_metadata_filter_bool_invalid_rerank_model_id(vectorDB, invalid_rerank_config_non_existent_model_id) -> None:
620626
with pytest.raises(dbapi.Error):
621-
vectorDB.similarity_search(HanaTestConstants.TEXTS[0], 3, filter={"ready": False}, rerank_config={"model_id": "non_existing_model"})
627+
vectorDB.similarity_search(HanaTestConstants.TEXTS[0], 3, filter={"ready": False}, rerank_config=invalid_rerank_config_non_existent_model_id)
622628

623629

624630
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
@@ -667,9 +673,9 @@ def test_hanavector_similarity_search_with_score_invalid_rerank_config(vectorDB,
667673

668674

669675
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")
670-
def test_hanavector_similarity_search_with_score_invalid_rerank_model_id(vectorDB) -> None:
676+
def test_hanavector_similarity_search_with_score_invalid_rerank_model_id(vectorDB, invalid_rerank_config_non_existent_model_id) -> None:
671677
with pytest.raises(dbapi.Error):
672-
vectorDB.similarity_search_with_score(HanaTestConstants.TEXTS[0], 3, rerank_config={"model_id": "non_existing_model"})
678+
vectorDB.similarity_search_with_score(HanaTestConstants.TEXTS[0], 3, rerank_config=invalid_rerank_config_non_existent_model_id)
673679

674680

675681
@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed")

tests/integration_tests/test_hana_db_internal_embeddings.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -84,16 +84,22 @@ def build_rerank_config(base_config: dict[str, Any] | None, top_n: int, query: s
8484

8585

8686
@pytest.fixture(params=[
87-
({"query": 5}, "rerank_config must contain 'query' and it must be a non-empty string"),
87+
({"query": 5}, "rerank_config must contain 'query' and it must be a string"),
8888
({"top_n": "not_an_int"}, "rerank_config 'top_n' must be a positive integer"),
8989
({"model_id": 5}, "rerank_config 'model_id' must be a non-empty string"),
90+
({"model_id": ""}, "rerank_config 'model_id' must be a non-empty string"),
9091
({"rank_fields": "not_a_list"}, "rerank_config 'rank_fields' must be a list of strings"),
9192
({"rank_fields": [1, 2, 3]}, "rerank_config 'rank_fields' must be a list of strings")
92-
], ids=["query_not_str", "top_n_not_int", "model_id_not_str", "rank_fields_not_list", "rank_fields_not_str"])
93+
], ids=["query_not_str", "top_n_not_int", "model_id_not_str", "model_id_empty_str", "rank_fields_not_list", "rank_fields_not_str"])
9394
def invalid_rerank_config_with_error_message(request):
9495
return request.param
9596

9697

98+
@pytest.fixture
99+
def invalid_rerank_config_non_existent_model_id():
100+
return {"query": "test_query", "model_id": "non_existing_model"}
101+
102+
97103
@pytest.fixture(scope="module", params=[{
98104
"internal_embedding_model_id": os.environ["HANA_DB_EMBEDDING_MODEL_ID"],
99105
}, {
@@ -247,9 +253,9 @@ def test_hanavector_similarity_search_with_metadata_filter_invalid_rerank_config
247253
vectorDB.similarity_search(HanaTestConstants.TEXTS[0], 3, filter={"start": 100}, rerank_config=invalid_rerank_config)
248254

249255

250-
def test_hanavector_similarity_search_with_metadata_filter_invalid_rerank_model_id(vectorDB) -> None:
256+
def test_hanavector_similarity_search_with_metadata_filter_invalid_rerank_model_id(vectorDB, invalid_rerank_config_non_existent_model_id) -> None:
251257
with pytest.raises(dbapi.Error):
252-
vectorDB.similarity_search(HanaTestConstants.TEXTS[0], 1, filter={"start": 100}, rerank_config={"model_id": "non_existing_model"})
258+
vectorDB.similarity_search(HanaTestConstants.TEXTS[0], 1, filter={"start": 100}, rerank_config=invalid_rerank_config_non_existent_model_id)
253259

254260

255261
def test_hanavector_similarity_search_simple(vectorDB, rerank_config_param) -> None:
@@ -281,9 +287,9 @@ def test_hanavector_similarity_search_simple_invalid_rerank_config(vectorDB, inv
281287
vectorDB.similarity_search(HanaTestConstants.TEXTS[0], 1, rerank_config=invalid_rerank_config)
282288

283289

284-
def test_hanavector_similarity_search_simple_invalid_rerank_model_id(vectorDB) -> None:
290+
def test_hanavector_similarity_search_simple_invalid_rerank_model_id(vectorDB, invalid_rerank_config_non_existent_model_id) -> None:
285291
with pytest.raises(dbapi.Error):
286-
vectorDB.similarity_search(HanaTestConstants.TEXTS[0], 1, rerank_config={"model_id": "non_existing_model"})
292+
vectorDB.similarity_search(HanaTestConstants.TEXTS[0], 1, rerank_config=invalid_rerank_config_non_existent_model_id)
287293

288294

289295
def test_hanavector_max_marginal_relevance_search(vectorDB) -> None:

0 commit comments

Comments
 (0)