Skip to content

Commit 3170632

Browse files
committed
feat(chunks): add POST /v1/chunks endpoint
1 parent 4f8de1b commit 3170632

12 files changed

Lines changed: 470 additions & 98 deletions

File tree

Lines changed: 299 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,299 @@
1+
import hashlib
2+
3+
from elasticsearch import AsyncElasticsearch, helpers
4+
5+
from api.clients.vector_store._basevectorstoreclient import BaseVectorStoreClient
6+
from api.schemas.chunks import Chunk
7+
from api.schemas.search import Search, SearchMethod
8+
9+
10+
class ElasticsearchVectorStoreClient(BaseVectorStoreClient, AsyncElasticsearch):
11+
default_method = SearchMethod.HYBRID
12+
13+
def __init__(self, *args, **kwargs):
14+
kwargs.pop("type", None) # remove type from kwargs to avoid passing it to the super class
15+
self.number_of_shards = kwargs.pop("number_of_shards", 1) # remove number_of_shards from kwargs to avoid passing it to the super class
16+
self.number_of_replicas = kwargs.pop("number_of_replicas", 1) # remove number_of_replicas from kwargs to avoid passing it to the super class
17+
AsyncElasticsearch.__init__(self, *args, **kwargs)
18+
19+
async def check(self) -> bool:
20+
try:
21+
await self.ping()
22+
return True
23+
except Exception:
24+
return False
25+
26+
async def close(self) -> None:
27+
await super(AsyncElasticsearch, self).transport.close()
28+
29+
async def create_collection(self, collection_id: int, vector_size: int) -> None:
30+
if await self.indices.exists(index=str(collection_id)):
31+
return
32+
33+
settings = {
34+
"number_of_shards": self.number_of_shards,
35+
"number_of_replicas": self.number_of_replicas,
36+
"similarity": {"default": {"type": "BM25"}},
37+
"analysis": {
38+
"filter": {
39+
"french_stop": {"type": "stop", "stopwords": "_french_"},
40+
"french_stemmer": {"type": "stemmer", "language": "light_french"},
41+
},
42+
"analyzer": {
43+
"french_analyzer": {
44+
"tokenizer": "standard",
45+
"filter": ["lowercase", "french_stop", "french_stemmer"],
46+
}
47+
},
48+
},
49+
}
50+
51+
mappings = {
52+
"dynamic_templates": [
53+
{"metadata_objects_disabled": {"path_match": "metadata.*", "match_mapping_type": "object", "mapping": {"enabled": False}}},
54+
{
55+
"metadata_dates_by_name": {
56+
"path_match": "metadata.*",
57+
"match_pattern": "regex",
58+
"match": "(?i).*(_at|_date|date)$",
59+
"mapping": {
60+
"type": "date",
61+
"ignore_malformed": True,
62+
"format": "strict_date_optional_time||strict_date_time||yyyy-MM-dd'T'HH:mm:ssZ||epoch_millis",
63+
},
64+
}
65+
},
66+
{"metadata_bools": {"path_match": "metadata.*", "match_mapping_type": "boolean", "mapping": {"type": "boolean"}}},
67+
{
68+
"metadata_numbers_long": {
69+
"path_match": "metadata.*",
70+
"match_mapping_type": "long",
71+
"mapping": {"type": "long", "ignore_malformed": True, "coerce": True},
72+
}
73+
},
74+
{
75+
"metadata_numbers_double": {
76+
"path_match": "metadata.*",
77+
"match_mapping_type": "double",
78+
"mapping": {"type": "double", "ignore_malformed": True, "coerce": True},
79+
}
80+
},
81+
{
82+
"metadata_strings": {
83+
"path_match": "metadata.*",
84+
"match_mapping_type": "string",
85+
"mapping": {"type": "keyword", "ignore_above": 1024},
86+
}
87+
},
88+
],
89+
"date_detection": False,
90+
"numeric_detection": False,
91+
"properties": {
92+
"id": {"type": "integer"},
93+
"embedding": {"type": "dense_vector", "dims": vector_size},
94+
"content": {"type": "text", "analyzer": "french_analyzer"},
95+
"metadata": {"type": "object", "dynamic": True},
96+
},
97+
}
98+
99+
await self.indices.create(index=str(collection_id), mappings=mappings, settings=settings)
100+
101+
async def delete_collection(self, collection_id: int) -> None:
102+
if not await self.indices.exists(index=str(collection_id)):
103+
return
104+
105+
await self.indices.delete(index=str(collection_id))
106+
107+
async def get_collections(self) -> list[int]:
108+
collections = await self.indices.get_alias()
109+
return [int(collection) for collection in collections]
110+
111+
async def get_chunk_count(self, collection_id: int, document_id: int) -> int | None:
112+
try:
113+
body = {"query": {"match": {"metadata.document_id": document_id}}}
114+
result = await AsyncElasticsearch.count(self, index=str(collection_id), body=body)
115+
return result["count"]
116+
except Exception:
117+
return None
118+
119+
async def delete_document(self, collection_id: int, document_id: int) -> None:
120+
body = {"query": {"match": {"metadata.document_id": document_id}}}
121+
await AsyncElasticsearch.delete_by_query(self, index=str(collection_id), body=body)
122+
await self.indices.refresh(index=str(collection_id))
123+
124+
async def get_chunks(self, collection_id: int, document_id: int, offset: int = 0, limit: int = 10, chunk_id: int | None = None) -> list[Chunk]:
125+
body = {"query": {"bool": {"must": [{"match": {"metadata.document_id": document_id}}]}}, "_source": ["id", "content", "metadata"]}
126+
if chunk_id is not None:
127+
body["query"]["bool"]["must"].append({"term": {"id": chunk_id}})
128+
129+
results = await AsyncElasticsearch.search(self, index=str(collection_id), body=body, from_=offset, size=limit)
130+
chunks = []
131+
for hit in results["hits"]["hits"]:
132+
chunks.append(
133+
Chunk(
134+
id=hit["_source"]["id"],
135+
document=document_id,
136+
collection=collection_id,
137+
content=hit["_source"]["content"],
138+
metadata=hit["_source"]["metadata"],
139+
)
140+
)
141+
return chunks
142+
143+
async def upsert(self, collection_id: int, chunks: list[Chunk], embeddings: list[list[float]]) -> None:
144+
actions = [
145+
{
146+
"_index": str(collection_id),
147+
"_id": hashlib.sha256(f"{chunk.collection}|{chunk.document}|{chunk.id}".encode()).hexdigest(),
148+
"_source": {
149+
"id": chunk.id,
150+
"content": chunk.content,
151+
"embedding": embedding,
152+
"metadata": chunk.metadata,
153+
},
154+
}
155+
for chunk, embedding in zip(chunks, embeddings)
156+
]
157+
158+
await helpers.async_bulk(client=self, actions=actions, index=collection_id)
159+
await self.indices.refresh(index=str(collection_id))
160+
161+
async def search(
162+
self,
163+
method: SearchMethod,
164+
collection_ids: list[int],
165+
query_prompt: str,
166+
query_vector: list[float],
167+
limit: int,
168+
offset: int,
169+
rff_k: int | None = 20,
170+
score_threshold: float = 0.0,
171+
) -> list[Search]:
172+
if method == SearchMethod.SEMANTIC:
173+
searches = await self._semantic_search(
174+
query_vector=query_vector, collection_ids=collection_ids, limit=limit, offset=offset, score_threshold=score_threshold
175+
)
176+
177+
elif method == SearchMethod.LEXICAL:
178+
searches = await self._lexical_search(
179+
query_prompt=query_prompt, collection_ids=collection_ids, limit=limit, offset=offset, score_threshold=score_threshold
180+
)
181+
182+
else: # method == SearchMethod.HYBRID
183+
searches = await self._hybrid_search(
184+
query_prompt=query_prompt, query_vector=query_vector, collection_ids=collection_ids, limit=limit, offset=offset, rff_k=rff_k
185+
)
186+
187+
return searches
188+
189+
async def _lexical_search(
190+
self, query_prompt: str, collection_ids: list[int], limit: int, offset: int, score_threshold: float = 0.0
191+
) -> list[Search]:
192+
collection_ids = [str(x) for x in collection_ids]
193+
fuzziness = {"fuzziness": "AUTO"} if len(query_prompt.split()) < 25 else {}
194+
body = {
195+
"query": {"multi_match": {"query": query_prompt, **fuzziness}},
196+
"size": limit,
197+
"from": offset,
198+
"_source": {"excludes": ["embedding"]},
199+
}
200+
results = await AsyncElasticsearch.search(self, index=collection_ids, body=body)
201+
hits = [hit for hit in results["hits"]["hits"] if hit]
202+
searches = [
203+
Search(
204+
method=SearchMethod.LEXICAL.value,
205+
score=hit["_score"],
206+
chunk=Chunk(id=hit["_source"]["id"], content=hit["_source"]["content"], metadata=hit["_source"]["metadata"]),
207+
)
208+
for hit in hits
209+
]
210+
211+
searches = [search for search in searches if search.score >= score_threshold]
212+
searches = sorted(searches, key=lambda x: x.score, reverse=True)[:limit]
213+
214+
return searches
215+
216+
async def _semantic_search(
217+
self, query_vector: list[float], collection_ids: list[int], limit: int, offset: int, score_threshold: float = 0.0
218+
) -> list[Search]:
219+
collection_ids = [str(x) for x in collection_ids]
220+
body = {
221+
"knn": {"field": "embedding", "query_vector": query_vector, "k": limit, "num_candidates": max(limit * 10, 100)},
222+
"size": limit,
223+
"from": offset,
224+
"_source": {"excludes": ["embedding"]},
225+
}
226+
results = await AsyncElasticsearch.search(self, index=collection_ids, body=body)
227+
hits = [hit for hit in results["hits"]["hits"] if hit]
228+
searches = [
229+
Search(
230+
method=SearchMethod.SEMANTIC.value,
231+
score=hit["_score"],
232+
chunk=Chunk(id=hit["_source"]["id"], content=hit["_source"]["content"], metadata=hit["_source"]["metadata"]),
233+
)
234+
for hit in hits
235+
]
236+
237+
searches = [search for search in searches if search.score >= score_threshold]
238+
searches = sorted(searches, key=lambda x: x.score, reverse=True)[:limit]
239+
240+
return searches
241+
242+
async def _hybrid_search(
243+
self, query_prompt: str, query_vector: list[float], collection_ids: list[int], limit: int, offset: int, rff_k: int, expansion_factor: int = 2
244+
) -> list[Search]:
245+
"""
246+
Hybrid search combines lexical and semantic search results using Reciprocal Rank Fusion (RRF).
247+
248+
Args:
249+
query_prompt (str): The search prompt
250+
query_vector (list[float]): The query vector
251+
collection_ids (List[int]): The collection ids
252+
k (int): The number of results to return
253+
rff_k (int): The constant k in the RRF formula
254+
expansion_factor (int): The factor that increases the number of results to search in each method before reranking
255+
256+
Returns:
257+
A combined list of searches with updated scores
258+
"""
259+
lexical_searches = await self._lexical_search(
260+
query_prompt=query_prompt, collection_ids=collection_ids, limit=int(limit * expansion_factor), offset=offset
261+
)
262+
semantic_searches = await self._semantic_search(
263+
query_vector=query_vector, collection_ids=collection_ids, limit=int(limit * expansion_factor), offset=offset
264+
)
265+
266+
combined_scores = {}
267+
search_map = {}
268+
for searches in [lexical_searches, semantic_searches]:
269+
for rank, search in enumerate(searches):
270+
chunk_id = search.chunk.metadata.get("document_id") + search.chunk.id
271+
if chunk_id not in combined_scores:
272+
combined_scores[chunk_id] = 0
273+
search_map[chunk_id] = search
274+
search_map[chunk_id].method = SearchMethod.HYBRID
275+
combined_scores[chunk_id] += 1 / (rff_k + rank + 1)
276+
277+
ranked_scores = sorted(combined_scores.items(), key=lambda item: item[1], reverse=True)
278+
reranked_searches = []
279+
for chunk_id, rrf_score in ranked_scores:
280+
search = search_map[chunk_id]
281+
search.score = rrf_score
282+
reranked_searches.append(search)
283+
284+
searches = reranked_searches[:limit]
285+
286+
return searches
287+
288+
async def last_chunk_id(self, collection_id: int, document_id: int) -> int | None:
289+
result = await AsyncElasticsearch.search(
290+
self,
291+
index=str(collection_id),
292+
size=0,
293+
query={"match": {"metadata.document_id": document_id}},
294+
aggs={"id_max": {"max": {"field": "id"}}},
295+
)
296+
value = result["aggregations"]["id_max"]["value"]
297+
value = int(value) if value is not None else None
298+
299+
return value

api/endpoints/chunks.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,54 @@
22

33
from elasticsearch import AsyncElasticsearch
44
from fastapi import APIRouter, Depends, Path, Query, Request, Security
5+
from redis.asyncio import Redis as AsyncRedis
56
from sqlalchemy.ext.asyncio import AsyncSession
67

78
from api.helpers._accesscontroller import AccessController
89
from api.helpers._elasticsearchvectorstore import ElasticsearchVectorStore
9-
from api.schemas.chunks import Chunk, Chunks
10+
from api.helpers.models import ModelRegistry
11+
from api.schemas.chunks import Chunk, Chunks, ChunksResponse, CreateChunks
1012
from api.utils.context import global_context, request_context
11-
from api.utils.dependencies import get_elasticsearch_client, get_elasticsearch_vector_store, get_postgres_session
12-
from api.utils.exceptions import ChunkNotFoundException
13+
from api.utils.dependencies import (
14+
get_elasticsearch_client,
15+
get_elasticsearch_vector_store,
16+
get_model_registry,
17+
get_postgres_session,
18+
get_redis_client,
19+
)
20+
from api.utils.exceptions import ChunkNotFoundException, CollectionNotFoundException
1321
from api.utils.variables import ENDPOINT__CHUNKS, ROUTER__CHUNKS
1422

1523
router = APIRouter(prefix="/v1", tags=[ROUTER__CHUNKS.title()])
1624

1725

26+
@router.post(path=ENDPOINT__CHUNKS, dependencies=[Security(dependency=AccessController())], status_code=201)
27+
async def create_chunks(
28+
request: Request,
29+
body: CreateChunks,
30+
postgres_session: AsyncSession = Depends(get_postgres_session),
31+
redis_client: AsyncRedis = Depends(get_redis_client),
32+
model_registry: ModelRegistry = Depends(get_model_registry),
33+
) -> ChunksResponse:
34+
"""
35+
Fill document with chunks.
36+
"""
37+
if not global_context.document_manager: # no vector store available
38+
raise CollectionNotFoundException()
39+
40+
chunk_ids = await global_context.document_manager.create_chunks(
41+
postgres_session=postgres_session,
42+
document_id=body.document,
43+
chunks=body.chunks,
44+
user_id=request_context.get().user_info.id,
45+
redis_client=redis_client,
46+
model_registry=model_registry,
47+
request_context=request_context,
48+
)
49+
50+
return ChunksResponse(ids=chunk_ids, status_code=201)
51+
52+
1853
@router.get(path=ENDPOINT__CHUNKS + "/{document:path}/{chunk:path}", dependencies=[Security(dependency=AccessController())], status_code=200)
1954
async def get_chunk(
2055
request: Request,

0 commit comments

Comments
 (0)