@@ -32,19 +32,15 @@ class VAEArith:
32
32
"""
33
33
34
34
def __init__ (self , x_dimension , z_dimension = 100 , ** kwargs ):
35
- tf .reset_default_graph ()
36
35
self .x_dim = x_dimension
37
36
self .z_dim = z_dimension
38
37
self .learning_rate = kwargs .get ("learning_rate" , 0.001 )
39
38
self .dropout_rate = kwargs .get ("dropout_rate" , 0.2 )
40
39
self .model_to_use = kwargs .get ("model_path" , "./models/scgen" )
41
40
self .alpha = kwargs .get ("alpha" , 0.00005 )
42
41
self .is_training = tf .placeholder (tf .bool , name = 'training_flag' )
43
- self .global_step = tf .Variable (0 , name = 'global_step' , trainable = False , dtype = tf .int32 )
44
42
self .x = tf .placeholder (tf .float32 , shape = [None , self .x_dim ], name = "data" )
45
43
self .z = tf .placeholder (tf .float32 , shape = [None , self .z_dim ], name = "latent" )
46
- self .time_step = tf .placeholder (tf .int32 )
47
- self .size = tf .placeholder (tf .int32 )
48
44
self .init_w = tf .contrib .layers .xavier_initializer ()
49
45
self ._create_network ()
50
46
self ._loss_function ()
@@ -119,7 +115,8 @@ def _sample_z(self):
119
115
# Returns
120
116
The computed Tensor of samples with shape [size, z_dim].
121
117
"""
122
- eps = tf .random_normal (shape = [self .size , self .z_dim ])
118
+ batch_size = tf .shape (self .mu )[0 ]
119
+ eps = tf .random_normal (shape = [batch_size , self .z_dim ])
123
120
return self .mu + tf .exp (self .log_var / 2 ) * eps
124
121
125
122
def _create_network (self ):
@@ -174,7 +171,7 @@ def to_latent(self, data):
174
171
latent: numpy nd-array
175
172
Returns array containing latent space encoding of 'data'
176
173
"""
177
- latent = self .sess .run (self .z_mean , feed_dict = {self .x : data , self .size : data . shape [ 0 ], self . is_training : False })
174
+ latent = self .sess .run (self .z_mean , feed_dict = {self .x : data , self .is_training : False })
178
175
return latent
179
176
180
177
def _avg_vector (self , data ):
@@ -429,8 +426,6 @@ def train(self, train_data, use_validation=False, valid_data=None, n_epochs=25,
429
426
"""
430
427
if initial_run :
431
428
log .info ("----Training----" )
432
- assign_step_zero = tf .assign (self .global_step , 0 )
433
- _init_step = self .sess .run (assign_step_zero )
434
429
if not initial_run :
435
430
self .saver .restore (self .sess , self .model_to_use )
436
431
if use_validation and valid_data is None :
@@ -442,9 +437,6 @@ def train(self, train_data, use_validation=False, valid_data=None, n_epochs=25,
442
437
min_delta = threshold
443
438
patience_cnt = 0
444
439
for it in range (n_epochs ):
445
- increment_global_step_op = tf .assign (self .global_step , self .global_step + 1 )
446
- _step = self .sess .run (increment_global_step_op )
447
- current_step = self .sess .run (self .global_step )
448
440
train_loss = 0.0
449
441
for lower in range (0 , train_data .shape [0 ], batch_size ):
450
442
upper = min (lower + batch_size , train_data .shape [0 ])
@@ -454,8 +446,7 @@ def train(self, train_data, use_validation=False, valid_data=None, n_epochs=25,
454
446
x_mb = train_data [lower :upper , :].X
455
447
if upper - lower > 1 :
456
448
_ , current_loss_train = self .sess .run ([self .solver , self .vae_loss ],
457
- feed_dict = {self .x : x_mb , self .time_step : current_step ,
458
- self .size : len (x_mb ), self .is_training : True })
449
+ feed_dict = {self .x : x_mb , self .is_training : True })
459
450
train_loss += current_loss_train
460
451
if use_validation :
461
452
valid_loss = 0
@@ -466,8 +457,7 @@ def train(self, train_data, use_validation=False, valid_data=None, n_epochs=25,
466
457
else :
467
458
x_mb = valid_data [lower :upper , :].X
468
459
current_loss_valid = self .sess .run (self .vae_loss ,
469
- feed_dict = {self .x : x_mb , self .time_step : current_step ,
470
- self .size : len (x_mb ), self .is_training : False })
460
+ feed_dict = {self .x : x_mb , self .is_training : False })
471
461
valid_loss += current_loss_valid
472
462
loss_hist .append (valid_loss / valid_data .shape [0 ])
473
463
if it > 0 and loss_hist [it - 1 ] - loss_hist [it ] > min_delta :
0 commit comments