Skip to content

[Schedulers Refactoring] Phase 1: timesteps and scaling #637

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 11 commits into from
1 change: 0 additions & 1 deletion src/diffusers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
"jaxlib": "jaxlib>=0.1.65,<=0.3.6",
"modelcards": "modelcards>=0.1.4",
"numpy": "numpy",
"onnxruntime": "onnxruntime",
"onnxruntime-gpu": "onnxruntime-gpu",
"pytest": "pytest",
"pytest-timeout": "pytest-timeout",
Expand Down
5 changes: 3 additions & 2 deletions src/diffusers/pipelines/ddim/pipeline_ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,14 @@ def __call__(
# set step values
self.scheduler.set_timesteps(num_inference_steps)

for t in self.progress_bar(self.scheduler.timesteps):
for step in self.progress_bar(self.scheduler.timesteps):
# 1. predict noise model_output
t = self.scheduler.get_noise_condition(step)
Copy link
Contributor

@patrickvonplaten patrickvonplaten Sep 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's confusing to me that get_noise_condition( ) return a time integer. Also why is this needed for DDIM?

model_output = self.unet(image, t).sample

# 2. predict previous mean of image x_t-1 and add variance depending on eta
# do x_t -> x_t-1
image = self.scheduler.step(model_output, t, image, eta).prev_sample
image = self.scheduler.step(model_output, step, image, eta).prev_sample

image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def __call__(
generator=generator,
)
latents = latents.to(self.device)
latents = self.scheduler.scale_initial_noise(latents)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do schedulers expect the latents in the final device or in CPU? As per this PR it looks like LMSDiscreteScheduler prepares sigmas in the CPU.

We could either:

  • Always follow the tensor, and move the sigmas etc. if necessary.
  • Implement scheduler.to() as @anton-l suggested in that PR.


self.scheduler.set_timesteps(num_inference_steps)

Expand All @@ -140,7 +141,7 @@ def __call__(
if accepts_eta:
extra_kwargs["eta"] = eta

for t in self.progress_bar(self.scheduler.timesteps):
for step in self.progress_bar(self.scheduler.timesteps):
if guidance_scale == 1.0:
# guidance_scale of 1 means no guidance
latents_input = latents
Expand All @@ -153,14 +154,16 @@ def __call__(
context = torch.cat([uncond_embeddings, text_embeddings])

# predict the noise residual
latents_input = self.scheduler.scale_model_input(latents_input, step)
t = self.scheduler.get_noise_condition(step)
noise_pred = self.unet(latents_input, t, encoder_hidden_states=context).sample
# perform guidance
if guidance_scale != 1.0:
noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample
latents = self.scheduler.step(noise_pred, step, latents, **extra_kwargs).prev_sample

# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __call__(
generator=generator,
)
latents = latents.to(self.device)
latents = self.scheduler.scale_initial_noise(latents)

self.scheduler.set_timesteps(num_inference_steps)

Expand All @@ -87,11 +88,13 @@ def __call__(
if accepts_eta:
extra_kwargs["eta"] = eta

for t in self.progress_bar(self.scheduler.timesteps):
for step in self.progress_bar(self.scheduler.timesteps):
# predict the noise residual
latents = self.scheduler.scale_model_input(latents, step)
t = self.scheduler.get_noise_condition(step)
noise_prediction = self.unet(latents, t).sample
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwargs).prev_sample
latents = self.scheduler.step(noise_prediction, step, latents, **extra_kwargs).prev_sample

# decode the image latents with the VAE
image = self.vqvae.decode(latents).sample
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ class StableDiffusionPipeline(DiffusionPipeline):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""

vae: AutoencoderKL
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a bit unrelated to the PR - let's try to put this in a different PR next time ;-)

text_encoder: CLIPTextModel
tokenizer: CLIPTokenizer
unet: UNet2DConditionModel
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
safety_checker: StableDiffusionSafetyChecker
feature_extractor: CLIPFeatureExtractor

def __init__(
self,
vae: AutoencoderKL,
Expand Down Expand Up @@ -231,14 +239,11 @@ def __call__(
if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
latents = latents.to(self.device)
latents = self.scheduler.scale_initial_noise(latents)
Copy link
Contributor

@patrickvonplaten patrickvonplaten Sep 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine with this function! Think it would also be fine to add an if-statement here (to only do it if the scheduler is continous)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this works correctly at the moment because the sigmas are changed when doing self.scheduler.set_timesteps


# set timesteps
self.scheduler.set_timesteps(num_inference_steps)

# if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = latents * self.scheduler.sigmas[0]

# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
Expand All @@ -248,13 +253,11 @@ def __call__(
if accepts_eta:
extra_step_kwargs["eta"] = eta

for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
for step in self.progress_bar(self.scheduler.timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
if isinstance(self.scheduler, LMSDiscreteScheduler):
sigma = self.scheduler.sigmas[i]
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, step)
t = self.scheduler.get_noise_condition(step)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I'm lost -> this function is not intuitive to me. t != noise_condition for me. Not happy if we need to force schedulers to have this function


# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
Expand All @@ -265,10 +268,7 @@ def __call__(
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = self.scheduler.step(noise_pred, step, latents, **extra_step_kwargs).prev_sample

# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,13 +241,7 @@ def __call__(
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
if isinstance(self.scheduler, LMSDiscreteScheduler):
timesteps = torch.tensor(
[num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device
)
else:
timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)
timesteps = torch.tensor([num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device)

# add noise to latents using the timesteps
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
Expand Down Expand Up @@ -291,14 +285,11 @@ def __call__(

latents = init_latents
t_start = max(num_inference_steps - init_timestep + offset, 0)
for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])):
t_index = t_start + i
for step in self.progress_bar(self.scheduler.timesteps[t_start:]):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
if isinstance(self.scheduler, LMSDiscreteScheduler):
sigma = self.scheduler.sigmas[t_index]
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, step)
t = self.scheduler.get_noise_condition(step)

# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
Expand All @@ -309,14 +300,9 @@ def __call__(
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample
# masking
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor(t_index))
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# masking
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t)
latents = self.scheduler.step(noise_pred, step, latents, **extra_step_kwargs).prev_sample
# masking
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor(step))

latents = (init_latents_proper * mask) + (latents * (1 - mask))

Expand Down
22 changes: 17 additions & 5 deletions src/diffusers/schedulers/scheduling_ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from .scheduling_utils import SchedulerMixin
from .scheduling_utils import BaseScheduler, SchedulerMixin


@dataclass
Expand Down Expand Up @@ -75,7 +75,7 @@ def alpha_bar(time_step):
return np.array(betas, dtype=np.float32)


class DDIMScheduler(SchedulerMixin, ConfigMixin):
class DDIMScheduler(BaseScheduler, SchedulerMixin, ConfigMixin):
"""
Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
diffusion probabilistic models (DDPMs) with non-Markovian guidance.
Expand Down Expand Up @@ -147,7 +147,8 @@ def __init__(

# setable values
self.num_inference_steps = None
self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
self.schedule = np.arange(0, num_train_timesteps)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do we need the schedule for?

self.timesteps = self.schedule[::-1].copy()

self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format)
Expand All @@ -162,6 +163,12 @@ def _get_variance(self, timestep, prev_timestep):

return variance

def get_noise_condition(self, step: int):
"""
Returns the input noise condition for a model.
"""
return self.schedule[step]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is not intuitive for me


def set_timesteps(self, num_inference_steps: int, **kwargs):
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
Expand All @@ -186,8 +193,10 @@ def set_timesteps(self, num_inference_steps: int, **kwargs):
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
self.timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy()
self.timesteps += offset
self.schedule = (np.arange(0, num_inference_steps) * step_ratio).round().copy()
self.schedule += offset

self.timesteps = np.arange(0, num_inference_steps)[::-1].copy()
self.set_format(tensor_format=self.tensor_format)

def step(
Expand Down Expand Up @@ -236,6 +245,8 @@ def step(
# - pred_sample_direction -> "direction pointing to x_t"
# - pred_prev_sample -> "x_t-1"

timestep = self.schedule[timestep]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I'm a bit lost - the code now here is more complex than it was before

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we now passing a different timestep to the scheduler than before that we then revert here?


# 1. get previous step value (=t-1)
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps

Expand Down Expand Up @@ -291,6 +302,7 @@ def add_noise(
) -> Union[torch.FloatTensor, np.ndarray]:
if self.tensor_format == "pt":
timesteps = timesteps.to(self.alphas_cumprod.device)
timesteps = self.schedule[timesteps]
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
Expand Down
34 changes: 27 additions & 7 deletions src/diffusers/schedulers/scheduling_lms_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from .scheduling_utils import SchedulerMixin
from .scheduling_utils import BaseScheduler, SchedulerMixin


@dataclass
Expand All @@ -43,7 +43,7 @@ class LMSDiscreteSchedulerOutput(BaseOutput):
pred_original_sample: Optional[torch.FloatTensor] = None


class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
class LMSDiscreteScheduler(BaseScheduler, SchedulerMixin, ConfigMixin):
"""
Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by
Katherine Crowson:
Expand Down Expand Up @@ -93,15 +93,36 @@ def __init__(
self.alphas_cumprod = np.cumprod(self.alphas, axis=0)

self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
self.sigmas = self.sigmas[::-1].copy()

# setable values
self.num_inference_steps = None
self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
self.schedule = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)
self.derivatives = []

self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format)

def scale_initial_noise(self, noise: torch.FloatTensor):
"""
Scales the initial noise to the correct range for the scheduler.
"""
return noise * self.sigmas[0]

def scale_model_input(self, sample: torch.FloatTensor, step: int):
"""
Scales the model input (`sample`) to the correct range for the scheduler.
"""
sigma = self.sigmas[self.num_inference_steps - step - 1]
return sample / ((sigma**2 + 1) ** 0.5)

def get_noise_condition(self, step: int):
"""
Returns the input noise condition for a model.
"""
return self.schedule[step]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No one-liner functions please


def get_lms_coefficient(self, order, t, current_order):
"""
Compute a linear multistep coefficient.
Expand Down Expand Up @@ -133,13 +154,11 @@ def set_timesteps(self, num_inference_steps: int):
the number of diffusion steps used when generating samples with a pre-trained model.
"""
self.num_inference_steps = num_inference_steps
self.timesteps = np.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=float)
self.timesteps = np.arange(0, num_inference_steps)[::-1].copy()

low_idx = np.floor(self.timesteps).astype(int)
high_idx = np.ceil(self.timesteps).astype(int)
frac = np.mod(self.timesteps, 1.0)
self.schedule = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx]
sigmas = np.interp(self.schedule[::-1], np.arange(0, len(sigmas)), sigmas)
self.sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)

self.derivatives = []
Expand Down Expand Up @@ -172,6 +191,7 @@ def step(
When returning a tuple, the first element is the sample tensor.

"""
timestep = int(self.num_inference_steps - timestep - 1)
sigma = self.sigmas[timestep]

# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
Expand Down
Loading