33import numpy as np
44import torch
55from scvi import REGISTRY_KEYS
6+ from scvi .module ._constants import _MODULEKEYS
67from scvi .module .base import BaseModuleClass , LossOutput , auto_move_data
78from scvi .nn import Encoder
89from torch .distributions import Normal
@@ -80,15 +81,15 @@ def __init__(
8081
8182 def _get_inference_input (self , tensors ):
8283 x = tensors [REGISTRY_KEYS .X_KEY ]
83- input_dict = dict (
84- x = x ,
85- )
84+ input_dict = {
85+ REGISTRY_KEYS . X_KEY : x ,
86+ }
8687 return input_dict
8788
8889 def _get_generative_input (self , tensors , inference_outputs ):
89- z = inference_outputs ["z" ]
90+ z = inference_outputs [_MODULEKEYS . Z_KEY ]
9091 input_dict = {
91- "z" : z ,
92+ _MODULEKEYS . Z_KEY : z ,
9293 }
9394 return input_dict
9495
@@ -101,15 +102,15 @@ def inference(self, x):
101102 """
102103 qz_m , qz_v , z = self .z_encoder (x )
103104
104- outputs = dict ( z = z , qzm = qz_m , qzv = qz_v )
105+ outputs = { _MODULEKEYS . Z_KEY : z , _MODULEKEYS . QZM_KEY : qz_m , _MODULEKEYS . QZV_KEY : qz_v }
105106 return outputs
106107
107108 @auto_move_data
108109 def generative (self , z ):
109110 """Runs the generative model."""
110111 px = self .decoder (z )
111112
112- return dict ( px = px )
113+ return { _MODULEKEYS . PX_KEY : px }
113114
114115 def loss (
115116 self ,
@@ -118,9 +119,9 @@ def loss(
118119 generative_outputs ,
119120 ):
120121 x = tensors [REGISTRY_KEYS .X_KEY ]
121- qz_m = inference_outputs ["qzm" ]
122- qz_v = inference_outputs ["qzv" ]
123- p = generative_outputs ["px" ]
122+ qz_m = inference_outputs [_MODULEKEYS . QZ_M_KEY ]
123+ qz_v = inference_outputs [_MODULEKEYS . QZ_V_KEY ]
124+ p = generative_outputs [_MODULEKEYS . PX_KEY ]
124125
125126 kld = kl (
126127 Normal (qz_m , torch .sqrt (qz_v )),
@@ -164,7 +165,7 @@ def sample(
164165 inference_kwargs = inference_kwargs ,
165166 compute_loss = False ,
166167 )
167- px = Normal (generative_outputs ["px" ], 1 ).sample ()
168+ px = Normal (generative_outputs [_MODULEKEYS . PX_KEY ], 1 ).sample ()
168169 return px .cpu ().numpy ()
169170
170171 def get_reconstruction_loss (self , x , px ) -> torch .Tensor :
0 commit comments