Skip to content

Commit 049d962

Browse files
committed
updated pip_desc.md
1 parent 1f5ad21 commit 049d962

File tree

1 file changed

+73
-55
lines changed

1 file changed

+73
-55
lines changed

pip_desc.md

Lines changed: 73 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,20 @@
11
[![CodeFactor](https://www.codefactor.io/repository/github/diyago/gan-for-tabular-data/badge)](https://www.codefactor.io/repository/github/diyago/gan-for-tabular-data)
22
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
3-
# GANs for tabular data
4-
<img src="https://raw.githubusercontent.com/Diyago/GAN-for-tabular-data/e5a4d437655261755de962b9779c73203611d921/images/logo%20tabular%20gan.svg" height="15%" width="15%">
3+
[![Downloads](https://pepy.tech/badge/tabgan)](https://pepy.tech/project/tabgan)
54

6-
We well know GANs for success in the realistic image generation. However, they can be applied in tabular data generation. We will review and examine some recent papers about tabular GANs in action.
5+
# GANs and Diffusions for tabular data
6+
7+
<img src="./images/tabular_gan.png" height="15%" width="15%">
8+
Generative Adversarial Networks (GANs) are well-known for their success in realistic image generation. However, they can also be applied to generate tabular data. Here will give opportunity to try some of them.
79

8-
* Github project: ["GAN-for-tabular-data"](https://github.com/Diyago/GAN-for-tabular-data)
910
* Arxiv article: ["Tabular GANs for uneven distribution"](https://arxiv.org/abs/2010.00638)
1011
* Medium post: [GANs for tabular data](https://towardsdatascience.com/review-of-gans-for-tabular-data-a30a2199342)
1112

12-
### Library goal
13-
14-
Let say we have **T_train** and **T_test** (train and test set respectively).
15-
We need to train the model on **T_train** and make predictions on **T_test**.
16-
However, we will increase the train by generating new data by GAN,
17-
somehow similar to **T_test**, without using ground truth labels.
18-
19-
### How to use library
13+
## How to use library
2014

2115
* Installation: `pip install tabgan`
22-
* To generate new data to train by sampling and then filtering by adversarial
23-
training call `GANGenerator().generate_data_pipe`:
16+
* To generate new data to train by sampling and then filtering by adversarial training
17+
call `GANGenerator().generate_data_pipe`:
2418

2519
``` python
2620
from tabgan.sampler import OriginalGenerator, GANGenerator
@@ -34,10 +28,11 @@ test = pd.DataFrame(np.random.randint(0, 100, size=(100, 4)), columns=list("ABCD
3428

3529
# generate data
3630
new_train1, new_target1 = OriginalGenerator().generate_data_pipe(train, target, test, )
37-
new_train1, new_target1 = GANGenerator().generate_data_pipe(train, target, test, )
31+
new_train2, new_target2 = GANGenerator().generate_data_pipe(train, target, test, )
32+
new_train3, new_target3 = ForestDiffusionGenerator().generate_data_pipe(train, target, test, )
3833

3934
# example with all params defined
40-
new_train3, new_target3 = GANGenerator(gen_x_times=1.1, cat_cols=None,
35+
new_train4, new_target4 = GANGenerator(gen_x_times=1.1, cat_cols=None,
4136
bot_filter_quantile=0.001, top_filter_quantile=0.999, is_post_process=True,
4237
adversarial_model_params={
4338
"metrics": "AUC", "max_depth": 2, "max_bin": 100,
@@ -47,21 +42,23 @@ new_train3, new_target3 = GANGenerator(gen_x_times=1.1, cat_cols=None,
4742
test, deep_copy=True, only_adversarial=False, use_adversarial=True)
4843
```
4944

50-
Both samplers `OriginalGenerator` and `GANGenerator` have same input parameters:
45+
All samplers `OriginalGenerator`, `ForestDiffusionGenerator` and `GANGenerator` have same input parameters.
46+
47+
1. **GANGenerator** based on **CTGAN**
48+
2. **ForestDiffusionGenerator** based on **Forest Diffusion**
5149

5250
* **gen_x_times**: float = 1.1 - how much data to generate, output might be less because of postprocessing and
53-
adversarial filtering
51+
adversarial filtering
5452
* **cat_cols**: list = None - categorical columns
5553
* **bot_filter_quantile**: float = 0.001 - bottom quantile for postprocess filtering
5654
* **top_filter_quantile**: float = 0.999 - top quantile for postprocess filtering
57-
* **is_post_process**: bool = True - perform or not postfiltering, if false bot_filter_quantile
58-
and top_filter_quantile ignored
55+
* **is_post_process**: bool = True - perform or not post-filtering, if false bot_filter_quantile and top_filter_quantile
56+
ignored
5957
* **adversarial_model_params**: dict params for adversarial filtering model, default values for binary task
60-
* **pregeneration_frac**: float = 2 - for generataion step gen_x_times * pregeneration_frac amount of data
61-
will be generated. However, in postprocessing (1 + gen_x_times) % of original data will be returned
58+
* **pregeneration_frac**: float = 2 - for generataion step gen_x_times * pregeneration_frac amount of data will
59+
generated. However in postprocessing (1 + gen_x_times) % of original data will be returned
6260
* **gen_params**: dict params for GAN training
6361

64-
6562
For `generate_data_pipe` methods params:
6663

6764
* **train_df**: pd.DataFrame Train dataframe which has separate target
@@ -72,8 +69,7 @@ For `generate_data_pipe` methods params:
7269
* **use_adversarial**: bool = True - perform or not adversarial filtering
7370
* **only_generated_data**: bool = False - After generation get only newly generated, without
7471
concating input train dataframe.
75-
76-
* **@return**: -> Tuple[pd.DataFrame, pd.DataFrame] - Newly generated train dataframe and test data
72+
* **@return**: -> Tuple[pd.DataFrame, pd.DataFrame] - Newly generated train dataframe and test data
7773

7874
Thus, you may use this library to improve your dataset quality:
7975

@@ -83,20 +79,19 @@ def fit_predict(clf, X_train, y_train, X_test, y_test):
8379
return sklearn.metrics.roc_auc_score(y_test, clf.predict_proba(X_test)[:, 1])
8480

8581

86-
if __name__ == "__main__":
87-
dataset = sklearn.datasets.load_breast_cancer()
88-
clf = sklearn.ensemble.RandomForestClassifier(n_estimators=25, max_depth=6)
89-
X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(
90-
pd.DataFrame(dataset.data), pd.DataFrame(dataset.target, columns=["target"]), test_size=0.33, random_state=42)
91-
print("initial metric", fit_predict(clf, X_train, y_train, X_test, y_test))
9282

93-
new_train1, new_target1 = OriginalGenerator().generate_data_pipe(X_train, y_train, X_test, )
94-
print("OriginalGenerator metric", fit_predict(clf, new_train1, new_target1, X_test, y_test))
83+
dataset = sklearn.datasets.load_breast_cancer()
84+
clf = sklearn.ensemble.RandomForestClassifier(n_estimators=25, max_depth=6)
85+
X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(
86+
pd.DataFrame(dataset.data), pd.DataFrame(dataset.target, columns=["target"]), test_size=0.33, random_state=42)
87+
print("initial metric", fit_predict(clf, X_train, y_train, X_test, y_test))
9588

96-
new_train1, new_target1 = GANGenerator().generate_data_pipe(X_train, y_train, X_test, )
97-
print("GANGenerator metric", fit_predict(clf, new_train1, new_target1, X_test, y_test))
98-
```
89+
new_train1, new_target1 = OriginalGenerator().generate_data_pipe(X_train, y_train, X_test, )
90+
print("OriginalGenerator metric", fit_predict(clf, new_train1, new_target1, X_test, y_test))
9991

92+
new_train1, new_target1 = GANGenerator().generate_data_pipe(X_train, y_train, X_test, )
93+
print("GANGenerator metric", fit_predict(clf, new_train1, new_target1, X_test, y_test))
94+
```
10095
## Timeseries GAN generation TimeGAN
10196

10297
You can easily adjust code to generate multidimensional timeseries data.
@@ -134,37 +129,60 @@ new_train = collect_dates(new_train)
134129
**Running experiment**
135130

136131
To run experiment follow these steps:
132+
137133
1. Clone the repository. All required dataset are stored in `./Research/data` folder
138134
2. Install requirements `pip install -r requirements.txt`
139-
4. Run all experiments `python ./Research/run_experiment.py`. Run all experiments `python run_experiment.py`. You may add more datasets, adjust validation type and categorical encoders.
140-
5. Observe metrics across all experiment in console or
141-
in `./Research/results/fit_predict_scores.txt`
142-
135+
4. Run all experiments `python ./Research/run_experiment.py`. Run all experiments `python run_experiment.py`. You may
136+
add more datasets, adjust validation type and categorical encoders.
137+
5. Observe metrics across all experiment in console or in `./Research/results/fit_predict_scores.txt`
143138

144139

145-
## Acknowledgments
140+
**Experiment design**
146141

147-
The author would like to thank Open Data Science community [7] for many
148-
valuable discussions and educational help in the growing field of machine and
149-
deep learning. Also, special big thanks to Sber [8] for allowing solving
150-
such tasks and providing computational resources.
142+
![Experiment design and workflow](./images/workflow.png?raw=true)
151143

152-
## References
144+
**Picture 1.1** Experiment design and workflow
153145

154-
[1] Jonathan Hui. GAN — What is Generative Adversarial Networks GAN? (2018), medium article
146+
## Results
147+
To determine the best sampling strategy, ROC AUC scores of each dataset were scaled (min-max scale) and then averaged
148+
among the dataset.
155149

156-
[2]Ian J. Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, Yoshua Bengio. Generative Adversarial Networks (2014). arXiv:1406.2661
150+
**Table 1.2** Different sampling results across the dataset, higher is better (100% - maximum per dataset ROC AUC)
157151

158-
[3] Lei Xu LIDS, Kalyan Veeramachaneni. Synthesizing Tabular Data using Generative Adversarial Networks (2018). arXiv:1811.11264v1 [cs.LG]
152+
| dataset_name | None | gan | sample_original |
153+
|:-----------------------|-------------------:|------------------:|------------------------------:|
154+
| credit | 0.997 | **0.998** | 0.997 |
155+
| employee | **0.986** | 0.966 | 0.972 |
156+
| mortgages | 0.984 | 0.964 | **0.988** |
157+
| poverty_A | 0.937 | **0.950** | 0.933 |
158+
| taxi | 0.966 | 0.938 | **0.987** |
159+
| adult | 0.995 | 0.967 | **0.998** |
159160

160-
[4] Lei Xu, Maria Skoularidou, Alfredo Cuesta-Infante, Kalyan Veeramachaneni. Modeling Tabular Data using Conditional GAN (2019). arXiv:1907.00503v2 [cs.LG]
161+
## Acknowledgments
161162

162-
[5] Denis Vorotyntsev. Benchmarking Categorical Encoders (2019). Medium post
163+
The author would like to thank Open Data Science community [7] for many valuable discussions and educational help in the
164+
growing field of machine and deep learning.
165+
166+
## Citation
167+
168+
If you use **GAN-for-tabular-data** in a scientific publication, we would appreciate references to the following BibTex entry:
169+
arxiv publication:
170+
```bibtex
171+
@misc{ashrapov2020tabular,
172+
title={Tabular GANs for uneven distribution},
173+
author={Insaf Ashrapov},
174+
year={2020},
175+
eprint={2010.00638},
176+
archivePrefix={arXiv},
177+
primaryClass={cs.LG}
178+
}
179+
```
163180

164-
[6] Insaf Ashrapov. GAN-for-tabular-data (2020). Github repository.
181+
## References
165182

166-
[7] Tero Karras, Samuli Laine, Miika Aittala, Janne Hellsten, Jaakko Lehtinen, Timo Aila. Analyzing and Improving the Image Quality of StyleGAN (2019) arXiv:1912.04958v2 [cs.CV]
183+
[1] Lei Xu LIDS, Kalyan Veeramachaneni. Synthesizing Tabular Data using Generative Adversarial Networks (2018). arXiv:
184+
1811.11264v1 [cs.LG]
167185

168-
[8] ODS.ai: Open data science (2020), https://ods.ai/
186+
[2] Alexia Jolicoeur-Martineau and Kilian Fatras and Tal Kachman. Generating and Imputing Tabular Data via Diffusion and Flow-based Gradient-Boosted Trees ((2023) https://github.com/SamsungSAILMontreal/ForestDiffusion [cs.LG]
169187

170-
[9] Sber (2020), https://www.sberbank.ru/
188+
[3] Lei Xu, Maria Skoularidou, Alfredo Cuesta-Infante, Kalyan Veeramachaneni. Modeling Tabular data using Conditional GAN. NeurIPS, (2019)

0 commit comments

Comments
 (0)