33import numpy as np
44import torch
55from scvi import REGISTRY_KEYS
6- from scvi .module ._constants import _MODULEKEYS
76from scvi .module .base import BaseModuleClass , LossOutput , auto_move_data
87from scvi .nn import Encoder
98from torch .distributions import Normal
@@ -81,15 +80,15 @@ def __init__(
8180
8281 def _get_inference_input (self , tensors ):
8382 x = tensors [REGISTRY_KEYS .X_KEY ]
84- input_dict = {
85- REGISTRY_KEYS . X_KEY : x ,
86- }
83+ input_dict = dict (
84+ x = x ,
85+ )
8786 return input_dict
8887
8988 def _get_generative_input (self , tensors , inference_outputs ):
90- z = inference_outputs [_MODULEKEYS . Z_KEY ]
89+ z = inference_outputs ["z" ]
9190 input_dict = {
92- _MODULEKEYS . Z_KEY : z ,
91+ "z" : z ,
9392 }
9493 return input_dict
9594
@@ -102,15 +101,15 @@ def inference(self, x):
102101 """
103102 qz_m , qz_v , z = self .z_encoder (x )
104103
105- outputs = { _MODULEKEYS . Z_KEY : z , _MODULEKEYS . QZM_KEY : qz_m , _MODULEKEYS . QZV_KEY : qz_v }
104+ outputs = dict ( z = z , qzm = qz_m , qzv = qz_v )
106105 return outputs
107106
108107 @auto_move_data
109108 def generative (self , z ):
110109 """Runs the generative model."""
111110 px = self .decoder (z )
112111
113- return { _MODULEKEYS . PX_KEY : px }
112+ return dict ( px = px )
114113
115114 def loss (
116115 self ,
@@ -119,9 +118,9 @@ def loss(
119118 generative_outputs ,
120119 ):
121120 x = tensors [REGISTRY_KEYS .X_KEY ]
122- qz_m = inference_outputs [_MODULEKEYS . QZ_M_KEY ]
123- qz_v = inference_outputs [_MODULEKEYS . QZ_V_KEY ]
124- p = generative_outputs [_MODULEKEYS . PX_KEY ]
121+ qz_m = inference_outputs ["qzm" ]
122+ qz_v = inference_outputs ["qzv" ]
123+ p = generative_outputs ["px" ]
125124
126125 kld = kl (
127126 Normal (qz_m , torch .sqrt (qz_v )),
@@ -165,7 +164,7 @@ def sample(
165164 inference_kwargs = inference_kwargs ,
166165 compute_loss = False ,
167166 )
168- px = Normal (generative_outputs [_MODULEKEYS . PX_KEY ], 1 ).sample ()
167+ px = Normal (generative_outputs ["px" ], 1 ).sample ()
169168 return px .cpu ().numpy ()
170169
171170 def get_reconstruction_loss (self , x , px ) -> torch .Tensor :
0 commit comments