@@ -45,23 +45,23 @@ def get_object_generator(self) -> Sampler:
4545
4646class SamplerOriginal (Sampler ):
4747 def __init__ (
48- self ,
49- gen_x_times : float = 1.1 ,
50- cat_cols : list = None ,
51- bot_filter_quantile : float = 0.001 ,
52- top_filter_quantile : float = 0.999 ,
53- is_post_process : bool = True ,
54- adversaial_model_params : dict = {
55- "metrics" : "AUC" ,
56- "max_depth" : 2 ,
57- "max_bin" : 100 ,
58- "n_estimators" : 500 ,
59- "learning_rate" : 0.02 ,
60- "random_state" : 42 ,
61- },
62- pregeneration_frac : float = 2 ,
63- only_generated_data : bool = False ,
64- gan_params : dict = {'batch_size' : 500 , 'patience' : 25 , "epochs" : 500 ,}
48+ self ,
49+ gen_x_times : float = 1.1 ,
50+ cat_cols : list = None ,
51+ bot_filter_quantile : float = 0.001 ,
52+ top_filter_quantile : float = 0.999 ,
53+ is_post_process : bool = True ,
54+ adversarial_model_params : dict = {
55+ "metrics" : "AUC" ,
56+ "max_depth" : 2 ,
57+ "max_bin" : 100 ,
58+ "n_estimators" : 500 ,
59+ "learning_rate" : 0.02 ,
60+ "random_state" : 42 ,
61+ },
62+ pregeneration_frac : float = 2 ,
63+ only_generated_data : bool = False ,
64+ gan_params : dict = {'batch_size' : 500 , 'patience' : 25 , "epochs" : 500 , }
6565 ):
6666 """
6767
@@ -75,7 +75,8 @@ def __init__(
7575 @param adversarial_model_params: dict params for adversarial filtering model, default values for binary task
7676 @param pregeneration_frac: float = 2 - for generation step gen_x_times * pregeneration_frac amount of data
7777 will generated. However in postprocessing (1 + gen_x_times) % of original data will be returned
78- @param only_generated_data: bool = False If True after generation get only newly generated, without concating input train dataframe.
78+ @param only_generated_data: bool = False If True after generation get only newly generated, without
79+ concating input train dataframe.
7980 @param gan_params: dict params for GAN training
8081 Only works for SamplerGAN.
8182 """
@@ -84,13 +85,14 @@ def __init__(
8485 self .is_post_process = is_post_process
8586 self .bot_filter_quantile = bot_filter_quantile
8687 self .top_filter_quantile = top_filter_quantile
87- self .adversarial_model_params = adversaial_model_params
88+ self .adversarial_model_params = adversarial_model_params
8889 self .pregeneration_frac = pregeneration_frac
8990 self .only_generated_data = only_generated_data
9091 self .gan_params = gan_params
9192 self .TEMP_TARGET = "TEMP_TARGET"
9293
93- def preprocess_data_df (self , df ) -> pd .DataFrame :
94+ @staticmethod
95+ def preprocess_data_df (df ) -> pd .DataFrame :
9496 logging .info ("Input shape: {}" .format (df .shape ))
9597 if isinstance (df , pd .DataFrame ) is False :
9698 raise ValueError (
@@ -99,7 +101,7 @@ def preprocess_data_df(self, df) -> pd.DataFrame:
99101 return df
100102
101103 def preprocess_data (
102- self , train , target , test_df
104+ self , train , target , test_df
103105 ) -> Tuple [pd .DataFrame , pd .DataFrame ]:
104106 train = self .preprocess_data_df (train )
105107 target = self .preprocess_data_df (target )
@@ -119,10 +121,10 @@ def preprocess_data(
119121 return train , target , test_df
120122
121123 def generate_data (
122- self , train_df , target , test_df , only_generated_data
124+ self , train_df , target , test_df , only_generated_data
123125 ) -> Tuple [pd .DataFrame , pd .DataFrame ]:
124126 if only_generated_data :
125- Warning . warn (
127+ Warning (
126128 "For SamplerOriginal setting only_generated_data doesn't change anything, "
127129 "because generated data sampled from the train!"
128130 )
@@ -158,7 +160,7 @@ def postprocess_data(self, train_df, target, test_df):
158160 max_val = test_df [num_col ].quantile (self .top_filter_quantile )
159161 filtered_df = train_df .loc [
160162 (train_df [num_col ] >= min_val ) & (train_df [num_col ] <= max_val )
161- ]
163+ ]
162164 if filtered_df .shape [0 ] < 10 :
163165 raise ValueError (
164166 "After post-processing generated data's shape less than 10. For columns {} test "
@@ -236,7 +238,7 @@ def _validate_data(train_df, target, test_df):
236238
237239class SamplerGAN (SamplerOriginal ):
238240 def generate_data (
239- self , train_df , target , test_df , only_generated_data : bool
241+ self , train_df , target , test_df , only_generated_data : bool
240242 ) -> Tuple [pd .DataFrame , pd .DataFrame ]:
241243 self ._validate_data (train_df , target , test_df )
242244 if target is not None :
@@ -298,7 +300,7 @@ def _sampler(creator: SampleData, in_train, in_target, in_test) -> None:
298300
299301def _drop_col_if_exist (df , col_to_drop ) -> pd .DataFrame :
300302 """
301- Drops col_to_drop from input dataframe df if sucj column exists
303+ Drops col_to_drop from input dataframe df if such column exists
302304 """
303305 if col_to_drop in df .columns :
304306 return df .drop (col_to_drop , axis = 1 )
0 commit comments