@@ -77,9 +77,9 @@ def _process_req(self, inputs, **kwargs):
77
77
)
78
78
if needs_upcasting :
79
79
self .vae = self .vae .to (torch .float32 )
80
- latents = latents .to (self .device , torch .float32 )
80
+ inputs = inputs .to (self .device , torch .float32 )
81
81
else :
82
- latents = inputs .to (self .device , self .dtype )
82
+ inputs = inputs .to (self .device , self .dtype )
83
83
84
84
# unscale/denormalize the latents
85
85
# denormalize with the mean and std if available and not None
@@ -95,21 +95,21 @@ def _process_req(self, inputs, **kwargs):
95
95
latents_mean = (
96
96
torch .tensor (self .vae .config .latents_mean )
97
97
.view (1 , 4 , 1 , 1 )
98
- .to (latents .device , latents .dtype )
98
+ .to (inputs .device , inputs .dtype )
99
99
)
100
100
latents_std = (
101
101
torch .tensor (self .vae .config .latents_std )
102
102
.view (1 , 4 , 1 , 1 )
103
- .to (latents .device , latents .dtype )
103
+ .to (inputs .device , inputs .dtype )
104
104
)
105
- latents = (
106
- latents * latents_std / self .vae .config .scaling_factor + latents_mean
105
+ inputs = (
106
+ inputs * latents_std / self .vae .config .scaling_factor + latents_mean
107
107
)
108
108
else :
109
- latents = latents / self .vae .config .scaling_factor
109
+ inputs = inputs / self .vae .config .scaling_factor
110
110
111
111
with torch .no_grad ():
112
- image = self .vae .decode (latents , return_dict = False )[0 ]
112
+ image = self .vae .decode (inputs , return_dict = False )[0 ]
113
113
114
114
if needs_upcasting :
115
115
self .vae .to (dtype = torch .float16 )
0 commit comments