Skip to content

Commit 267784c

Browse files
committed
Set up new balancer test
1 parent 6d308e9 commit 267784c

3 files changed

Lines changed: 74 additions & 11 deletions

File tree

src/balancers.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import numpy as np
2+
3+
from sklearn.base import BaseEstimator
4+
from sklearn.utils.class_weight import compute_sample_weight as _compute_sample_weight
5+
6+
class DynamicLossBalancer(BaseEstimator):
7+
"""Dynamic Balancer
8+
9+
Parameters
10+
----------
11+
ratio : float
12+
Maximum weighting factor applied to minority class.
13+
window_size : int
14+
Number of recent samples to compute pseudo-derivative.
15+
activation : str
16+
Activation function to shape the derivative response. One of: 'linear', 'sigmoid', 'tanh'.
17+
initial_weights : dict
18+
Initial static weights per class, e.g., {0: 1.0, 1: 1.0}.
19+
"""
20+
21+
name = "dynamic balancer"
22+
label = "Dynamic Balanced Sample Weight"
23+
24+
def __init__(self, ratio=1.0, window_size=10, activation='linear', a=1.0):
25+
self.ratio = ratio
26+
self.window_size = window_size
27+
self.activation = activation
28+
self.a = a
29+
30+
def _activation_fn(self, x):
31+
if self.activation == 'linear':
32+
return x
33+
elif self.activation == 'sigmoid':
34+
return 1 / (1 + np.exp(-x))
35+
elif self.activation == 'tanh':
36+
return np.tanh(x)
37+
else:
38+
raise ValueError(f"Unsupported activation: {self.activation}")
39+
40+
def compute_sample_weight(self, y):
41+
if len(set(y)) != 2:
42+
raise ValueError("Only binary classification is supported.")
43+
44+
slope = (np.sum(y[-min(self.window_size, len(y)):])) / min(self.window_size, len(y))
45+
46+
base_class_0_weight = sum(y == 1) / (self.ratio * sum(y == 0))
47+
act_base_class_0_weight = self._activation_fn(base_class_0_weight)
48+
49+
# Compute dynamic weight: interpolates from 1.0 to ratio
50+
class_1_weight = 1.0
51+
class_0_weight = self.ratio * (self.a + (slope * act_base_class_0_weight))
52+
53+
weights = _compute_sample_weight(
54+
class_weight={0: class_0_weight, 1: class_1_weight}, y=y
55+
)
56+
return weights * (len(y) / np.sum(weights)) # normalize sum of weights to sample count

src/feature_extractors.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44

55
def tfidf_params(trial: optuna.trial.FrozenTrial):
66
max_df = trial.suggest_float("tfidf__max_df", 0.5, 1.0)
7+
min_df = trial.suggest_int("tfidf__min_df", 1, 10)
8+
ngram_range = trial.suggest_int("tfidf__ngram_range", 1, 3)
79

810
return {
911
"max_df": max_df,
10-
"min_df": 1,
11-
"ngram_range": (1, 2),
12+
"min_df": min_df,
13+
"ngram_range": (1, ngram_range),
1214
"sublinear_tf": True,
1315
}
1416

src/main.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,20 @@
1616
from asreview.models.balancers import Balanced
1717
from asreview.models.queriers import Max
1818

19+
from balancers import DynamicLossBalancer
1920
from classifiers import classifier_params, classifiers
2021
from feature_extractors import feature_extractor_params, feature_extractors
2122

2223
# Study variables
2324
VERSION = 1
2425
METRIC = "loss" # Options: "loss", "ndcg"
25-
STUDY_SET = "full"
26+
STUDY_SET = "demo"
2627
CLASSIFIER_TYPE = "svm" # Options: "nb", "log", "svm", "rf"
27-
FEATURE_EXTRACTOR_TYPE = "e5" # Options: "tfidf", "onehot", "labse", "bge-m3", "stella", "mxbai", "gist", "e5", "gte", "kalm", "lajavaness", "snowflake"
28+
FEATURE_EXTRACTOR_TYPE = "tfidf" # Options: "tfidf", "onehot", "labse", "bge-m3", "stella", "mxbai", "gist", "e5", "gte", "kalm", "lajavaness", "snowflake"
2829
PICKLE_FOLDER_PATH = Path("synergy-dataset", f"pickles_{FEATURE_EXTRACTOR_TYPE}")
29-
PRE_PROCESSED_FMS = True # False = on the fly
30+
PRE_PROCESSED_FMS = False # False = on the fly
3031
PARALLELIZE_OBJECTIVE = True
31-
AUTO_SHUTDOWN = False
32+
AUTO_SHUTDOWN = True
3233

3334
# Optuna variables
3435
OPTUNA_N_TRIALS = 500
@@ -111,11 +112,12 @@ def run_sequential(studies, *args, **kwargs):
111112

112113

113114
# Function to process each row
114-
def process_row(row, clf_params, fe_params, ratio):
115+
def process_row(row, clf_params, fe_params, ratio, a, activation, window_size):
115116
priors = row["prior_inclusions"] + row["prior_exclusions"]
116117

117118
# Create balancer with optuna value
118-
blc = Balanced(ratio=ratio)
119+
#blc = DynamicLossBalancer(ratio=ratio)
120+
blc = DynamicLossBalancer(ratio=ratio, a=a, activation=activation, window_size=window_size)
119121

120122
# Create classifier and feature extractor with params
121123
clf = classifiers[CLASSIFIER_TYPE](**clf_params)
@@ -171,6 +173,9 @@ def objective_report(report_order):
171173
def objective(trial):
172174
# Use normal distribution for ratio (ratio effect is linear)
173175
ratio = trial.suggest_float("ratio", 1.0, 10.0)
176+
a = trial.suggest_float("a", 1.0, 10.0)
177+
activation = trial.suggest_categorical("activation", ["linear", "sigmoid", "tanh"])
178+
window_size = trial.suggest_int("window_size", 10, 100)
174179

175180
clf_params = classifier_params[CLASSIFIER_TYPE](trial)
176181
fe_params = (
@@ -181,11 +186,11 @@ def objective(trial):
181186

182187
if PARALLELIZE_OBJECTIVE:
183188
metric_values = run_parallel(
184-
studies, clf_params=clf_params, fe_params=fe_params, ratio=ratio
189+
studies, clf_params=clf_params, fe_params=fe_params, ratio=ratio, a=a, activation=activation, window_size=window_size
185190
)
186191
else:
187192
metric_values = run_sequential(
188-
studies, clf_params=clf_params, fe_params=fe_params, ratio=ratio
193+
studies, clf_params=clf_params, fe_params=fe_params, ratio=ratio, a=a, activation=activation, window_size=window_size
189194
)
190195

191196
all_metric_values = []
@@ -268,7 +273,7 @@ def download_pickles(report_order):
268273
storage=os.getenv(
269274
"DB_URI", "sqlite:///db.sqlite3"
270275
), # Specify the storage URL here.
271-
study_name=f"ASReview2_0b4-{CLASSIFIER_TYPE}-{FEATURE_EXTRACTOR_TYPE}-{STUDY_SET}-{VERSION}",
276+
study_name=f"ASReview2_1_1_1-{CLASSIFIER_TYPE}-{FEATURE_EXTRACTOR_TYPE}-{STUDY_SET}-{VERSION}",
272277
direction="minimize" if METRIC == "loss" else "maximize",
273278
sampler=sampler,
274279
load_if_exists=True,

0 commit comments

Comments
 (0)