-
Notifications
You must be signed in to change notification settings - Fork 6.1k
[Schedulers Refactoring] Phase 1: timesteps and scaling #637
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
c63e6e8
556e687
ccc6afb
ba351f5
f4e717e
f58846d
1d7a9fc
0d0395b
90b1aaa
606df49
9cfd2dc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -98,13 +98,14 @@ def __call__( | |
# set step values | ||
self.scheduler.set_timesteps(num_inference_steps) | ||
|
||
for t in self.progress_bar(self.scheduler.timesteps): | ||
for step in self.progress_bar(self.scheduler.timesteps): | ||
# 1. predict noise model_output | ||
t = self.scheduler.get_noise_condition(step) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's confusing to me that |
||
model_output = self.unet(image, t).sample | ||
|
||
# 2. predict previous mean of image x_t-1 and add variance depending on eta | ||
# do x_t -> x_t-1 | ||
image = self.scheduler.step(model_output, t, image, eta).prev_sample | ||
image = self.scheduler.step(model_output, step, image, eta).prev_sample | ||
|
||
image = (image / 2 + 0.5).clamp(0, 1) | ||
image = image.cpu().permute(0, 2, 3, 1).numpy() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -130,6 +130,7 @@ def __call__( | |
generator=generator, | ||
) | ||
latents = latents.to(self.device) | ||
latents = self.scheduler.scale_initial_noise(latents) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
self.scheduler.set_timesteps(num_inference_steps) | ||
|
||
|
@@ -140,7 +141,7 @@ def __call__( | |
if accepts_eta: | ||
extra_kwargs["eta"] = eta | ||
|
||
for t in self.progress_bar(self.scheduler.timesteps): | ||
for step in self.progress_bar(self.scheduler.timesteps): | ||
if guidance_scale == 1.0: | ||
# guidance_scale of 1 means no guidance | ||
latents_input = latents | ||
|
@@ -153,14 +154,16 @@ def __call__( | |
context = torch.cat([uncond_embeddings, text_embeddings]) | ||
|
||
# predict the noise residual | ||
latents_input = self.scheduler.scale_model_input(latents_input, step) | ||
t = self.scheduler.get_noise_condition(step) | ||
noise_pred = self.unet(latents_input, t, encoder_hidden_states=context).sample | ||
# perform guidance | ||
if guidance_scale != 1.0: | ||
noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) | ||
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) | ||
|
||
# compute the previous noisy sample x_t -> x_t-1 | ||
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample | ||
latents = self.scheduler.step(noise_pred, step, latents, **extra_kwargs).prev_sample | ||
|
||
# scale and decode the image latents with vae | ||
latents = 1 / 0.18215 * latents | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,6 +42,14 @@ class StableDiffusionPipeline(DiffusionPipeline): | |
Model that extracts features from generated images to be used as inputs for the `safety_checker`. | ||
""" | ||
|
||
vae: AutoencoderKL | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is a bit unrelated to the PR - let's try to put this in a different PR next time ;-) |
||
text_encoder: CLIPTextModel | ||
tokenizer: CLIPTokenizer | ||
unet: UNet2DConditionModel | ||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] | ||
safety_checker: StableDiffusionSafetyChecker | ||
feature_extractor: CLIPFeatureExtractor | ||
pcuenca marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def __init__( | ||
self, | ||
vae: AutoencoderKL, | ||
|
@@ -231,14 +239,11 @@ def __call__( | |
if latents.shape != latents_shape: | ||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") | ||
latents = latents.to(self.device) | ||
latents = self.scheduler.scale_initial_noise(latents) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fine with this function! Think it would also be fine to add an There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this works correctly at the moment because the sigmas are changed when doing |
||
|
||
# set timesteps | ||
self.scheduler.set_timesteps(num_inference_steps) | ||
|
||
# if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas | ||
if isinstance(self.scheduler, LMSDiscreteScheduler): | ||
latents = latents * self.scheduler.sigmas[0] | ||
|
||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature | ||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. | ||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 | ||
|
@@ -248,13 +253,11 @@ def __call__( | |
if accepts_eta: | ||
extra_step_kwargs["eta"] = eta | ||
|
||
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): | ||
for step in self.progress_bar(self.scheduler.timesteps): | ||
# expand the latents if we are doing classifier free guidance | ||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | ||
if isinstance(self.scheduler, LMSDiscreteScheduler): | ||
sigma = self.scheduler.sigmas[i] | ||
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS | ||
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) | ||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, step) | ||
t = self.scheduler.get_noise_condition(step) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here I'm lost -> this function is not intuitive to me. |
||
|
||
# predict the noise residual | ||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample | ||
|
@@ -265,10 +268,7 @@ def __call__( | |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | ||
|
||
# compute the previous noisy sample x_t -> x_t-1 | ||
if isinstance(self.scheduler, LMSDiscreteScheduler): | ||
latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample | ||
else: | ||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample | ||
latents = self.scheduler.step(noise_pred, step, latents, **extra_step_kwargs).prev_sample | ||
|
||
# scale and decode the image latents with vae | ||
latents = 1 / 0.18215 * latents | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,7 +25,7 @@ | |
|
||
from ..configuration_utils import ConfigMixin, register_to_config | ||
from ..utils import BaseOutput | ||
from .scheduling_utils import SchedulerMixin | ||
from .scheduling_utils import BaseScheduler, SchedulerMixin | ||
|
||
|
||
@dataclass | ||
|
@@ -75,7 +75,7 @@ def alpha_bar(time_step): | |
return np.array(betas, dtype=np.float32) | ||
|
||
|
||
class DDIMScheduler(SchedulerMixin, ConfigMixin): | ||
class DDIMScheduler(BaseScheduler, SchedulerMixin, ConfigMixin): | ||
""" | ||
Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising | ||
diffusion probabilistic models (DDPMs) with non-Markovian guidance. | ||
|
@@ -147,7 +147,8 @@ def __init__( | |
|
||
# setable values | ||
self.num_inference_steps = None | ||
self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() | ||
self.schedule = np.arange(0, num_train_timesteps) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What do we need the |
||
self.timesteps = self.schedule[::-1].copy() | ||
|
||
self.tensor_format = tensor_format | ||
self.set_format(tensor_format=tensor_format) | ||
|
@@ -162,6 +163,12 @@ def _get_variance(self, timestep, prev_timestep): | |
|
||
return variance | ||
|
||
def get_noise_condition(self, step: int): | ||
""" | ||
Returns the input noise condition for a model. | ||
""" | ||
return self.schedule[step] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function is not intuitive for me |
||
|
||
def set_timesteps(self, num_inference_steps: int, **kwargs): | ||
""" | ||
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. | ||
|
@@ -186,8 +193,10 @@ def set_timesteps(self, num_inference_steps: int, **kwargs): | |
step_ratio = self.config.num_train_timesteps // self.num_inference_steps | ||
# creates integer timesteps by multiplying by ratio | ||
# casting to int to avoid issues when num_inference_step is power of 3 | ||
self.timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy() | ||
self.timesteps += offset | ||
self.schedule = (np.arange(0, num_inference_steps) * step_ratio).round().copy() | ||
self.schedule += offset | ||
|
||
self.timesteps = np.arange(0, num_inference_steps)[::-1].copy() | ||
self.set_format(tensor_format=self.tensor_format) | ||
|
||
def step( | ||
|
@@ -236,6 +245,8 @@ def step( | |
# - pred_sample_direction -> "direction pointing to x_t" | ||
# - pred_prev_sample -> "x_t-1" | ||
|
||
timestep = self.schedule[timestep] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here I'm a bit lost - the code now here is more complex than it was before There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are we now passing a different timestep to the scheduler than before that we then revert here? |
||
|
||
# 1. get previous step value (=t-1) | ||
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps | ||
|
||
|
@@ -291,6 +302,7 @@ def add_noise( | |
) -> Union[torch.FloatTensor, np.ndarray]: | ||
if self.tensor_format == "pt": | ||
timesteps = timesteps.to(self.alphas_cumprod.device) | ||
timesteps = self.schedule[timesteps] | ||
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 | ||
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) | ||
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,7 +22,7 @@ | |
|
||
from ..configuration_utils import ConfigMixin, register_to_config | ||
from ..utils import BaseOutput | ||
from .scheduling_utils import SchedulerMixin | ||
from .scheduling_utils import BaseScheduler, SchedulerMixin | ||
|
||
|
||
@dataclass | ||
|
@@ -43,7 +43,7 @@ class LMSDiscreteSchedulerOutput(BaseOutput): | |
pred_original_sample: Optional[torch.FloatTensor] = None | ||
|
||
|
||
class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): | ||
class LMSDiscreteScheduler(BaseScheduler, SchedulerMixin, ConfigMixin): | ||
""" | ||
Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by | ||
Katherine Crowson: | ||
|
@@ -93,15 +93,36 @@ def __init__( | |
self.alphas_cumprod = np.cumprod(self.alphas, axis=0) | ||
|
||
self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 | ||
self.sigmas = self.sigmas[::-1].copy() | ||
|
||
# setable values | ||
self.num_inference_steps = None | ||
self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() | ||
self.schedule = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float) | ||
self.derivatives = [] | ||
|
||
self.tensor_format = tensor_format | ||
self.set_format(tensor_format=tensor_format) | ||
|
||
def scale_initial_noise(self, noise: torch.FloatTensor): | ||
""" | ||
Scales the initial noise to the correct range for the scheduler. | ||
""" | ||
return noise * self.sigmas[0] | ||
|
||
def scale_model_input(self, sample: torch.FloatTensor, step: int): | ||
""" | ||
Scales the model input (`sample`) to the correct range for the scheduler. | ||
""" | ||
sigma = self.sigmas[self.num_inference_steps - step - 1] | ||
return sample / ((sigma**2 + 1) ** 0.5) | ||
|
||
def get_noise_condition(self, step: int): | ||
""" | ||
Returns the input noise condition for a model. | ||
""" | ||
return self.schedule[step] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No one-liner functions please |
||
|
||
def get_lms_coefficient(self, order, t, current_order): | ||
""" | ||
Compute a linear multistep coefficient. | ||
|
@@ -133,13 +154,11 @@ def set_timesteps(self, num_inference_steps: int): | |
the number of diffusion steps used when generating samples with a pre-trained model. | ||
""" | ||
self.num_inference_steps = num_inference_steps | ||
self.timesteps = np.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=float) | ||
self.timesteps = np.arange(0, num_inference_steps)[::-1].copy() | ||
|
||
low_idx = np.floor(self.timesteps).astype(int) | ||
high_idx = np.ceil(self.timesteps).astype(int) | ||
frac = np.mod(self.timesteps, 1.0) | ||
self.schedule = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float) | ||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) | ||
sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx] | ||
sigmas = np.interp(self.schedule[::-1], np.arange(0, len(sigmas)), sigmas) | ||
self.sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) | ||
|
||
self.derivatives = [] | ||
|
@@ -172,6 +191,7 @@ def step( | |
When returning a tuple, the first element is the sample tensor. | ||
|
||
""" | ||
timestep = int(self.num_inference_steps - timestep - 1) | ||
sigma = self.sigmas[timestep] | ||
|
||
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise | ||
|
Uh oh!
There was an error while loading. Please reload this page.