Skip to content

Commit

Permalink
Modular backend - add FreeU (#6641)
Browse files Browse the repository at this point in the history
## Summary

FreeU code from #6577.
Also fix issue with sometimes slightly different output.

## Related Issues / Discussions

#6606 

https://invokeai.notion.site/Modular-Stable-Diffusion-Backend-Design-Document-e8952daab5d5472faecdc4a72d377b0d

## QA Instructions

Run with and without set `USE_MODULAR_DENOISE` environment.

## Merge Plan

Nope.
If you think that there should be some kind of tests - feel free to add.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
  • Loading branch information
RyanJDick authored Jul 23, 2024
2 parents de39c5e + db52f56 commit 7b8e25f
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 9 deletions.
9 changes: 7 additions & 2 deletions invokeai/app/invocations/denoise_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
Expand Down Expand Up @@ -795,18 +796,22 @@ def step_callback(state: PipelineIntermediateState) -> None:
if self.cfg_rescale_multiplier > 0:
ext_manager.add_extension(RescaleCFGExt(self.cfg_rescale_multiplier))

### freeu
if self.unet.freeu_config:
ext_manager.add_extension(FreeUExt(self.unet.freeu_config))

# ext: t2i/ip adapter
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)

unet_info = context.models.load(self.unet.unet)
assert isinstance(unet_info.model, UNet2DConditionModel)
with (
unet_info.model_on_device() as (model_state_dict, unet),
unet_info.model_on_device() as (cached_weights, unet),
ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls),
# ext: controlnet
ext_manager.patch_extensions(unet),
# ext: freeu, seamless, ip adapter, lora
ext_manager.patch_unet(model_state_dict, unet),
ext_manager.patch_unet(unet, cached_weights),
):
sd_backend = StableDiffusionBackend(unet, scheduler)
denoise_ctx.unet = unet
Expand Down
6 changes: 4 additions & 2 deletions invokeai/backend/stable_diffusion/diffusion_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,10 @@ def combine_noise_preds(ctx: DenoiseContext) -> torch.Tensor:
if isinstance(guidance_scale, list):
guidance_scale = guidance_scale[ctx.step_index]

return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale)
# return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred)
# Note: Although this `torch.lerp(...)` line is logically equivalent to the current CFG line, it seems to result
# in slightly different outputs. It is suspected that this is caused by small precision differences.
# return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale)
return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred)

def run_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager, conditioning_mode: ConditioningMode):
sample = ctx.latent_model_input
Expand Down
4 changes: 2 additions & 2 deletions invokeai/backend/stable_diffusion/extensions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from contextlib import contextmanager
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Dict, List
from typing import TYPE_CHECKING, Callable, Dict, List, Optional

import torch
from diffusers import UNet2DConditionModel
Expand Down Expand Up @@ -56,5 +56,5 @@ def patch_extension(self, context: DenoiseContext):
yield None

@contextmanager
def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel):
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
yield None
35 changes: 35 additions & 0 deletions invokeai/backend/stable_diffusion/extensions/freeu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from __future__ import annotations

from contextlib import contextmanager
from typing import TYPE_CHECKING, Dict, Optional

import torch
from diffusers import UNet2DConditionModel

from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase

if TYPE_CHECKING:
from invokeai.app.shared.models import FreeUConfig


class FreeUExt(ExtensionBase):
def __init__(
self,
freeu_config: FreeUConfig,
):
super().__init__()
self._freeu_config = freeu_config

@contextmanager
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
unet.enable_freeu(
b1=self._freeu_config.b1,
b2=self._freeu_config.b2,
s1=self._freeu_config.s1,
s2=self._freeu_config.s2,
)

try:
yield
finally:
unet.disable_freeu()
10 changes: 7 additions & 3 deletions invokeai/backend/stable_diffusion/extensions_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,13 @@ def patch_extensions(self, context: DenoiseContext):
yield None

@contextmanager
def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel):
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
if self._is_canceled and self._is_canceled():
raise CanceledException

# TODO: create logic in PR with extension which uses it
yield None
# TODO: create weight patch logic in PR with extension which uses it
with ExitStack() as exit_stack:
for ext in self._extensions:
exit_stack.enter_context(ext.patch_unet(unet, cached_weights))

yield None

0 comments on commit 7b8e25f

Please sign in to comment.