Skip to content

Commit 3eab5a1

Browse files
committed
fixed typos
sklearn version in setup is defined issues #24 #23 #22
1 parent 962625d commit 3eab5a1

File tree

4 files changed

+34
-32
lines changed

4 files changed

+34
-32
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ new_train2, new_target2 = GANGenerator().generate_data_pipe(train, target, test,
3333
# example with all params defined
3434
new_train3, new_target3 = GANGenerator(gen_x_times=1.1, cat_cols=None,
3535
bot_filter_quantile=0.001, top_filter_quantile=0.999, is_post_process=True,
36-
adversaial_model_params={
36+
adversarial_model_params={
3737
"metrics": "AUC", "max_depth": 2, "max_bin": 100,
3838
"learning_rate": 0.02, "random_state": 42, "n_estimators": 500,
3939
}, pregeneration_frac=2, only_generated_data=False,
@@ -50,7 +50,7 @@ Both samplers `OriginalGenerator` and `GANGenerator` have same input parameters:
5050
* **top_filter_quantile**: float = 0.999 - bottom quantile for postprocess filtering
5151
* **is_post_process**: bool = True - perform or not post-filtering, if false bot_filter_quantile and top_filter_quantile
5252
ignored
53-
* **adversaial_model_params**: dict params for adversarial filtering model, default values for binary task
53+
* **adversarial_model_params**: dict params for adversarial filtering model, default values for binary task
5454
* **pregeneration_frac**: float = 2 - for generataion step gen_x_times * pregeneration_frac amount of data will
5555
generated. However in postprocessing (1 + gen_x_times) % of original data will be returned
5656
* **gan_params**: dict params for GAN training

pip_desc.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ new_train1, new_target1 = GANGenerator().generate_data_pipe(train, target, test,
3939
# example with all params defined
4040
new_train3, new_target3 = GANGenerator(gen_x_times=1.1, cat_cols=None,
4141
bot_filter_quantile=0.001, top_filter_quantile=0.999, is_post_process=True,
42-
adversaial_model_params={
42+
adversarial_model_params={
4343
"metrics": "AUC", "max_depth": 2, "max_bin": 100,
4444
"learning_rate": 0.02, "random_state": 42, "n_estimators": 500,
4545
}, pregeneration_frac=2, only_generated_data=False,
@@ -56,9 +56,9 @@ adversarial filtering
5656
* **top_filter_quantile**: float = 0.999 - bottom quantile for postprocess filtering
5757
* **is_post_process**: bool = True - perform or not postfiltering, if false bot_filter_quantile
5858
and top_filter_quantile ignored
59-
* **adversaial_model_params**: dict params for adversarial filtering model, default values for binary task
59+
* **adversarial_model_params**: dict params for adversarial filtering model, default values for binary task
6060
* **pregeneration_frac**: float = 2 - for generataion step gen_x_times * pregeneration_frac amount of data
61-
will generated. However in postprocessing (1 + gen_x_times) % of original data will be returned
61+
will be generated. However, in postprocessing (1 + gen_x_times) % of original data will be returned
6262
* **gan_params**: dict params for GAN training
6363

6464

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ install_requires =
4242
category_encoders
4343
torch
4444
lightgbm
45-
scikit_learn
45+
scikit_learn==0.23.2
4646
torchvision
4747
python-dateutil
4848
tqdm

src/tabgan/sampler.py

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -45,23 +45,23 @@ def get_object_generator(self) -> Sampler:
4545

4646
class SamplerOriginal(Sampler):
4747
def __init__(
48-
self,
49-
gen_x_times: float = 1.1,
50-
cat_cols: list = None,
51-
bot_filter_quantile: float = 0.001,
52-
top_filter_quantile: float = 0.999,
53-
is_post_process: bool = True,
54-
adversaial_model_params: dict = {
55-
"metrics": "AUC",
56-
"max_depth": 2,
57-
"max_bin": 100,
58-
"n_estimators": 500,
59-
"learning_rate": 0.02,
60-
"random_state": 42,
61-
},
62-
pregeneration_frac: float = 2,
63-
only_generated_data: bool = False,
64-
gan_params: dict = {'batch_size': 500, 'patience': 25, "epochs" : 500,}
48+
self,
49+
gen_x_times: float = 1.1,
50+
cat_cols: list = None,
51+
bot_filter_quantile: float = 0.001,
52+
top_filter_quantile: float = 0.999,
53+
is_post_process: bool = True,
54+
adversarial_model_params: dict = {
55+
"metrics": "AUC",
56+
"max_depth": 2,
57+
"max_bin": 100,
58+
"n_estimators": 500,
59+
"learning_rate": 0.02,
60+
"random_state": 42,
61+
},
62+
pregeneration_frac: float = 2,
63+
only_generated_data: bool = False,
64+
gan_params: dict = {'batch_size': 500, 'patience': 25, "epochs": 500, }
6565
):
6666
"""
6767
@@ -75,7 +75,8 @@ def __init__(
7575
@param adversarial_model_params: dict params for adversarial filtering model, default values for binary task
7676
@param pregeneration_frac: float = 2 - for generation step gen_x_times * pregeneration_frac amount of data
7777
will generated. However in postprocessing (1 + gen_x_times) % of original data will be returned
78-
@param only_generated_data: bool = False If True after generation get only newly generated, without concating input train dataframe.
78+
@param only_generated_data: bool = False If True after generation get only newly generated, without
79+
concating input train dataframe.
7980
@param gan_params: dict params for GAN training
8081
Only works for SamplerGAN.
8182
"""
@@ -84,13 +85,14 @@ def __init__(
8485
self.is_post_process = is_post_process
8586
self.bot_filter_quantile = bot_filter_quantile
8687
self.top_filter_quantile = top_filter_quantile
87-
self.adversarial_model_params = adversaial_model_params
88+
self.adversarial_model_params = adversarial_model_params
8889
self.pregeneration_frac = pregeneration_frac
8990
self.only_generated_data = only_generated_data
9091
self.gan_params = gan_params
9192
self.TEMP_TARGET = "TEMP_TARGET"
9293

93-
def preprocess_data_df(self, df) -> pd.DataFrame:
94+
@staticmethod
95+
def preprocess_data_df(df) -> pd.DataFrame:
9496
logging.info("Input shape: {}".format(df.shape))
9597
if isinstance(df, pd.DataFrame) is False:
9698
raise ValueError(
@@ -99,7 +101,7 @@ def preprocess_data_df(self, df) -> pd.DataFrame:
99101
return df
100102

101103
def preprocess_data(
102-
self, train, target, test_df
104+
self, train, target, test_df
103105
) -> Tuple[pd.DataFrame, pd.DataFrame]:
104106
train = self.preprocess_data_df(train)
105107
target = self.preprocess_data_df(target)
@@ -119,10 +121,10 @@ def preprocess_data(
119121
return train, target, test_df
120122

121123
def generate_data(
122-
self, train_df, target, test_df, only_generated_data
124+
self, train_df, target, test_df, only_generated_data
123125
) -> Tuple[pd.DataFrame, pd.DataFrame]:
124126
if only_generated_data:
125-
Warning.warn(
127+
Warning(
126128
"For SamplerOriginal setting only_generated_data doesn't change anything, "
127129
"because generated data sampled from the train!"
128130
)
@@ -158,7 +160,7 @@ def postprocess_data(self, train_df, target, test_df):
158160
max_val = test_df[num_col].quantile(self.top_filter_quantile)
159161
filtered_df = train_df.loc[
160162
(train_df[num_col] >= min_val) & (train_df[num_col] <= max_val)
161-
]
163+
]
162164
if filtered_df.shape[0] < 10:
163165
raise ValueError(
164166
"After post-processing generated data's shape less than 10. For columns {} test "
@@ -236,7 +238,7 @@ def _validate_data(train_df, target, test_df):
236238

237239
class SamplerGAN(SamplerOriginal):
238240
def generate_data(
239-
self, train_df, target, test_df, only_generated_data: bool
241+
self, train_df, target, test_df, only_generated_data: bool
240242
) -> Tuple[pd.DataFrame, pd.DataFrame]:
241243
self._validate_data(train_df, target, test_df)
242244
if target is not None:
@@ -298,7 +300,7 @@ def _sampler(creator: SampleData, in_train, in_target, in_test) -> None:
298300

299301
def _drop_col_if_exist(df, col_to_drop) -> pd.DataFrame:
300302
"""
301-
Drops col_to_drop from input dataframe df if sucj column exists
303+
Drops col_to_drop from input dataframe df if such column exists
302304
"""
303305
if col_to_drop in df.columns:
304306
return df.drop(col_to_drop, axis=1)

0 commit comments

Comments
 (0)