@@ -46,7 +46,7 @@ class DDIMSchedulerOutput(BaseOutput):
46
46
pred_original_sample : Optional [torch .FloatTensor ] = None
47
47
48
48
49
- def betas_for_alpha_bar (num_diffusion_timesteps , max_beta = 0.999 ):
49
+ def betas_for_alpha_bar (num_diffusion_timesteps , max_beta = 0.999 ) -> torch . Tensor :
50
50
"""
51
51
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
52
52
(1-beta) over time from t = [0,1].
@@ -72,7 +72,7 @@ def alpha_bar(time_step):
72
72
t1 = i / num_diffusion_timesteps
73
73
t2 = (i + 1 ) / num_diffusion_timesteps
74
74
betas .append (min (1 - alpha_bar (t2 ) / alpha_bar (t1 ), max_beta ))
75
- return np . array (betas , dtype = np . float32 )
75
+ return torch . tensor (betas )
76
76
77
77
78
78
class DDIMScheduler (SchedulerMixin , ConfigMixin ):
@@ -106,7 +106,6 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
106
106
an offset added to the inference steps. You can use a combination of `offset=1` and
107
107
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
108
108
stable diffusion.
109
- tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
110
109
111
110
"""
112
111
@@ -121,36 +120,34 @@ def __init__(
121
120
clip_sample : bool = True ,
122
121
set_alpha_to_one : bool = True ,
123
122
steps_offset : int = 0 ,
124
- tensor_format : str = "pt" ,
125
123
):
126
124
if trained_betas is not None :
127
- self .betas = np . asarray (trained_betas )
125
+ self .betas = torch . from_numpy (trained_betas )
128
126
if beta_schedule == "linear" :
129
- self .betas = np .linspace (beta_start , beta_end , num_train_timesteps , dtype = np .float32 )
127
+ self .betas = torch .linspace (beta_start , beta_end , num_train_timesteps , dtype = torch .float32 )
130
128
elif beta_schedule == "scaled_linear" :
131
129
# this schedule is very specific to the latent diffusion model.
132
- self .betas = np .linspace (beta_start ** 0.5 , beta_end ** 0.5 , num_train_timesteps , dtype = np .float32 ) ** 2
130
+ self .betas = (
131
+ torch .linspace (beta_start ** 0.5 , beta_end ** 0.5 , num_train_timesteps , dtype = torch .float32 ) ** 2
132
+ )
133
133
elif beta_schedule == "squaredcos_cap_v2" :
134
134
# Glide cosine schedule
135
135
self .betas = betas_for_alpha_bar (num_train_timesteps )
136
136
else :
137
137
raise NotImplementedError (f"{ beta_schedule } does is not implemented for { self .__class__ } " )
138
138
139
139
self .alphas = 1.0 - self .betas
140
- self .alphas_cumprod = np .cumprod (self .alphas , axis = 0 )
140
+ self .alphas_cumprod = torch .cumprod (self .alphas , dim = 0 )
141
141
142
142
# At every step in ddim, we are looking into the previous alphas_cumprod
143
143
# For the final step, there is no previous alphas_cumprod because we are already at 0
144
144
# `set_alpha_to_one` decides whether we set this parameter simply to one or
145
145
# whether we use the final alpha of the "non-previous" one.
146
- self .final_alpha_cumprod = np . array (1.0 ) if set_alpha_to_one else self .alphas_cumprod [0 ]
146
+ self .final_alpha_cumprod = torch . tensor (1.0 ) if set_alpha_to_one else self .alphas_cumprod [0 ]
147
147
148
148
# setable values
149
149
self .num_inference_steps = None
150
- self .timesteps = np .arange (0 , num_train_timesteps )[::- 1 ].copy ()
151
-
152
- self .tensor_format = tensor_format
153
- self .set_format (tensor_format = tensor_format )
150
+ self .timesteps = np .arange (0 , num_train_timesteps )[::- 1 ]
154
151
155
152
def _get_variance (self , timestep , prev_timestep ):
156
153
alpha_prod_t = self .alphas_cumprod [timestep ]
@@ -186,15 +183,14 @@ def set_timesteps(self, num_inference_steps: int, **kwargs):
186
183
step_ratio = self .config .num_train_timesteps // self .num_inference_steps
187
184
# creates integer timesteps by multiplying by ratio
188
185
# casting to int to avoid issues when num_inference_step is power of 3
189
- self .timesteps = (np .arange (0 , num_inference_steps ) * step_ratio ).round ()[::- 1 ]. copy ()
186
+ self .timesteps = (np .arange (0 , num_inference_steps ) * step_ratio ).round ()[::- 1 ]
190
187
self .timesteps += offset
191
- self .set_format (tensor_format = self .tensor_format )
192
188
193
189
def step (
194
190
self ,
195
- model_output : Union [ torch .FloatTensor , np . ndarray ] ,
191
+ model_output : torch .FloatTensor ,
196
192
timestep : int ,
197
- sample : Union [ torch .FloatTensor , np . ndarray ] ,
193
+ sample : torch .FloatTensor ,
198
194
eta : float = 0.0 ,
199
195
use_clipped_model_output : bool = False ,
200
196
generator = None ,
@@ -205,9 +201,9 @@ def step(
205
201
process from the learned model outputs (most often the predicted noise).
206
202
207
203
Args:
208
- model_output (`torch.FloatTensor` or `np.ndarray` ): direct output from learned diffusion model.
204
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
209
205
timestep (`int`): current discrete timestep in the diffusion chain.
210
- sample (`torch.FloatTensor` or `np.ndarray` ):
206
+ sample (`torch.FloatTensor`):
211
207
current instance of sample being created by diffusion process.
212
208
eta (`float`): weight of noise for added noise in diffusion step.
213
209
use_clipped_model_output (`bool`): TODO
@@ -251,7 +247,7 @@ def step(
251
247
252
248
# 4. Clip "predicted x_0"
253
249
if self .config .clip_sample :
254
- pred_original_sample = self . clip (pred_original_sample , - 1 , 1 )
250
+ pred_original_sample = torch . clamp (pred_original_sample , - 1 , 1 )
255
251
256
252
# 5. compute variance: "sigma_t(η)" -> see formula (16)
257
253
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
@@ -273,9 +269,6 @@ def step(
273
269
noise = torch .randn (model_output .shape , generator = generator ).to (device )
274
270
variance = self ._get_variance (timestep , prev_timestep ) ** (0.5 ) * eta * noise
275
271
276
- if not torch .is_tensor (model_output ):
277
- variance = variance .numpy ()
278
-
279
272
prev_sample = prev_sample + variance
280
273
281
274
if not return_dict :
@@ -285,16 +278,20 @@ def step(
285
278
286
279
def add_noise (
287
280
self ,
288
- original_samples : Union [torch .FloatTensor , np .ndarray ],
289
- noise : Union [torch .FloatTensor , np .ndarray ],
290
- timesteps : Union [torch .IntTensor , np .ndarray ],
291
- ) -> Union [torch .FloatTensor , np .ndarray ]:
292
- if self .tensor_format == "pt" :
293
- timesteps = timesteps .to (self .alphas_cumprod .device )
281
+ original_samples : torch .FloatTensor ,
282
+ noise : torch .FloatTensor ,
283
+ timesteps : torch .IntTensor ,
284
+ ) -> torch .FloatTensor :
285
+ timesteps = timesteps .to (self .alphas_cumprod .device )
294
286
sqrt_alpha_prod = self .alphas_cumprod [timesteps ] ** 0.5
295
- sqrt_alpha_prod = self .match_shape (sqrt_alpha_prod , original_samples )
287
+ sqrt_alpha_prod = sqrt_alpha_prod .flatten ()
288
+ while len (sqrt_alpha_prod .shape ) < len (original_samples .shape ):
289
+ sqrt_alpha_prod = sqrt_alpha_prod .unsqueeze (- 1 )
290
+
296
291
sqrt_one_minus_alpha_prod = (1 - self .alphas_cumprod [timesteps ]) ** 0.5
297
- sqrt_one_minus_alpha_prod = self .match_shape (sqrt_one_minus_alpha_prod , original_samples )
292
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod .flatten ()
293
+ while len (sqrt_one_minus_alpha_prod .shape ) < len (original_samples .shape ):
294
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod .unsqueeze (- 1 )
298
295
299
296
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
300
297
return noisy_samples
0 commit comments