Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions litellm/llms/azure_ai/vector_stores/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def transform_search_vector_store_request(
api_base: str,
litellm_logging_obj: LiteLLMLoggingObj,
litellm_params: dict,
extra_body: Optional[Dict[str, Any]] = None,
) -> Tuple[str, Dict[str, Any]]:
"""
Transform search request for Azure AI Search API
Expand Down
3 changes: 3 additions & 0 deletions litellm/llms/base_llm/vector_store/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def transform_search_vector_store_request(
api_base: str,
litellm_logging_obj: LiteLLMLoggingObj,
litellm_params: dict,
extra_body: Optional[Dict[str, Any]] = None,
) -> Tuple[str, Dict]:
pass

Expand All @@ -70,6 +71,7 @@ async def atransform_search_vector_store_request(
api_base: str,
litellm_logging_obj: LiteLLMLoggingObj,
litellm_params: dict,
extra_body: Optional[Dict[str, Any]] = None,
) -> Tuple[str, Dict]:
"""
Optional async version of transform_search_vector_store_request.
Expand All @@ -84,6 +86,7 @@ async def atransform_search_vector_store_request(
api_base=api_base,
litellm_logging_obj=litellm_logging_obj,
litellm_params=litellm_params,
extra_body=extra_body,
)

@abstractmethod
Expand Down
43 changes: 34 additions & 9 deletions litellm/llms/bedrock/vector_stores/transformation.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
from urllib.parse import urlparse

import httpx

from litellm._logging import verbose_logger
from litellm.llms.base_llm.vector_store.transformation import BaseVectorStoreConfig
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
from litellm.types.integrations.rag.bedrock_knowledgebase import (
BedrockKBContent,
BedrockKBResponse,
BedrockKBRetrievalConfiguration,
BedrockKBResponse,
BedrockKBRetrievalQuery,
)
from litellm.types.router import GenericLiteLLMParams
Expand Down Expand Up @@ -202,6 +204,7 @@ def transform_search_vector_store_request(
api_base: str,
litellm_logging_obj: LiteLLMLoggingObj,
litellm_params: dict,
extra_body: Optional[Dict[str, Any]] = None,
) -> Tuple[str, Dict]:
if isinstance(query, list):
query = " ".join(query)
Expand All @@ -213,24 +216,46 @@ def transform_search_vector_store_request(
}

retrieval_config: Dict[str, Any] = {}

if isinstance(extra_body, dict):
retrieval_config = deepcopy(
extra_body.get("retrievalConfiguration")
or extra_body.get("retrieval_configuration")
or {}
)
max_results = vector_store_search_optional_params.get("max_num_results")
if max_results is not None:
existing_number_of_results = retrieval_config.get(
"vectorSearchConfiguration", {}
).get("numberOfResults")
if (
existing_number_of_results is not None
and existing_number_of_results != max_results
):
verbose_logger.debug(
"Overriding extra_body retrievalConfiguration.vectorSearchConfiguration.numberOfResults (%s) with max_num_results=%s",
existing_number_of_results,
max_results,
)
retrieval_config.setdefault("vectorSearchConfiguration", {})[
"numberOfResults"
] = max_results
filters = vector_store_search_optional_params.get("filters")
if filters is not None:
existing_filter = retrieval_config.get("vectorSearchConfiguration", {}).get(
"filter"
)
if existing_filter is not None and existing_filter != filters:
verbose_logger.debug(
"Overriding extra_body retrievalConfiguration.vectorSearchConfiguration.filter with filters from vector_store_search_optional_params"
)
retrieval_config.setdefault("vectorSearchConfiguration", {})[
"filter"
] = filters
Comment thread
Sameerlite marked this conversation as resolved.
Comment thread
Sameerlite marked this conversation as resolved.
if retrieval_config:
# Create a properly typed retrieval configuration
typed_retrieval_config: BedrockKBRetrievalConfiguration = {}
if "vectorSearchConfiguration" in retrieval_config:
typed_retrieval_config["vectorSearchConfiguration"] = retrieval_config[
"vectorSearchConfiguration"
]
request_body["retrievalConfiguration"] = typed_retrieval_config
request_body["retrievalConfiguration"] = cast(
BedrockKBRetrievalConfiguration, retrieval_config
)

litellm_logging_obj.model_call_details["query"] = query
return url, request_body
Expand Down
3 changes: 3 additions & 0 deletions litellm/llms/custom_httpx/llm_http_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8585,6 +8585,7 @@ async def async_vector_store_search_handler(
api_base=api_base,
litellm_logging_obj=logging_obj,
litellm_params=dict(litellm_params),
extra_body=extra_body,
)
else:
(
Expand All @@ -8597,6 +8598,7 @@ async def async_vector_store_search_handler(
api_base=api_base,
litellm_logging_obj=logging_obj,
litellm_params=dict(litellm_params),
extra_body=extra_body,
)
all_optional_params: Dict[str, Any] = dict(litellm_params)
all_optional_params.update(vector_store_search_optional_params or {})
Expand Down Expand Up @@ -8697,6 +8699,7 @@ def vector_store_search_handler(
api_base=api_base,
litellm_logging_obj=logging_obj,
litellm_params=dict(litellm_params),
extra_body=extra_body,
)

all_optional_params: Dict[str, Any] = dict(litellm_params)
Expand Down
1 change: 1 addition & 0 deletions litellm/llms/gemini/vector_stores/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def transform_search_vector_store_request(
api_base: str,
litellm_logging_obj: LiteLLMLoggingObj,
litellm_params: dict,
extra_body: Optional[Dict[str, Any]] = None,
) -> Tuple[str, Dict]:
"""
Transform search request to Gemini's generateContent format.
Expand Down
1 change: 1 addition & 0 deletions litellm/llms/milvus/vector_stores/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def transform_search_vector_store_request(
api_base: str,
litellm_logging_obj: LiteLLMLoggingObj,
litellm_params: dict,
extra_body: Optional[Dict[str, Any]] = None,
) -> Tuple[str, Dict[str, Any]]:
"""
Transform search request for Azure AI Search API
Expand Down
1 change: 1 addition & 0 deletions litellm/llms/openai/vector_stores/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def transform_search_vector_store_request(
api_base: str,
litellm_logging_obj: LiteLLMLoggingObj,
litellm_params: dict,
extra_body: Optional[Dict[str, Any]] = None,
) -> Tuple[str, Dict]:
url = f"{api_base}/{vector_store_id}/search"
typed_request_body = VectorStoreSearchRequest(
Expand Down
2 changes: 2 additions & 0 deletions litellm/llms/pg_vector/vector_stores/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def transform_search_vector_store_request(
api_base: str,
litellm_logging_obj: LiteLLMLoggingObj,
litellm_params: dict,
extra_body: Optional[Dict[str, Any]] = None,
) -> Tuple[str, Dict]:
url = f"{api_base}/{vector_store_id}/search"
_, request_body = super().transform_search_vector_store_request(
Expand All @@ -89,5 +90,6 @@ def transform_search_vector_store_request(
api_base=api_base,
litellm_logging_obj=litellm_logging_obj,
litellm_params=litellm_params,
extra_body=extra_body,
)
return url, request_body
1 change: 1 addition & 0 deletions litellm/llms/ragflow/vector_stores/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def transform_search_vector_store_request(
api_base: str,
litellm_logging_obj: LiteLLMLoggingObj,
litellm_params: dict,
extra_body: Optional[Dict[str, Any]] = None,
) -> Tuple[str, Dict]:
"""RAGFlow vector stores are management-only, search is not supported."""
raise NotImplementedError(
Expand Down
2 changes: 2 additions & 0 deletions litellm/llms/s3_vectors/vector_stores/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def transform_search_vector_store_request(
api_base: str,
litellm_logging_obj: LiteLLMLoggingObj,
litellm_params: dict,
extra_body: Optional[Dict[str, Any]] = None,
) -> Tuple[str, Dict]:
"""Sync version - generates embedding synchronously."""
# For S3 Vectors, vector_store_id should be in format: bucket_name:index_name
Expand Down Expand Up @@ -140,6 +141,7 @@ async def atransform_search_vector_store_request(
api_base: str,
litellm_logging_obj: LiteLLMLoggingObj,
litellm_params: dict,
extra_body: Optional[Dict[str, Any]] = None,
) -> Tuple[str, Dict]:
"""Async version - generates embedding asynchronously."""
# For S3 Vectors, vector_store_id should be in format: bucket_name:index_name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def transform_search_vector_store_request(
api_base: str,
litellm_logging_obj: LiteLLMLoggingObj,
litellm_params: dict,
extra_body: Optional[Dict[str, Any]] = None,
) -> Tuple[str, Dict[str, Any]]:
"""
Transform search request for Vertex AI RAG API
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def transform_search_vector_store_request(
api_base: str,
litellm_logging_obj: LiteLLMLoggingObj,
litellm_params: dict,
extra_body: Optional[Dict[str, Any]] = None,
) -> Tuple[str, Dict[str, Any]]:
"""
Transform search request for Vertex AI RAG API
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ async def fake_async_vector_store_search_handler(
api_base=api_base,
litellm_logging_obj=logging_obj,
litellm_params=litellm_params_dict,
extra_body=None,
)
)
captured_request_body["url"] = url
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,112 @@ def test_transform_search_request():
api_base="https://bedrock-agent-runtime.us-west-2.amazonaws.com/knowledgebases",
litellm_logging_obj=mock_log,
litellm_params={},
extra_body=None,
)

assert url.endswith("/kb123/retrieve")
assert body["retrievalQuery"].get("text") == "hello"


def test_transform_search_request_uses_only_retrieval_config_from_extra_body():
config = BedrockVectorStoreConfig()
mock_log = MagicMock()
mock_log.model_call_details = {}

url, body = config.transform_search_vector_store_request(
vector_store_id="kb123",
query="hello",
vector_store_search_optional_params={},
api_base="https://bedrock-agent-runtime.us-west-2.amazonaws.com/knowledgebases",
litellm_logging_obj=mock_log,
litellm_params={},
extra_body={
"retrievalConfiguration": {
"vectorSearchConfiguration": {
"overrideSearchType": "HYBRID",
"numberOfResults": 8,
}
},
"unrelatedField": {"should_not": "be_forwarded"},
},
)

assert url.endswith("/kb123/retrieve")
assert body["retrievalQuery"].get("text") == "hello"
assert (
body["retrievalConfiguration"]["vectorSearchConfiguration"][
"overrideSearchType"
]
== "HYBRID"
)
assert "unrelatedField" not in body


def test_transform_search_request_does_not_mutate_extra_body_and_overrides_number_of_results():
config = BedrockVectorStoreConfig()
mock_log = MagicMock()
mock_log.model_call_details = {}
extra_body = {
"retrievalConfiguration": {
"vectorSearchConfiguration": {
"overrideSearchType": "HYBRID",
"numberOfResults": 8,
}
}
}

_, body = config.transform_search_vector_store_request(
vector_store_id="kb123",
query="hello",
vector_store_search_optional_params={"max_num_results": 10},
api_base="https://bedrock-agent-runtime.us-west-2.amazonaws.com/knowledgebases",
litellm_logging_obj=mock_log,
litellm_params={},
extra_body=extra_body,
)

assert (
body["retrievalConfiguration"]["vectorSearchConfiguration"]["numberOfResults"]
== 10
)
assert (
extra_body["retrievalConfiguration"]["vectorSearchConfiguration"][
"numberOfResults"
]
== 8
)


def test_transform_search_request_overrides_filter_without_mutating_extra_body():
config = BedrockVectorStoreConfig()
mock_log = MagicMock()
mock_log.model_call_details = {}
extra_body = {
"retrievalConfiguration": {
"vectorSearchConfiguration": {
"filter": {"equals": {"key": "tenant", "value": "a"}}
}
}
}
new_filter = {"equals": {"key": "tenant", "value": "b"}}

_, body = config.transform_search_vector_store_request(
vector_store_id="kb123",
query="hello",
vector_store_search_optional_params={"filters": new_filter},
api_base="https://bedrock-agent-runtime.us-west-2.amazonaws.com/knowledgebases",
litellm_logging_obj=mock_log,
litellm_params={},
extra_body=extra_body,
)

assert (
body["retrievalConfiguration"]["vectorSearchConfiguration"]["filter"]
== new_filter
)
assert (
extra_body["retrievalConfiguration"]["vectorSearchConfiguration"]["filter"][
"equals"
]["value"]
== "a"
)
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def test_transform_search_request_invalid_vector_store_id(self):
api_base="https://s3vectors.us-west-2.api.aws",
litellm_logging_obj=mock_logging_obj,
litellm_params={},
extra_body=None,
)

def test_transform_search_response(self):
Expand Down
1 change: 1 addition & 0 deletions tests/vector_store_tests/test_ragflow_vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ def test_transform_search_vector_store_request_not_implemented(self):
api_base="http://localhost:9380",
litellm_logging_obj=logging_obj,
litellm_params={},
extra_body=None,
)

def test_transform_search_vector_store_response_not_implemented(self):
Expand Down
Loading