1616from asreview .models .balancers import Balanced
1717from asreview .models .queriers import Max
1818
19+ from balancers import DynamicLossBalancer
1920from classifiers import classifier_params , classifiers
2021from feature_extractors import feature_extractor_params , feature_extractors
2122
2223# Study variables
2324VERSION = 1
2425METRIC = "loss" # Options: "loss", "ndcg"
25- STUDY_SET = "full "
26+ STUDY_SET = "demo "
2627CLASSIFIER_TYPE = "svm" # Options: "nb", "log", "svm", "rf"
27- FEATURE_EXTRACTOR_TYPE = "e5 " # Options: "tfidf", "onehot", "labse", "bge-m3", "stella", "mxbai", "gist", "e5", "gte", "kalm", "lajavaness", "snowflake"
28+ FEATURE_EXTRACTOR_TYPE = "tfidf " # Options: "tfidf", "onehot", "labse", "bge-m3", "stella", "mxbai", "gist", "e5", "gte", "kalm", "lajavaness", "snowflake"
2829PICKLE_FOLDER_PATH = Path ("synergy-dataset" , f"pickles_{ FEATURE_EXTRACTOR_TYPE } " )
29- PRE_PROCESSED_FMS = True # False = on the fly
30+ PRE_PROCESSED_FMS = False # False = on the fly
3031PARALLELIZE_OBJECTIVE = True
31- AUTO_SHUTDOWN = False
32+ AUTO_SHUTDOWN = True
3233
3334# Optuna variables
3435OPTUNA_N_TRIALS = 500
@@ -111,11 +112,12 @@ def run_sequential(studies, *args, **kwargs):
111112
112113
113114# Function to process each row
114- def process_row (row , clf_params , fe_params , ratio ):
115+ def process_row (row , clf_params , fe_params , ratio , a , activation , window_size ):
115116 priors = row ["prior_inclusions" ] + row ["prior_exclusions" ]
116117
117118 # Create balancer with optuna value
118- blc = Balanced (ratio = ratio )
119+ #blc = DynamicLossBalancer(ratio=ratio)
120+ blc = DynamicLossBalancer (ratio = ratio , a = a , activation = activation , window_size = window_size )
119121
120122 # Create classifier and feature extractor with params
121123 clf = classifiers [CLASSIFIER_TYPE ](** clf_params )
@@ -171,6 +173,9 @@ def objective_report(report_order):
171173 def objective (trial ):
172174 # Use normal distribution for ratio (ratio effect is linear)
173175 ratio = trial .suggest_float ("ratio" , 1.0 , 10.0 )
176+ a = trial .suggest_float ("a" , 1.0 , 10.0 )
177+ activation = trial .suggest_categorical ("activation" , ["linear" , "sigmoid" , "tanh" ])
178+ window_size = trial .suggest_int ("window_size" , 10 , 100 )
174179
175180 clf_params = classifier_params [CLASSIFIER_TYPE ](trial )
176181 fe_params = (
@@ -181,11 +186,11 @@ def objective(trial):
181186
182187 if PARALLELIZE_OBJECTIVE :
183188 metric_values = run_parallel (
184- studies , clf_params = clf_params , fe_params = fe_params , ratio = ratio
189+ studies , clf_params = clf_params , fe_params = fe_params , ratio = ratio , a = a , activation = activation , window_size = window_size
185190 )
186191 else :
187192 metric_values = run_sequential (
188- studies , clf_params = clf_params , fe_params = fe_params , ratio = ratio
193+ studies , clf_params = clf_params , fe_params = fe_params , ratio = ratio , a = a , activation = activation , window_size = window_size
189194 )
190195
191196 all_metric_values = []
@@ -268,7 +273,7 @@ def download_pickles(report_order):
268273 storage = os .getenv (
269274 "DB_URI" , "sqlite:///db.sqlite3"
270275 ), # Specify the storage URL here.
271- study_name = f"ASReview2_0b4 -{ CLASSIFIER_TYPE } -{ FEATURE_EXTRACTOR_TYPE } -{ STUDY_SET } -{ VERSION } " ,
276+ study_name = f"ASReview2_1_1_1 -{ CLASSIFIER_TYPE } -{ FEATURE_EXTRACTOR_TYPE } -{ STUDY_SET } -{ VERSION } " ,
272277 direction = "minimize" if METRIC == "loss" else "maximize" ,
273278 sampler = sampler ,
274279 load_if_exists = True ,
0 commit comments