-
Notifications
You must be signed in to change notification settings - Fork 5.9k
[SchedulerDesign] Alternative scheduler design #711
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
@@ -178,7 +179,8 @@ def step( | |||
When returning a tuple, the first element is the sample tensor. | |||
|
|||
""" | |||
sigma = self.sigmas[timestep] | |||
index = (self.config.num_train_timesteps - timestep) // (self.config.num_train_timesteps // self.num_inference_steps) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 I like this bit, having the scheduler responsible for figuring out what to do with the timestep instead of having the pipeline keep track of how schedulers interpret their arguments.
if self.scheduler.type == SchedulerType.CONTINUOUS: | ||
latents = latents * self.scheduler.init_sigma |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A step in the right direction. I'd love to get rid of the if
entirely, but as long as we have it, defining a scheduler.type enum is much preferable to isinstance!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@anton-l can also be convinced here to change it to something that might be cleaner (again if we don't have to force a function upon DDIM or DDPM)
Overall, I feel quite strongly about the following though:
- Let's not make easy schedulers more complex because we'd like to support newer, more complex schedulers
- Forcing every scheduler to implement a certain method that can grow arbitrary in complexity is much worse that educating if statements with a nice comment.
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) | ||
|
||
if self.scheduler.type == SchedulerType.CONTINUOUS and self.model.config.trained_scheduler_type == SchedulerType.DISCRETE: | ||
latent_model_input = self.scheduler.scale_(latent_model_input) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this scale_
implemented?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@patrickvonplaten as you said in #637:
I think we should focus much more on global differences
Which isn't solved fully here for a couple of reasons (that I see at the moment):
- Putting the schedulers into continuous and discrete categories is a bit restrictive IMO, as it implies that these categories have different APIs, while mathematically there's no such restriction (we can freely convert between the timestep representations, as Karras et al. have shown). Having this distinction on the model side can be beneficial to learn what type of noise conditioning the model expects, but we can get away with having it on the pipeline side too.
- With this design we still have a very different representation of
timesteps
between different schedulers: sometimes they're int, sometimes they're float, sometimes they're descending, sometimes ascending. [Schedulers Refactoring] Phase 1: timesteps and scaling #637 addresses this by allowing the schedulers to have their internal notion ofschedule
(orsigmas, or whatever they choose), while only having integer descending
timesteps` in the public API. This makes debugging the pipelines way easier (i.e. timesteps no longer jump all over the place, we only have indices/steps), while also ensuring that literally any scheduler (with further refactoring for our VP, VE and Karras) is usable with e.g. Stable Diffusion.
I agree that #637 turned out to be bulky, so hopefully we can meet somewhere in the middle after iterating a bit.
@@ -178,7 +179,8 @@ def step( | |||
When returning a tuple, the first element is the sample tensor. | |||
|
|||
""" | |||
sigma = self.sigmas[timestep] | |||
index = (self.config.num_train_timesteps - timestep) // (self.config.num_train_timesteps // self.num_inference_steps) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since timestep
is float (don't mind the wrong type annotation and doc for now), I think the surest way to get its index is self.timesteps.where(timesteps)
, as timesteps are linearly interpolated here.
But where()
just looks like a hack, while we can refactor the timesteps and scheduler properly.
@anton-l : RE:
-> I don't understand this. All of our schedulers belong always to one of the two classes no? The math doesn't have to be perfect there it's just important that people understand what timestep representation is expected. We can also find a better name
-> I've seen only two representations so far continous and int and all of these should be called |
latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample | ||
else: | ||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample | ||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could be considered a bug correction IMO
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS | ||
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) | ||
|
||
if self.scheduler.type == SchedulerType.CONTINUOUS and self.model.config.trained_scheduler_type == SchedulerType.DISCRETE: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if self.scheduler.type == SchedulerType.CONTINUOUS and self.model.config.trained_scheduler_type == SchedulerType.DISCRETE: | |
if self.scheduler.type == SchedulerType.CONTINUOUS and self.model.config.trained_scheduler_type == SchedulerType.DISCRETE: |
Here I'm very open to hear better suggestions @anton-l (if you can think about something that doesn't require changes to DDIM or DDPM but that's cleaner)
There's a lot of back-and-forth about that timestamps list and what its internal representation is and what the interfaces to it look like— but is it necessary for a Scheduler to expose that list at all? There might be some. I'm not sure the current pipelines need them at all. As far as I can tell, the only thing they need to know is the t of the current step so they can pass it to It also embodies the assumption that there is a fixed schedule known up front, which doesn't need to be the case. For example https://github.com/LuChengTHU/dpm-solver implements a sampler that uses adaptive time steps— it just keeps going until it's close enough to 0. That's the sort of thing I wasn't going to bring up in this early pass of the redesign, but I do so now because it seems like it might be more freeing than complicating. |
Thanks for bringing it up! Regarding the current design I think we can assume though that timesteps can always be written as a list of values. Timesteps cannot be written as a list of values are exceptional for now IMO (these things can't be traced etc... so not sure how useful they are) |
Closing this PR in favor of: #719 |
What does traced mean in this context? The only disadvantage I could think of to giving up a fixed number of iterations for something adaptive is if there was some JIT compiler doing some loop unrolling or parallelization or something. But CPython isn't that smart, and it's not parallelizable because each iteration depends on the output of the one previous. |
Yeah sorry I more or less meant this by "traced" - JAX won't like this and also not sure if ONNX would be happy about this |
Oh, do we have to play by these JAX rules? Good to know. …though I'm skeptical about the value of JIT-ing that whole inference loop, and even if you did want to, But we can leave that until later. No need to worry over how to implement the adaptive form of DPM-Solver before we're even able to implement the more predictable DPM-Solver-fast sampler. |
We also don't have to play by the JAX rules at all ;-) We want to support JAX as soon as it runs on free Google colabs: googlecolab/colabtools#3009 (comment) so that we give users more power (8 TPUs are pretty powerful). That being said it doesn't mean that all PyTorch functionality will have to take into account how the mirror would work in JAX. It's totally fine if the two frameworks diverge. More generally, it allows helps PyTorch optimization libraries like ONNX, TensorRT a lot though if all memory can be pre-allocated. |
Alternative to #637 cc @anton-l