Skip to content

[Schedulers Refactoring] Phase 1: timesteps and scaling #637

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 11 commits into from

Conversation

anton-l
Copy link
Member

@anton-l anton-l commented Sep 26, 2022

As the refactor in #616 started getting too big to review in one go, I decided to approach this in a couple of smaller PRs that touch as few files as possible.

Note: it's easier to start reviewing from pipeline_stable_diffusion.py, as it shows the reasoning behind the API changes better.

Basic ideas that this PR implements:

  • scheduler.timesteps are now always just integer (descending?) indices in range [0, num_inference_steps) in inference mode, e.g. [49, 48,...,1, 0]. This makes sure that we always iterate over the same range for every scheduler.
  • scheduler.schedule replaces the original timesteps and contains either discrete noise conditions (t or sigma) for the model (like in DDIM: [999, 978, ... 20, 0]) or the resampled ones (like in LMS: [999.0, 977.7, ... 0.0])
  • scheduler.scale_initial_noise() scales the initial torch.randn, as sometimes (e.g. in Karras, LMS or VE schedulers) the initial noise is not N(0, 1), but rather in N(0, max_sigma^2). This function has to be applied after sampling the noise in the pipeline.
  • scheduler.scale_model_input(sample, step) has to be applied for each UNet input sample, as sometimes the inputs need to be scaled (e.g. for Karras or LMS)
  • scheduler.get_noise_condition(step) gets the noise condition (t or sigma) for the UNet. Sometimes the t needs to be scaled (e.g. in Karras+Euler), so this is implemented as a future-proof function that can access the scheduler's parameters.
  • scheduler.step() must always accept a value from scheduler.timesteps (rather than scheduler.schedule) as input, so that we can use it as index for schedule, sigmas or whatever the scheduler needs for the step.

TODO: merge the Pytorch schedulers and rebase these changes on top

Coming up in the next PRs (Phase 2+):

  • t and sigma interchangeability to use any continuous scheduler with a discrete model and vise-versa
  • better (readable/intuitive) support for Nth order solvers that require combining multiple forward passes of the UNet

Comment on lines 41 to 62
class BaseScheduler(abc.ABC):

def scale_initial_noise(self, noise: torch.FloatTensor):
"""
Scales the initial noise to the correct range for the scheduler.
"""
return noise

def scale_model_input(self, sample: torch.FloatTensor, step: int):
"""
Scales the model input (`sample`) to the correct range for the scheduler.
"""
return sample

@abc.abstractmethod
def get_noise_condition(self, step: int):
"""
Returns the input noise condition for the model (e.g. `timestep` or `sigma`).
"""
raise NotImplementedError("Scheduler must implement the `get_noise_condition` function.")


Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class combines the new required functions and ideally should be merged with SchedulerMixin (left standalone for easier reviewing for now).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree to have it merged; otherwise schedulers need to inherit from both.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@anton-l anton-l requested a review from patil-suraj September 26, 2022 01:38
@anton-l
Copy link
Member Author

anton-l commented Sep 26, 2022

Not touching the VE schedulers yet, they will follow when we have convertable t<->sigmas.
But overall: ready for review!

Comment on lines 95 to 97
sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
sigmas = sigmas[::-1].copy()
self.sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixes #454

Comment on lines 41 to 62
class BaseScheduler(abc.ABC):

def scale_initial_noise(self, noise: torch.FloatTensor):
"""
Scales the initial noise to the correct range for the scheduler.
"""
return noise

def scale_model_input(self, sample: torch.FloatTensor, step: int):
"""
Scales the model input (`sample`) to the correct range for the scheduler.
"""
return sample

@abc.abstractmethod
def get_noise_condition(self, step: int):
"""
Returns the input noise condition for the model (e.g. `timestep` or `sigma`).
"""
raise NotImplementedError("Scheduler must implement the `get_noise_condition` function.")


Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree to have it merged; otherwise schedulers need to inherit from both.

Copy link
Member

@pcuenca pcuenca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initially I thought the changes would result in more verbose code in the pipelines, but then I realized they just replace special cases with function calls, and it's much clearer this way.

I just left a few comments about details I may not be understanding properly.

@@ -130,6 +130,7 @@ def __call__(
generator=generator,
)
latents = latents.to(self.device)
latents = self.scheduler.scale_initial_noise(latents)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do schedulers expect the latents in the final device or in CPU? As per this PR it looks like LMSDiscreteScheduler prepares sigmas in the CPU.

We could either:

  • Always follow the tensor, and move the sigmas etc. if necessary.
  • Implement scheduler.to() as @anton-l suggested in that PR.

@@ -37,6 +38,27 @@ class SchedulerOutput(BaseOutput):
prev_sample: torch.FloatTensor


class BaseScheduler(abc.ABC):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree to merge it with SchedulerMixin if we can.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't remove SchedulerMixin completely yet, as it's used by the Flax schedulers, so I guess we'll address it after the Flax refactoring

return sample

@abc.abstractmethod
def get_noise_condition(self, step: int):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't understand why this was different to step() but then I read @anton-l's comment in the description. Sounds good!

Copy link
Contributor

@patrickvonplaten patrickvonplaten Sep 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function is not intuitive for me to be honest. We pass a step variable to a get_noise_condition to then pass the output to a step(...) function and all this represents for most of our schedulers a timestep

"""
Returns the input noise condition for the model (e.g. `timestep` or `sigma`).
"""
raise NotImplementedError("Scheduler must implement the `get_noise_condition` function.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, should we default it to invoking step() rather than raising an exception? Then only the schedulers that need it would implement it. If you are reading the code for current schedulers, there's one thing less to worry about.

@patrickvonplaten patrickvonplaten self-assigned this Sep 27, 2022
@anton-l
Copy link
Member Author

anton-l commented Sep 27, 2022

Resolved the merge conflicts, ready for review again :)

anton-l added a commit that referenced this pull request Sep 28, 2022
anton-l added a commit that referenced this pull request Sep 28, 2022
anton-l added a commit that referenced this pull request Sep 28, 2022
* Fix the LMS pytorch regression

* Copy over the changes from #637

* Copy over the changes from #637

* Fix betas test
# 1. predict noise model_output
t = self.scheduler.get_noise_condition(step)
Copy link
Contributor

@patrickvonplaten patrickvonplaten Sep 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's confusing to me that get_noise_condition( ) return a time integer. Also why is this needed for DDIM?

@@ -155,7 +155,8 @@ def __init__(

# setable values
self.num_inference_steps = None
self.timesteps = np.arange(0, num_train_timesteps)[::-1]
self.schedule = np.arange(0, num_train_timesteps)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do we need the schedule for?

"""
Returns the input noise condition for a model.
"""
return self.schedule[step]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is not intuitive for me

@@ -240,6 +249,8 @@ def step(
# - pred_sample_direction -> "direction pointing to x_t"
# - pred_prev_sample -> "x_t-1"

timestep = self.schedule[timestep]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I'm a bit lost - the code now here is more complex than it was before

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we now passing a different timestep to the scheduler than before that we then revert here?

@@ -176,6 +197,8 @@ def step(
When returning a tuple, the first element is the sample tensor.

"""
# FIXME: accounting for the descending sigmas
timestep = int(len(self.timesteps) - timestep - 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand this -> we are reverting timestep here but give it the same name

This is hard to understand / read

@@ -46,6 +46,14 @@ class StableDiffusionPipeline(DiffusionPipeline):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""

vae: AutoencoderKL
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a bit unrelated to the PR - let's try to put this in a different PR next time ;-)

@@ -230,14 +238,11 @@ def __call__(
if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
latents = latents.to(self.device)
latents = self.scheduler.scale_initial_noise(latents)
Copy link
Contributor

@patrickvonplaten patrickvonplaten Sep 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine with this function! Think it would also be fine to add an if-statement here (to only do it if the scheduler is continous)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this works correctly at the moment because the sigmas are changed when doing self.scheduler.set_timesteps

# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, step)
t = self.scheduler.get_noise_condition(step)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I'm lost -> this function is not intuitive to me. t != noise_condition for me. Not happy if we need to force schedulers to have this function

"""
Returns the input noise condition for a model.
"""
return self.schedule[step]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No one-liner functions please

@@ -36,6 +37,38 @@ class SchedulerOutput(BaseOutput):
prev_sample: torch.FloatTensor


class BaseScheduler(abc.ABC):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class BaseScheduler(abc.ABC):
class BaseScheduler:

@@ -36,6 +37,38 @@ class SchedulerOutput(BaseOutput):
prev_sample: torch.FloatTensor


class BaseScheduler(abc.ABC):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not super happy about this. If some schedulers don't need a scale_initial_noise function we shouldn't force them to call it to get back a no-op, same with scale_model_input

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be honest, I'm not super convinced by this design:

  • Don't think we should force schedulers to call functions if they don't need to. Instead we should raise an error if a function is not called for a scheduler. This is also better for backwards comp
  • Currently, I don't think the BaseScheduler provides much value and I would not be in favor of adding it. The functions: scale_initial_noise and scale_model_input are not required by DDIM or PNDM but just K-LMS. Just because K-LMS needs them doesn't warrant to add them as no-ops to DDIM and PNDM IMO. I think it's not good design to say "The default we put in a class every scheduler inherits from is to not scale and then schedulers that need to scale have to overwrite the method". This means if I add a new continuous scheduler now and forget to add scale_initial_noise then I'll get silent errors. Instead I want to get a big error
  • I think we're changing too much with too little gains here. We're still starting from a design we have chosen and should try to reduce the mental energy it takes to go from the current design to the new design. Here we introduce a lot of new variables, function names:
    self.schedule, scale_initial_noise, scale_model_input, get_noise_condition, the meaning of step is redefined, sometimes we pass step to timesteps sometimes to schedule => This is too much. Let's try to limit the new things to be learnt.
    Ideas:
  • I think we should focus much more on global differences between schedulers and derive simple logic from there. Why don't we just give each scheduler a class variable "discrete" or "continous" (I think every scheduler has to be one of the two). We can enforce this by checking in SchedulerMixin that every scheduler needs to have exactly one of the two. Then depending on whether it's "continous" or "discrete" we scale inputs or not. We can throw nice error messages in the step(...) function of the scheduler and the model forward that complain if variables are passed in the incorrect space.
  • I don't think we really need to change DDIM or PNDM's functionality, we just need to adapt K-LMS
  • I think it would make a lot of sense to also add a "trained_on_continous" and "trained_on_discrete" to the model's config so that we can also throw a nice error in the model's forward if the wrong dtype is passed

@keturn
Copy link
Contributor

keturn commented Sep 30, 2022

There were some things here that had me scratching my head a bit too, but

  1. I don't understand the differences between the continuous and discrete approaches well enough. Maybe I should try reading Kerras's Elucidating again? Or is there an example of a scheduler with the same underlying math written in both styles that I could read to compare?
  2. I was extending significant credit on the basis of this being a "part one of two" PR, expecting that fields or methods that seem trivial now will get extended in interesting ways in Part Two.

@keturn
Copy link
Contributor

keturn commented Sep 30, 2022

a possibly-related thread on sampler/scheduler interface design is crowsonkb/k-diffusion#23

@patrickvonplaten
Copy link
Contributor

Alternative to this design: #711

@anton-l
Copy link
Member Author

anton-l commented Oct 5, 2022

Superseded by #719

@anton-l anton-l closed this Oct 5, 2022
prathikr pushed a commit to prathikr/diffusers that referenced this pull request Oct 26, 2022
* Fix the LMS pytorch regression

* Copy over the changes from huggingface#637

* Copy over the changes from huggingface#637

* Fix betas test
@anton-l anton-l deleted the scheduler-refactor-v2 branch November 17, 2022 14:54
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* Fix the LMS pytorch regression

* Copy over the changes from huggingface#637

* Copy over the changes from huggingface#637

* Fix betas test
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants