Skip to content

Commit b274d95

Browse files
committed
fixed random forest
1 parent 23a648f commit b274d95

2 files changed

Lines changed: 44 additions & 35 deletions

File tree

assets/RandomForest.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
"balance_positive_and_negative_interactions_opt_set": true
1010
},
1111
"model_parameters": {
12-
"balance_method": ["none", "downsample", "sample_weight"],
12+
"balance_method": ["none", "downsample", "oversample"],
1313
"protein_sample_per_ddi_train_set": [1,2,5,10],
1414
"n_estimators": [100, 200, 500],
1515
"max_depth": [10, 20, 50],

bin/random_forest.py

Lines changed: 43 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -16,26 +16,48 @@
1616
)
1717
from pathlib import Path
1818
from sklearn.model_selection import RandomizedSearchCV, PredefinedSplit
19-
from sklearn.utils.class_weight import compute_sample_weight
2019

2120

2221
# Three ways to address the heavy positive-class imbalance in the training set.
23-
# "none" — train on raw data, no correction.
24-
# "downsample" — load_embedding_data(balance_classes=True): equal pos/neg by
25-
# resampling; preserves the original pipeline behaviour.
26-
# "sample_weight" — train on the full data and apply class-balanced weights;
27-
# no information loss, cuML RF accepts sample_weight in fit().
28-
BALANCE_METHODS = ("none", "downsample", "sample_weight")
22+
# "none" — train on raw data, no correction.
23+
# "downsample" — load_embedding_data(balance_classes=True): equal pos/neg by
24+
# resampling; preserves the original pipeline behaviour.
25+
# "oversample" — replicate minority-class rows to match majority size; keeps
26+
# all original data and is equivalent to balanced integer
27+
# sample weights. cuML RandomForestClassifier.fit() does not
28+
# accept a sample_weight kwarg, so we materialise the weights
29+
# as duplicated rows instead.
30+
BALANCE_METHODS = ("none", "downsample", "oversample")
31+
32+
33+
def _oversample_minority(x, y, seed):
34+
"""Replicate minority-class rows so class counts match. Returns (x, y)."""
35+
rng = np.random.default_rng(seed)
36+
y_arr = np.asarray(y).astype(np.int32)
37+
classes, counts = np.unique(y_arr, return_counts=True)
38+
if len(classes) < 2:
39+
return x, y_arr
40+
majority_count = counts.max()
41+
parts_x = [x]
42+
parts_y = [y_arr]
43+
for cls, cnt in zip(classes, counts):
44+
if cnt >= majority_count:
45+
continue
46+
idx = np.where(y_arr == cls)[0]
47+
need = majority_count - cnt
48+
pick = rng.choice(idx, size=need, replace=True)
49+
parts_x.append(x[pick])
50+
parts_y.append(y_arr[pick])
51+
x_out = np.concatenate(parts_x, axis=0)
52+
y_out = np.concatenate(parts_y, axis=0)
53+
perm = rng.permutation(len(y_out))
54+
return x_out[perm], y_out[perm]
2955

3056

3157
def _load_train_with_balance(
3258
args, balance_method: str, samples_per_ddi: int, seed: int
3359
):
34-
"""Load training arrays under one of the three balance strategies.
35-
36-
Returns (x_train, y_train, sample_weight). sample_weight is None except for
37-
balance_method == 'sample_weight'.
38-
"""
60+
"""Load training arrays under one of the three balance strategies."""
3961
if balance_method not in BALANCE_METHODS:
4062
raise ValueError(f"Unknown balance_method: {balance_method}")
4163
downsample = balance_method == "downsample"
@@ -48,12 +70,9 @@ def _load_train_with_balance(
4870
balance_classes=downsample,
4971
samples_per_ddi=samples_per_ddi,
5072
)
51-
sw = None
52-
if balance_method == "sample_weight":
53-
sw = compute_sample_weight("balanced", y_train.astype(np.int32)).astype(
54-
np.float32
55-
)
56-
return x_train, y_train, sw
73+
if balance_method == "oversample":
74+
x_train, y_train = _oversample_minority(x_train, y_train, seed)
75+
return x_train, y_train
5776

5877

5978
def main():
@@ -176,7 +195,7 @@ def train_model(args):
176195
f"[grid] balance_method={balance_method} "
177196
f"samples_per_ddi={protein_sample_per_ddi_train_set}"
178197
)
179-
x_train, y_train, sw_train = _load_train_with_balance(
198+
x_train, y_train = _load_train_with_balance(
180199
args, balance_method, protein_sample_per_ddi_train_set, args.seed
181200
)
182201

@@ -199,14 +218,7 @@ def train_model(args):
199218
verbose=2,
200219
scoring="average_precision",
201220
)
202-
# sample_weight is sliced per fold by sklearn; pass dummy 1.0 weights
203-
# for the opt rows so the array length matches x/y.
204-
fit_kwargs = {}
205-
if sw_train is not None:
206-
fit_kwargs["sample_weight"] = np.concatenate(
207-
[sw_train, np.ones(len(x_opt), dtype=np.float32)]
208-
)
209-
grid_search.fit(x, y, **fit_kwargs)
221+
grid_search.fit(x, y)
210222

211223
best_model_parameters_and_performance.append(
212224
(
@@ -218,7 +230,7 @@ def train_model(args):
218230
)
219231

220232
# B3: drop per-iter buffers before next outer iter
221-
del x, y, x_train, y_train, sw_train, classifier, grid_search
233+
del x, y, x_train, y_train, classifier, grid_search
222234
gc.collect()
223235

224236
best_model_parameters_and_performance.sort(key=lambda x: x[1], reverse=True)
@@ -238,7 +250,7 @@ def train_model(args):
238250
clear_load_cache()
239251
gc.collect()
240252

241-
x_train, y_train, sw_train = _load_train_with_balance(
253+
x_train, y_train = _load_train_with_balance(
242254
args, balance_method, protein_sample_per_ddi_train_set, args.seed
243255
)
244256
classifier = RandomForestClassifier(**params)
@@ -247,13 +259,10 @@ def train_model(args):
247259
y_train_i32 = y_train.astype(np.int32)
248260
del x_train, y_train
249261
gc.collect()
250-
if sw_train is not None:
251-
classifier.fit(x_train_f32, y_train_i32, sample_weight=sw_train)
252-
else:
253-
classifier.fit(x_train_f32, y_train_i32)
262+
classifier.fit(x_train_f32, y_train_i32)
254263

255264
# Free training buffers before allocating x_opt again.
256-
del x_train_f32, y_train_i32, sw_train
265+
del x_train_f32, y_train_i32
257266
gc.collect()
258267

259268
random.seed(args.seed)

0 commit comments

Comments
 (0)