Skip to content

Commit c9a4c48

Browse files
author
glederrey
committed
[UPD] Examples ready for v2
1 parent d8ecc89 commit c9a4c48

File tree

15 files changed

+9027
-9593
lines changed

15 files changed

+9027
-9593
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ Directed Acyclic Tabular GAN (**DATGAN**) for integrating expert knowledge in sy
99
- Development Status: [Alpha](https://pypi.org/search/?q=&o=&c=Development+Status+%3A%3A+3+-+Alpha)
1010
- Homepage: https://github.com/glederrey/DATGAN
1111

12-
> The preprint of the article for this model will be available on arXiv by the end of February/early March.
12+
> **The preprint of the article for this model should be available on arXiv by mid-March.**
1313
1414
## Overview
1515

@@ -153,7 +153,7 @@ Once you have a **DATGAN** instance, you can call the method `fit` and passing t
153153
- `data`: the original DataFrame
154154
- `graph`: the `networkx` DAG
155155
- `continuous_columns`: the list of continuous columns
156-
- `preprocessed_data_path`: the path to the preprocessed data if done in Step 4.
156+
- `preprocessed_data_path`: the path to the preprocessed data if done in Step 4 or the path where to save them.
157157
```python
158158
datgan.fit(df, graph, continuous_columns, preprocessed_data_path='./encoded_data')
159159
```

datgan/datgan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ def default_parameter_values(self, data, continuous_columns):
397397
if self.loss_function == 'SGAN':
398398
self.g_period = 1
399399
elif self.loss_function == 'WGAN':
400-
self.g_period = 3
400+
self.g_period = 2
401401
elif self.loss_function == 'WGGP':
402402
self.g_period = 5
403403

datgan/synthesizer/losses/WGGPLoss.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def __init__(self, metadata, var_order, name="WGGPLoss"):
2525
"""
2626
super().__init__(metadata, var_order, name=name)
2727

28-
self.lambda_ = tf.Variable(10.0, dtype=tf.float32)
28+
self.lambda_ = tf.constant(10.0, dtype=tf.float32)
2929

3030
def gen_loss(self, synth_output, transformed_orig, transformed_synth, l2_reg):
3131
"""
@@ -87,7 +87,7 @@ def discr_loss(self, orig_output, synth_output, interp_grad, l2_reg):
8787

8888
# the gradient penalty loss
8989
interp_grad_norm = tf.sqrt(tf.reduce_sum(tf.square(interp_grad), axis=[1]))
90-
grad_pen = self.lambda_ * tf.reduce_mean((interp_grad_norm - 1.0) ** 2)
90+
grad_pen = tf.multiply(self.lambda_, tf.reduce_mean((interp_grad_norm - 1.0) ** 2))
9191

9292
# Full loss
9393
loss = fake_loss - real_loss + grad_pen + l2_reg

datgan/synthesizer/losses/__init__.py

Whitespace-only changes.

datgan/synthesizer/synthesizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -279,8 +279,8 @@ def initialize(self):
279279
self.loss = WGANLoss(self.metadata, self.var_order)
280280
elif self.loss_function == 'WGGP':
281281
# Optimizers
282-
self.optimizerG = tf.keras.optimizers.Adam(learning_rate=self.learning_rate, beta_1=0, beta_2=0.9)
283-
self.optimizerD = tf.keras.optimizers.Adam(learning_rate=self.learning_rate, beta_1=0, beta_2=0.9)
282+
self.optimizerG = tf.keras.optimizers.Adam(learning_rate=self.learning_rate, beta_1=0.0, beta_2=0.9)
283+
self.optimizerD = tf.keras.optimizers.Adam(learning_rate=self.learning_rate, beta_1=0.0, beta_2=0.9)
284284

285285
# Loss function
286286
self.loss = WGGPLoss(self.metadata, self.var_order)

example/data/CMAP_synthetic.csv

Lines changed: 8929 additions & 8929 deletions
Large diffs are not rendered by default.
-89 Bytes
Loading
-457 Bytes
Loading
632 Bytes
Loading
4.44 MB
Binary file not shown.

0 commit comments

Comments
 (0)