13
13
from threading import Thread
14
14
from typing import List
15
15
import gradio
16
+ from torchvision .transforms import ToTensor
16
17
import urllib3
17
18
from PIL import Image
18
19
from modules import processing
24
25
from scripts .spartan .control_net import pack_control_net
25
26
from scripts .spartan .shared import logger
26
27
from scripts .spartan .ui import UI
27
- from scripts .spartan .world import World , State
28
+ from scripts .spartan .world import World , State , Job
28
29
29
30
old_sigint_handler = signal .getsignal (signal .SIGINT )
30
31
old_sigterm_handler = signal .getsignal (signal .SIGTERM )
@@ -61,7 +62,7 @@ def show(self, is_img2img):
61
62
return scripts .AlwaysVisible
62
63
63
64
def ui (self , is_img2img ):
64
- extension_ui = UI (world = self .world )
65
+ extension_ui = UI (world = self .world , is_img2img = is_img2img )
65
66
# root, api_exposed = extension_ui.create_ui()
66
67
components = extension_ui .create_ui ()
67
68
@@ -71,77 +72,61 @@ def ui(self, is_img2img):
71
72
# return some components that should be exposed to the api
72
73
return components
73
74
74
- def add_to_gallery (self , processed , p ):
75
- """adds generated images to the image gallery after waiting for all workers to finish"""
75
+ def api_to_internal (self , job ) -> ([], [], [], [], [] ):
76
+ # takes worker response received from api and returns parsed objects in internal sdwui format. E.g. all_seeds
76
77
77
- def processed_inject_image (image , info_index , save_path_override = None , grid = False , response = None ):
78
- image_params : json = response ['parameters' ]
79
- image_info_post : json = json .loads (response ["info" ]) # image info known after processing
80
- num_response_images = image_params ["batch_size" ] * image_params ["n_iter" ]
81
-
82
- seed = None
83
- subseed = None
84
- negative_prompt = None
85
- pos_prompt = None
78
+ image_params : json = job .worker .response ['parameters' ]
79
+ image_info_post : json = json .loads (job .worker .response ["info" ]) # image info known after processing
80
+ all_seeds , all_subseeds , all_negative_prompts , all_prompts , images = [], [], [], [], []
86
81
82
+ for i in range (len (job .worker .response ["images" ])):
87
83
try :
88
- if num_response_images > 1 :
89
- seed = image_info_post ['all_seeds' ][info_index ]
90
- subseed = image_info_post ['all_subseeds' ][info_index ]
91
- negative_prompt = image_info_post ['all_negative_prompts' ][info_index ]
92
- pos_prompt = image_info_post ['all_prompts' ][info_index ]
93
- else :
94
- seed = image_info_post ['seed' ]
95
- subseed = image_info_post ['subseed' ]
96
- negative_prompt = image_info_post ['negative_prompt' ]
97
- pos_prompt = image_info_post ['prompt' ]
84
+ if image_params [ "batch_size" ] * image_params [ "n_iter" ] > 1 :
85
+ all_seeds . append ( image_info_post ['all_seeds' ][i ])
86
+ all_subseeds . append ( image_info_post ['all_subseeds' ][i ])
87
+ all_negative_prompts . append ( image_info_post ['all_negative_prompts' ][i ])
88
+ all_prompts . append ( image_info_post ['all_prompts' ][i ])
89
+ else : # only a single image received
90
+ all_seeds . append ( image_info_post ['seed' ])
91
+ all_subseeds . append ( image_info_post ['subseed' ])
92
+ all_negative_prompts . append ( image_info_post ['negative_prompt' ])
93
+ all_prompts . append ( image_info_post ['prompt' ])
98
94
except IndexError :
99
- # like with controlnet masks, there isn't always full post-gen info, so we use the first images'
100
- logger .debug (f"Image at index { i } for '{ job .worker .label } ' was missing some post-generation data" )
101
- processed_inject_image (image = image , info_index = 0 , response = response )
102
- return
103
-
104
- processed .all_seeds .append (seed )
105
- processed .all_subseeds .append (subseed )
106
- processed .all_negative_prompts .append (negative_prompt )
107
- processed .all_prompts .append (pos_prompt )
108
- processed .images .append (image ) # actual received image
109
-
110
- # generate info-text string
111
- # modules.ui_common -> update_generation_info renders to html below gallery
112
- images_per_batch = p .n_iter * p .batch_size
113
- # zero-indexed position of image in total batch (so including master results)
114
- true_image_pos = len (processed .images ) - 1
115
- num_remote_images = images_per_batch * p .batch_size
116
- if p .n_iter > 1 : # if splitting by batch count
117
- num_remote_images *= p .n_iter - 1
95
+ # # like with controlnet masks, there isn't always full post-gen info, so we use the first images'
96
+ # logger.debug(f"Image at index {info_index} for '{job.worker.label}' was missing some post-generation data")
97
+ # self.processed_inject_image(image=image, info_index=0, job=job, p=p)
98
+ # return
99
+ logger .critical (f"Image at index { i } for '{ job .worker .label } ' was missing some post-generation data" )
100
+ continue
118
101
119
- logger .debug (f"image { true_image_pos + 1 } /{ self .world .p .batch_size * p .n_iter } , "
120
- f"info-index: { info_index } " )
102
+ # parse image
103
+ image_bytes = base64 .b64decode (job .worker .response ["images" ][i ])
104
+ image = Image .open (io .BytesIO (image_bytes ))
105
+ transform = ToTensor ()
106
+ images .append (transform (image ))
121
107
122
- if self .world .thin_client_mode :
123
- p .all_negative_prompts = processed .all_negative_prompts
108
+ return all_seeds , all_subseeds , all_negative_prompts , all_prompts , images
124
109
125
- try :
126
- info_text = image_info_post [ 'infotexts' ][ i ]
127
- except IndexError :
128
- if not grid :
129
- logger . warning ( f"image { true_image_pos + 1 } was missing info-text" )
130
- info_text = processed . infotexts [ 0 ]
131
- info_text += f", Worker Label: { job . worker . label } "
132
- processed . infotexts . append ( info_text )
133
-
134
- # automatically save received image to local disk if desired
135
- if cmd_opts . distributed_remotes_autosave :
136
- save_image (
137
- image = image ,
138
- path = p . outpath_samples if save_path_override is None else save_path_override ,
139
- basename = "" ,
140
- seed = seed ,
141
- prompt = pos_prompt ,
142
- info = info_text ,
143
- extension = opts . samples_format
144
- )
110
+ def inject_job ( self , job : Job , p , pp ) :
111
+ """Adds the work completed by one Job via its worker response to the processing and postprocessing objects"""
112
+ all_seeds , all_subseeds , all_negative_prompts , all_prompts , images = self . api_to_internal ( job )
113
+
114
+ p . seeds . extend ( all_seeds )
115
+ p . subseeds . extend ( all_subseeds )
116
+ p . negative_prompts . extend ( all_negative_prompts )
117
+ p . prompts . extend ( all_prompts )
118
+
119
+ num_local = self . world . p . n_iter * self . world . p . batch_size + ( opts . return_grid - self . world . thin_client_mode )
120
+ num_injected = len ( pp . images ) - self . world . p . batch_size
121
+ for i , image in enumerate ( images ):
122
+ # modules.ui_common -> update_generation_info renders to html below gallery
123
+ gallery_index = num_local + num_injected + i # zero-indexed point of image in total gallery
124
+ job . gallery_map . append ( gallery_index ) # so we know where to edit infotext
125
+ pp . images . append ( image )
126
+ logger . debug ( f"image { gallery_index + 1 + self . world . thin_client_mode } / { self . world . num_gallery () } " )
127
+
128
+ def update_gallery ( self , pp , p ):
129
+ """adds all remotely generated images to the image gallery after waiting for all workers to finish"""
145
130
146
131
# get master ipm by estimating based on worker speed
147
132
master_elapsed = time .time () - self .master_start
@@ -158,8 +143,7 @@ def processed_inject_image(image, info_index, save_path_override=None, grid=Fals
158
143
logger .debug ("all worker request threads returned" )
159
144
webui_state .textinfo = "Distributed - injecting images"
160
145
161
- # some worker which we know has a good response that we can use for generating the grid
162
- donor_worker = None
146
+ received_images = False
163
147
for job in self .world .jobs :
164
148
if job .worker .response is None or job .batch_size < 1 or job .worker .master :
165
149
continue
@@ -170,8 +154,7 @@ def processed_inject_image(image, info_index, save_path_override=None, grid=Fals
170
154
if (job .batch_size * p .n_iter ) < len (images ):
171
155
logger .debug (f"requested { job .batch_size } image(s) from '{ job .worker .label } ', got { len (images )} " )
172
156
173
- if donor_worker is None :
174
- donor_worker = job .worker
157
+ received_images = True
175
158
except KeyError :
176
159
if job .batch_size > 0 :
177
160
logger .warning (f"Worker '{ job .worker .label } ' had no images" )
@@ -185,41 +168,27 @@ def processed_inject_image(image, info_index, save_path_override=None, grid=Fals
185
168
logger .exception (e )
186
169
continue
187
170
188
- # visibly add work from workers to the image gallery
189
- for i in range (0 , len (images )):
190
- image_bytes = base64 .b64decode (images [i ])
191
- image = Image .open (io .BytesIO (image_bytes ))
171
+ # adding the images in
172
+ self .inject_job (job , p , pp )
192
173
193
- # inject image
194
- processed_inject_image (image = image , info_index = i , response = job .worker .response )
195
-
196
- if donor_worker is None :
174
+ # TODO fix controlnet masks returned via api having no generation info
175
+ if received_images is False :
197
176
logger .critical ("couldn't collect any responses, the extension will have no effect" )
198
177
return
199
178
200
- # generate and inject grid
201
- if opts .return_grid and len (processed .images ) > 1 :
202
- grid = image_grid (processed .images , len (processed .images ))
203
- processed_inject_image (
204
- image = grid ,
205
- info_index = 0 ,
206
- save_path_override = p .outpath_grids ,
207
- grid = True ,
208
- response = donor_worker .response
209
- )
210
-
211
- # cleanup after we're doing using all the responses
212
- for worker in self .world .get_workers ():
213
- worker .response = None
214
-
215
- p .batch_size = len (processed .images )
179
+ p .batch_size = len (pp .images )
180
+ webui_state .textinfo = ""
216
181
return
217
182
218
183
# p's type is
219
184
# "modules.processing.StableDiffusionProcessing*"
220
185
def before_process (self , p , * args ):
221
- if not self .world .enabled :
222
- logger .debug ("extension is disabled" )
186
+ is_img2img = getattr (p , 'init_images' , False )
187
+ if is_img2img and self .world .enabled_i2i is False :
188
+ logger .debug ("extension is disabled for i2i" )
189
+ return
190
+ elif not is_img2img and self .world .enabled is False :
191
+ logger .debug ("extension is disabled for t2i" )
223
192
return
224
193
self .world .update (p )
225
194
@@ -234,6 +203,14 @@ def before_process(self, p, *args):
234
203
continue
235
204
title = script .title ()
236
205
206
+ if title == "ADetailer" :
207
+ adetailer_args = p .script_args [script .args_from :script .args_to ]
208
+
209
+ # InputAccordion main toggle, skip img2img toggle
210
+ if adetailer_args [0 ] and adetailer_args [1 ]:
211
+ logger .debug (f"adetailer is skipping img2img, returning control to wui" )
212
+ return
213
+
237
214
# check for supported scripts
238
215
if title == "ControlNet" :
239
216
# grab all controlnet units
@@ -346,18 +323,34 @@ def before_process(self, p, *args):
346
323
p .batch_size = self .world .master_job ().batch_size
347
324
self .master_start = time .time ()
348
325
349
- # generate images assigned to local machine
350
- p .do_not_save_grid = True # don't generate grid from master as we are doing this later.
351
326
self .runs_since_init += 1
352
327
return
353
328
354
- def postprocess (self , p , processed , * args ):
355
- if not self .world .enabled :
329
+ def postprocess_batch_list (self , p , pp , * args , ** kwargs ):
330
+ if not self .world .thin_client_mode and p .n_iter != kwargs ['batch_number' ] + 1 : # skip if not the final batch
331
+ return
332
+
333
+ is_img2img = getattr (p , 'init_images' , False )
334
+ if is_img2img and self .world .enabled_i2i is False :
335
+ return
336
+ elif not is_img2img and self .world .enabled is False :
356
337
return
357
338
358
339
if self .master_start is not None :
359
- self .add_to_gallery (p = p , processed = processed )
340
+ self .update_gallery (p = p , pp = pp )
341
+
360
342
343
+ def postprocess (self , p , processed , * args ):
344
+ for job in self .world .jobs :
345
+ if job .worker .response is not None :
346
+ for i , v in enumerate (job .gallery_map ):
347
+ infotext = json .loads (job .worker .response ['info' ])['infotexts' ][i ]
348
+ infotext += f", Worker Label: { job .worker .label } "
349
+ processed .infotexts [v ] = infotext
350
+
351
+ # cleanup
352
+ for worker in self .world .get_workers ():
353
+ worker .response = None
361
354
# restore process_images_inner if it was monkey-patched
362
355
processing .process_images_inner = self .original_process_images_inner
363
356
# save any dangling state to prevent load_config in next iteration overwriting it
0 commit comments