diff --git a/scgen/_scgenvae.py b/scgen/_scgenvae.py index 113d8cb..5030b1a 100644 --- a/scgen/_scgenvae.py +++ b/scgen/_scgenvae.py @@ -101,7 +101,7 @@ def inference(self, x): """ qz_m, qz_v, z = self.z_encoder(x) - outputs = dict(z=z, qz_m=qz_m, qz_v=qz_v) + outputs = dict(z=z, qzm=qz_m, qzv=qz_v) return outputs @auto_move_data @@ -118,8 +118,8 @@ def loss( generative_outputs, ): x = tensors[REGISTRY_KEYS.X_KEY] - qz_m = inference_outputs["qz_m"] - qz_v = inference_outputs["qz_v"] + qz_m = inference_outputs["qzm"] + qz_v = inference_outputs["qzv"] p = generative_outputs["px"] kld = kl(