Skip to content

[Schedulers Refactoring] Phase 1: timesteps and scaling #637

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 11 commits into from
1 change: 1 addition & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from .pipeline_utils import DiffusionPipeline
from .pipelines import DDIMPipeline, DDPMPipeline, KarrasVePipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline
from .schedulers import (
BaseScheduler,
DDIMScheduler,
DDPMScheduler,
KarrasVeScheduler,
Expand Down
1 change: 0 additions & 1 deletion src/diffusers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
"jaxlib": "jaxlib>=0.1.65,<=0.3.6",
"modelcards": "modelcards>=0.1.4",
"numpy": "numpy",
"onnxruntime": "onnxruntime",
"onnxruntime-gpu": "onnxruntime-gpu",
"pytest": "pytest",
"pytest-timeout": "pytest-timeout",
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
LOADABLE_CLASSES = {
"diffusers": {
"ModelMixin": ["save_pretrained", "from_pretrained"],
"BaseScheduler": ["save_config", "from_config"],
"SchedulerMixin": ["save_config", "from_config"],
"DiffusionPipeline": ["save_pretrained", "from_pretrained"],
"OnnxRuntimeModel": ["save_pretrained", "from_pretrained"],
Expand Down
5 changes: 3 additions & 2 deletions src/diffusers/pipelines/ddim/pipeline_ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,15 @@ def __call__(
# set step values
self.scheduler.set_timesteps(num_inference_steps)

for t in self.progress_bar(self.scheduler.timesteps):
for step in self.progress_bar(self.scheduler.timesteps):
# 1. predict noise model_output
t = self.scheduler.get_noise_condition(step)
Copy link
Contributor

@patrickvonplaten patrickvonplaten Sep 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's confusing to me that get_noise_condition( ) return a time integer. Also why is this needed for DDIM?

model_output = self.unet(image, t).sample

# 2. predict previous mean of image x_t-1 and add variance depending on eta
# eta corresponds to η in paper and should be between [0, 1]
# do x_t -> x_t-1
image = self.scheduler.step(model_output, t, image, eta).prev_sample
image = self.scheduler.step(model_output, step, image, eta).prev_sample

image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
Expand Down
5 changes: 3 additions & 2 deletions src/diffusers/pipelines/ddpm/pipeline_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,13 @@ def __call__(
# set step values
self.scheduler.set_timesteps(1000)

for t in self.progress_bar(self.scheduler.timesteps):
for step in self.progress_bar(self.scheduler.timesteps):
# 1. predict noise model_output
t = self.scheduler.get_noise_condition(step)
model_output = self.unet(image, t).sample

# 2. compute previous image: x_t -> t_t-1
image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample
image = self.scheduler.step(model_output, step, image, generator=generator).prev_sample

image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def __call__(
generator=generator,
)
latents = latents.to(self.device)
latents = self.scheduler.scale_initial_noise(latents)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do schedulers expect the latents in the final device or in CPU? As per this PR it looks like LMSDiscreteScheduler prepares sigmas in the CPU.

We could either:

  • Always follow the tensor, and move the sigmas etc. if necessary.
  • Implement scheduler.to() as @anton-l suggested in that PR.


self.scheduler.set_timesteps(num_inference_steps)

Expand All @@ -127,7 +128,7 @@ def __call__(
if accepts_eta:
extra_kwargs["eta"] = eta

for t in self.progress_bar(self.scheduler.timesteps):
for step in self.progress_bar(self.scheduler.timesteps):
if guidance_scale == 1.0:
# guidance_scale of 1 means no guidance
latents_input = latents
Expand All @@ -140,14 +141,16 @@ def __call__(
context = torch.cat([uncond_embeddings, text_embeddings])

# predict the noise residual
latents_input = self.scheduler.scale_model_input(latents_input, step)
t = self.scheduler.get_noise_condition(step)
noise_pred = self.unet(latents_input, t, encoder_hidden_states=context).sample
# perform guidance
if guidance_scale != 1.0:
noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample
latents = self.scheduler.step(noise_pred, step, latents, **extra_kwargs).prev_sample

# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __call__(
generator=generator,
)
latents = latents.to(self.device)
latents = self.scheduler.scale_initial_noise(latents)

self.scheduler.set_timesteps(num_inference_steps)

Expand All @@ -73,11 +74,13 @@ def __call__(
if accepts_eta:
extra_kwargs["eta"] = eta

for t in self.progress_bar(self.scheduler.timesteps):
for step in self.progress_bar(self.scheduler.timesteps):
# predict the noise residual
latents = self.scheduler.scale_model_input(latents, step)
t = self.scheduler.get_noise_condition(step)
noise_prediction = self.unet(latents, t).sample
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwargs).prev_sample
latents = self.scheduler.step(noise_prediction, step, latents, **extra_kwargs).prev_sample

# decode the image latents with the VAE
image = self.vqvae.decode(latents).sample
Expand Down
6 changes: 4 additions & 2 deletions src/diffusers/pipelines/pndm/pipeline_pndm.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,14 @@ def __call__(
generator=generator,
)
image = image.to(self.device)
image = self.scheduler.scale_initial_noise(image)

self.scheduler.set_timesteps(num_inference_steps)
for t in self.progress_bar(self.scheduler.timesteps):
for step in self.progress_bar(self.scheduler.timesteps):
t = self.scheduler.get_noise_condition(step)
model_output = self.unet(image, t).sample

image = self.scheduler.step(model_output, t, image).prev_sample
image = self.scheduler.step(model_output, step, image).prev_sample

image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ class StableDiffusionPipeline(DiffusionPipeline):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""

vae: AutoencoderKL
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a bit unrelated to the PR - let's try to put this in a different PR next time ;-)

text_encoder: CLIPTextModel
tokenizer: CLIPTokenizer
unet: UNet2DConditionModel
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
safety_checker: StableDiffusionSafetyChecker
feature_extractor: CLIPFeatureExtractor

def __init__(
self,
vae: AutoencoderKL,
Expand Down Expand Up @@ -230,14 +238,11 @@ def __call__(
if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
latents = latents.to(self.device)
latents = self.scheduler.scale_initial_noise(latents)
Copy link
Contributor

@patrickvonplaten patrickvonplaten Sep 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine with this function! Think it would also be fine to add an if-statement here (to only do it if the scheduler is continous)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this works correctly at the moment because the sigmas are changed when doing self.scheduler.set_timesteps


# set timesteps
self.scheduler.set_timesteps(num_inference_steps)

# if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = latents * self.scheduler.sigmas[0]

# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
Expand All @@ -247,13 +252,11 @@ def __call__(
if accepts_eta:
extra_step_kwargs["eta"] = eta

for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
for step in self.progress_bar(self.scheduler.timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
if isinstance(self.scheduler, LMSDiscreteScheduler):
sigma = self.scheduler.sigmas[i]
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, step)
t = self.scheduler.get_noise_condition(step)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I'm lost -> this function is not intuitive to me. t != noise_condition for me. Not happy if we need to force schedulers to have this function


# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
Expand All @@ -264,10 +267,7 @@ def __call__(
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = self.scheduler.step(noise_pred, step, latents, **extra_step_kwargs).prev_sample

# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,13 +206,7 @@ def __call__(
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
if isinstance(self.scheduler, LMSDiscreteScheduler):
timesteps = torch.tensor(
[num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device
)
else:
timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)
timesteps = torch.tensor([init_timestep - 1] * batch_size, dtype=torch.long, device=self.device)

# add noise to latents using the timesteps
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
Expand Down Expand Up @@ -265,17 +259,11 @@ def __call__(
latents = init_latents

t_start = max(num_inference_steps - init_timestep + offset, 0)
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps[t_start:])):
t_index = t_start + i

for step in self.progress_bar(self.scheduler.timesteps[t_start:]):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents

# if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
if isinstance(self.scheduler, LMSDiscreteScheduler):
sigma = self.scheduler.sigmas[t_index]
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, step)
t = self.scheduler.get_noise_condition(step)

# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
Expand All @@ -286,10 +274,7 @@ def __call__(
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = self.scheduler.step(noise_pred, step, latents, **extra_step_kwargs).prev_sample

# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import torch

import PIL
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer

from ...configuration_utils import FrozenDict
Expand Down Expand Up @@ -240,13 +239,7 @@ def __call__(
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
if isinstance(self.scheduler, LMSDiscreteScheduler):
timesteps = torch.tensor(
[num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device
)
else:
timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)
timesteps = torch.tensor([init_timestep - 1] * batch_size, dtype=torch.long, device=self.device)

# add noise to latents using the timesteps
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
Expand Down Expand Up @@ -298,14 +291,11 @@ def __call__(

latents = init_latents
t_start = max(num_inference_steps - init_timestep + offset, 0)
for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])):
t_index = t_start + i
for step in self.progress_bar(self.scheduler.timesteps[t_start:]):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
if isinstance(self.scheduler, LMSDiscreteScheduler):
sigma = self.scheduler.sigmas[t_index]
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, step)
t = self.scheduler.get_noise_condition(step)

# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
Expand All @@ -316,14 +306,9 @@ def __call__(
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample
# masking
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.LongTensor([t_index]))
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# masking
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.LongTensor([t]))
latents = self.scheduler.step(noise_pred, step, latents, **extra_step_kwargs).prev_sample
# masking
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor(step))

latents = (init_latents_proper * mask) + (latents * (1 - mask))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def __call__(
latents = np.random.randn(*latents_shape).astype(np.float32)
elif latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
latents = self.scheduler.scale_initial_noise(latents)

# set timesteps
self.scheduler.set_timesteps(num_inference_steps)
Expand All @@ -126,13 +127,11 @@ def __call__(
if accepts_eta:
extra_step_kwargs["eta"] = eta

for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
for step in self.progress_bar(self.scheduler.timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
if isinstance(self.scheduler, LMSDiscreteScheduler):
sigma = self.scheduler.sigmas[i]
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, step)
t = self.scheduler.get_noise_condition(step)

# predict the noise residual
noise_pred = self.unet(
Expand All @@ -146,10 +145,7 @@ def __call__(
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = self.scheduler.step(noise_pred, step, latents, **extra_step_kwargs).prev_sample

# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/schedulers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from .scheduling_pndm import PNDMScheduler
from .scheduling_sde_ve import ScoreSdeVeScheduler
from .scheduling_sde_vp import ScoreSdeVpScheduler
from .scheduling_utils import SchedulerMixin
from .scheduling_utils import BaseScheduler, SchedulerMixin
else:
from ..utils.dummy_pt_objects import * # noqa F403

Expand Down
Loading