@@ -52,7 +52,6 @@ def __init__(self, **kwargs):
5252 self .gpu_lock = threading .Lock ()
5353 self .rag = None
5454 self .get_model_response = self .get_model_response_qianfan
55- pass
5655
5756 def get_model_response_deepseek (self , prompt ):
5857 # Please install OpenAI SDK first: `pip3 install openai`
@@ -156,9 +155,14 @@ def get_access_token():
156155
157156 def preprocess (self , ** kwargs ):
158157 print ("BaseModel preprocess" )
159- # input('stop here preprocess')
158+ # Load the RAG model and vector store once during setup.
159+ # Previously, GovernmentRAG was instantiated inside process_query()
160+ # for every single query, which reloaded the embedding model and
161+ # ChromaDB each time. That made the ThreadPoolExecutor useless
162+ # since all threads were serialized by the gpu_lock while waiting
163+ # for the slow model loading to finish.
160164 self .rag = GovernmentRAG (model_name = "/home/icyfeather/models/bge-m3" , device = "cuda" , persist_directory = "./chroma_db" )
161- LOGGER .info ("RAG initialized" )
165+ LOGGER .info ("RAG initialized once for all queries " )
162166
163167 def train (self , train_data , valid_data = None , ** kwargs ):
164168 print ("BaseModel doesn't need to train" )
@@ -168,42 +172,49 @@ def save(self, model_path):
168172 print ("BaseModel doesn't need to save" )
169173
170174 def process_query (self , query : str , ground_truth : str , location : str , rag_type : str ) -> str :
171- """Process a single query with the specified RAG type."""
175+ """Process a single query with the specified RAG type.
176+
177+ The RAG instance (self.rag) is initialized once in preprocess() and
178+ reused here. The gpu_lock only protects the brief vector similarity
179+ search, not any model loading, so threads can actually run in parallel
180+ for the LLM API calls.
181+ """
172182 try :
173183 if rag_type == "[model]" :
184+ # No RAG needed, just ask the LLM directly
174185 response = self .get_model_response (query )
175186 else :
187+ # Run the embedding-based retrieval under the GPU lock.
188+ # This is a quick vector search, not a full model reload.
176189 with self .gpu_lock :
177- if rag_type == "[global]" :
178- if self .rag is None :
179- self .rag = GovernmentRAG (model_name = "/home/icyfeather/models/bge-m3" , device = "cuda" , persist_directory = "./chroma_db" )
180- elif rag_type == "[local]" :
181- self .rag = GovernmentRAG (model_name = "/home/icyfeather/models/bge-m3" , device = "cuda" , persist_directory = "./chroma_db" , provinces = [location ])
182- else : # [other]
183- all_locations = set (self .all_locations )
184- self .rag = GovernmentRAG (model_name = "/home/icyfeather/models/bge-m3" , device = "cuda" , persist_directory = "./chroma_db" , provinces = list (all_locations - set ([location ])))
185-
186190 relevant_docs = self .rag .query (query , k = 1 )
187-
188- # Clear GPU cache after query
189- if torch .cuda .is_available ():
190- torch .cuda .empty_cache ()
191-
192- response = self .get_model_response ("在你回答问题之前,你被提供了以下可能相关的信息:" + relevant_docs + "\n 现在请你回答问题:" + query )
191+
192+ # Build the augmented prompt with retrieved context
193+ augmented_prompt = (
194+ "在你回答问题之前,你被提供了以下可能相关的信息:"
195+ + relevant_docs
196+ + "\n 现在请你回答问题:"
197+ + query
198+ )
199+ response = self .get_model_response (augmented_prompt )
193200
194201 return response + "||" + ground_truth + "||" + location + "||" + rag_type
195202 except Exception as e :
196203 LOGGER .error (f"Error in process_query: { str (e )} " )
197- # Clear GPU cache in case of error
198- if torch .cuda .is_available ():
199- torch .cuda .empty_cache ()
200204 raise e
201205
202206 def predict (self , data , input_shape = None , ** kwargs ):
203207 print ("BaseModel predict" )
204208 LOGGER .info ("BaseModel predict" )
205- LOGGER .info (f"Dataset: { data .dataset_name } " )
206- LOGGER .info (f"Description: { data .description } " )
209+
210+ # Make sure the RAG system is ready before processing queries
211+ if self .rag is None :
212+ LOGGER .info ("RAG not initialized yet, loading now..." )
213+ self .rag = GovernmentRAG (
214+ model_name = "/home/icyfeather/models/bge-m3" ,
215+ device = "cuda" ,
216+ persist_directory = "./chroma_db"
217+ )
207218
208219 answer_list = []
209220
@@ -213,21 +224,19 @@ def predict(self, data, input_shape=None, **kwargs):
213224 # Create tasks for all queries
214225 tasks = []
215226 for i in range (len (data .x )):
216- # Add global task
217227 tasks .append ((data .x [i ], data .y [i ], current_dir , "[global]" ))
218- # Add local task
219228 tasks .append ((data .x [i ], data .y [i ], current_dir , "[local]" ))
220- # Add other task
221229 tasks .append ((data .x [i ], data .y [i ], current_dir , "[other]" ))
222- # Add model task
223230 tasks .append ((data .x [i ], data .y [i ], current_dir , "[model]" ))
224231
225- # Process tasks in parallel using ThreadPoolExecutor
226- with concurrent .futures .ThreadPoolExecutor (max_workers = 4 ) as executor : # Reduced number of workers
232+ # Process tasks in parallel using ThreadPoolExecutor.
233+ # Now that GovernmentRAG is loaded once and shared, the threads
234+ # only block briefly on the gpu_lock for vector search. The slow
235+ # LLM API calls happen outside the lock and truly run in parallel.
236+ with concurrent .futures .ThreadPoolExecutor (max_workers = 4 ) as executor :
227237 futures = [executor .submit (self .process_query , query , gt , loc , rag_type )
228238 for query , gt , loc , rag_type in tasks ]
229239
230- # Use tqdm to show progress
231240 for future in tqdm (concurrent .futures .as_completed (futures ), total = len (futures ), desc = "Processing queries" ):
232241 try :
233242 result = future .result ()
0 commit comments