Skip to content

Commit 1fa9966

Browse files
authored
Let query parameters take precedence for retriever tool (llamaindex/openai) (#183)
Signed-off-by: Ann Zhang <ann.zhang@databricks.com>
1 parent 4c70dc5 commit 1fa9966

File tree

5 files changed

+60
-4
lines changed

5 files changed

+60
-4
lines changed

integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,3 +330,18 @@ def test_kwargs_are_passed_through() -> None:
330330
score_threshold=0.5,
331331
extra_param="something random",
332332
)
333+
334+
335+
def test_kwargs_override_both_num_results_and_query_type() -> None:
336+
vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX, num_results=10, query_type="ANN")
337+
vector_search_tool._vector_store.similarity_search = MagicMock()
338+
339+
vector_search_tool.invoke(
340+
{"query": "what cities are in Germany", "k": 3, "query_type": "HYBRID"},
341+
)
342+
vector_search_tool._vector_store.similarity_search.assert_called_once_with(
343+
query="what cities are in Germany",
344+
k=3, # Should use overridden value
345+
query_type="HYBRID", # Should use overridden value
346+
filter={},
347+
)

integrations/llamaindex/src/databricks_llamaindex/vector_search_retriever_tool.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,15 +117,19 @@ def get_query_text_vector(query: str) -> Tuple[Optional[str], Optional[List[floa
117117
kwargs = {**kwargs, **(self.model_extra or {})}
118118
kwargs = {k: v for k, v in kwargs.items() if k in signature.parameters}
119119

120+
# Allow kwargs to override the default values upon invocation
121+
num_results = kwargs.pop("num_results", self.num_results)
122+
query_type = kwargs.pop("query_type", self.query_type)
123+
120124
# Ensure that we don't have duplicate keys
121125
kwargs.update(
122126
{
123127
"query_text": query_text,
124128
"query_vector": query_vector,
125129
"columns": self.columns,
126130
"filters": combined_filters,
127-
"num_results": self.num_results,
128-
"query_type": self.query_type,
131+
"num_results": num_results,
132+
"query_type": query_type,
129133
}
130134
)
131135
search_resp = self._index.similarity_search(**kwargs)

integrations/llamaindex/tests/unit_tests/test_vector_search_retriever_tool.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,3 +242,18 @@ def test_filters_are_combined() -> None:
242242
query_type=vector_search_tool.query_type,
243243
query_vector=None,
244244
)
245+
246+
247+
def test_kwargs_override_both_num_results_and_query_type() -> None:
248+
vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX, num_results=10, query_type="ANN")
249+
vector_search_tool._index = create_autospec(VectorSearchIndex, instance=True)
250+
251+
vector_search_tool.call(query="what cities are in Germany", num_results=3, query_type="HYBRID")
252+
vector_search_tool._index.similarity_search.assert_called_once_with(
253+
columns=vector_search_tool.columns,
254+
query_text="what cities are in Germany",
255+
filters={},
256+
num_results=3, # Should use overridden value
257+
query_type="HYBRID", # Should use overridden value
258+
query_vector=None,
259+
)

integrations/openai/src/databricks_openai/vector_search_retriever_tool.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,14 +232,19 @@ def execute(
232232
signature = inspect.signature(self._index.similarity_search)
233233
kwargs = {**kwargs, **(self.model_extra or {})}
234234
kwargs = {k: v for k, v in kwargs.items() if k in signature.parameters}
235+
236+
# Allow kwargs to override the default values upon invocation
237+
num_results = kwargs.pop("num_results", self.num_results)
238+
query_type = kwargs.pop("query_type", self.query_type)
239+
235240
kwargs.update(
236241
{
237242
"query_text": query_text,
238243
"query_vector": query_vector,
239244
"columns": self.columns,
240245
"filters": combined_filters,
241-
"num_results": self.num_results,
242-
"query_type": self.query_type,
246+
"num_results": num_results,
247+
"query_type": query_type,
243248
}
244249
)
245250
search_resp = self._index.similarity_search(**kwargs)

integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,3 +363,20 @@ def test_filters_are_combined() -> None:
363363
query_type=vector_search_tool.query_type,
364364
query_vector=None,
365365
)
366+
367+
368+
def test_kwargs_override_both_num_results_and_query_type() -> None:
369+
vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX, num_results=10, query_type="ANN")
370+
vector_search_tool._index = create_autospec(VectorSearchIndex, instance=True)
371+
372+
vector_search_tool.execute(
373+
query="what cities are in Germany", num_results=3, query_type="HYBRID"
374+
)
375+
vector_search_tool._index.similarity_search.assert_called_once_with(
376+
columns=vector_search_tool.columns,
377+
query_text="what cities are in Germany",
378+
filters={},
379+
num_results=3, # Should use overridden value
380+
query_type="HYBRID", # Should use overridden value
381+
query_vector=None,
382+
)

0 commit comments

Comments
 (0)