Skip to content

Commit 7b8e25f

Browse files
authored
Modular backend - add FreeU (#6641)
## Summary FreeU code from #6577. Also fix issue with sometimes slightly different output. ## Related Issues / Discussions #6606 https://invokeai.notion.site/Modular-Stable-Diffusion-Backend-Design-Document-e8952daab5d5472faecdc4a72d377b0d ## QA Instructions Run with and without set `USE_MODULAR_DENOISE` environment. ## Merge Plan Nope. If you think that there should be some kind of tests - feel free to add. ## Checklist - [x] _The PR has a short but descriptive title, suitable for a changelog_ - [ ] _Tests added / updated (if applicable)_ - [ ] _Documentation added / updated (if applicable)_
2 parents de39c5e + db52f56 commit 7b8e25f

File tree

5 files changed

+55
-9
lines changed

5 files changed

+55
-9
lines changed

invokeai/app/invocations/denoise_latents.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
5959
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
6060
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
61+
from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt
6162
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
6263
from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt
6364
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
@@ -795,18 +796,22 @@ def step_callback(state: PipelineIntermediateState) -> None:
795796
if self.cfg_rescale_multiplier > 0:
796797
ext_manager.add_extension(RescaleCFGExt(self.cfg_rescale_multiplier))
797798

799+
### freeu
800+
if self.unet.freeu_config:
801+
ext_manager.add_extension(FreeUExt(self.unet.freeu_config))
802+
798803
# ext: t2i/ip adapter
799804
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
800805

801806
unet_info = context.models.load(self.unet.unet)
802807
assert isinstance(unet_info.model, UNet2DConditionModel)
803808
with (
804-
unet_info.model_on_device() as (model_state_dict, unet),
809+
unet_info.model_on_device() as (cached_weights, unet),
805810
ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls),
806811
# ext: controlnet
807812
ext_manager.patch_extensions(unet),
808813
# ext: freeu, seamless, ip adapter, lora
809-
ext_manager.patch_unet(model_state_dict, unet),
814+
ext_manager.patch_unet(unet, cached_weights),
810815
):
811816
sd_backend = StableDiffusionBackend(unet, scheduler)
812817
denoise_ctx.unet = unet

invokeai/backend/stable_diffusion/diffusion_backend.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,10 @@ def combine_noise_preds(ctx: DenoiseContext) -> torch.Tensor:
100100
if isinstance(guidance_scale, list):
101101
guidance_scale = guidance_scale[ctx.step_index]
102102

103-
return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale)
104-
# return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred)
103+
# Note: Although this `torch.lerp(...)` line is logically equivalent to the current CFG line, it seems to result
104+
# in slightly different outputs. It is suspected that this is caused by small precision differences.
105+
# return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale)
106+
return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred)
105107

106108
def run_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager, conditioning_mode: ConditioningMode):
107109
sample = ctx.latent_model_input

invokeai/backend/stable_diffusion/extensions/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from contextlib import contextmanager
44
from dataclasses import dataclass
5-
from typing import TYPE_CHECKING, Callable, Dict, List
5+
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
66

77
import torch
88
from diffusers import UNet2DConditionModel
@@ -56,5 +56,5 @@ def patch_extension(self, context: DenoiseContext):
5656
yield None
5757

5858
@contextmanager
59-
def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel):
59+
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
6060
yield None
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from __future__ import annotations
2+
3+
from contextlib import contextmanager
4+
from typing import TYPE_CHECKING, Dict, Optional
5+
6+
import torch
7+
from diffusers import UNet2DConditionModel
8+
9+
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
10+
11+
if TYPE_CHECKING:
12+
from invokeai.app.shared.models import FreeUConfig
13+
14+
15+
class FreeUExt(ExtensionBase):
16+
def __init__(
17+
self,
18+
freeu_config: FreeUConfig,
19+
):
20+
super().__init__()
21+
self._freeu_config = freeu_config
22+
23+
@contextmanager
24+
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
25+
unet.enable_freeu(
26+
b1=self._freeu_config.b1,
27+
b2=self._freeu_config.b2,
28+
s1=self._freeu_config.s1,
29+
s2=self._freeu_config.s2,
30+
)
31+
32+
try:
33+
yield
34+
finally:
35+
unet.disable_freeu()

invokeai/backend/stable_diffusion/extensions_manager.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,13 @@ def patch_extensions(self, context: DenoiseContext):
6363
yield None
6464

6565
@contextmanager
66-
def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel):
66+
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
6767
if self._is_canceled and self._is_canceled():
6868
raise CanceledException
6969

70-
# TODO: create logic in PR with extension which uses it
71-
yield None
70+
# TODO: create weight patch logic in PR with extension which uses it
71+
with ExitStack() as exit_stack:
72+
for ext in self._extensions:
73+
exit_stack.enter_context(ext.patch_unet(unet, cached_weights))
74+
75+
yield None

0 commit comments

Comments
 (0)