Skip to content

Commit c19f00c

Browse files
committed
Revert changes due to MODULE_KEYS not being exported
1 parent e51dcc2 commit c19f00c

File tree

1 file changed

+11
-12
lines changed

1 file changed

+11
-12
lines changed

scgen/_scgenvae.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import numpy as np
44
import torch
55
from scvi import REGISTRY_KEYS
6-
from scvi.module._constants import _MODULEKEYS
76
from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data
87
from scvi.nn import Encoder
98
from torch.distributions import Normal
@@ -81,15 +80,15 @@ def __init__(
8180

8281
def _get_inference_input(self, tensors):
8382
x = tensors[REGISTRY_KEYS.X_KEY]
84-
input_dict = {
85-
REGISTRY_KEYS.X_KEY:x,
86-
}
83+
input_dict = dict(
84+
x=x,
85+
)
8786
return input_dict
8887

8988
def _get_generative_input(self, tensors, inference_outputs):
90-
z = inference_outputs[_MODULEKEYS.Z_KEY]
89+
z = inference_outputs["z"]
9190
input_dict = {
92-
_MODULEKEYS.Z_KEY: z,
91+
"z": z,
9392
}
9493
return input_dict
9594

@@ -102,15 +101,15 @@ def inference(self, x):
102101
"""
103102
qz_m, qz_v, z = self.z_encoder(x)
104103

105-
outputs = {_MODULEKEYS.Z_KEY:z, _MODULEKEYS.QZM_KEY:qz_m, _MODULEKEYS.QZV_KEY:qz_v}
104+
outputs = dict(z=z, qzm=qz_m, qzv=qz_v)
106105
return outputs
107106

108107
@auto_move_data
109108
def generative(self, z):
110109
"""Runs the generative model."""
111110
px = self.decoder(z)
112111

113-
return {_MODULEKEYS.PX_KEY:px}
112+
return dict(px=px)
114113

115114
def loss(
116115
self,
@@ -119,9 +118,9 @@ def loss(
119118
generative_outputs,
120119
):
121120
x = tensors[REGISTRY_KEYS.X_KEY]
122-
qz_m = inference_outputs[_MODULEKEYS.QZ_M_KEY]
123-
qz_v = inference_outputs[_MODULEKEYS.QZ_V_KEY]
124-
p = generative_outputs[_MODULEKEYS.PX_KEY]
121+
qz_m = inference_outputs["qzm"]
122+
qz_v = inference_outputs["qzv"]
123+
p = generative_outputs["px"]
125124

126125
kld = kl(
127126
Normal(qz_m, torch.sqrt(qz_v)),
@@ -165,7 +164,7 @@ def sample(
165164
inference_kwargs=inference_kwargs,
166165
compute_loss=False,
167166
)
168-
px = Normal(generative_outputs[_MODULEKEYS.PX_KEY], 1).sample()
167+
px = Normal(generative_outputs["px"], 1).sample()
169168
return px.cpu().numpy()
170169

171170
def get_reconstruction_loss(self, x, px) -> torch.Tensor:

0 commit comments

Comments
 (0)