Skip to content

[Pytorch] Pytorch only schedulers #534

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 41 commits into from
Sep 27, 2022
Merged

Conversation

kashif
Copy link
Contributor

@kashif kashif commented Sep 16, 2022

Remove numpy clauses from schedulers to make them pytorch only and fixed use of timesteps in pipelines

@kashif kashif marked this pull request as draft September 16, 2022 15:53
@kashif kashif changed the title pytorch only schedulers [WIP] pytorch only schedulers Sep 16, 2022
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Sep 16, 2022

The documentation is not available anymore as the PR was closed or merged.

@kashif kashif changed the title [WIP] pytorch only schedulers Pytorch only schedulers Sep 18, 2022
@kashif kashif marked this pull request as ready for review September 18, 2022 19:45
@kashif kashif changed the title Pytorch only schedulers [Pytorch] Pytorch only schedulers Sep 19, 2022
@vishnu-anirudh
Copy link
Contributor

@vishnu-anirudh can you kindly check if i am doing the right thing in the pipelines especially with the add_noise() calls thanks!

Hello @kashif , Thanks for your changes so far. I think, in the commit related to the add_noise, the type hint seems to be for timestamps (example).

Though it is a good idea. But it may be good to have consistency in timesteps typehint throughout the code. So if we plan to have numpy typehint for other timesteps, then it may be better to have numpy typehint here (same for torch typehint).

@kashif, I hope it made sense. Please let me know if that's not what you expected me to check as part of the add_noise calls.

Copy link
Member

@anton-l anton-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the changes overall, and the tests are mostly running smoothly, thank you @kashif!

Getting a device mismatch due to self.sigmas always being on cpu here:

>       noisy_samples = original_samples + noise * sigma
E       RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

../src/diffusers/schedulers/scheduling_lms_discrete.py:209: RuntimeError

So scheduler.to(device) might have to be implemented. But I haven't thought about a solution too much, so maybe you have a workaround.

@anton-l
Copy link
Member

anton-l commented Sep 26, 2022

@patrickvonplaten @patil-suraj could you give this PR a quick review if you have time? It'll be easier to rebase #637 if this is merged first.

@pcuenca
Copy link
Member

pcuenca commented Sep 26, 2022

I like the changes overall, and the tests are mostly running smoothly, thank you @kashif!

Getting a device mismatch due to self.sigmas always being on cpu here:

>       noisy_samples = original_samples + noise * sigma
E       RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

../src/diffusers/schedulers/scheduling_lms_discrete.py:209: RuntimeError

So scheduler.to(device) might have to be implemented. But I haven't thought about a solution too much, so maybe you have a workaround.

We already have a move immediately before:

timesteps = timesteps.to(self.sigmas.device)

I would suggest something like this in this case:

sigmas = self.sigmas.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)

Unless we want to do computation in CPU as I think we do in the other schedulers.

@kashif
Copy link
Contributor Author

kashif commented Sep 26, 2022

@pcuenca fixed

@patrickvonplaten patrickvonplaten self-assigned this Sep 27, 2022
@@ -35,7 +35,6 @@ class DDPMPipeline(DiffusionPipeline):

def __init__(self, unet, scheduler):
super().__init__()
scheduler = scheduler.set_format("pt")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once tests are green, let's merge this one as it's quite important :-)

@anton-l
Copy link
Member

anton-l commented Sep 27, 2022

Great work @kashif!

@anton-l anton-l merged commit bd8df2d into huggingface:main Sep 27, 2022
@kashif kashif deleted the remove-numpy branch September 27, 2022 13:31
@ghost ghost mentioned this pull request Sep 27, 2022
@pcuenca pcuenca mentioned this pull request Sep 29, 2022
3 tasks
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* pytorch only schedulers

* fix style

* remove match_shape

* pytorch only ddpm

* remove SchedulerMixin

* remove numpy from karras_ve

* fix types

* remove numpy from lms_discrete

* remove numpy from pndm

* fix typo

* remove mixin and numpy from sde_vp and ve

* remove remaining tensor_format

* fix style

* sigmas has to be torch tensor

* removed set_format in readme

* remove set format from docs

* remove set_format from pipelines

* update tests

* fix typo

* continue to use mixin

* fix imports

* removed unsed imports

* match shape instead of assuming image shapes

* remove import typo

* update call to add_noise

* use math instead of numpy

* fix t_index

* removed commented out numpy tests

* timesteps needs to be discrete

* cast timesteps to int in flax scheduler too

* fix device mismatch issue

* small fix

* Update src/diffusers/schedulers/scheduling_pndm.py

Co-authored-by: Patrick von Platen <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants