Skip to content

Commit f107ba7

Browse files
authored
Merge pull request #46 from lsst-dm/tickets/SP-2363
SP-2363: Update CustomWeaviateVectorStore
2 parents dc8eda9 + 43f45a0 commit f107ba7

File tree

2 files changed

+29
-25
lines changed

2 files changed

+29
-25
lines changed

python/rubin/rag/chatbot.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -85,19 +85,14 @@ def configure_client() -> WeaviateClient:
8585

8686
def configure_retriever() -> VectorStoreRetriever:
8787
"""Configure the Weaviate retriever."""
88-
search_kwargs = {
89-
"k": 6,
90-
"return_metadata": ["score"],
91-
}
92-
9388
selected_sources = [
9489
source.lower() for source in st.session_state["required_sources"]
9590
]
9691
if selected_sources:
9792
filters = Filter.by_property("source_key").contains_any(
9893
selected_sources
9994
)
100-
search_kwargs["filters"] = filters
95+
search_kwargs = {"k": 6, "where_filter": filters}
10196

10297
return CustomWeaviateVectorStore(
10398
client=configure_client(),

python/rubin/rag/custom_weaviate_vector_store.py

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@
2727
from collections.abc import Callable
2828
from typing import Any
2929

30+
from langchain_core.documents.base import Document
3031
from langchain_weaviate.vectorstores import WeaviateVectorStore
32+
from weaviate.classes.query import MetadataQuery
3133

3234

3335
class CustomWeaviateVectorStore(WeaviateVectorStore):
@@ -47,6 +49,11 @@ def __init__(
4749
if use_multi_tenancy is None:
4850
use_multi_tenancy = False
4951

52+
self.client = client
53+
self.index_name = index_name
54+
self.text_key = text_key
55+
self.embedding = embedding
56+
5057
super().__init__(
5158
client=client,
5259
index_name=index_name,
@@ -57,27 +64,29 @@ def __init__(
5764
use_multi_tenancy=use_multi_tenancy,
5865
)
5966

60-
def similarity_search(self, query: str, k: int = 4, **kwargs: Any) -> list:
67+
def similarity_search(
68+
self, query: str, k: int = 4, **kwargs: Any
69+
) -> list[Document]:
6170
"""
62-
Perform a similarity search and return documents
63-
along with their similarity scores.
64-
65-
Args:
66-
query (str): The query text to search for.
67-
k (int): The number of results to return (default: 4).
68-
**kwargs: Additional keyword arguments to pass.
69-
70-
Returns
71-
-------
72-
List[Tuple[Document, float]]: A list of tuples
73-
where each tuple contains a
74-
document and its corresponding similarity score.
71+
Return list of documents most similar to the query text and their
72+
score. A higher score means more similarity, with a max of 1.
7573
"""
76-
docs = self._perform_search(query, k, return_score=True, **kwargs)
74+
where_filter = kwargs.get("where_filter")
75+
collection = self.client.collections.get(self.index_name)
76+
response = collection.query.hybrid(
77+
query=query,
78+
limit=k,
79+
filters=where_filter,
80+
alpha=1,
81+
return_metadata=MetadataQuery(score=True, explain_score=True),
82+
)
7783

7884
results = []
79-
for doc in docs:
80-
doc[0].metadata["score"] = doc[1]
81-
results.append(doc[0])
82-
85+
for obj in response.objects:
86+
text = obj.properties.get("page_content", "")
87+
metadata = obj.properties.copy() if obj.properties else {}
88+
metadata["score"] = (
89+
obj.metadata.score
90+
) # Inject the score into metadata
91+
results.append(Document(page_content=text, metadata=metadata))
8392
return results

0 commit comments

Comments
 (0)