Skip to content

Commit 9ee2c66

Browse files
committed
move experiment to research folder
1 parent 35168d9 commit 9ee2c66

File tree

26 files changed

+2024
-0
lines changed

26 files changed

+2024
-0
lines changed

Research/.deepsource.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
version = 1
2+
3+
[[analyzers]]
4+
name = "python"
5+
enabled = true
6+
7+
[analyzers.meta]
8+
runtime_version = "3.x.x"

Research/ctgan/README.MD

Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
REFERENCE (initial code): https://github.com/sdv-dev/CTGAN
2+
3+
<p align="left">
4+
<img width=15% src="https://dai.lids.mit.edu/wp-content/uploads/2018/06/Logo_DAI_highres.png" alt=“sdv-dev” />
5+
<i>An open source project from Data to AI Lab at MIT.</i>
6+
</p>
7+
8+
[![Development Status](https://img.shields.io/badge/Development%20Status-2%20--%20Pre--Alpha-yellow)](https://pypi.org/search/?c=Development+Status+%3A%3A+2+-+Pre-Alpha)
9+
[![PyPI Shield](https://img.shields.io/pypi/v/ctgan.svg)](https://pypi.python.org/pypi/ctgan)
10+
[![Travis CI Shield](https://travis-ci.org/sdv-dev/CTGAN.svg?branch=master)](https://travis-ci.org/sdv-dev/CTGAN)
11+
[![Downloads](https://pepy.tech/badge/ctgan)](https://pepy.tech/project/ctgan)
12+
[![Coverage Status](https://codecov.io/gh/sdv-dev/CTGAN/branch/master/graph/badge.svg)](https://codecov.io/gh/sdv-dev/CTGAN)
13+
14+
# CTGAN
15+
16+
Implementation of our NeurIPS paper [Modeling Tabular data using Conditional GAN](https://arxiv.org/abs/1907.00503).
17+
18+
CTGAN is a GAN-based data synthesizer that can generate synthetic tabular data with high fidelity.
19+
20+
* License: [MIT](https://github.com/sdv-dev/CTGAN/blob/master/LICENSE)
21+
* Development Status: [Pre-Alpha](https://pypi.org/search/?c=Development+Status+%3A%3A+2+-+Pre-Alpha)
22+
* Documentation: https://sdv-dev.github.io/CTGAN
23+
* Homepage: https://github.com/sdv-dev/CTGAN
24+
25+
## Overview
26+
27+
Based on previous work ([TGAN](https://github.com/sdv-dev/TGAN)) on synthetic data generation,
28+
we develop a new model called CTGAN. Several major differences make CTGAN outperform TGAN.
29+
30+
- **Preprocessing**: CTGAN uses more sophisticated Variational Gaussian Mixture Model to detect
31+
modes of continuous columns.
32+
- **Network structure**: TGAN uses LSTM to generate synthetic data column by column. CTGAN uses
33+
Fully-connected networks which is more efficient.
34+
- **Features to prevent mode collapse**: We design a conditional generator and resample the
35+
training data to prevent model collapse on discrete columns. We use WGANGP and PacGAN to
36+
stabilize the training of GAN.
37+
38+
39+
# Install
40+
41+
## Requirements
42+
43+
**CTGAN** has been developed and tested on [Python 3.5, 3.6 and 3.7](https://www.python.org/downloads/)
44+
45+
## Install from PyPI
46+
47+
The recommended way to installing **CTGAN** is using [pip](https://pip.pypa.io/en/stable/):
48+
49+
```bash
50+
pip install ctgan
51+
```
52+
53+
This will pull and install the latest stable release from [PyPI](https://pypi.org/).
54+
55+
If you want to install from source or contribute to the project please read the
56+
[Contributing Guide](https://sdv-dev.github.io/CTGAN/contributing.html#get-started).
57+
58+
# Data Format
59+
60+
**CTGAN** expects the input data to be a table given as either a `numpy.ndarray` or a
61+
`pandas.DataFrame` object with two types of columns:
62+
63+
* **Continuous Columns**: Columns that contain numerical values and which can take any value.
64+
* **Discrete columns**: Columns that only contain a finite number of possible values, wether
65+
these are string values or not.
66+
67+
This is an example of a table with 4 columns:
68+
69+
* A continuous column with float values
70+
* A continuous column with integer values
71+
* A discrete column with string values
72+
* A discrete column with integer values
73+
74+
| | A | B | C | D |
75+
|---|------|-----|-----|---|
76+
| 0 | 0.1 | 100 | 'a' | 1 |
77+
| 1 | -1.3 | 28 | 'b' | 2 |
78+
| 2 | 0.3 | 14 | 'a' | 2 |
79+
| 3 | 1.4 | 87 | 'a' | 3 |
80+
| 4 | -0.1 | 69 | 'b' | 2 |
81+
82+
83+
**NOTE**: CTGAN does not distinguish between float and integer columns, which means that it will
84+
sample float values in all cases. If integer values are required, the outputted float values
85+
must be rounded to integers in a later step, outside of CTGAN.
86+
87+
# Python Quickstart
88+
89+
In this short tutorial we will guide you through a series of steps that will help you
90+
getting started with **CTGAN**.
91+
92+
## 1. Model the data
93+
94+
### Step 1: Prepare your data
95+
96+
Before being able to use CTGAN you will need to prepare your data as specified above.
97+
98+
For this example, we will be loading some data using the `ctgan.load_demo` function.
99+
100+
```python
101+
from ctgan import load_demo
102+
103+
data = load_demo()
104+
```
105+
106+
This will download a copy of the [Adult Census Dataset](https://archive.ics.uci.edu/ml/datasets/adult) as a dataframe:
107+
108+
| age | workclass | fnlwgt | ... | hours-per-week | native-country | income |
109+
|-------|------------------|----------|-----|------------------|------------------|----------|
110+
| 39 | State-gov | 77516 | ... | 40 | United-States | <=50K |
111+
| 50 | Self-emp-not-inc | 83311 | ... | 13 | United-States | <=50K |
112+
| 38 | Private | 215646 | ... | 40 | United-States | <=50K |
113+
| 53 | Private | 234721 | ... | 40 | United-States | <=50K |
114+
| 28 | Private | 338409 | ... | 40 | Cuba | <=50K |
115+
| ... | ... | ... | ... | ... | ... | ... |
116+
117+
118+
Aside from the table itself, you will need to create a list with the names of the discrete
119+
variables.
120+
121+
For this example:
122+
123+
```python
124+
discrete_columns = [
125+
'workclass',
126+
'education',
127+
'marital-status',
128+
'occupation',
129+
'relationship',
130+
'race',
131+
'sex',
132+
'native-country',
133+
'income'
134+
]
135+
```
136+
137+
### Step 2: Fit CTGAN to your data
138+
139+
Once you have the data ready, you need to import and create an instance of the `CTGANSynthesizer`
140+
class and fit it passing your data and the list of discrete columns.
141+
142+
```python
143+
from ctgan import CTGANSynthesizer
144+
145+
ctgan = CTGANSynthesizer()
146+
ctgan.fit(data, discrete_columns)
147+
```
148+
149+
This process is likely to take a long time to run.
150+
If you want to make the process shorter, or longer, you can control the number of training epochs
151+
that the model will be performing by adding it to the `fit` call:
152+
153+
```python
154+
ctgan.fit(data, discrete_columns, epochs=5)
155+
```
156+
157+
## 2. Generate synthetic data
158+
159+
Once the process has finished, all you need to do is call the `sample` method of your
160+
`CTGANSynthesizer` instance indicating the number of rows that you want to generate.
161+
162+
```python
163+
samples = ctgan.sample(1000)
164+
```
165+
166+
The output will be a table with the exact same format as the input and filled with the synthetic
167+
data generated by the model.
168+
169+
| age | workclass | fnlwgt | ... | hours-per-week | native-country | income |
170+
|---------|--------------|-----------|-----|------------------|------------------|----------|
171+
| 26.3191 | Private | 124079 | ... | 40.1557 | United-States | <=50K |
172+
| 39.8558 | Private | 133996 | ... | 40.2507 | United-States | <=50K |
173+
| 38.2477 | Self-emp-inc | 135955 | ... | 40.1124 | Ecuador | <=50K |
174+
| 29.6468 | Private | 3331.86 | ... | 27.012 | United-States | <=50K |
175+
| 20.9853 | Private | 120637 | ... | 40.0238 | United-States | <=50K |
176+
| ... | ... | ... | ... | ... | ... | ... |
177+
178+
179+
# Join our community
180+
181+
1. If you would like to try more dataset examples, please have a look at the [examples folder](
182+
https://github.com/sdv-dev/CTGAN/tree/master/examples) of the repository. Please contact us
183+
if you have a usage example that you would want to share with the community.
184+
2. If you want to contribute to the project code, please head to the [Contributing Guide](
185+
https://sdv-dev.github.io/CTGAN/contributing.html#get-started) for more details about how to do it.
186+
3. If you have any doubts, feature requests or detect an error, please [open an issue on github](
187+
https://github.com/sdv-dev/CTGAN/issues)
188+
4. Also do not forget to check the [project documentation site](https://sdv-dev.github.io/CTGAN/)!
189+
190+
191+
# Citing TGAN
192+
193+
If you use CTGAN, please cite the following work:
194+
195+
- *Lei Xu, Maria Skoularidou, Alfredo Cuesta-Infante, Kalyan Veeramachaneni.* **Modeling Tabular data using Conditional GAN**. NeurIPS, 2019.
196+
197+
```LaTeX
198+
@inproceedings{xu2019modeling,
199+
title={Modeling Tabular data using Conditional GAN},
200+
author={Xu, Lei and Skoularidou, Maria and Cuesta-Infante, Alfredo and Veeramachaneni, Kalyan},
201+
booktitle={Advances in Neural Information Processing Systems},
202+
year={2019}
203+
}
204+
```
205+
206+
# Related Projects
207+
Please note that these libraries are external contributions and are not maintained nor supervised by
208+
the MIT DAI-Lab team.
209+
210+
## R interface for CTGAN
211+
212+
A wrapper around **CTGAN** has been implemented by Kevin Kuo @kevinykuo, bringing the functionalities
213+
of **CTGAN** to **R** users.
214+
215+
More details can be found in the corresponding repository: https://github.com/kasaai/ctgan
216+
217+
## CTGAN Server CLI
218+
219+
A package to easily deploy **CTGAN** onto a remote server. This package is developed by Timothy Pillow @oregonpillow.
220+
221+
More details can be found in the corresponding repository: https://github.com/oregonpillow/ctgan-server-cli

Research/ctgan/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# -*- coding: utf-8 -*-
2+
3+
"""Top-level package for ctgan."""
4+
5+
__author__ = 'MIT Data To AI Lab'
6+
__email__ = '[email protected]'
7+
__version__ = '0.2.1'
8+
9+
from ctgan.demo import load_demo
10+
from ctgan.synthesizer import CTGANSynthesizer
11+
12+
__all__ = (
13+
'CTGANSynthesizer',
14+
'load_demo'
15+
)

Research/ctgan/__main__.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import argparse
2+
3+
from ctgan.data import read_csv, read_tsv, write_tsv
4+
from ctgan.synthesizer import CTGANSynthesizer
5+
6+
7+
def _parse_args():
8+
parser = argparse.ArgumentParser(description='CTGAN Command Line Interface')
9+
parser.add_argument('-e', '--epochs', default=300, type=int,
10+
help='Number of training epochs')
11+
parser.add_argument('-t', '--tsv', action='store_true',
12+
help='Load data in TSV format instead of CSV')
13+
parser.add_argument('--no-header', dest='header', action='store_false',
14+
help='The CSV file has no header. Discrete columns will be indices.')
15+
16+
parser.add_argument('-m', '--metadata', help='Path to the metadata')
17+
parser.add_argument('-d', '--discrete',
18+
help='Comma separated list of discrete columns, no whitespaces')
19+
20+
parser.add_argument('-n', '--num-samples', type=int,
21+
help='Number of rows to sample. Defaults to the training data size')
22+
23+
parser.add_argument('data', help='Path to training data')
24+
parser.add_argument('output', help='Path of the output file')
25+
26+
return parser.parse_args()
27+
28+
29+
def main():
30+
args = _parse_args()
31+
32+
if args.tsv:
33+
data, discrete_columns = read_tsv(args.data, args.metadata)
34+
else:
35+
data, discrete_columns = read_csv(args.data, args.metadata, args.header, args.discrete)
36+
37+
model = CTGANSynthesizer()
38+
model.fit(data, discrete_columns, args.epochs)
39+
40+
num_samples = args.num_samples or len(data)
41+
sampled = model.sample(num_samples)
42+
43+
if args.tsv:
44+
write_tsv(sampled, args.metadata, args.output)
45+
else:
46+
sampled.to_csv(args.output, index=False)

Research/ctgan/conditional.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import numpy as np
2+
3+
4+
class ConditionalGenerator(object):
5+
def __init__(self, data, output_info, log_frequency):
6+
self.model = []
7+
8+
start = 0
9+
skip = False
10+
max_interval = 0
11+
counter = 0
12+
for item in output_info:
13+
if item[1] == 'tanh':
14+
start += item[0]
15+
skip = True
16+
continue
17+
18+
elif item[1] == 'softmax':
19+
if skip:
20+
skip = False
21+
start += item[0]
22+
continue
23+
24+
end = start + item[0]
25+
max_interval = max(max_interval, end - start)
26+
counter += 1
27+
self.model.append(np.argmax(data[:, start:end], axis=-1))
28+
start = end
29+
30+
else:
31+
assert 0
32+
33+
assert start == data.shape[1]
34+
35+
self.interval = []
36+
self.n_col = 0
37+
self.n_opt = 0
38+
skip = False
39+
start = 0
40+
self.p = np.zeros((counter, max_interval))
41+
for item in output_info:
42+
if item[1] == 'tanh':
43+
skip = True
44+
start += item[0]
45+
continue
46+
elif item[1] == 'softmax':
47+
if skip:
48+
start += item[0]
49+
skip = False
50+
continue
51+
end = start + item[0]
52+
tmp = np.sum(data[:, start:end], axis=0)
53+
if log_frequency:
54+
tmp = np.log(tmp + 1)
55+
tmp = tmp / np.sum(tmp)
56+
self.p[self.n_col, :item[0]] = tmp
57+
self.interval.append((self.n_opt, item[0]))
58+
self.n_opt += item[0]
59+
self.n_col += 1
60+
start = end
61+
else:
62+
assert 0
63+
64+
self.interval = np.asarray(self.interval)
65+
66+
def random_choice_prob_index(self, idx):
67+
a = self.p[idx]
68+
r = np.expand_dims(np.random.rand(a.shape[0]), axis=1)
69+
return (a.cumsum(axis=1) > r).argmax(axis=1)
70+
71+
def sample(self, batch):
72+
if self.n_col == 0:
73+
return None
74+
75+
batch = batch
76+
idx = np.random.choice(np.arange(self.n_col), batch)
77+
78+
vec1 = np.zeros((batch, self.n_opt), dtype='float32')
79+
mask1 = np.zeros((batch, self.n_col), dtype='float32')
80+
mask1[np.arange(batch), idx] = 1
81+
opt1prime = self.random_choice_prob_index(idx)
82+
opt1 = self.interval[idx, 0] + opt1prime
83+
vec1[np.arange(batch), opt1] = 1
84+
85+
return vec1, mask1, idx, opt1prime
86+
87+
def sample_zero(self, batch):
88+
if self.n_col == 0:
89+
return None
90+
91+
vec = np.zeros((batch, self.n_opt), dtype='float32')
92+
idx = np.random.choice(np.arange(self.n_col), batch)
93+
for i in range(batch):
94+
col = idx[i]
95+
pick = int(np.random.choice(self.model[col]))
96+
vec[i, pick + self.interval[col, 0]] = 1
97+
98+
return vec

0 commit comments

Comments
 (0)