Skip to content

Commit 546d62f

Browse files
[Pytorch] Pytorch only schedulers (huggingface#534)
* pytorch only schedulers * fix style * remove match_shape * pytorch only ddpm * remove SchedulerMixin * remove numpy from karras_ve * fix types * remove numpy from lms_discrete * remove numpy from pndm * fix typo * remove mixin and numpy from sde_vp and ve * remove remaining tensor_format * fix style * sigmas has to be torch tensor * removed set_format in readme * remove set format from docs * remove set_format from pipelines * update tests * fix typo * continue to use mixin * fix imports * removed unsed imports * match shape instead of assuming image shapes * remove import typo * update call to add_noise * use math instead of numpy * fix t_index * removed commented out numpy tests * timesteps needs to be discrete * cast timesteps to int in flax scheduler too * fix device mismatch issue * small fix * Update src/diffusers/schedulers/scheduling_pndm.py Co-authored-by: Patrick von Platen <[email protected]>
1 parent 3e6214d commit 546d62f

20 files changed

+204
-311
lines changed

Diff for: pipelines/ddim/pipeline_ddim.py

-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ class DDIMPipeline(DiffusionPipeline):
3535

3636
def __init__(self, unet, scheduler):
3737
super().__init__()
38-
scheduler = scheduler.set_format("pt")
3938
self.register_modules(unet=unet, scheduler=scheduler)
4039

4140
@torch.no_grad()

Diff for: pipelines/ddpm/pipeline_ddpm.py

-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ class DDPMPipeline(DiffusionPipeline):
3535

3636
def __init__(self, unet, scheduler):
3737
super().__init__()
38-
scheduler = scheduler.set_format("pt")
3938
self.register_modules(unet=unet, scheduler=scheduler)
4039

4140
@torch.no_grad()

Diff for: pipelines/latent_diffusion/pipeline_latent_diffusion.py

-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ def __init__(
4545
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
4646
):
4747
super().__init__()
48-
scheduler = scheduler.set_format("pt")
4948
self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
5049

5150
@torch.no_grad()

Diff for: pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py

-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ class LDMPipeline(DiffusionPipeline):
2323

2424
def __init__(self, vqvae: VQModel, unet: UNet2DModel, scheduler: DDIMScheduler):
2525
super().__init__()
26-
scheduler = scheduler.set_format("pt")
2726
self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler)
2827

2928
@torch.no_grad()

Diff for: pipelines/pndm/pipeline_pndm.py

-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ class PNDMPipeline(DiffusionPipeline):
3939

4040
def __init__(self, unet: UNet2DModel, scheduler: PNDMScheduler):
4141
super().__init__()
42-
scheduler = scheduler.set_format("pt")
4342
self.register_modules(unet=unet, scheduler=scheduler)
4443

4544
@torch.no_grad()

Diff for: pipelines/stable_diffusion/pipeline_stable_diffusion.py

-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ def __init__(
5757
feature_extractor: CLIPFeatureExtractor,
5858
):
5959
super().__init__()
60-
scheduler = scheduler.set_format("pt")
6160

6261
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
6362
warnings.warn(

Diff for: pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py

-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ def __init__(
6969
feature_extractor: CLIPFeatureExtractor,
7070
):
7171
super().__init__()
72-
scheduler = scheduler.set_format("pt")
7372

7473
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
7574
warnings.warn(

Diff for: pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@ def __init__(
8383
feature_extractor: CLIPFeatureExtractor,
8484
):
8585
super().__init__()
86-
scheduler = scheduler.set_format("pt")
8786
logger.info("`StableDiffusionInpaintPipeline` is experimental and will very likely change in the future.")
8887

8988
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
@@ -320,11 +319,11 @@ def __call__(
320319
if isinstance(self.scheduler, LMSDiscreteScheduler):
321320
latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample
322321
# masking
323-
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor(t_index))
322+
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.LongTensor([t_index]))
324323
else:
325324
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
326325
# masking
327-
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t)
326+
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.LongTensor([t]))
328327

329328
latents = (init_latents_proper * mask) + (latents * (1 - mask))
330329

Diff for: pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py

-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ def __init__(
3535
feature_extractor: CLIPFeatureExtractor,
3636
):
3737
super().__init__()
38-
scheduler = scheduler.set_format("np")
3938
self.register_modules(
4039
vae_decoder=vae_decoder,
4140
text_encoder=text_encoder,

Diff for: pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py

-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ class KarrasVePipeline(DiffusionPipeline):
2929

3030
def __init__(self, unet: UNet2DModel, scheduler: KarrasVeScheduler):
3131
super().__init__()
32-
scheduler = scheduler.set_format("pt")
3332
self.register_modules(unet=unet, scheduler=scheduler)
3433

3534
@torch.no_grad()

Diff for: schedulers/README.md

+1-2
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88

99
- Schedulers should provide one or more `def step(...)` functions that should be called iteratively to unroll the diffusion loop during
1010
the forward pass.
11-
- Schedulers should be framework-agnostic, but provide a simple functionality to convert the scheduler into a specific framework, such as PyTorch
12-
with a `set_format(...)` method.
11+
- Schedulers should be framework specific.
1312

1413
## Examples
1514

Diff for: schedulers/scheduling_ddim.py

+28-31
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class DDIMSchedulerOutput(BaseOutput):
4646
pred_original_sample: Optional[torch.FloatTensor] = None
4747

4848

49-
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
49+
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
5050
"""
5151
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
5252
(1-beta) over time from t = [0,1].
@@ -72,7 +72,7 @@ def alpha_bar(time_step):
7272
t1 = i / num_diffusion_timesteps
7373
t2 = (i + 1) / num_diffusion_timesteps
7474
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
75-
return np.array(betas, dtype=np.float32)
75+
return torch.tensor(betas)
7676

7777

7878
class DDIMScheduler(SchedulerMixin, ConfigMixin):
@@ -106,7 +106,6 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
106106
an offset added to the inference steps. You can use a combination of `offset=1` and
107107
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
108108
stable diffusion.
109-
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
110109
111110
"""
112111

@@ -121,36 +120,34 @@ def __init__(
121120
clip_sample: bool = True,
122121
set_alpha_to_one: bool = True,
123122
steps_offset: int = 0,
124-
tensor_format: str = "pt",
125123
):
126124
if trained_betas is not None:
127-
self.betas = np.asarray(trained_betas)
125+
self.betas = torch.from_numpy(trained_betas)
128126
if beta_schedule == "linear":
129-
self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
127+
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
130128
elif beta_schedule == "scaled_linear":
131129
# this schedule is very specific to the latent diffusion model.
132-
self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2
130+
self.betas = (
131+
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
132+
)
133133
elif beta_schedule == "squaredcos_cap_v2":
134134
# Glide cosine schedule
135135
self.betas = betas_for_alpha_bar(num_train_timesteps)
136136
else:
137137
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
138138

139139
self.alphas = 1.0 - self.betas
140-
self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
140+
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
141141

142142
# At every step in ddim, we are looking into the previous alphas_cumprod
143143
# For the final step, there is no previous alphas_cumprod because we are already at 0
144144
# `set_alpha_to_one` decides whether we set this parameter simply to one or
145145
# whether we use the final alpha of the "non-previous" one.
146-
self.final_alpha_cumprod = np.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
146+
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
147147

148148
# setable values
149149
self.num_inference_steps = None
150-
self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
151-
152-
self.tensor_format = tensor_format
153-
self.set_format(tensor_format=tensor_format)
150+
self.timesteps = np.arange(0, num_train_timesteps)[::-1]
154151

155152
def _get_variance(self, timestep, prev_timestep):
156153
alpha_prod_t = self.alphas_cumprod[timestep]
@@ -186,15 +183,14 @@ def set_timesteps(self, num_inference_steps: int, **kwargs):
186183
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
187184
# creates integer timesteps by multiplying by ratio
188185
# casting to int to avoid issues when num_inference_step is power of 3
189-
self.timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy()
186+
self.timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1]
190187
self.timesteps += offset
191-
self.set_format(tensor_format=self.tensor_format)
192188

193189
def step(
194190
self,
195-
model_output: Union[torch.FloatTensor, np.ndarray],
191+
model_output: torch.FloatTensor,
196192
timestep: int,
197-
sample: Union[torch.FloatTensor, np.ndarray],
193+
sample: torch.FloatTensor,
198194
eta: float = 0.0,
199195
use_clipped_model_output: bool = False,
200196
generator=None,
@@ -205,9 +201,9 @@ def step(
205201
process from the learned model outputs (most often the predicted noise).
206202
207203
Args:
208-
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
204+
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
209205
timestep (`int`): current discrete timestep in the diffusion chain.
210-
sample (`torch.FloatTensor` or `np.ndarray`):
206+
sample (`torch.FloatTensor`):
211207
current instance of sample being created by diffusion process.
212208
eta (`float`): weight of noise for added noise in diffusion step.
213209
use_clipped_model_output (`bool`): TODO
@@ -251,7 +247,7 @@ def step(
251247

252248
# 4. Clip "predicted x_0"
253249
if self.config.clip_sample:
254-
pred_original_sample = self.clip(pred_original_sample, -1, 1)
250+
pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
255251

256252
# 5. compute variance: "sigma_t(η)" -> see formula (16)
257253
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
@@ -273,9 +269,6 @@ def step(
273269
noise = torch.randn(model_output.shape, generator=generator).to(device)
274270
variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise
275271

276-
if not torch.is_tensor(model_output):
277-
variance = variance.numpy()
278-
279272
prev_sample = prev_sample + variance
280273

281274
if not return_dict:
@@ -285,16 +278,20 @@ def step(
285278

286279
def add_noise(
287280
self,
288-
original_samples: Union[torch.FloatTensor, np.ndarray],
289-
noise: Union[torch.FloatTensor, np.ndarray],
290-
timesteps: Union[torch.IntTensor, np.ndarray],
291-
) -> Union[torch.FloatTensor, np.ndarray]:
292-
if self.tensor_format == "pt":
293-
timesteps = timesteps.to(self.alphas_cumprod.device)
281+
original_samples: torch.FloatTensor,
282+
noise: torch.FloatTensor,
283+
timesteps: torch.IntTensor,
284+
) -> torch.FloatTensor:
285+
timesteps = timesteps.to(self.alphas_cumprod.device)
294286
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
295-
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
287+
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
288+
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
289+
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
290+
296291
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
297-
sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
292+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
293+
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
294+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
298295

299296
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
300297
return noisy_samples

0 commit comments

Comments
 (0)