-
Notifications
You must be signed in to change notification settings - Fork 5.9k
[Schedulers Refactoring] Phase 1: timesteps and scaling #637
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
class BaseScheduler(abc.ABC): | ||
|
||
def scale_initial_noise(self, noise: torch.FloatTensor): | ||
""" | ||
Scales the initial noise to the correct range for the scheduler. | ||
""" | ||
return noise | ||
|
||
def scale_model_input(self, sample: torch.FloatTensor, step: int): | ||
""" | ||
Scales the model input (`sample`) to the correct range for the scheduler. | ||
""" | ||
return sample | ||
|
||
@abc.abstractmethod | ||
def get_noise_condition(self, step: int): | ||
""" | ||
Returns the input noise condition for the model (e.g. `timestep` or `sigma`). | ||
""" | ||
raise NotImplementedError("Scheduler must implement the `get_noise_condition` function.") | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This class combines the new required functions and ideally should be merged with SchedulerMixin
(left standalone for easier reviewing for now).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I agree to have it merged; otherwise schedulers need to inherit from both.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
Not touching the VE schedulers yet, they will follow when we have convertable t<->sigmas. |
sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 | ||
sigmas = sigmas[::-1].copy() | ||
self.sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixes #454
class BaseScheduler(abc.ABC): | ||
|
||
def scale_initial_noise(self, noise: torch.FloatTensor): | ||
""" | ||
Scales the initial noise to the correct range for the scheduler. | ||
""" | ||
return noise | ||
|
||
def scale_model_input(self, sample: torch.FloatTensor, step: int): | ||
""" | ||
Scales the model input (`sample`) to the correct range for the scheduler. | ||
""" | ||
return sample | ||
|
||
@abc.abstractmethod | ||
def get_noise_condition(self, step: int): | ||
""" | ||
Returns the input noise condition for the model (e.g. `timestep` or `sigma`). | ||
""" | ||
raise NotImplementedError("Scheduler must implement the `get_noise_condition` function.") | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I agree to have it merged; otherwise schedulers need to inherit from both.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Initially I thought the changes would result in more verbose code in the pipelines, but then I realized they just replace special cases with function calls, and it's much clearer this way.
I just left a few comments about details I may not be understanding properly.
@@ -130,6 +130,7 @@ def __call__( | |||
generator=generator, | |||
) | |||
latents = latents.to(self.device) | |||
latents = self.scheduler.scale_initial_noise(latents) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -37,6 +38,27 @@ class SchedulerOutput(BaseOutput): | |||
prev_sample: torch.FloatTensor | |||
|
|||
|
|||
class BaseScheduler(abc.ABC): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree to merge it with SchedulerMixin
if we can.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't remove SchedulerMixin
completely yet, as it's used by the Flax schedulers, so I guess we'll address it after the Flax refactoring
return sample | ||
|
||
@abc.abstractmethod | ||
def get_noise_condition(self, step: int): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't understand why this was different to step()
but then I read @anton-l's comment in the description. Sounds good!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function is not intuitive for me to be honest. We pass a step
variable to a get_noise_condition
to then pass the output to a step(...)
function and all this represents for most of our schedulers a timestep
""" | ||
Returns the input noise condition for the model (e.g. `timestep` or `sigma`). | ||
""" | ||
raise NotImplementedError("Scheduler must implement the `get_noise_condition` function.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
However, should we default it to invoking step()
rather than raising an exception? Then only the schedulers that need it would implement it. If you are reading the code for current schedulers, there's one thing less to worry about.
Resolved the merge conflicts, ready for review again :) |
# 1. predict noise model_output | ||
t = self.scheduler.get_noise_condition(step) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's confusing to me that get_noise_condition( )
return a time
integer. Also why is this needed for DDIM?
@@ -155,7 +155,8 @@ def __init__( | |||
|
|||
# setable values | |||
self.num_inference_steps = None | |||
self.timesteps = np.arange(0, num_train_timesteps)[::-1] | |||
self.schedule = np.arange(0, num_train_timesteps) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do we need the schedule
for?
""" | ||
Returns the input noise condition for a model. | ||
""" | ||
return self.schedule[step] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is not intuitive for me
@@ -240,6 +249,8 @@ def step( | |||
# - pred_sample_direction -> "direction pointing to x_t" | |||
# - pred_prev_sample -> "x_t-1" | |||
|
|||
timestep = self.schedule[timestep] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here I'm a bit lost - the code now here is more complex than it was before
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we now passing a different timestep to the scheduler than before that we then revert here?
@@ -176,6 +197,8 @@ def step( | |||
When returning a tuple, the first element is the sample tensor. | |||
|
|||
""" | |||
# FIXME: accounting for the descending sigmas | |||
timestep = int(len(self.timesteps) - timestep - 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand this -> we are reverting timestep
here but give it the same name
This is hard to understand / read
@@ -46,6 +46,14 @@ class StableDiffusionPipeline(DiffusionPipeline): | |||
Model that extracts features from generated images to be used as inputs for the `safety_checker`. | |||
""" | |||
|
|||
vae: AutoencoderKL |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is a bit unrelated to the PR - let's try to put this in a different PR next time ;-)
@@ -230,14 +238,11 @@ def __call__( | |||
if latents.shape != latents_shape: | |||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") | |||
latents = latents.to(self.device) | |||
latents = self.scheduler.scale_initial_noise(latents) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fine with this function! Think it would also be fine to add an if-statement
here (to only do it if the scheduler is continous)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this works correctly at the moment because the sigmas are changed when doing self.scheduler.set_timesteps
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS | ||
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) | ||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, step) | ||
t = self.scheduler.get_noise_condition(step) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here I'm lost -> this function is not intuitive to me. t
!= noise_condition
for me. Not happy if we need to force schedulers to have this function
""" | ||
Returns the input noise condition for a model. | ||
""" | ||
return self.schedule[step] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No one-liner functions please
@@ -36,6 +37,38 @@ class SchedulerOutput(BaseOutput): | |||
prev_sample: torch.FloatTensor | |||
|
|||
|
|||
class BaseScheduler(abc.ABC): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class BaseScheduler(abc.ABC): | |
class BaseScheduler: |
@@ -36,6 +37,38 @@ class SchedulerOutput(BaseOutput): | |||
prev_sample: torch.FloatTensor | |||
|
|||
|
|||
class BaseScheduler(abc.ABC): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not super happy about this. If some schedulers don't need a scale_initial_noise
function we shouldn't force them to call it to get back a no-op, same with scale_model_input
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be honest, I'm not super convinced by this design:
- Don't think we should force schedulers to call functions if they don't need to. Instead we should raise an error if a function is not called for a scheduler. This is also better for backwards comp
- Currently, I don't think the
BaseScheduler
provides much value and I would not be in favor of adding it. The functions:scale_initial_noise
andscale_model_input
are not required by DDIM or PNDM but just K-LMS. Just because K-LMS needs them doesn't warrant to add them as no-ops to DDIM and PNDM IMO. I think it's not good design to say "The default we put in a class every scheduler inherits from is to not scale and then schedulers that need to scale have to overwrite the method". This means if I add a new continuous scheduler now and forget to addscale_initial_noise
then I'll get silent errors. Instead I want to get a big error - I think we're changing too much with too little gains here. We're still starting from a design we have chosen and should try to reduce the mental energy it takes to go from the current design to the new design. Here we introduce a lot of new variables, function names:
self.schedule
,scale_initial_noise
,scale_model_input
,get_noise_condition
, the meaning ofstep
is redefined, sometimes we passstep
totimesteps
sometimes toschedule
=> This is too much. Let's try to limit the new things to be learnt.
Ideas: - I think we should focus much more on global differences between schedulers and derive simple logic from there. Why don't we just give each scheduler a class variable
"discrete"
or"continous"
(I think every scheduler has to be one of the two). We can enforce this by checking inSchedulerMixin
that every scheduler needs to have exactly one of the two. Then depending on whether it's"continous"
or"discrete"
we scale inputs or not. We can throw nice error messages in thestep(...)
function of the scheduler and the model forward that complain if variables are passed in the incorrect space. - I don't think we really need to change
DDIM
orPNDM
's functionality, we just need to adapt K-LMS - I think it would make a lot of sense to also add a
"trained_on_continous"
and"trained_on_discrete"
to the model's config so that we can also throw a nice error in the model's forward if the wrongdtype
is passed
There were some things here that had me scratching my head a bit too, but
|
a possibly-related thread on sampler/scheduler interface design is crowsonkb/k-diffusion#23 |
Alternative to this design: #711 |
Superseded by #719 |
* Fix the LMS pytorch regression * Copy over the changes from huggingface#637 * Copy over the changes from huggingface#637 * Fix betas test
* Fix the LMS pytorch regression * Copy over the changes from huggingface#637 * Copy over the changes from huggingface#637 * Fix betas test
As the refactor in #616 started getting too big to review in one go, I decided to approach this in a couple of smaller PRs that touch as few files as possible.
Note: it's easier to start reviewing from
pipeline_stable_diffusion.py
, as it shows the reasoning behind the API changes better.Basic ideas that this PR implements:
scheduler.timesteps
are now always just integer (descending?) indices in range[0, num_inference_steps)
in inference mode, e.g.[49, 48,...,1, 0]
. This makes sure that we always iterate over the same range for every scheduler.scheduler.schedule
replaces the originaltimesteps
and contains either discrete noise conditions (t
orsigma
) for the model (like in DDIM:[999, 978, ... 20, 0]
) or the resampled ones (like in LMS:[999.0, 977.7, ... 0.0]
)scheduler.scale_initial_noise()
scales the initialtorch.randn
, as sometimes (e.g. in Karras, LMS or VE schedulers) the initial noise is notN(0, 1)
, but rather inN(0, max_sigma^2)
. This function has to be applied after sampling the noise in the pipeline.scheduler.scale_model_input(sample, step)
has to be applied for each UNet input sample, as sometimes the inputs need to be scaled (e.g. for Karras or LMS)scheduler.get_noise_condition(step)
gets the noise condition (t
orsigma
) for the UNet. Sometimes thet
needs to be scaled (e.g. in Karras+Euler), so this is implemented as a future-proof function that can access the scheduler's parameters.scheduler.step()
must always accept a value fromscheduler.timesteps
(rather thanscheduler.schedule
) as input, so that we can use it as index forschedule
,sigmas
or whatever the scheduler needs for the step.TODO: merge the Pytorch schedulers and rebase these changes on top
Coming up in the next PRs (Phase 2+):
t
andsigma
interchangeability to use any continuous scheduler with a discrete model and vise-versa