Skip to content

Commit e51dcc2

Browse files
committed
Change all occurrences to _MODULEKEYS
1 parent 678c4aa commit e51dcc2

File tree

1 file changed

+12
-11
lines changed

1 file changed

+12
-11
lines changed

scgen/_scgenvae.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numpy as np
44
import torch
55
from scvi import REGISTRY_KEYS
6+
from scvi.module._constants import _MODULEKEYS
67
from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data
78
from scvi.nn import Encoder
89
from torch.distributions import Normal
@@ -80,15 +81,15 @@ def __init__(
8081

8182
def _get_inference_input(self, tensors):
8283
x = tensors[REGISTRY_KEYS.X_KEY]
83-
input_dict = dict(
84-
x=x,
85-
)
84+
input_dict = {
85+
REGISTRY_KEYS.X_KEY:x,
86+
}
8687
return input_dict
8788

8889
def _get_generative_input(self, tensors, inference_outputs):
89-
z = inference_outputs["z"]
90+
z = inference_outputs[_MODULEKEYS.Z_KEY]
9091
input_dict = {
91-
"z": z,
92+
_MODULEKEYS.Z_KEY: z,
9293
}
9394
return input_dict
9495

@@ -101,15 +102,15 @@ def inference(self, x):
101102
"""
102103
qz_m, qz_v, z = self.z_encoder(x)
103104

104-
outputs = dict(z=z, qzm=qz_m, qzv=qz_v)
105+
outputs = {_MODULEKEYS.Z_KEY:z, _MODULEKEYS.QZM_KEY:qz_m, _MODULEKEYS.QZV_KEY:qz_v}
105106
return outputs
106107

107108
@auto_move_data
108109
def generative(self, z):
109110
"""Runs the generative model."""
110111
px = self.decoder(z)
111112

112-
return dict(px=px)
113+
return {_MODULEKEYS.PX_KEY:px}
113114

114115
def loss(
115116
self,
@@ -118,9 +119,9 @@ def loss(
118119
generative_outputs,
119120
):
120121
x = tensors[REGISTRY_KEYS.X_KEY]
121-
qz_m = inference_outputs["qzm"]
122-
qz_v = inference_outputs["qzv"]
123-
p = generative_outputs["px"]
122+
qz_m = inference_outputs[_MODULEKEYS.QZ_M_KEY]
123+
qz_v = inference_outputs[_MODULEKEYS.QZ_V_KEY]
124+
p = generative_outputs[_MODULEKEYS.PX_KEY]
124125

125126
kld = kl(
126127
Normal(qz_m, torch.sqrt(qz_v)),
@@ -164,7 +165,7 @@ def sample(
164165
inference_kwargs=inference_kwargs,
165166
compute_loss=False,
166167
)
167-
px = Normal(generative_outputs["px"], 1).sample()
168+
px = Normal(generative_outputs[_MODULEKEYS.PX_KEY], 1).sample()
168169
return px.cpu().numpy()
169170

170171
def get_reconstruction_loss(self, x, px) -> torch.Tensor:

0 commit comments

Comments
 (0)