Skip to content

Commit 5d8577f

Browse files
fix(review): address PR feedback regarding RAG concurrency and config defaults
Signed-off-by: Aryan Patel <aryan.patel7291@gmail.com>
1 parent f870166 commit 5d8577f

File tree

1 file changed

+9
-7
lines changed
  • examples/government_rag/singletask_learning_bench/testalgorithms

1 file changed

+9
-7
lines changed

examples/government_rag/singletask_learning_bench/testalgorithms/basemodel.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def get_access_token():
157157
def preprocess(self, **kwargs):
158158
print("BaseModel preprocess")
159159
# input('stop here preprocess')
160-
self.rag = GovernmentRAG(base_path=Context.get_parameters("base_path", "/path/ianvs/dataset/gov_rag"), model_name="/home/icyfeather/models/bge-m3", device="cuda", persist_directory="./chroma_db")
160+
self.rag = GovernmentRAG(base_path=Context.get_parameters("base_path"), model_name=Context.get_parameters("model_name", "BAAI/bge-large-zh-v1.5"), device="cuda", persist_directory="./chroma_db")
161161
LOGGER.info("RAG initialized")
162162

163163
def train(self, train_data, valid_data=None, **kwargs):
@@ -174,16 +174,18 @@ def process_query(self, query: str, ground_truth: str, location: str, rag_type:
174174
response = self.get_model_response(query)
175175
else:
176176
with self.gpu_lock:
177+
base_path = Context.get_parameters("base_path")
178+
model_name = Context.get_parameters("model_name", "BAAI/bge-large-zh-v1.5")
179+
177180
if rag_type == "[global]":
178-
if self.rag is None:
179-
self.rag = GovernmentRAG(base_path=Context.get_parameters("base_path", "/path/ianvs/dataset/gov_rag"), model_name="/home/icyfeather/models/bge-m3", device="cuda", persist_directory="./chroma_db")
181+
rag = GovernmentRAG(base_path=base_path, model_name=model_name, device="cuda", persist_directory="./chroma_db")
180182
elif rag_type == "[local]":
181-
self.rag = GovernmentRAG(base_path=Context.get_parameters("base_path", "/path/ianvs/dataset/gov_rag"), model_name="/home/icyfeather/models/bge-m3", device="cuda", persist_directory="./chroma_db", provinces=[location])
183+
rag = GovernmentRAG(base_path=base_path, model_name=model_name, device="cuda", persist_directory="./chroma_db", provinces=[location])
182184
else: # [other]
183-
all_locations = set(self.all_locations)
184-
self.rag = GovernmentRAG(base_path=Context.get_parameters("base_path", "/path/ianvs/dataset/gov_rag"), model_name="/home/icyfeather/models/bge-m3", device="cuda", persist_directory="./chroma_db", provinces=list(all_locations - set([location])))
185+
all_locations = set(getattr(self, "all_locations", []))
186+
rag = GovernmentRAG(base_path=base_path, model_name=model_name, device="cuda", persist_directory="./chroma_db", provinces=list(all_locations - set([location])))
185187

186-
relevant_docs = self.rag.query(query, k=1)
188+
relevant_docs = rag.query(query, k=1)
187189

188190
# Clear GPU cache after query
189191
if torch.cuda.is_available():

0 commit comments

Comments
 (0)