Skip to content

Commit 2061dcc

Browse files
author
Mohsen Naghipourfar
committed
change VAE Keras interface and update VAE TF arguments
1 parent 42f7a92 commit 2061dcc

File tree

2 files changed

+5
-16
lines changed

2 files changed

+5
-16
lines changed

scgen/models/_vae.py

+5-15
Original file line numberDiff line numberDiff line change
@@ -32,19 +32,15 @@ class VAEArith:
3232
"""
3333

3434
def __init__(self, x_dimension, z_dimension=100, **kwargs):
35-
tf.reset_default_graph()
3635
self.x_dim = x_dimension
3736
self.z_dim = z_dimension
3837
self.learning_rate = kwargs.get("learning_rate", 0.001)
3938
self.dropout_rate = kwargs.get("dropout_rate", 0.2)
4039
self.model_to_use = kwargs.get("model_path", "./models/scgen")
4140
self.alpha = kwargs.get("alpha", 0.00005)
4241
self.is_training = tf.placeholder(tf.bool, name='training_flag')
43-
self.global_step = tf.Variable(0, name='global_step', trainable=False, dtype=tf.int32)
4442
self.x = tf.placeholder(tf.float32, shape=[None, self.x_dim], name="data")
4543
self.z = tf.placeholder(tf.float32, shape=[None, self.z_dim], name="latent")
46-
self.time_step = tf.placeholder(tf.int32)
47-
self.size = tf.placeholder(tf.int32)
4844
self.init_w = tf.contrib.layers.xavier_initializer()
4945
self._create_network()
5046
self._loss_function()
@@ -119,7 +115,8 @@ def _sample_z(self):
119115
# Returns
120116
The computed Tensor of samples with shape [size, z_dim].
121117
"""
122-
eps = tf.random_normal(shape=[self.size, self.z_dim])
118+
batch_size = tf.shape(self.mu)[0]
119+
eps = tf.random_normal(shape=[batch_size, self.z_dim])
123120
return self.mu + tf.exp(self.log_var / 2) * eps
124121

125122
def _create_network(self):
@@ -174,7 +171,7 @@ def to_latent(self, data):
174171
latent: numpy nd-array
175172
Returns array containing latent space encoding of 'data'
176173
"""
177-
latent = self.sess.run(self.z_mean, feed_dict={self.x: data, self.size: data.shape[0], self.is_training: False})
174+
latent = self.sess.run(self.z_mean, feed_dict={self.x: data, self.is_training: False})
178175
return latent
179176

180177
def _avg_vector(self, data):
@@ -429,8 +426,6 @@ def train(self, train_data, use_validation=False, valid_data=None, n_epochs=25,
429426
"""
430427
if initial_run:
431428
log.info("----Training----")
432-
assign_step_zero = tf.assign(self.global_step, 0)
433-
_init_step = self.sess.run(assign_step_zero)
434429
if not initial_run:
435430
self.saver.restore(self.sess, self.model_to_use)
436431
if use_validation and valid_data is None:
@@ -442,9 +437,6 @@ def train(self, train_data, use_validation=False, valid_data=None, n_epochs=25,
442437
min_delta = threshold
443438
patience_cnt = 0
444439
for it in range(n_epochs):
445-
increment_global_step_op = tf.assign(self.global_step, self.global_step + 1)
446-
_step = self.sess.run(increment_global_step_op)
447-
current_step = self.sess.run(self.global_step)
448440
train_loss = 0.0
449441
for lower in range(0, train_data.shape[0], batch_size):
450442
upper = min(lower + batch_size, train_data.shape[0])
@@ -454,8 +446,7 @@ def train(self, train_data, use_validation=False, valid_data=None, n_epochs=25,
454446
x_mb = train_data[lower:upper, :].X
455447
if upper - lower > 1:
456448
_, current_loss_train = self.sess.run([self.solver, self.vae_loss],
457-
feed_dict={self.x: x_mb, self.time_step: current_step,
458-
self.size: len(x_mb), self.is_training: True})
449+
feed_dict={self.x: x_mb, self.is_training: True})
459450
train_loss += current_loss_train
460451
if use_validation:
461452
valid_loss = 0
@@ -466,8 +457,7 @@ def train(self, train_data, use_validation=False, valid_data=None, n_epochs=25,
466457
else:
467458
x_mb = valid_data[lower:upper, :].X
468459
current_loss_valid = self.sess.run(self.vae_loss,
469-
feed_dict={self.x: x_mb, self.time_step: current_step,
470-
self.size: len(x_mb), self.is_training: False})
460+
feed_dict={self.x: x_mb, self.is_training: False})
471461
valid_loss += current_loss_valid
472462
loss_hist.append(valid_loss / valid_data.shape[0])
473463
if it > 0 and loss_hist[it - 1] - loss_hist[it] > min_delta:

scgen/models/_vae_keras.py

-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ class VAEArithKeras:
4646
"""
4747

4848
def __init__(self, x_dimension, z_dimension=100, **kwargs):
49-
tf.reset_default_graph()
5049
self.x_dim = x_dimension
5150
self.z_dim = z_dimension
5251
self.learning_rate = kwargs.get("learning_rate", 0.001)

0 commit comments

Comments
 (0)