@@ -157,7 +157,7 @@ def get_access_token():
157157 def preprocess (self , ** kwargs ):
158158 print ("BaseModel preprocess" )
159159 # input('stop here preprocess')
160- self .rag = GovernmentRAG (base_path = Context .get_parameters ("base_path" , "/path/ianvs/dataset/gov_rag" ), model_name = "/home/icyfeather/models/ bge-m3" , device = "cuda" , persist_directory = "./chroma_db" )
160+ self .rag = GovernmentRAG (base_path = Context .get_parameters ("base_path" ), model_name = Context . get_parameters ( "model_name" , "BAAI/ bge-large-zh-v1.5" ) , device = "cuda" , persist_directory = "./chroma_db" )
161161 LOGGER .info ("RAG initialized" )
162162
163163 def train (self , train_data , valid_data = None , ** kwargs ):
@@ -174,16 +174,18 @@ def process_query(self, query: str, ground_truth: str, location: str, rag_type:
174174 response = self .get_model_response (query )
175175 else :
176176 with self .gpu_lock :
177+ base_path = Context .get_parameters ("base_path" )
178+ model_name = Context .get_parameters ("model_name" , "BAAI/bge-large-zh-v1.5" )
179+
177180 if rag_type == "[global]" :
178- if self .rag is None :
179- self .rag = GovernmentRAG (base_path = Context .get_parameters ("base_path" , "/path/ianvs/dataset/gov_rag" ), model_name = "/home/icyfeather/models/bge-m3" , device = "cuda" , persist_directory = "./chroma_db" )
181+ rag = GovernmentRAG (base_path = base_path , model_name = model_name , device = "cuda" , persist_directory = "./chroma_db" )
180182 elif rag_type == "[local]" :
181- self . rag = GovernmentRAG (base_path = Context . get_parameters ( " base_path" , "/path/ianvs/dataset/gov_rag" ), model_name = "/home/icyfeather/models/bge-m3" , device = "cuda" , persist_directory = "./chroma_db" , provinces = [location ])
183+ rag = GovernmentRAG (base_path = base_path , model_name = model_name , device = "cuda" , persist_directory = "./chroma_db" , provinces = [location ])
182184 else : # [other]
183- all_locations = set (self . all_locations )
184- self . rag = GovernmentRAG (base_path = Context . get_parameters ( " base_path" , "/path/ianvs/dataset/gov_rag" ), model_name = "/home/icyfeather/models/bge-m3" , device = "cuda" , persist_directory = "./chroma_db" , provinces = list (all_locations - set ([location ])))
185+ all_locations = set (getattr ( self , " all_locations" , []) )
186+ rag = GovernmentRAG (base_path = base_path , model_name = model_name , device = "cuda" , persist_directory = "./chroma_db" , provinces = list (all_locations - set ([location ])))
185187
186- relevant_docs = self . rag .query (query , k = 1 )
188+ relevant_docs = rag .query (query , k = 1 )
187189
188190 # Clear GPU cache after query
189191 if torch .cuda .is_available ():
0 commit comments