|
| 1 | +import hashlib |
| 2 | + |
| 3 | +from elasticsearch import AsyncElasticsearch, helpers |
| 4 | + |
| 5 | +from api.clients.vector_store._basevectorstoreclient import BaseVectorStoreClient |
| 6 | +from api.schemas.chunks import Chunk |
| 7 | +from api.schemas.search import Search, SearchMethod |
| 8 | + |
| 9 | + |
| 10 | +class ElasticsearchVectorStoreClient(BaseVectorStoreClient, AsyncElasticsearch): |
| 11 | + default_method = SearchMethod.HYBRID |
| 12 | + |
| 13 | + def __init__(self, *args, **kwargs): |
| 14 | + kwargs.pop("type", None) # remove type from kwargs to avoid passing it to the super class |
| 15 | + self.number_of_shards = kwargs.pop("number_of_shards", 1) # remove number_of_shards from kwargs to avoid passing it to the super class |
| 16 | + self.number_of_replicas = kwargs.pop("number_of_replicas", 1) # remove number_of_replicas from kwargs to avoid passing it to the super class |
| 17 | + AsyncElasticsearch.__init__(self, *args, **kwargs) |
| 18 | + |
| 19 | + async def check(self) -> bool: |
| 20 | + try: |
| 21 | + await self.ping() |
| 22 | + return True |
| 23 | + except Exception: |
| 24 | + return False |
| 25 | + |
| 26 | + async def close(self) -> None: |
| 27 | + await super(AsyncElasticsearch, self).transport.close() |
| 28 | + |
| 29 | + async def create_collection(self, collection_id: int, vector_size: int) -> None: |
| 30 | + if await self.indices.exists(index=str(collection_id)): |
| 31 | + return |
| 32 | + |
| 33 | + settings = { |
| 34 | + "number_of_shards": self.number_of_shards, |
| 35 | + "number_of_replicas": self.number_of_replicas, |
| 36 | + "similarity": {"default": {"type": "BM25"}}, |
| 37 | + "analysis": { |
| 38 | + "filter": { |
| 39 | + "french_stop": {"type": "stop", "stopwords": "_french_"}, |
| 40 | + "french_stemmer": {"type": "stemmer", "language": "light_french"}, |
| 41 | + }, |
| 42 | + "analyzer": { |
| 43 | + "french_analyzer": { |
| 44 | + "tokenizer": "standard", |
| 45 | + "filter": ["lowercase", "french_stop", "french_stemmer"], |
| 46 | + } |
| 47 | + }, |
| 48 | + }, |
| 49 | + } |
| 50 | + |
| 51 | + mappings = { |
| 52 | + "dynamic_templates": [ |
| 53 | + {"metadata_objects_disabled": {"path_match": "metadata.*", "match_mapping_type": "object", "mapping": {"enabled": False}}}, |
| 54 | + { |
| 55 | + "metadata_dates_by_name": { |
| 56 | + "path_match": "metadata.*", |
| 57 | + "match_pattern": "regex", |
| 58 | + "match": "(?i).*(_at|_date|date)$", |
| 59 | + "mapping": { |
| 60 | + "type": "date", |
| 61 | + "ignore_malformed": True, |
| 62 | + "format": "strict_date_optional_time||strict_date_time||yyyy-MM-dd'T'HH:mm:ssZ||epoch_millis", |
| 63 | + }, |
| 64 | + } |
| 65 | + }, |
| 66 | + {"metadata_bools": {"path_match": "metadata.*", "match_mapping_type": "boolean", "mapping": {"type": "boolean"}}}, |
| 67 | + { |
| 68 | + "metadata_numbers_long": { |
| 69 | + "path_match": "metadata.*", |
| 70 | + "match_mapping_type": "long", |
| 71 | + "mapping": {"type": "long", "ignore_malformed": True, "coerce": True}, |
| 72 | + } |
| 73 | + }, |
| 74 | + { |
| 75 | + "metadata_numbers_double": { |
| 76 | + "path_match": "metadata.*", |
| 77 | + "match_mapping_type": "double", |
| 78 | + "mapping": {"type": "double", "ignore_malformed": True, "coerce": True}, |
| 79 | + } |
| 80 | + }, |
| 81 | + { |
| 82 | + "metadata_strings": { |
| 83 | + "path_match": "metadata.*", |
| 84 | + "match_mapping_type": "string", |
| 85 | + "mapping": {"type": "keyword", "ignore_above": 1024}, |
| 86 | + } |
| 87 | + }, |
| 88 | + ], |
| 89 | + "date_detection": False, |
| 90 | + "numeric_detection": False, |
| 91 | + "properties": { |
| 92 | + "id": {"type": "integer"}, |
| 93 | + "embedding": {"type": "dense_vector", "dims": vector_size}, |
| 94 | + "content": {"type": "text", "analyzer": "french_analyzer"}, |
| 95 | + "metadata": {"type": "object", "dynamic": True}, |
| 96 | + }, |
| 97 | + } |
| 98 | + |
| 99 | + await self.indices.create(index=str(collection_id), mappings=mappings, settings=settings) |
| 100 | + |
| 101 | + async def delete_collection(self, collection_id: int) -> None: |
| 102 | + if not await self.indices.exists(index=str(collection_id)): |
| 103 | + return |
| 104 | + |
| 105 | + await self.indices.delete(index=str(collection_id)) |
| 106 | + |
| 107 | + async def get_collections(self) -> list[int]: |
| 108 | + collections = await self.indices.get_alias() |
| 109 | + return [int(collection) for collection in collections] |
| 110 | + |
| 111 | + async def get_chunk_count(self, collection_id: int, document_id: int) -> int | None: |
| 112 | + try: |
| 113 | + body = {"query": {"match": {"metadata.document_id": document_id}}} |
| 114 | + result = await AsyncElasticsearch.count(self, index=str(collection_id), body=body) |
| 115 | + return result["count"] |
| 116 | + except Exception: |
| 117 | + return None |
| 118 | + |
| 119 | + async def delete_document(self, collection_id: int, document_id: int) -> None: |
| 120 | + body = {"query": {"match": {"metadata.document_id": document_id}}} |
| 121 | + await AsyncElasticsearch.delete_by_query(self, index=str(collection_id), body=body) |
| 122 | + await self.indices.refresh(index=str(collection_id)) |
| 123 | + |
| 124 | + async def get_chunks(self, collection_id: int, document_id: int, offset: int = 0, limit: int = 10, chunk_id: int | None = None) -> list[Chunk]: |
| 125 | + body = {"query": {"bool": {"must": [{"match": {"metadata.document_id": document_id}}]}}, "_source": ["id", "content", "metadata"]} |
| 126 | + if chunk_id is not None: |
| 127 | + body["query"]["bool"]["must"].append({"term": {"id": chunk_id}}) |
| 128 | + |
| 129 | + results = await AsyncElasticsearch.search(self, index=str(collection_id), body=body, from_=offset, size=limit) |
| 130 | + chunks = [] |
| 131 | + for hit in results["hits"]["hits"]: |
| 132 | + chunks.append( |
| 133 | + Chunk( |
| 134 | + id=hit["_source"]["id"], |
| 135 | + document=document_id, |
| 136 | + collection=collection_id, |
| 137 | + content=hit["_source"]["content"], |
| 138 | + metadata=hit["_source"]["metadata"], |
| 139 | + ) |
| 140 | + ) |
| 141 | + return chunks |
| 142 | + |
| 143 | + async def upsert(self, collection_id: int, chunks: list[Chunk], embeddings: list[list[float]]) -> None: |
| 144 | + actions = [ |
| 145 | + { |
| 146 | + "_index": str(collection_id), |
| 147 | + "_id": hashlib.sha256(f"{chunk.collection}|{chunk.document}|{chunk.id}".encode()).hexdigest(), |
| 148 | + "_source": { |
| 149 | + "id": chunk.id, |
| 150 | + "content": chunk.content, |
| 151 | + "embedding": embedding, |
| 152 | + "metadata": chunk.metadata, |
| 153 | + }, |
| 154 | + } |
| 155 | + for chunk, embedding in zip(chunks, embeddings) |
| 156 | + ] |
| 157 | + |
| 158 | + await helpers.async_bulk(client=self, actions=actions, index=collection_id) |
| 159 | + await self.indices.refresh(index=str(collection_id)) |
| 160 | + |
| 161 | + async def search( |
| 162 | + self, |
| 163 | + method: SearchMethod, |
| 164 | + collection_ids: list[int], |
| 165 | + query_prompt: str, |
| 166 | + query_vector: list[float], |
| 167 | + limit: int, |
| 168 | + offset: int, |
| 169 | + rff_k: int | None = 20, |
| 170 | + score_threshold: float = 0.0, |
| 171 | + ) -> list[Search]: |
| 172 | + if method == SearchMethod.SEMANTIC: |
| 173 | + searches = await self._semantic_search( |
| 174 | + query_vector=query_vector, collection_ids=collection_ids, limit=limit, offset=offset, score_threshold=score_threshold |
| 175 | + ) |
| 176 | + |
| 177 | + elif method == SearchMethod.LEXICAL: |
| 178 | + searches = await self._lexical_search( |
| 179 | + query_prompt=query_prompt, collection_ids=collection_ids, limit=limit, offset=offset, score_threshold=score_threshold |
| 180 | + ) |
| 181 | + |
| 182 | + else: # method == SearchMethod.HYBRID |
| 183 | + searches = await self._hybrid_search( |
| 184 | + query_prompt=query_prompt, query_vector=query_vector, collection_ids=collection_ids, limit=limit, offset=offset, rff_k=rff_k |
| 185 | + ) |
| 186 | + |
| 187 | + return searches |
| 188 | + |
| 189 | + async def _lexical_search( |
| 190 | + self, query_prompt: str, collection_ids: list[int], limit: int, offset: int, score_threshold: float = 0.0 |
| 191 | + ) -> list[Search]: |
| 192 | + collection_ids = [str(x) for x in collection_ids] |
| 193 | + fuzziness = {"fuzziness": "AUTO"} if len(query_prompt.split()) < 25 else {} |
| 194 | + body = { |
| 195 | + "query": {"multi_match": {"query": query_prompt, **fuzziness}}, |
| 196 | + "size": limit, |
| 197 | + "from": offset, |
| 198 | + "_source": {"excludes": ["embedding"]}, |
| 199 | + } |
| 200 | + results = await AsyncElasticsearch.search(self, index=collection_ids, body=body) |
| 201 | + hits = [hit for hit in results["hits"]["hits"] if hit] |
| 202 | + searches = [ |
| 203 | + Search( |
| 204 | + method=SearchMethod.LEXICAL.value, |
| 205 | + score=hit["_score"], |
| 206 | + chunk=Chunk(id=hit["_source"]["id"], content=hit["_source"]["content"], metadata=hit["_source"]["metadata"]), |
| 207 | + ) |
| 208 | + for hit in hits |
| 209 | + ] |
| 210 | + |
| 211 | + searches = [search for search in searches if search.score >= score_threshold] |
| 212 | + searches = sorted(searches, key=lambda x: x.score, reverse=True)[:limit] |
| 213 | + |
| 214 | + return searches |
| 215 | + |
| 216 | + async def _semantic_search( |
| 217 | + self, query_vector: list[float], collection_ids: list[int], limit: int, offset: int, score_threshold: float = 0.0 |
| 218 | + ) -> list[Search]: |
| 219 | + collection_ids = [str(x) for x in collection_ids] |
| 220 | + body = { |
| 221 | + "knn": {"field": "embedding", "query_vector": query_vector, "k": limit, "num_candidates": max(limit * 10, 100)}, |
| 222 | + "size": limit, |
| 223 | + "from": offset, |
| 224 | + "_source": {"excludes": ["embedding"]}, |
| 225 | + } |
| 226 | + results = await AsyncElasticsearch.search(self, index=collection_ids, body=body) |
| 227 | + hits = [hit for hit in results["hits"]["hits"] if hit] |
| 228 | + searches = [ |
| 229 | + Search( |
| 230 | + method=SearchMethod.SEMANTIC.value, |
| 231 | + score=hit["_score"], |
| 232 | + chunk=Chunk(id=hit["_source"]["id"], content=hit["_source"]["content"], metadata=hit["_source"]["metadata"]), |
| 233 | + ) |
| 234 | + for hit in hits |
| 235 | + ] |
| 236 | + |
| 237 | + searches = [search for search in searches if search.score >= score_threshold] |
| 238 | + searches = sorted(searches, key=lambda x: x.score, reverse=True)[:limit] |
| 239 | + |
| 240 | + return searches |
| 241 | + |
| 242 | + async def _hybrid_search( |
| 243 | + self, query_prompt: str, query_vector: list[float], collection_ids: list[int], limit: int, offset: int, rff_k: int, expansion_factor: int = 2 |
| 244 | + ) -> list[Search]: |
| 245 | + """ |
| 246 | + Hybrid search combines lexical and semantic search results using Reciprocal Rank Fusion (RRF). |
| 247 | +
|
| 248 | + Args: |
| 249 | + query_prompt (str): The search prompt |
| 250 | + query_vector (list[float]): The query vector |
| 251 | + collection_ids (List[int]): The collection ids |
| 252 | + k (int): The number of results to return |
| 253 | + rff_k (int): The constant k in the RRF formula |
| 254 | + expansion_factor (int): The factor that increases the number of results to search in each method before reranking |
| 255 | +
|
| 256 | + Returns: |
| 257 | + A combined list of searches with updated scores |
| 258 | + """ |
| 259 | + lexical_searches = await self._lexical_search( |
| 260 | + query_prompt=query_prompt, collection_ids=collection_ids, limit=int(limit * expansion_factor), offset=offset |
| 261 | + ) |
| 262 | + semantic_searches = await self._semantic_search( |
| 263 | + query_vector=query_vector, collection_ids=collection_ids, limit=int(limit * expansion_factor), offset=offset |
| 264 | + ) |
| 265 | + |
| 266 | + combined_scores = {} |
| 267 | + search_map = {} |
| 268 | + for searches in [lexical_searches, semantic_searches]: |
| 269 | + for rank, search in enumerate(searches): |
| 270 | + chunk_id = search.chunk.metadata.get("document_id") + search.chunk.id |
| 271 | + if chunk_id not in combined_scores: |
| 272 | + combined_scores[chunk_id] = 0 |
| 273 | + search_map[chunk_id] = search |
| 274 | + search_map[chunk_id].method = SearchMethod.HYBRID |
| 275 | + combined_scores[chunk_id] += 1 / (rff_k + rank + 1) |
| 276 | + |
| 277 | + ranked_scores = sorted(combined_scores.items(), key=lambda item: item[1], reverse=True) |
| 278 | + reranked_searches = [] |
| 279 | + for chunk_id, rrf_score in ranked_scores: |
| 280 | + search = search_map[chunk_id] |
| 281 | + search.score = rrf_score |
| 282 | + reranked_searches.append(search) |
| 283 | + |
| 284 | + searches = reranked_searches[:limit] |
| 285 | + |
| 286 | + return searches |
| 287 | + |
| 288 | + async def last_chunk_id(self, collection_id: int, document_id: int) -> int | None: |
| 289 | + result = await AsyncElasticsearch.search( |
| 290 | + self, |
| 291 | + index=str(collection_id), |
| 292 | + size=0, |
| 293 | + query={"match": {"metadata.document_id": document_id}}, |
| 294 | + aggs={"id_max": {"max": {"field": "id"}}}, |
| 295 | + ) |
| 296 | + value = result["aggregations"]["id_max"]["value"] |
| 297 | + value = int(value) if value is not None else None |
| 298 | + |
| 299 | + return value |
0 commit comments