Skip to content

Commit 42f7a92

Browse files
author
Mohsen Naghipourfar
committed
Update scGen keras network
1 parent 22dcc45 commit 42f7a92

File tree

1 file changed

+6
-11
lines changed

1 file changed

+6
-11
lines changed

scgen/models/_vae_keras.py

+6-11
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def _encoder(self):
9898
log_var = Dense(self.z_dim, kernel_initializer=self.init_w)(h)
9999
z = Lambda(self._sample_z, output_shape=(self.z_dim,), name="Z")([mean, log_var])
100100

101-
self.encoder_model = Model(inputs=self.x, outputs=[mean, log_var, z], name="encoder")
101+
self.encoder_model = Model(inputs=self.x, outputs=z, name="encoder")
102102
return mean, log_var
103103

104104
def _decoder(self):
@@ -178,7 +178,7 @@ def _create_network(self):
178178
self.mu, self.log_var = self._encoder()
179179

180180
self.x_hat = self._decoder()
181-
self.vae_model = Model(inputs=self.x, outputs=self.decoder_model(self.encoder_model(self.x)[2]), name="VAE")
181+
self.vae_model = Model(inputs=self.x, outputs=self.decoder_model(self.encoder_model(self.x)), name="VAE")
182182

183183
def _loss_function(self):
184184
"""
@@ -224,7 +224,7 @@ def to_latent(self, data):
224224
latent: numpy nd-array
225225
Returns array containing latent space encoding of 'data'
226226
"""
227-
latent = self.encoder_model.predict(data)[2]
227+
latent = self.encoder_model.predict(data)
228228
return latent
229229

230230
def _avg_vector(self, data):
@@ -246,7 +246,7 @@ def _avg_vector(self, data):
246246
latent_avg = numpy.average(latent, axis=0)
247247
return latent_avg
248248

249-
def reconstruct(self, data, use_data=False):
249+
def reconstruct(self, data):
250250
"""
251251
Map back the latent space encoding via the decoder.
252252
@@ -265,11 +265,6 @@ def reconstruct(self, data, use_data=False):
265265
rec_data: 'numpy nd-array'
266266
Returns 'numpy nd-array` containing reconstructed 'data' in shape [n_obs, n_vars].
267267
"""
268-
# if use_data:
269-
# latent = data
270-
# else:
271-
# latent = self.to_latent(data)
272-
# rec_data = self.sess.run(self.x_hat, feed_dict={self.z_mean: latent, self.is_training: False})
273268
rec_data = self.decoder_model.predict(x=data)
274269
return rec_data
275270

@@ -321,7 +316,7 @@ def linear_interpolation(self, source_adata, dest_adata, n_steps):
321316
vector = start * (1 - alpha) + end * alpha
322317
vectors[i, :] = vector
323318
vectors = numpy.array(vectors)
324-
interpolation = self.reconstruct(vectors, use_data=True)
319+
interpolation = self.reconstruct(vectors)
325320
return interpolation
326321

327322
def predict(self, adata, conditions, cell_type_key, condition_key, adata_to_predict=None, celltype_to_predict=None, obs_key="all"):
@@ -393,7 +388,7 @@ def predict(self, adata, conditions, cell_type_key, condition_key, adata_to_pred
393388
else:
394389
latent_cd = self.to_latent(ctrl_pred.X)
395390
stim_pred = delta + latent_cd
396-
predicted_cells = self.reconstruct(stim_pred, use_data=True)
391+
predicted_cells = self.reconstruct(stim_pred)
397392
return predicted_cells, delta
398393

399394
def restore_model(self):

0 commit comments

Comments
 (0)