Skip to content

Commit b2878e5

Browse files
author
Dorin POMIAN
committed
fix: Insert documents in batches for Azure AI Search vector store
1 parent 181c42c commit b2878e5

2 files changed

Lines changed: 64 additions & 16 deletions

File tree

src/docs2vecs/subcommands/indexer/skills/ada002_embedding_skill.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
from typing import List
2-
from typing import Optional
1+
from typing import List, Optional
32

43
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
54

@@ -22,13 +21,25 @@ def az_ada002_embeddings(self, content: str):
2221
return embed_model.get_query_embedding(content)
2322

2423
def run(self, input: Optional[List[Document]] = None) -> Optional[List[Document]]:
25-
self.logger.info("Running AzureAda002EmbeddingSkill...")
26-
self.logger.info(f"Number of documents: {len(input)}")
24+
self.logger.info(
25+
f"Running Azure Embedding Skill with deployment name: {self._config['deployment_name']}..."
26+
)
27+
28+
docs_count = len(input)
29+
chunks_count = sum(len(doc.chunks) for doc in input)
30+
31+
self.logger.info(
32+
f"Processing a total of documents: {docs_count}. Total number of chunks: {chunks_count}"
33+
)
2734

2835
for doc in input:
2936
self.logger.debug(f"Processing document: {doc.filename}")
3037
for chunk in doc.chunks:
3138
self.logger.debug(f"Creating embedding for chunk: {chunk.chunk_id}")
32-
chunk.embedding = "" if not chunk.content else self.az_ada002_embeddings(chunk.content)
39+
chunk.embedding = (
40+
""
41+
if not chunk.content
42+
else self.az_ada002_embeddings(chunk.content)
43+
)
3344

3445
return input

src/docs2vecs/subcommands/indexer/skills/azure_vector_store_skill.py

Lines changed: 48 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
from typing import List
2-
from typing import Optional
1+
from typing import List, Optional
32

43
from azure.core.credentials import AzureKeyCredential
54
from azure.identity import DefaultAzureCredential
@@ -15,18 +14,34 @@
1514

1615

1716
class AzureVectorStoreSkill(IndexerSkill):
18-
def __init__(self, config: dict, global_config: Config, vector_store_tracker: VectorStoreTracker = None):
17+
def __init__(
18+
self,
19+
config: dict,
20+
global_config: Config,
21+
vector_store_tracker: VectorStoreTracker = None,
22+
):
1923
super().__init__(config, global_config)
2024
self._vector_store_tracker = vector_store_tracker
2125
self._overwrite_index = self._config.get("overwrite_index", False)
2226

23-
az_credential = AzureKeyCredential(self._config.get("api_key", "")) if self._config.get("api_key", "") else DefaultAzureCredential()
27+
az_credential = (
28+
AzureKeyCredential(self._config.get("api_key", ""))
29+
if self._config.get("api_key", "")
30+
else DefaultAzureCredential()
31+
)
2432
self._search_client = SearchClient(
2533
endpoint=self._config.get("endpoint"),
2634
index_name=self._config.get("index_name"),
2735
credential=az_credential,
2836
)
29-
self._index_client = SearchIndexClient(endpoint=self._config.get("endpoint"), credential=az_credential)
37+
self._index_client = SearchIndexClient(
38+
endpoint=self._config.get("endpoint"), credential=az_credential
39+
)
40+
41+
max_batch_size = 50
42+
self._config["batch_size"] = min(
43+
max(1, self._config.get("batch_size", max_batch_size)), max_batch_size
44+
)
3045

3146
def _upload_embeddings(self, chunks: List[Chunk]):
3247
field_mapping = self._config.get("field_mapping", {})
@@ -38,21 +53,38 @@ def _upload_embeddings(self, chunks: List[Chunk]):
3853
results = []
3954
if chunks:
4055
az_ai_search_documents = [
41-
{field_mapping[key]: getattr(chunk, key) for key in field_mapping if hasattr(chunk, key)} for chunk in chunks
56+
{
57+
field_mapping[key]: getattr(chunk, key)
58+
for key in field_mapping
59+
if hasattr(chunk, key)
60+
}
61+
for chunk in chunks
4262
]
4363

44-
results = self._search_client.upload_documents(documents=az_ai_search_documents)
64+
start_idx = 0
65+
batch_size = self._config.get("batch_size")
66+
67+
while start_idx < len(az_ai_search_documents):
68+
batch = az_ai_search_documents[start_idx : start_idx + batch_size]
69+
results.extend(self._search_client.upload_documents(documents=batch))
70+
start_idx += batch_size
4571

4672
return results
4773

4874
def _update_tracker(self, chunks: List[Chunk], results: List[IndexingResult]):
4975
if self._vector_store_tracker:
5076
self._vector_store_tracker.update_documents(chunks, results)
5177

52-
def _log_upload_results(self, chunk_id_list: List[str], results: List[IndexingResult]):
78+
def _log_upload_results(
79+
self, chunk_id_list: List[str], results: List[IndexingResult]
80+
):
5381
if self.logger:
5482
res = [
55-
{"chunk_id": chunk_id, "succeeded": result.succeeded, "status_code": result.status_code}
83+
{
84+
"chunk_id": chunk_id,
85+
"succeeded": result.succeeded,
86+
"status_code": result.status_code,
87+
}
5688
for chunk_id, result in zip(chunk_id_list, results)
5789
]
5890
self.logger.debug(f"Azure AI Search upload results: {res}")
@@ -65,7 +97,9 @@ def _cleanup_index(self):
6597

6698
# First search for all documents
6799
results = self._search_client.search(
68-
search_text="*", select=[key_field], include_total_count=True # Only get the key field as that's all we need for deletion
100+
search_text="*",
101+
select=[key_field],
102+
include_total_count=True, # Only get the key field as that's all we need for deletion
69103
)
70104

71105
# Get all document IDs using the correct key field
@@ -84,7 +118,10 @@ def run(self, input: Optional[List[Document]] = None) -> List[Document]:
84118
chunks = {}
85119

86120
if self._vector_store_tracker:
87-
chunks = {chunk.document_id: chunk for chunk in self._vector_store_tracker.retrieve_failed_documents()}
121+
chunks = {
122+
chunk.document_id: chunk
123+
for chunk in self._vector_store_tracker.retrieve_failed_documents()
124+
}
88125

89126
self.logger.debug(f"Going to process {len(input)} documents")
90127
for doc in input:

0 commit comments

Comments
 (0)