1- from typing import List
2- from typing import Optional
1+ from typing import List , Optional
32
43from azure .core .credentials import AzureKeyCredential
54from azure .identity import DefaultAzureCredential
1514
1615
1716class AzureVectorStoreSkill (IndexerSkill ):
18- def __init__ (self , config : dict , global_config : Config , vector_store_tracker : VectorStoreTracker = None ):
17+ def __init__ (
18+ self ,
19+ config : dict ,
20+ global_config : Config ,
21+ vector_store_tracker : VectorStoreTracker = None ,
22+ ):
1923 super ().__init__ (config , global_config )
2024 self ._vector_store_tracker = vector_store_tracker
2125 self ._overwrite_index = self ._config .get ("overwrite_index" , False )
2226
23- az_credential = AzureKeyCredential (self ._config .get ("api_key" , "" )) if self ._config .get ("api_key" , "" ) else DefaultAzureCredential ()
27+ az_credential = (
28+ AzureKeyCredential (self ._config .get ("api_key" , "" ))
29+ if self ._config .get ("api_key" , "" )
30+ else DefaultAzureCredential ()
31+ )
2432 self ._search_client = SearchClient (
2533 endpoint = self ._config .get ("endpoint" ),
2634 index_name = self ._config .get ("index_name" ),
2735 credential = az_credential ,
2836 )
29- self ._index_client = SearchIndexClient (endpoint = self ._config .get ("endpoint" ), credential = az_credential )
37+ self ._index_client = SearchIndexClient (
38+ endpoint = self ._config .get ("endpoint" ), credential = az_credential
39+ )
40+
41+ max_batch_size = 50
42+ self ._config ["batch_size" ] = min (
43+ max (1 , self ._config .get ("batch_size" , max_batch_size )), max_batch_size
44+ )
3045
3146 def _upload_embeddings (self , chunks : List [Chunk ]):
3247 field_mapping = self ._config .get ("field_mapping" , {})
@@ -38,21 +53,38 @@ def _upload_embeddings(self, chunks: List[Chunk]):
3853 results = []
3954 if chunks :
4055 az_ai_search_documents = [
41- {field_mapping [key ]: getattr (chunk , key ) for key in field_mapping if hasattr (chunk , key )} for chunk in chunks
56+ {
57+ field_mapping [key ]: getattr (chunk , key )
58+ for key in field_mapping
59+ if hasattr (chunk , key )
60+ }
61+ for chunk in chunks
4262 ]
4363
44- results = self ._search_client .upload_documents (documents = az_ai_search_documents )
64+ start_idx = 0
65+ batch_size = self ._config .get ("batch_size" )
66+
67+ while start_idx < len (az_ai_search_documents ):
68+ batch = az_ai_search_documents [start_idx : start_idx + batch_size ]
69+ results .extend (self ._search_client .upload_documents (documents = batch ))
70+ start_idx += batch_size
4571
4672 return results
4773
4874 def _update_tracker (self , chunks : List [Chunk ], results : List [IndexingResult ]):
4975 if self ._vector_store_tracker :
5076 self ._vector_store_tracker .update_documents (chunks , results )
5177
52- def _log_upload_results (self , chunk_id_list : List [str ], results : List [IndexingResult ]):
78+ def _log_upload_results (
79+ self , chunk_id_list : List [str ], results : List [IndexingResult ]
80+ ):
5381 if self .logger :
5482 res = [
55- {"chunk_id" : chunk_id , "succeeded" : result .succeeded , "status_code" : result .status_code }
83+ {
84+ "chunk_id" : chunk_id ,
85+ "succeeded" : result .succeeded ,
86+ "status_code" : result .status_code ,
87+ }
5688 for chunk_id , result in zip (chunk_id_list , results )
5789 ]
5890 self .logger .debug (f"Azure AI Search upload results: { res } " )
@@ -65,7 +97,9 @@ def _cleanup_index(self):
6597
6698 # First search for all documents
6799 results = self ._search_client .search (
68- search_text = "*" , select = [key_field ], include_total_count = True # Only get the key field as that's all we need for deletion
100+ search_text = "*" ,
101+ select = [key_field ],
102+ include_total_count = True , # Only get the key field as that's all we need for deletion
69103 )
70104
71105 # Get all document IDs using the correct key field
@@ -84,7 +118,10 @@ def run(self, input: Optional[List[Document]] = None) -> List[Document]:
84118 chunks = {}
85119
86120 if self ._vector_store_tracker :
87- chunks = {chunk .document_id : chunk for chunk in self ._vector_store_tracker .retrieve_failed_documents ()}
121+ chunks = {
122+ chunk .document_id : chunk
123+ for chunk in self ._vector_store_tracker .retrieve_failed_documents ()
124+ }
88125
89126 self .logger .debug (f"Going to process { len (input )} documents" )
90127 for doc in input :
0 commit comments