Skip to content

Commit f04e3a9

Browse files
fix: eliminate redundant RAG model reloading that defeated parallelization
Signed-off-by: Aryan Patel <aryan.patel7291@gmail.com>
1 parent 0a1acb9 commit f04e3a9

File tree

1 file changed

+40
-31
lines changed
  • examples/government_rag/singletask_learning_bench/testalgorithms

1 file changed

+40
-31
lines changed

examples/government_rag/singletask_learning_bench/testalgorithms/basemodel.py

Lines changed: 40 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ def __init__(self, **kwargs):
5252
self.gpu_lock = threading.Lock()
5353
self.rag = None
5454
self.get_model_response = self.get_model_response_qianfan
55-
pass
5655

5756
def get_model_response_deepseek(self, prompt):
5857
# Please install OpenAI SDK first: `pip3 install openai`
@@ -156,9 +155,14 @@ def get_access_token():
156155

157156
def preprocess(self, **kwargs):
158157
print("BaseModel preprocess")
159-
# input('stop here preprocess')
158+
# Load the RAG model and vector store once during setup.
159+
# Previously, GovernmentRAG was instantiated inside process_query()
160+
# for every single query, which reloaded the embedding model and
161+
# ChromaDB each time. That made the ThreadPoolExecutor useless
162+
# since all threads were serialized by the gpu_lock while waiting
163+
# for the slow model loading to finish.
160164
self.rag = GovernmentRAG(model_name="/home/icyfeather/models/bge-m3", device="cuda", persist_directory="./chroma_db")
161-
LOGGER.info("RAG initialized")
165+
LOGGER.info("RAG initialized once for all queries")
162166

163167
def train(self, train_data, valid_data=None, **kwargs):
164168
print("BaseModel doesn't need to train")
@@ -168,42 +172,49 @@ def save(self, model_path):
168172
print("BaseModel doesn't need to save")
169173

170174
def process_query(self, query: str, ground_truth: str, location: str, rag_type: str) -> str:
171-
"""Process a single query with the specified RAG type."""
175+
"""Process a single query with the specified RAG type.
176+
177+
The RAG instance (self.rag) is initialized once in preprocess() and
178+
reused here. The gpu_lock only protects the brief vector similarity
179+
search, not any model loading, so threads can actually run in parallel
180+
for the LLM API calls.
181+
"""
172182
try:
173183
if rag_type == "[model]":
184+
# No RAG needed, just ask the LLM directly
174185
response = self.get_model_response(query)
175186
else:
187+
# Run the embedding-based retrieval under the GPU lock.
188+
# This is a quick vector search, not a full model reload.
176189
with self.gpu_lock:
177-
if rag_type == "[global]":
178-
if self.rag is None:
179-
self.rag = GovernmentRAG(model_name="/home/icyfeather/models/bge-m3", device="cuda", persist_directory="./chroma_db")
180-
elif rag_type == "[local]":
181-
self.rag = GovernmentRAG(model_name="/home/icyfeather/models/bge-m3", device="cuda", persist_directory="./chroma_db", provinces=[location])
182-
else: # [other]
183-
all_locations = set(self.all_locations)
184-
self.rag = GovernmentRAG(model_name="/home/icyfeather/models/bge-m3", device="cuda", persist_directory="./chroma_db", provinces=list(all_locations - set([location])))
185-
186190
relevant_docs = self.rag.query(query, k=1)
187-
188-
# Clear GPU cache after query
189-
if torch.cuda.is_available():
190-
torch.cuda.empty_cache()
191-
192-
response = self.get_model_response("在你回答问题之前,你被提供了以下可能相关的信息:" + relevant_docs + "\n现在请你回答问题:" + query)
191+
192+
# Build the augmented prompt with retrieved context
193+
augmented_prompt = (
194+
"在你回答问题之前,你被提供了以下可能相关的信息:"
195+
+ relevant_docs
196+
+ "\n现在请你回答问题:"
197+
+ query
198+
)
199+
response = self.get_model_response(augmented_prompt)
193200

194201
return response + "||" + ground_truth + "||" + location + "||" + rag_type
195202
except Exception as e:
196203
LOGGER.error(f"Error in process_query: {str(e)}")
197-
# Clear GPU cache in case of error
198-
if torch.cuda.is_available():
199-
torch.cuda.empty_cache()
200204
raise e
201205

202206
def predict(self, data, input_shape=None, **kwargs):
203207
print("BaseModel predict")
204208
LOGGER.info("BaseModel predict")
205-
LOGGER.info(f"Dataset: {data.dataset_name}")
206-
LOGGER.info(f"Description: {data.description}")
209+
210+
# Make sure the RAG system is ready before processing queries
211+
if self.rag is None:
212+
LOGGER.info("RAG not initialized yet, loading now...")
213+
self.rag = GovernmentRAG(
214+
model_name="/home/icyfeather/models/bge-m3",
215+
device="cuda",
216+
persist_directory="./chroma_db"
217+
)
207218

208219
answer_list = []
209220

@@ -213,21 +224,19 @@ def predict(self, data, input_shape=None, **kwargs):
213224
# Create tasks for all queries
214225
tasks = []
215226
for i in range(len(data.x)):
216-
# Add global task
217227
tasks.append((data.x[i], data.y[i], current_dir, "[global]"))
218-
# Add local task
219228
tasks.append((data.x[i], data.y[i], current_dir, "[local]"))
220-
# Add other task
221229
tasks.append((data.x[i], data.y[i], current_dir, "[other]"))
222-
# Add model task
223230
tasks.append((data.x[i], data.y[i], current_dir, "[model]"))
224231

225-
# Process tasks in parallel using ThreadPoolExecutor
226-
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: # Reduced number of workers
232+
# Process tasks in parallel using ThreadPoolExecutor.
233+
# Now that GovernmentRAG is loaded once and shared, the threads
234+
# only block briefly on the gpu_lock for vector search. The slow
235+
# LLM API calls happen outside the lock and truly run in parallel.
236+
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
227237
futures = [executor.submit(self.process_query, query, gt, loc, rag_type)
228238
for query, gt, loc, rag_type in tasks]
229239

230-
# Use tqdm to show progress
231240
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Processing queries"):
232241
try:
233242
result = future.result()

0 commit comments

Comments
 (0)