1
+ from typing import Callable
2
+
1
3
import numpy as np
2
4
import torch
3
5
from PIL import Image
21
23
from invokeai .backend .tiles .utils import TBLR , Tile
22
24
23
25
24
- @invocation ("spandrel_image_to_image" , title = "Image-to-Image" , tags = ["upscale" ], category = "upscale" , version = "1.1 .0" )
26
+ @invocation ("spandrel_image_to_image" , title = "Image-to-Image" , tags = ["upscale" ], category = "upscale" , version = "1.2 .0" )
25
27
class SpandrelImageToImageInvocation (BaseInvocation , WithMetadata , WithBoard ):
26
28
"""Run any spandrel image-to-image model (https://github.com/chaiNNer-org/spandrel)."""
27
29
@@ -34,8 +36,19 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
34
36
tile_size : int = InputField (
35
37
default = 512 , description = "The tile size for tiled image-to-image. Set to 0 to disable tiling."
36
38
)
39
+ scale : float = InputField (
40
+ default = 4.0 ,
41
+ gt = 0.0 ,
42
+ le = 16.0 ,
43
+ description = "The final scale of the output image. If the model does not upscale the image, this will be ignored." ,
44
+ )
45
+ fit_to_multiple_of_8 : bool = InputField (
46
+ default = False ,
47
+ description = "If true, the output image will be resized to the nearest multiple of 8 in both dimensions." ,
48
+ )
37
49
38
- def _scale_tile (self , tile : Tile , scale : int ) -> Tile :
50
+ @classmethod
51
+ def scale_tile (cls , tile : Tile , scale : int ) -> Tile :
39
52
return Tile (
40
53
coords = TBLR (
41
54
top = tile .coords .top * scale ,
@@ -51,20 +64,22 @@ def _scale_tile(self, tile: Tile, scale: int) -> Tile:
51
64
),
52
65
)
53
66
54
- @torch .inference_mode ()
55
- def invoke (self , context : InvocationContext ) -> ImageOutput :
56
- # Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to
57
- # revisit this.
58
- image = context .images .get_pil (self .image .image_name , mode = "RGB" )
59
-
67
+ @classmethod
68
+ def upscale_image (
69
+ cls ,
70
+ image : Image .Image ,
71
+ tile_size : int ,
72
+ spandrel_model : SpandrelImageToImageModel ,
73
+ is_canceled : Callable [[], bool ],
74
+ ) -> Image .Image :
60
75
# Compute the image tiles.
61
- if self . tile_size > 0 :
76
+ if tile_size > 0 :
62
77
min_overlap = 20
63
78
tiles = calc_tiles_min_overlap (
64
79
image_height = image .height ,
65
80
image_width = image .width ,
66
- tile_height = self . tile_size ,
67
- tile_width = self . tile_size ,
81
+ tile_height = tile_size ,
82
+ tile_width = tile_size ,
68
83
min_overlap = min_overlap ,
69
84
)
70
85
else :
@@ -85,60 +100,123 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
85
100
# Prepare input image for inference.
86
101
image_tensor = SpandrelImageToImageModel .pil_to_tensor (image )
87
102
88
- # Load the model.
89
- spandrel_model_info = context .models .load (self .image_to_image_model )
90
-
91
- # Run the model on each tile.
92
- with spandrel_model_info as spandrel_model :
93
- assert isinstance (spandrel_model , SpandrelImageToImageModel )
103
+ # Scale the tiles for re-assembling the final image.
104
+ scale = spandrel_model .scale
105
+ scaled_tiles = [cls .scale_tile (tile , scale = scale ) for tile in tiles ]
94
106
95
- # Scale the tiles for re-assembling the final image.
96
- scale = spandrel_model .scale
97
- scaled_tiles = [self ._scale_tile (tile , scale = scale ) for tile in tiles ]
107
+ # Prepare the output tensor.
108
+ _ , channels , height , width = image_tensor .shape
109
+ output_tensor = torch .zeros (
110
+ (height * scale , width * scale , channels ), dtype = torch .uint8 , device = torch .device ("cpu" )
111
+ )
98
112
99
- # Prepare the output tensor.
100
- _ , channels , height , width = image_tensor .shape
101
- output_tensor = torch .zeros (
102
- (height * scale , width * scale , channels ), dtype = torch .uint8 , device = torch .device ("cpu" )
103
- )
113
+ image_tensor = image_tensor .to (device = spandrel_model .device , dtype = spandrel_model .dtype )
104
114
105
- image_tensor = image_tensor .to (device = spandrel_model .device , dtype = spandrel_model .dtype )
106
-
107
- for tile , scaled_tile in tqdm (list (zip (tiles , scaled_tiles , strict = True )), desc = "Upscaling Tiles" ):
108
- # Exit early if the invocation has been canceled.
109
- if context .util .is_canceled ():
110
- raise CanceledException
111
-
112
- # Extract the current tile from the input tensor.
113
- input_tile = image_tensor [
114
- :, :, tile .coords .top : tile .coords .bottom , tile .coords .left : tile .coords .right
115
- ].to (device = spandrel_model .device , dtype = spandrel_model .dtype )
116
-
117
- # Run the model on the tile.
118
- output_tile = spandrel_model .run (input_tile )
119
-
120
- # Convert the output tile into the output tensor's format.
121
- # (N, C, H, W) -> (C, H, W)
122
- output_tile = output_tile .squeeze (0 )
123
- # (C, H, W) -> (H, W, C)
124
- output_tile = output_tile .permute (1 , 2 , 0 )
125
- output_tile = output_tile .clamp (0 , 1 )
126
- output_tile = (output_tile * 255 ).to (dtype = torch .uint8 , device = torch .device ("cpu" ))
127
-
128
- # Merge the output tile into the output tensor.
129
- # We only keep half of the overlap on the top and left side of the tile. We do this in case there are
130
- # edge artifacts. We don't bother with any 'blending' in the current implementation - for most upscalers
131
- # it seems unnecessary, but we may find a need in the future.
132
- top_overlap = scaled_tile .overlap .top // 2
133
- left_overlap = scaled_tile .overlap .left // 2
134
- output_tensor [
135
- scaled_tile .coords .top + top_overlap : scaled_tile .coords .bottom ,
136
- scaled_tile .coords .left + left_overlap : scaled_tile .coords .right ,
137
- :,
138
- ] = output_tile [top_overlap :, left_overlap :, :]
115
+ # Run the model on each tile.
116
+ for tile , scaled_tile in tqdm (list (zip (tiles , scaled_tiles , strict = True )), desc = "Upscaling Tiles" ):
117
+ # Exit early if the invocation has been canceled.
118
+ if is_canceled ():
119
+ raise CanceledException
120
+
121
+ # Extract the current tile from the input tensor.
122
+ input_tile = image_tensor [
123
+ :, :, tile .coords .top : tile .coords .bottom , tile .coords .left : tile .coords .right
124
+ ].to (device = spandrel_model .device , dtype = spandrel_model .dtype )
125
+
126
+ # Run the model on the tile.
127
+ output_tile = spandrel_model .run (input_tile )
128
+
129
+ # Convert the output tile into the output tensor's format.
130
+ # (N, C, H, W) -> (C, H, W)
131
+ output_tile = output_tile .squeeze (0 )
132
+ # (C, H, W) -> (H, W, C)
133
+ output_tile = output_tile .permute (1 , 2 , 0 )
134
+ output_tile = output_tile .clamp (0 , 1 )
135
+ output_tile = (output_tile * 255 ).to (dtype = torch .uint8 , device = torch .device ("cpu" ))
136
+
137
+ # Merge the output tile into the output tensor.
138
+ # We only keep half of the overlap on the top and left side of the tile. We do this in case there are
139
+ # edge artifacts. We don't bother with any 'blending' in the current implementation - for most upscalers
140
+ # it seems unnecessary, but we may find a need in the future.
141
+ top_overlap = scaled_tile .overlap .top // 2
142
+ left_overlap = scaled_tile .overlap .left // 2
143
+ output_tensor [
144
+ scaled_tile .coords .top + top_overlap : scaled_tile .coords .bottom ,
145
+ scaled_tile .coords .left + left_overlap : scaled_tile .coords .right ,
146
+ :,
147
+ ] = output_tile [top_overlap :, left_overlap :, :]
139
148
140
149
# Convert the output tensor to a PIL image.
141
150
np_image = output_tensor .detach ().numpy ().astype (np .uint8 )
142
151
pil_image = Image .fromarray (np_image )
152
+
153
+ return pil_image
154
+
155
+ @torch .inference_mode ()
156
+ def invoke (self , context : InvocationContext ) -> ImageOutput :
157
+ # Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to
158
+ # revisit this.
159
+ image = context .images .get_pil (self .image .image_name , mode = "RGB" )
160
+
161
+ # Load the model.
162
+ spandrel_model_info = context .models .load (self .image_to_image_model )
163
+
164
+ # The target size of the image, determined by the provided scale. We'll run the upscaler until we hit this size.
165
+ # Later, we may mutate this value if the model doesn't upscale the image or if the user requested a multiple of 8.
166
+ target_width = int (image .width * self .scale )
167
+ target_height = int (image .height * self .scale )
168
+
169
+ # Do the upscaling.
170
+ with spandrel_model_info as spandrel_model :
171
+ assert isinstance (spandrel_model , SpandrelImageToImageModel )
172
+
173
+ # First pass of upscaling. Note: `pil_image` will be mutated.
174
+ pil_image = self .upscale_image (image , self .tile_size , spandrel_model , context .util .is_canceled )
175
+
176
+ # Some models don't upscale the image, but we have no way to know this in advance. We'll check if the model
177
+ # upscaled the image and run the loop below if it did. We'll require the model to upscale both dimensions
178
+ # to be considered an upscale model.
179
+ is_upscale_model = pil_image .width > image .width and pil_image .height > image .height
180
+
181
+ if is_upscale_model :
182
+ # This is an upscale model, so we should keep upscaling until we reach the target size.
183
+ iterations = 1
184
+ while pil_image .width < target_width or pil_image .height < target_height :
185
+ pil_image = self .upscale_image (pil_image , self .tile_size , spandrel_model , context .util .is_canceled )
186
+ iterations += 1
187
+
188
+ # Sanity check to prevent excessive or infinite loops. All known upscaling models are at least 2x.
189
+ # Our max scale is 16x, so with a 2x model, we should never exceed 16x == 2^4 -> 4 iterations.
190
+ # We'll allow one extra iteration "just in case" and bail at 5 upscaling iterations. In practice,
191
+ # we should never reach this limit.
192
+ if iterations >= 5 :
193
+ context .logger .warning (
194
+ "Upscale loop reached maximum iteration count of 5, stopping upscaling early."
195
+ )
196
+ break
197
+ else :
198
+ # This model doesn't upscale the image. We should ignore the scale parameter, modifying the output size
199
+ # to be the same as the processed image size.
200
+
201
+ # The output size is now the size of the processed image.
202
+ target_width = pil_image .width
203
+ target_height = pil_image .height
204
+
205
+ # Warn the user if they requested a scale greater than 1.
206
+ if self .scale > 1 :
207
+ context .logger .warning (
208
+ "Model does not increase the size of the image, but a greater scale than 1 was requested. Image will not be scaled."
209
+ )
210
+
211
+ # We may need to resize the image to a multiple of 8. Use floor division to ensure we don't scale the image up
212
+ # in the final resize
213
+ if self .fit_to_multiple_of_8 :
214
+ target_width = int (target_width // 8 * 8 )
215
+ target_height = int (target_height // 8 * 8 )
216
+
217
+ # Final resize. Per PIL documentation, Lanczos provides the best quality for both upscale and downscale.
218
+ # See: https://pillow.readthedocs.io/en/stable/handbook/concepts.html#filters-comparison-table
219
+ pil_image = pil_image .resize ((target_width , target_height ), resample = Image .Resampling .LANCZOS )
220
+
143
221
image_dto = context .images .save (image = pil_image )
144
222
return ImageOutput .build (image_dto )
0 commit comments