@@ -98,7 +98,7 @@ def _encoder(self):
98
98
log_var = Dense (self .z_dim , kernel_initializer = self .init_w )(h )
99
99
z = Lambda (self ._sample_z , output_shape = (self .z_dim ,), name = "Z" )([mean , log_var ])
100
100
101
- self .encoder_model = Model (inputs = self .x , outputs = [ mean , log_var , z ] , name = "encoder" )
101
+ self .encoder_model = Model (inputs = self .x , outputs = z , name = "encoder" )
102
102
return mean , log_var
103
103
104
104
def _decoder (self ):
@@ -178,7 +178,7 @@ def _create_network(self):
178
178
self .mu , self .log_var = self ._encoder ()
179
179
180
180
self .x_hat = self ._decoder ()
181
- self .vae_model = Model (inputs = self .x , outputs = self .decoder_model (self .encoder_model (self .x )[ 2 ] ), name = "VAE" )
181
+ self .vae_model = Model (inputs = self .x , outputs = self .decoder_model (self .encoder_model (self .x )), name = "VAE" )
182
182
183
183
def _loss_function (self ):
184
184
"""
@@ -224,7 +224,7 @@ def to_latent(self, data):
224
224
latent: numpy nd-array
225
225
Returns array containing latent space encoding of 'data'
226
226
"""
227
- latent = self .encoder_model .predict (data )[ 2 ]
227
+ latent = self .encoder_model .predict (data )
228
228
return latent
229
229
230
230
def _avg_vector (self , data ):
@@ -246,7 +246,7 @@ def _avg_vector(self, data):
246
246
latent_avg = numpy .average (latent , axis = 0 )
247
247
return latent_avg
248
248
249
- def reconstruct (self , data , use_data = False ):
249
+ def reconstruct (self , data ):
250
250
"""
251
251
Map back the latent space encoding via the decoder.
252
252
@@ -265,11 +265,6 @@ def reconstruct(self, data, use_data=False):
265
265
rec_data: 'numpy nd-array'
266
266
Returns 'numpy nd-array` containing reconstructed 'data' in shape [n_obs, n_vars].
267
267
"""
268
- # if use_data:
269
- # latent = data
270
- # else:
271
- # latent = self.to_latent(data)
272
- # rec_data = self.sess.run(self.x_hat, feed_dict={self.z_mean: latent, self.is_training: False})
273
268
rec_data = self .decoder_model .predict (x = data )
274
269
return rec_data
275
270
@@ -321,7 +316,7 @@ def linear_interpolation(self, source_adata, dest_adata, n_steps):
321
316
vector = start * (1 - alpha ) + end * alpha
322
317
vectors [i , :] = vector
323
318
vectors = numpy .array (vectors )
324
- interpolation = self .reconstruct (vectors , use_data = True )
319
+ interpolation = self .reconstruct (vectors )
325
320
return interpolation
326
321
327
322
def predict (self , adata , conditions , cell_type_key , condition_key , adata_to_predict = None , celltype_to_predict = None , obs_key = "all" ):
@@ -393,7 +388,7 @@ def predict(self, adata, conditions, cell_type_key, condition_key, adata_to_pred
393
388
else :
394
389
latent_cd = self .to_latent (ctrl_pred .X )
395
390
stim_pred = delta + latent_cd
396
- predicted_cells = self .reconstruct (stim_pred , use_data = True )
391
+ predicted_cells = self .reconstruct (stim_pred )
397
392
return predicted_cells , delta
398
393
399
394
def restore_model (self ):
0 commit comments