Skip to content

Add callback parameters for Stable Diffusion pipelines #521

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Oct 2, 2022

Conversation

jamestiotio
Copy link
Contributor

@jamestiotio jamestiotio commented Sep 15, 2022

This PR adds the parameters callback and callback_frequency to the __call__ methods of the Stable Diffusion pipelines.

This PR closes #459.

As discussed in the accompanying issue, these callbacks can be helpful for consumers of the diffusers API who would like to inspect the current progress of the pipeline.

Instead of defining both callback and img_callback such as those implemented in the original Stable Diffusion Repository, this implementation follows the more general version of the callback, with the callback function having the following signature:

def callback(
    step: int,
    timestep: np.ndarray,
    latents: torch.FloatTensor,
    image: Union[List[PIL.Image.Image], np.ndarray]
):
    # ...

This way, downstream consumers can select whichever parameters they want to use to implement incremental diffusion. This also avoids unnecessary, duplicated parameters.

Consumers can also specify the callback_frequency parameter to indicate how often they would like the callback function to be called. If callback is defined but callback_frequency is not specified, the callback function will be called at every step.

Signed-off-by: James R T [email protected]

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Sep 15, 2022

The documentation is not available anymore as the PR was closed or merged.

@keturn
Copy link
Contributor

keturn commented Sep 15, 2022

When I wanted to see how things were working under the hood, I did this exact thing. Except since I was snoopy I sent the latents themselves and the timestep, not just the decoded image and step number.

I'm on the fence whether it would be a good idea to make the callback signature more complicated to support more use cases, or whether it's best to do the simpler thing that is what people are interested in 95% of the time.

The other thing to note is that this dovetails with the discussion in #374. We could expose intermediate results by yielding them from a generator, or by sending them back through a callback, but putting both interfaces in the same pipeline would be a mess.

@keturn
Copy link
Contributor

keturn commented Sep 15, 2022

I'm on the fence whether it would be a good idea to make the callback signature more complicated to support more use cases, or whether it's best to do the simpler thing that is what people are interested in 95% of the time.

If it's worth including an option to result the latents in the output (as proposed by #506), then it's worth including them in the intermediate results.

@deepdiffuser
Copy link

thanks for making this PR @jamestiotio! this will be useful for us in https://github.com/franksalim/franksalim-imagetools

@jamestiotio
Copy link
Contributor Author

jamestiotio commented Sep 16, 2022

@keturn @deepdiffuser I have implemented your suggestions and fixes.

If it's worth including an option to result the latents in the output (as proposed by #506), then it's worth including them in the intermediate results.

Alright, I have refactored the implementation to allow the callback function to take in the timesteps and latents as parameters as well.

The other thing to note is that this dovetails with the discussion in #374. We could expose intermediate results by yielding them from a generator, or by sending them back through a callback, but putting both interfaces in the same pipeline would be a mess.

Hmm, I am not sure about whether a callback function or a generator would be preferable for this. Personally, I don't mind with either one, but I will leave the decision of which one to implement and, if it is the generator version, how to go about implementing that generator (should we simply yield a tuple containing all 4 parameters?) to the maintainers.

Also, the linked issue contains a comment that raised the problem of returning intermediate results due to the need of running the safety checker, and because of that, I have refactored the implementation to also run the safety checker on every intermediate image before passing them to the callback function. While this will slow down the overall execution time whenever a callback function is defined, I believe that this should be an acceptable compromise.

@patrickvonplaten
Copy link
Contributor

Overall, a bit worried that this is not "bare-bone" / "core-functionality" enough to deserve to be in the main pipelines. If we add more and more of such features the pipelines will quickly explode in terms of complexity.

Would a community pipeline be fine for you maybe? (working on making community pipelines easier to use with pip diffusers)

Also @patil-suraj @anton-l @pcuenca here

@keturn
Copy link
Contributor

keturn commented Sep 22, 2022

On "core functionality":

I think it's a feature common in other implementations, it's useful for any kind of interactive application, and "how do I save the intermediate images" is one of the FAQs we get over on Stable Diffusion discord.

Granted, that last bit is partly because some people start with incorrect assumptions how the inference process works and they expect the 20th step of a 100-step task will give them the same thing they see as the 20th step from a 20 step task, but that's not an argument against it. If there are low-cost ways to better enable explainability tools like diffusers-interpret, that sounds like great core functionality.

I am not sure about whether a callback function or a generator would be preferable for this.

After building out a little demo, I don't know whether I'd argue a generator is a better external interface for it, but it is certainly easier to wrap a generator implementation and make it invoke callbacks than it would be to try to do it the other way around.

@jamestiotio
Copy link
Contributor Author

jamestiotio commented Sep 23, 2022

Regarding the concern on "core functionality", I agree with @keturn.

After building out a little demo, I don't know whether I'd argue a generator is a better external interface for it, but it is certainly easier to wrap a generator implementation and make it invoke callbacks than it would be to try to do it the other way around.

Thank you for the demo @keturn, it definitely helped me visualise how a generator version of the pipeline can be used.

Another concern that came to my mind after checking out your demo was backward compatibility. Implementing a generator version by default (i.e., without the user specifying some kind of flag to __call__) might break current applications, since any calls to StableDiffusionPipeline(...) and its derivatives would need to be modified in two major ways:

  1. Users who are not interested in the intermediate outputs would not be able to use this code snippet anymore:

    image = pipe(prompt).images[0]

    Instead, users would need to refactor their implementation by using yield from and needlessly iterating through all of the intermediate outputs (and throwing them away). This brings us to the second point.

  2. The calls to the pipelines have to be put into a function to utilize the yield from keyword. While this might seem trivial since using a function to wrap around the pipeline is highly recommended for all of its usual benefits (modularity, maintainability, etc.), and while production-grade applications will most likely use functions anyway, there might be applications that do not use functions for this. The demo examples in various READMEs of diffusers here, here, and here also do not explicitly put the pipelines in functions (understandably so, since they are simply showing the most basic usage of the pipelines).

Providing an explicit, user-specified flag to __call__ to indicate that the user is aware of this might solve this issue, but I am not sure if this is desirable. Such a flag would basically make the generator version opt-in. Any thoughts about this?

@pcuenca
Copy link
Member

pcuenca commented Sep 23, 2022

Overall, a bit worried that this is not "bare-bone" / "core-functionality" enough

I'm torn about this one. On one hand, it is true that adding too many pieces will imply conditional code, additional arguments, and other forms of complexity. From that point of view, I'd rather have it as a community pipeline. On the other hand, this particular feature might be general enough to support many different use cases, as shown by the community. In particular, explainability and interpretability could benefit from it, like diffusers-interpret shows, or this puzzling latents visualization. A community pipeline would also work fine, but it has a couple of drawbacks too: discoverability, and keeping implementations in sync with the main code.

Other considerations:

  • I don't see an easy way to make this compatible with the upcoming Flax implementations.
  • What to do with the safety checker.

I think diffusers has to be welcoming to tinkerers and at the same time not overwhelming. Is this general enough to deserve to be part of the core? After reflecting a bit, I'm leaning towards yes, but I'd love to hear more arguments.

On the design itself, after reading the different versions of the code I have a couple of initial comments:

  • It should just be the latents alone, in my opinion. Not all uses require decoding.
  • I suggest callback_frequency should be called callback_steps or something along those lines.
  • The generator approach is really interesting, I loved the code. But I think it's hard to keep backwards compatibility and support the most common case (no intermediates are required) without substantially breaking the linearity of the code, which is important for clarity.

I'm happy to do a more thorough review after some more discussion :)

@jamestiotio
Copy link
Contributor Author

jamestiotio commented Sep 23, 2022

  • It should just be the latents alone, in my opinion. Not all uses require decoding.
  • I suggest callback_frequency should be called callback_steps or something along those lines.

@pcuenca I have implemented some of your suggestions above.

  • What to do with the safety checker.

Some of my thoughts on this aspect:

  1. If we do not run the safety checker, the safety checker is effectively rendered useless since users can bypass it by extracting the latents from the final step and decoding it themselves to get the final output image.

  2. If we run the safety checker, then there are two remaining concerns:

    1. This will require decoding the latents using the VAE and passing it to the safety checker. Since the latents are already decoded anyway, we might as well pass it to the callback function to avoid duplicate effort by users who might want to decode the latents.
    2. What value should we return as the latents if the safety checker detects an image with potential NSFW content? If we simply pass the original, unmodified latents back, then users would be able to bypass the safety checker as well (same scenario as the one described in 1).

Signed-off-by: James R T <[email protected]>
@ahrm
Copy link

ahrm commented Sep 24, 2022

My two cents on this issue is that a callback functionality is essential for most interactive applications. I had to implement it myself for my desktop frontend and I was very surprised to find out that it is not a built-in feature.

I understand your desire to keep the interface simple, but I would say this functionality is too important not to include. The main problem is that you can not simply subclass StableDiffusionPipeline and add this functionality, currently the only way to do it is to copy and modify the source code, which may cause a lot of outdated and fragmented implementations of this functionality in the wild.

One possible solution for this issue without complicating the interface is to have a method, say, on_iteration_end which is a noop by default but it is called after every iteration. If we have this, the user can simply subclass StableDiffusionPipeline and override this method, without complicating the call signature.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR is good to go for me then! Thanks a lot for your thoughtful comments here.
Agree that callbacks are general enough to be implemented in the core-pipelines.

I like the changes as they are done now since they are kept simple. Regarding your two questions @pcuenca:

  1. Flax doesn't need to mirror the behavior here
  2. I think the way the safety checker is handled in this PR is nice.

We just need to add three tests to merge this PR @jamestiotio could you maybe add three test showing intermediate decoding of the pipelines in test_pipelines.py?

Copy link
Member

@pcuenca pcuenca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reviewed the main Stable Diffusion pipeline and suggested some changes, essentially to restore the decoding and safety checker to inside the __call__ function as they were before, since we are just sending the latents to the callback. I also commented a few nits about the timestep sent to the callback, I believe it should just be an int.

The same comments would apply to the other pipelines, so I didn't repeat them there. I did add an additional comment in the ONNX pipeline to send the latents as numpy arrays for consistency, but @anton-l is much more knowledgeable about that.

@pcuenca
Copy link
Member

pcuenca commented Sep 28, 2022

  1. I think the way the safety checker is handled in this PR is nice.

I'm asking because in my opinion we should send the latents to the callback instead of the decoded image, to make it as general as possible. There are several examples of users leveraging the latents for their work. If that's the case, we won't decode the image and therefore won't run the safety checker for the user.

Downstream applications that expose the intermediate latents should therefore do it themselves. Is that an acceptable solution @patrickvonplaten? I wrote my review assuming that it was :)

Final question, what do you think about @ahrm's proposal to use a subclassable function (empty in the StableDiffusionPipeline) to provide the latents? The advantage is that less parameters would be required in the __call__ function.

@pcuenca
Copy link
Member

pcuenca commented Sep 28, 2022

We just need to add three tests to merge this PR @jamestiotio could you maybe add three test showing intermediate decoding of the pipelines in test_pipelines.py?

@jamestiotio would you be able to take care of that? Otherwise I can do it.

@pcuenca
Copy link
Member

pcuenca commented Sep 28, 2022

By the way, great work @jamestiotio, I liked it a lot that you paid a lot of attention to details and included docstrings and type hints :) From my point of view this is very cool and ready to go when we iron out the final details!

@jamestiotio
Copy link
Contributor Author

jamestiotio commented Sep 29, 2022

If that's the case, we won't decode the image and therefore won't run the safety checker for the user.

Downstream applications that expose the intermediate latents should therefore do it themselves.

In that case, as mentioned above by @keturn, would it be better to keep the decode_latents and run_safety_checker functions separated from __call__ as they are now to allow downstream applications to call those functions by themselves? This way, we can minimize code duplication.

Final question, what do you think about @ahrm's proposal to use a subclassable function (empty in the StableDiffusionPipeline) to provide the latents? The advantage is that less parameters would be required in the __call__ function.

As a user of the library, I am impartial to both implementations (subclassable function and _call_ function parameters). Since this is more relevant to a stylistic choice and an ease-of-maintenance issue, I will leave this decision to the maintainers (maybe @patrickvonplaten?). Feedback from other community members is also welcome! :)

We just need to add three tests to merge this PR @jamestiotio could you maybe add three test showing intermediate decoding of the pipelines in test_pipelines.py?

@jamestiotio would you be able to take care of that? Otherwise I can do it.

I should be able to take care of this. Do give me some time to add them.

By the way, great work @jamestiotio, I liked it a lot that you paid a lot of attention to details and included docstrings and type hints :) From my point of view this is very cool and ready to go when we iron out the final details!

Thank you! 😄

@patrickvonplaten
Copy link
Contributor

Awesome this looks good to me. Ok to merge whenever @pcuenca is happy with it :-)

@jamestiotio
Copy link
Contributor Author

jamestiotio commented Oct 1, 2022

@pcuenca As previously mentioned, I have added the corresponding 4 tests (one for each Stable Diffusion pipeline) in commit ddbdec7.

Do let me know if you require any further changes. Cheers!

@jamestiotio jamestiotio requested a review from pcuenca October 1, 2022 11:13
@pcuenca
Copy link
Member

pcuenca commented Oct 1, 2022

Hi @jamestiotio!

Thanks a lot for writing the tests, they look great! However, they fail in my system. I'm seeing two problems:

  • The latents are received by the callback in a cuda device so they can't be converted to a numpy array.
  • The computed values do not match the ones you wrote in the expected slice.

Are you testing in GPU too? Otherwise, I can try to fix them myself :)

@jamestiotio
Copy link
Contributor Author

jamestiotio commented Oct 2, 2022

Hi @pcuenca, apologies for the previously failing tests. It seems that my assumption that each step on a CPU is equivalent to each step on a GPU and would produce the same values of latents was wrong. 😅

I have fixed the tests accordingly in commit fe05ea2.

Since my GPU does not seem to possess enough VRAM to be able to load the full-precision model, I have decided to modify the tests slightly:

  1. For the text2img, img2img, and inpaint pipelines, I used half-precision weights via the fp16 branch, used automatic mixed precision via autocast, and enabled attention slicing.
  2. For the onnx pipeline, I changed the test to use CPUExecutionProvider instead.

All 4 tests pass on my side. Feel free to modify the values of the expected_slice on your side if you decide to use CUDAExecutionProvider for the onnx pipeline and the full-precision model for the rest.

@pcuenca
Copy link
Member

pcuenca commented Oct 2, 2022

Hi @jamestiotio don't worry, I'll merge this one and then update the tests so they use standard arguments.

Thanks a lot!

@pcuenca pcuenca merged commit 2558977 into huggingface:main Oct 2, 2022
prathikr pushed a commit to prathikr/diffusers that referenced this pull request Oct 26, 2022
* Add callback parameters for Stable Diffusion pipelines

Signed-off-by: James R T <[email protected]>

* Lint code with `black --preview`

Signed-off-by: James R T <[email protected]>

* Refactor callback implementation for Stable Diffusion pipelines

* Fix missing imports

Signed-off-by: James R T <[email protected]>

* Fix documentation format

Signed-off-by: James R T <[email protected]>

* Add kwargs parameter to standardize with other pipelines

Signed-off-by: James R T <[email protected]>

* Modify Stable Diffusion pipeline callback parameters

Signed-off-by: James R T <[email protected]>

* Remove useless imports

Signed-off-by: James R T <[email protected]>

* Change types for timestep and onnx latents

* Fix docstring style

* Return decode_latents and run_safety_checker back into __call__

* Remove unused imports

* Add intermediate state tests for Stable Diffusion pipelines

Signed-off-by: James R T <[email protected]>

* Fix intermediate state tests for Stable Diffusion pipelines

Signed-off-by: James R T <[email protected]>

Signed-off-by: James R T <[email protected]>
@nicollegah
Copy link

nicollegah commented Jan 22, 2023

Hi peeps, thanks for your amazing work. Is it possible to use this callback parameter to do masking with the img2img pipeline? What I want is, mask an area of an image that should remain unchanged while the img2img pipeline "normally" affects the rest of the image. If yes: How would I do that?

@patrickvonplaten
Copy link
Contributor

Hey @nicollegah,

Could you maybe open a new issue for this one?

@nicollegah
Copy link

Sure! I just did #2073

yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* Add callback parameters for Stable Diffusion pipelines

Signed-off-by: James R T <[email protected]>

* Lint code with `black --preview`

Signed-off-by: James R T <[email protected]>

* Refactor callback implementation for Stable Diffusion pipelines

* Fix missing imports

Signed-off-by: James R T <[email protected]>

* Fix documentation format

Signed-off-by: James R T <[email protected]>

* Add kwargs parameter to standardize with other pipelines

Signed-off-by: James R T <[email protected]>

* Modify Stable Diffusion pipeline callback parameters

Signed-off-by: James R T <[email protected]>

* Remove useless imports

Signed-off-by: James R T <[email protected]>

* Change types for timestep and onnx latents

* Fix docstring style

* Return decode_latents and run_safety_checker back into __call__

* Remove unused imports

* Add intermediate state tests for Stable Diffusion pipelines

Signed-off-by: James R T <[email protected]>

* Fix intermediate state tests for Stable Diffusion pipelines

Signed-off-by: James R T <[email protected]>

Signed-off-by: James R T <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Incremental Diffusion
8 participants