Skip to content

Commit db52f56

Browse files
committed
Merge branch 'main' into stalker-modular_freeu
2 parents 5f0fe3c + de39c5e commit db52f56

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+1430
-255
lines changed

invokeai/app/invocations/denoise_latents.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
6161
from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt
6262
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
63+
from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt
6364
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
6465
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
6566
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
@@ -791,6 +792,10 @@ def step_callback(state: PipelineIntermediateState) -> None:
791792

792793
ext_manager.add_extension(PreviewExt(step_callback))
793794

795+
### cfg rescale
796+
if self.cfg_rescale_multiplier > 0:
797+
ext_manager.add_extension(RescaleCFGExt(self.cfg_rescale_multiplier))
798+
794799
### freeu
795800
if self.unet.freeu_config:
796801
ext_manager.add_extension(FreeUExt(self.unet.freeu_config))
Lines changed: 137 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Callable
2+
13
import numpy as np
24
import torch
35
from PIL import Image
@@ -21,7 +23,7 @@
2123
from invokeai.backend.tiles.utils import TBLR, Tile
2224

2325

24-
@invocation("spandrel_image_to_image", title="Image-to-Image", tags=["upscale"], category="upscale", version="1.1.0")
26+
@invocation("spandrel_image_to_image", title="Image-to-Image", tags=["upscale"], category="upscale", version="1.2.0")
2527
class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
2628
"""Run any spandrel image-to-image model (https://github.com/chaiNNer-org/spandrel)."""
2729

@@ -34,8 +36,19 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
3436
tile_size: int = InputField(
3537
default=512, description="The tile size for tiled image-to-image. Set to 0 to disable tiling."
3638
)
39+
scale: float = InputField(
40+
default=4.0,
41+
gt=0.0,
42+
le=16.0,
43+
description="The final scale of the output image. If the model does not upscale the image, this will be ignored.",
44+
)
45+
fit_to_multiple_of_8: bool = InputField(
46+
default=False,
47+
description="If true, the output image will be resized to the nearest multiple of 8 in both dimensions.",
48+
)
3749

38-
def _scale_tile(self, tile: Tile, scale: int) -> Tile:
50+
@classmethod
51+
def scale_tile(cls, tile: Tile, scale: int) -> Tile:
3952
return Tile(
4053
coords=TBLR(
4154
top=tile.coords.top * scale,
@@ -51,20 +64,22 @@ def _scale_tile(self, tile: Tile, scale: int) -> Tile:
5164
),
5265
)
5366

54-
@torch.inference_mode()
55-
def invoke(self, context: InvocationContext) -> ImageOutput:
56-
# Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to
57-
# revisit this.
58-
image = context.images.get_pil(self.image.image_name, mode="RGB")
59-
67+
@classmethod
68+
def upscale_image(
69+
cls,
70+
image: Image.Image,
71+
tile_size: int,
72+
spandrel_model: SpandrelImageToImageModel,
73+
is_canceled: Callable[[], bool],
74+
) -> Image.Image:
6075
# Compute the image tiles.
61-
if self.tile_size > 0:
76+
if tile_size > 0:
6277
min_overlap = 20
6378
tiles = calc_tiles_min_overlap(
6479
image_height=image.height,
6580
image_width=image.width,
66-
tile_height=self.tile_size,
67-
tile_width=self.tile_size,
81+
tile_height=tile_size,
82+
tile_width=tile_size,
6883
min_overlap=min_overlap,
6984
)
7085
else:
@@ -85,60 +100,123 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
85100
# Prepare input image for inference.
86101
image_tensor = SpandrelImageToImageModel.pil_to_tensor(image)
87102

88-
# Load the model.
89-
spandrel_model_info = context.models.load(self.image_to_image_model)
90-
91-
# Run the model on each tile.
92-
with spandrel_model_info as spandrel_model:
93-
assert isinstance(spandrel_model, SpandrelImageToImageModel)
103+
# Scale the tiles for re-assembling the final image.
104+
scale = spandrel_model.scale
105+
scaled_tiles = [cls.scale_tile(tile, scale=scale) for tile in tiles]
94106

95-
# Scale the tiles for re-assembling the final image.
96-
scale = spandrel_model.scale
97-
scaled_tiles = [self._scale_tile(tile, scale=scale) for tile in tiles]
107+
# Prepare the output tensor.
108+
_, channels, height, width = image_tensor.shape
109+
output_tensor = torch.zeros(
110+
(height * scale, width * scale, channels), dtype=torch.uint8, device=torch.device("cpu")
111+
)
98112

99-
# Prepare the output tensor.
100-
_, channels, height, width = image_tensor.shape
101-
output_tensor = torch.zeros(
102-
(height * scale, width * scale, channels), dtype=torch.uint8, device=torch.device("cpu")
103-
)
113+
image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype)
104114

105-
image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype)
106-
107-
for tile, scaled_tile in tqdm(list(zip(tiles, scaled_tiles, strict=True)), desc="Upscaling Tiles"):
108-
# Exit early if the invocation has been canceled.
109-
if context.util.is_canceled():
110-
raise CanceledException
111-
112-
# Extract the current tile from the input tensor.
113-
input_tile = image_tensor[
114-
:, :, tile.coords.top : tile.coords.bottom, tile.coords.left : tile.coords.right
115-
].to(device=spandrel_model.device, dtype=spandrel_model.dtype)
116-
117-
# Run the model on the tile.
118-
output_tile = spandrel_model.run(input_tile)
119-
120-
# Convert the output tile into the output tensor's format.
121-
# (N, C, H, W) -> (C, H, W)
122-
output_tile = output_tile.squeeze(0)
123-
# (C, H, W) -> (H, W, C)
124-
output_tile = output_tile.permute(1, 2, 0)
125-
output_tile = output_tile.clamp(0, 1)
126-
output_tile = (output_tile * 255).to(dtype=torch.uint8, device=torch.device("cpu"))
127-
128-
# Merge the output tile into the output tensor.
129-
# We only keep half of the overlap on the top and left side of the tile. We do this in case there are
130-
# edge artifacts. We don't bother with any 'blending' in the current implementation - for most upscalers
131-
# it seems unnecessary, but we may find a need in the future.
132-
top_overlap = scaled_tile.overlap.top // 2
133-
left_overlap = scaled_tile.overlap.left // 2
134-
output_tensor[
135-
scaled_tile.coords.top + top_overlap : scaled_tile.coords.bottom,
136-
scaled_tile.coords.left + left_overlap : scaled_tile.coords.right,
137-
:,
138-
] = output_tile[top_overlap:, left_overlap:, :]
115+
# Run the model on each tile.
116+
for tile, scaled_tile in tqdm(list(zip(tiles, scaled_tiles, strict=True)), desc="Upscaling Tiles"):
117+
# Exit early if the invocation has been canceled.
118+
if is_canceled():
119+
raise CanceledException
120+
121+
# Extract the current tile from the input tensor.
122+
input_tile = image_tensor[
123+
:, :, tile.coords.top : tile.coords.bottom, tile.coords.left : tile.coords.right
124+
].to(device=spandrel_model.device, dtype=spandrel_model.dtype)
125+
126+
# Run the model on the tile.
127+
output_tile = spandrel_model.run(input_tile)
128+
129+
# Convert the output tile into the output tensor's format.
130+
# (N, C, H, W) -> (C, H, W)
131+
output_tile = output_tile.squeeze(0)
132+
# (C, H, W) -> (H, W, C)
133+
output_tile = output_tile.permute(1, 2, 0)
134+
output_tile = output_tile.clamp(0, 1)
135+
output_tile = (output_tile * 255).to(dtype=torch.uint8, device=torch.device("cpu"))
136+
137+
# Merge the output tile into the output tensor.
138+
# We only keep half of the overlap on the top and left side of the tile. We do this in case there are
139+
# edge artifacts. We don't bother with any 'blending' in the current implementation - for most upscalers
140+
# it seems unnecessary, but we may find a need in the future.
141+
top_overlap = scaled_tile.overlap.top // 2
142+
left_overlap = scaled_tile.overlap.left // 2
143+
output_tensor[
144+
scaled_tile.coords.top + top_overlap : scaled_tile.coords.bottom,
145+
scaled_tile.coords.left + left_overlap : scaled_tile.coords.right,
146+
:,
147+
] = output_tile[top_overlap:, left_overlap:, :]
139148

140149
# Convert the output tensor to a PIL image.
141150
np_image = output_tensor.detach().numpy().astype(np.uint8)
142151
pil_image = Image.fromarray(np_image)
152+
153+
return pil_image
154+
155+
@torch.inference_mode()
156+
def invoke(self, context: InvocationContext) -> ImageOutput:
157+
# Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to
158+
# revisit this.
159+
image = context.images.get_pil(self.image.image_name, mode="RGB")
160+
161+
# Load the model.
162+
spandrel_model_info = context.models.load(self.image_to_image_model)
163+
164+
# The target size of the image, determined by the provided scale. We'll run the upscaler until we hit this size.
165+
# Later, we may mutate this value if the model doesn't upscale the image or if the user requested a multiple of 8.
166+
target_width = int(image.width * self.scale)
167+
target_height = int(image.height * self.scale)
168+
169+
# Do the upscaling.
170+
with spandrel_model_info as spandrel_model:
171+
assert isinstance(spandrel_model, SpandrelImageToImageModel)
172+
173+
# First pass of upscaling. Note: `pil_image` will be mutated.
174+
pil_image = self.upscale_image(image, self.tile_size, spandrel_model, context.util.is_canceled)
175+
176+
# Some models don't upscale the image, but we have no way to know this in advance. We'll check if the model
177+
# upscaled the image and run the loop below if it did. We'll require the model to upscale both dimensions
178+
# to be considered an upscale model.
179+
is_upscale_model = pil_image.width > image.width and pil_image.height > image.height
180+
181+
if is_upscale_model:
182+
# This is an upscale model, so we should keep upscaling until we reach the target size.
183+
iterations = 1
184+
while pil_image.width < target_width or pil_image.height < target_height:
185+
pil_image = self.upscale_image(pil_image, self.tile_size, spandrel_model, context.util.is_canceled)
186+
iterations += 1
187+
188+
# Sanity check to prevent excessive or infinite loops. All known upscaling models are at least 2x.
189+
# Our max scale is 16x, so with a 2x model, we should never exceed 16x == 2^4 -> 4 iterations.
190+
# We'll allow one extra iteration "just in case" and bail at 5 upscaling iterations. In practice,
191+
# we should never reach this limit.
192+
if iterations >= 5:
193+
context.logger.warning(
194+
"Upscale loop reached maximum iteration count of 5, stopping upscaling early."
195+
)
196+
break
197+
else:
198+
# This model doesn't upscale the image. We should ignore the scale parameter, modifying the output size
199+
# to be the same as the processed image size.
200+
201+
# The output size is now the size of the processed image.
202+
target_width = pil_image.width
203+
target_height = pil_image.height
204+
205+
# Warn the user if they requested a scale greater than 1.
206+
if self.scale > 1:
207+
context.logger.warning(
208+
"Model does not increase the size of the image, but a greater scale than 1 was requested. Image will not be scaled."
209+
)
210+
211+
# We may need to resize the image to a multiple of 8. Use floor division to ensure we don't scale the image up
212+
# in the final resize
213+
if self.fit_to_multiple_of_8:
214+
target_width = int(target_width // 8 * 8)
215+
target_height = int(target_height // 8 * 8)
216+
217+
# Final resize. Per PIL documentation, Lanczos provides the best quality for both upscale and downscale.
218+
# See: https://pillow.readthedocs.io/en/stable/handbook/concepts.html#filters-comparison-table
219+
pil_image = pil_image.resize((target_width, target_height), resample=Image.Resampling.LANCZOS)
220+
143221
image_dto = context.images.save(image=pil_image)
144222
return ImageOutput.build(image_dto)

invokeai/backend/stable_diffusion/denoise_context.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -83,47 +83,47 @@ class DenoiseContext:
8383
unet: Optional[UNet2DConditionModel] = None
8484

8585
# Current state of latent-space image in denoising process.
86-
# None until `pre_denoise_loop` callback.
86+
# None until `PRE_DENOISE_LOOP` callback.
8787
# Shape: [batch, channels, latent_height, latent_width]
8888
latents: Optional[torch.Tensor] = None
8989

9090
# Current denoising step index.
91-
# None until `pre_step` callback.
91+
# None until `PRE_STEP` callback.
9292
step_index: Optional[int] = None
9393

9494
# Current denoising step timestep.
95-
# None until `pre_step` callback.
95+
# None until `PRE_STEP` callback.
9696
timestep: Optional[torch.Tensor] = None
9797

9898
# Arguments which will be passed to UNet model.
99-
# Available in `pre_unet`/`post_unet` callbacks, otherwise will be None.
99+
# Available in `PRE_UNET`/`POST_UNET` callbacks, otherwise will be None.
100100
unet_kwargs: Optional[UNetKwargs] = None
101101

102102
# SchedulerOutput class returned from step function(normally, generated by scheduler).
103-
# Supposed to be used only in `post_step` callback, otherwise can be None.
103+
# Supposed to be used only in `POST_STEP` callback, otherwise can be None.
104104
step_output: Optional[SchedulerOutput] = None
105105

106106
# Scaled version of `latents`, which will be passed to unet_kwargs initialization.
107-
# Available in events inside step(between `pre_step` and `post_stop`).
107+
# Available in events inside step(between `PRE_STEP` and `POST_STEP`).
108108
# Shape: [batch, channels, latent_height, latent_width]
109109
latent_model_input: Optional[torch.Tensor] = None
110110

111111
# [TMP] Defines on which conditionings current unet call will be runned.
112-
# Available in `pre_unet`/`post_unet` callbacks, otherwise will be None.
112+
# Available in `PRE_UNET`/`POST_UNET` callbacks, otherwise will be None.
113113
conditioning_mode: Optional[ConditioningMode] = None
114114

115115
# [TMP] Noise predictions from negative conditioning.
116-
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
116+
# Available in `POST_COMBINE_NOISE_PREDS` callback, otherwise will be None.
117117
# Shape: [batch, channels, latent_height, latent_width]
118118
negative_noise_pred: Optional[torch.Tensor] = None
119119

120120
# [TMP] Noise predictions from positive conditioning.
121-
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
121+
# Available in `POST_COMBINE_NOISE_PREDS` callback, otherwise will be None.
122122
# Shape: [batch, channels, latent_height, latent_width]
123123
positive_noise_pred: Optional[torch.Tensor] = None
124124

125125
# Combined noise prediction from passed conditionings.
126-
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
126+
# Available in `POST_COMBINE_NOISE_PREDS` callback, otherwise will be None.
127127
# Shape: [batch, channels, latent_height, latent_width]
128128
noise_pred: Optional[torch.Tensor] = None
129129

invokeai/backend/stable_diffusion/diffusion_backend.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,12 @@ def step(self, ctx: DenoiseContext, ext_manager: ExtensionsManager) -> Scheduler
7676
both_noise_pred = self.run_unet(ctx, ext_manager, ConditioningMode.Both)
7777
ctx.negative_noise_pred, ctx.positive_noise_pred = both_noise_pred.chunk(2)
7878

79-
# ext: override apply_cfg
80-
ctx.noise_pred = self.apply_cfg(ctx)
79+
# ext: override combine_noise_preds
80+
ctx.noise_pred = self.combine_noise_preds(ctx)
8181

8282
# ext: cfg_rescale [modify_noise_prediction]
8383
# TODO: rename
84-
ext_manager.run_callback(ExtensionCallbackType.POST_APPLY_CFG, ctx)
84+
ext_manager.run_callback(ExtensionCallbackType.POST_COMBINE_NOISE_PREDS, ctx)
8585

8686
# compute the previous noisy sample x_t -> x_t-1
8787
step_output = ctx.scheduler.step(ctx.noise_pred, ctx.timestep, ctx.latents, **ctx.inputs.scheduler_step_kwargs)
@@ -95,7 +95,7 @@ def step(self, ctx: DenoiseContext, ext_manager: ExtensionsManager) -> Scheduler
9595
return step_output
9696

9797
@staticmethod
98-
def apply_cfg(ctx: DenoiseContext) -> torch.Tensor:
98+
def combine_noise_preds(ctx: DenoiseContext) -> torch.Tensor:
9999
guidance_scale = ctx.inputs.conditioning_data.guidance_scale
100100
if isinstance(guidance_scale, list):
101101
guidance_scale = guidance_scale[ctx.step_index]

invokeai/backend/stable_diffusion/extension_callback_type.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@ class ExtensionCallbackType(Enum):
99
POST_STEP = "post_step"
1010
PRE_UNET = "pre_unet"
1111
POST_UNET = "post_unet"
12-
POST_APPLY_CFG = "post_apply_cfg"
12+
POST_COMBINE_NOISE_PREDS = "post_combine_noise_preds"
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
import torch
6+
7+
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
8+
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
9+
10+
if TYPE_CHECKING:
11+
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
12+
13+
14+
class RescaleCFGExt(ExtensionBase):
15+
def __init__(self, rescale_multiplier: float):
16+
super().__init__()
17+
self._rescale_multiplier = rescale_multiplier
18+
19+
@staticmethod
20+
def _rescale_cfg(total_noise_pred: torch.Tensor, pos_noise_pred: torch.Tensor, multiplier: float = 0.7):
21+
"""Implementation of Algorithm 2 from https://arxiv.org/pdf/2305.08891.pdf."""
22+
ro_pos = torch.std(pos_noise_pred, dim=(1, 2, 3), keepdim=True)
23+
ro_cfg = torch.std(total_noise_pred, dim=(1, 2, 3), keepdim=True)
24+
25+
x_rescaled = total_noise_pred * (ro_pos / ro_cfg)
26+
x_final = multiplier * x_rescaled + (1.0 - multiplier) * total_noise_pred
27+
return x_final
28+
29+
@callback(ExtensionCallbackType.POST_COMBINE_NOISE_PREDS)
30+
def rescale_noise_pred(self, ctx: DenoiseContext):
31+
if self._rescale_multiplier > 0:
32+
ctx.noise_pred = self._rescale_cfg(
33+
ctx.noise_pred,
34+
ctx.positive_noise_pred,
35+
self._rescale_multiplier,
36+
)

0 commit comments

Comments
 (0)