1616)
1717from pathlib import Path
1818from sklearn .model_selection import RandomizedSearchCV , PredefinedSplit
19- from sklearn .utils .class_weight import compute_sample_weight
2019
2120
2221# Three ways to address the heavy positive-class imbalance in the training set.
23- # "none" — train on raw data, no correction.
24- # "downsample" — load_embedding_data(balance_classes=True): equal pos/neg by
25- # resampling; preserves the original pipeline behaviour.
26- # "sample_weight" — train on the full data and apply class-balanced weights;
27- # no information loss, cuML RF accepts sample_weight in fit().
28- BALANCE_METHODS = ("none" , "downsample" , "sample_weight" )
22+ # "none" — train on raw data, no correction.
23+ # "downsample" — load_embedding_data(balance_classes=True): equal pos/neg by
24+ # resampling; preserves the original pipeline behaviour.
25+ # "oversample" — replicate minority-class rows to match majority size; keeps
26+ # all original data and is equivalent to balanced integer
27+ # sample weights. cuML RandomForestClassifier.fit() does not
28+ # accept a sample_weight kwarg, so we materialise the weights
29+ # as duplicated rows instead.
30+ BALANCE_METHODS = ("none" , "downsample" , "oversample" )
31+
32+
33+ def _oversample_minority (x , y , seed ):
34+ """Replicate minority-class rows so class counts match. Returns (x, y)."""
35+ rng = np .random .default_rng (seed )
36+ y_arr = np .asarray (y ).astype (np .int32 )
37+ classes , counts = np .unique (y_arr , return_counts = True )
38+ if len (classes ) < 2 :
39+ return x , y_arr
40+ majority_count = counts .max ()
41+ parts_x = [x ]
42+ parts_y = [y_arr ]
43+ for cls , cnt in zip (classes , counts ):
44+ if cnt >= majority_count :
45+ continue
46+ idx = np .where (y_arr == cls )[0 ]
47+ need = majority_count - cnt
48+ pick = rng .choice (idx , size = need , replace = True )
49+ parts_x .append (x [pick ])
50+ parts_y .append (y_arr [pick ])
51+ x_out = np .concatenate (parts_x , axis = 0 )
52+ y_out = np .concatenate (parts_y , axis = 0 )
53+ perm = rng .permutation (len (y_out ))
54+ return x_out [perm ], y_out [perm ]
2955
3056
3157def _load_train_with_balance (
3258 args , balance_method : str , samples_per_ddi : int , seed : int
3359):
34- """Load training arrays under one of the three balance strategies.
35-
36- Returns (x_train, y_train, sample_weight). sample_weight is None except for
37- balance_method == 'sample_weight'.
38- """
60+ """Load training arrays under one of the three balance strategies."""
3961 if balance_method not in BALANCE_METHODS :
4062 raise ValueError (f"Unknown balance_method: { balance_method } " )
4163 downsample = balance_method == "downsample"
@@ -48,12 +70,9 @@ def _load_train_with_balance(
4870 balance_classes = downsample ,
4971 samples_per_ddi = samples_per_ddi ,
5072 )
51- sw = None
52- if balance_method == "sample_weight" :
53- sw = compute_sample_weight ("balanced" , y_train .astype (np .int32 )).astype (
54- np .float32
55- )
56- return x_train , y_train , sw
73+ if balance_method == "oversample" :
74+ x_train , y_train = _oversample_minority (x_train , y_train , seed )
75+ return x_train , y_train
5776
5877
5978def main ():
@@ -176,7 +195,7 @@ def train_model(args):
176195 f"[grid] balance_method={ balance_method } "
177196 f"samples_per_ddi={ protein_sample_per_ddi_train_set } "
178197 )
179- x_train , y_train , sw_train = _load_train_with_balance (
198+ x_train , y_train = _load_train_with_balance (
180199 args , balance_method , protein_sample_per_ddi_train_set , args .seed
181200 )
182201
@@ -199,14 +218,7 @@ def train_model(args):
199218 verbose = 2 ,
200219 scoring = "average_precision" ,
201220 )
202- # sample_weight is sliced per fold by sklearn; pass dummy 1.0 weights
203- # for the opt rows so the array length matches x/y.
204- fit_kwargs = {}
205- if sw_train is not None :
206- fit_kwargs ["sample_weight" ] = np .concatenate (
207- [sw_train , np .ones (len (x_opt ), dtype = np .float32 )]
208- )
209- grid_search .fit (x , y , ** fit_kwargs )
221+ grid_search .fit (x , y )
210222
211223 best_model_parameters_and_performance .append (
212224 (
@@ -218,7 +230,7 @@ def train_model(args):
218230 )
219231
220232 # B3: drop per-iter buffers before next outer iter
221- del x , y , x_train , y_train , sw_train , classifier , grid_search
233+ del x , y , x_train , y_train , classifier , grid_search
222234 gc .collect ()
223235
224236 best_model_parameters_and_performance .sort (key = lambda x : x [1 ], reverse = True )
@@ -238,7 +250,7 @@ def train_model(args):
238250 clear_load_cache ()
239251 gc .collect ()
240252
241- x_train , y_train , sw_train = _load_train_with_balance (
253+ x_train , y_train = _load_train_with_balance (
242254 args , balance_method , protein_sample_per_ddi_train_set , args .seed
243255 )
244256 classifier = RandomForestClassifier (** params )
@@ -247,13 +259,10 @@ def train_model(args):
247259 y_train_i32 = y_train .astype (np .int32 )
248260 del x_train , y_train
249261 gc .collect ()
250- if sw_train is not None :
251- classifier .fit (x_train_f32 , y_train_i32 , sample_weight = sw_train )
252- else :
253- classifier .fit (x_train_f32 , y_train_i32 )
262+ classifier .fit (x_train_f32 , y_train_i32 )
254263
255264 # Free training buffers before allocating x_opt again.
256- del x_train_f32 , y_train_i32 , sw_train
265+ del x_train_f32 , y_train_i32
257266 gc .collect ()
258267
259268 random .seed (args .seed )
0 commit comments