Skip to content

Commit 91df0fa

Browse files
authored
Merge pull request #968 from rhyswynn/automatic1111-webui
Update to split sampler and scheduler selections
2 parents 3224268 + c37c19a commit 91df0fa

File tree

6 files changed

+51
-28
lines changed

6 files changed

+51
-28
lines changed

scripts/deforum_helpers/args.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import modules.paths as ph
2323
import modules.shared as sh
2424
from modules.processing import get_fixed_seed
25-
from .defaults import get_guided_imgs_default_json, mask_fill_choices, get_samplers_list
25+
from .defaults import get_guided_imgs_default_json, mask_fill_choices, get_samplers_list, get_schedulers_list
2626
from .deforum_controlnet import controlnet_component_names
2727
from .general_utils import get_os, substitute_placeholders
2828

@@ -766,6 +766,12 @@ def DeforumArgs():
766766
"choices": get_samplers_list().values(),
767767
"value": "Euler a",
768768
},
769+
"scheduler": {
770+
"label": "Scheduler",
771+
"type": "dropdown",
772+
"choices": get_schedulers_list().values(),
773+
"value": "Automatic",
774+
},
769775
"steps": {
770776
"label": "Steps",
771777
"type": "slider",

scripts/deforum_helpers/defaults.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,27 +24,27 @@ def get_samplers_list():
2424
'dpm2 a': 'DPM2 a',
2525
'dpm++ 2s a': 'DPM++ 2S a',
2626
'dpm++ 2m': 'DPM++ 2M',
27+
'dpm++ 2m sde': 'DPM++ 2M SDE',
2728
'dpm++ sde': 'DPM++ SDE',
28-
'dpm++ 2m sde karras': 'DPM++ 2M SDE Karras',
2929
'dpm fast': 'DPM fast',
3030
'dpm adaptive': 'DPM adaptive',
31-
'lms karras': 'LMS Karras',
32-
'dpm2 karras': 'DPM2 Karras',
33-
'dpm2 a karras': 'DPM2 a Karras',
34-
'dpm++ 2s a karras': 'DPM++ 2S a Karras',
35-
'dpm++ 2m karras': 'DPM++ 2M Karras',
36-
'dpm++ sde karras': 'DPM++ SDE Karras',
37-
'dpm++ 2m sde exponential': 'DPM++ 2M SDE Exponential',
3831
'dpm++ 2m sde heun': 'DPM++ 2M SDE Heun',
39-
'dpm++ 2m sde heun karras': 'DPM++ 2M SDE Heun Karras',
40-
'dpm++ 2m sde Heun Exponential': 'DPM++ 2M SDE Heun Exponential',
4132
'dpm++ 3m sde': 'DPM++ 3M SDE',
42-
'dpm++ 3m sde karras': 'DPM++ 3M SDE Karras',
43-
'dpm++ 3m sde exponential': 'DPM++ 3M SDE Exponential',
4433
'ddim': 'DDIM',
4534
'plms': 'PLMS',
4635
'unipc': 'UniPC',
47-
'restart': 'Restart'
36+
'restart': 'Restart',
37+
'lcm': 'LCM'
38+
}
39+
40+
def get_schedulers_list():
41+
return {
42+
'automatic': 'Automatic',
43+
'uniform': 'Uniform',
44+
'karras': 'Karras',
45+
'exponential': 'Exponential',
46+
'polyexponential': 'Polyexponential',
47+
'sgm uniform': 'SGM Uniform'
4848
}
4949

5050
def DeforumAnimPrompts():

scripts/deforum_helpers/generate.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from .load_images import load_img, prepare_mask, check_mask_for_errors
2929
from .webui_sd_pipeline import get_webui_sd_pipeline
3030
from .rich import console
31-
from .defaults import get_samplers_list
31+
from .defaults import get_samplers_list, get_schedulers_list
3232
from .prompt import check_is_number
3333
from .opts_overrider import A1111OptionsOverrider
3434
import cv2
@@ -70,14 +70,14 @@ def pairwise_repl(iterable):
7070
next(b, None)
7171
return zip(a, b)
7272

73-
def generate(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame=0, sampler_name=None):
73+
def generate(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame=0, sampler_name=None, scheduler_name=None):
7474
if state.interrupted:
7575
return None
7676

7777
if args.reroll_blank_frames == 'ignore':
78-
return generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame, sampler_name)
78+
return generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame, sampler_name, scheduler_name)
7979

80-
image, caught_vae_exception = generate_with_nans_check(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame, sampler_name)
80+
image, caught_vae_exception = generate_with_nans_check(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame, sampler_name, scheduler_name)
8181

8282
if caught_vae_exception or not image.getbbox():
8383
patience = args.reroll_patience
@@ -86,7 +86,7 @@ def generate(args, keys, anim_args, loop_args, controlnet_args, root, parseq_ada
8686
while caught_vae_exception or not image.getbbox():
8787
print("Rerolling with +1 seed...")
8888
args.seed += 1
89-
image, caught_vae_exception = generate_with_nans_check(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame, sampler_name)
89+
image, caught_vae_exception = generate_with_nans_check(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame, sampler_name, scheduler_name=None)
9090
patience -= 1
9191
if patience == 0:
9292
print("Rerolling with +1 seed failed for 10 iterations! Try setting webui's precision to 'full' and if it fails, please report this to the devs! Interrupting...")
@@ -100,12 +100,12 @@ def generate(args, keys, anim_args, loop_args, controlnet_args, root, parseq_ada
100100
return None
101101
return image
102102

103-
def generate_with_nans_check(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame=0, sampler_name=None):
103+
def generate_with_nans_check(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame=0, sampler_name=None, scheduler_name=None):
104104
if cmd_opts.disable_nan_check:
105-
image = generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame, sampler_name)
105+
image = generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame, sampler_name, scheduler_name)
106106
else:
107107
try:
108-
image = generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame, sampler_name)
108+
image = generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame, sampler_name, scheduler_name)
109109
except Exception as e:
110110
if "A tensor with all NaNs was produced in VAE." in repr(e):
111111
print(e)
@@ -114,7 +114,7 @@ def generate_with_nans_check(args, keys, anim_args, loop_args, controlnet_args,
114114
raise e
115115
return image, False
116116

117-
def generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame=0, sampler_name=None):
117+
def generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame=0, sampler_name=None, scheduler_name=None):
118118
# Setup the pipeline
119119
p = get_webui_sd_pipeline(args, root)
120120
p.prompt, p.negative_prompt = split_weighted_subprompts(args.prompt, frame, anim_args.max_frames)
@@ -176,6 +176,13 @@ def generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, pars
176176
else:
177177
raise RuntimeError(f"Sampler name '{sampler_name}' is invalid. Please check the available sampler list in the 'Run' tab")
178178

179+
available_schedulers = get_schedulers_list()
180+
if scheduler_name is not None:
181+
if scheduler_name in available_schedulers.keys():
182+
p.scheduler = available_schedulers[scheduler_name]
183+
else:
184+
raise RuntimeError(f"Scheduler name '{scheduler_name}' is invalid. Please check the available scheduler list in the 'Run' tab")
185+
179186
if args.checkpoint is not None:
180187
info = sd_models.get_closet_checkpoint_match(args.checkpoint)
181188
if info is None:
@@ -220,6 +227,7 @@ def generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, pars
220227
seed_resize_from_h=p.seed_resize_from_h,
221228
seed_resize_from_w=p.seed_resize_from_w,
222229
sampler_name=p.sampler_name,
230+
scheduler=p.scheduler,
223231
batch_size=p.batch_size,
224232
n_iter=p.n_iter,
225233
steps=p.steps,

scripts/deforum_helpers/settings.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,18 @@ def load_all_settings(*args, ui_launch=False, **kwargs):
129129
result = {}
130130
for key, default_val in data.items():
131131
val = jdata.get(key, default_val)
132-
if key == 'sampler' and isinstance(val, int):
133-
from modules.sd_samplers import samplers_for_img2img
134-
val = samplers_for_img2img[val].name
132+
if key == 'sampler' and isinstance(val, str):
133+
samp_val = val.split()
134+
scheduler_val = None
135+
if samp_val[-1] in ['Uniform','SGM Uniform','Karras','Exponential','Polyexponential']:
136+
scheduler_val = samp_val[-1]
137+
val = (val.split(" " + samp_val[-1]))[0]
138+
if key == 'scheduler' and isinstance(val, str):
139+
if scheduler_val is not None:
140+
val = scheduler_val
141+
else:
142+
from modules.sd_schedulers import schedulers_map
143+
val = schedulers_map[val].label
135144
elif key == 'fill' and isinstance(val, int):
136145
val = mask_fill_choices[val]
137146
elif key in {'reroll_blank_frames', 'noise_type'} and key not in jdata:
@@ -142,7 +151,6 @@ def load_all_settings(*args, ui_launch=False, **kwargs):
142151
val = jdata.get(key, default_val)
143152
elif key == 'animation_prompts':
144153
val = json.dumps(jdata['prompts'], ensure_ascii=False, indent=4)
145-
146154
result[key] = val
147155

148156
if ui_launch:

scripts/deforum_helpers/ui_elements.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def get_tab_run(d, da):
4646
motion_preview_mode = create_gr_elem(d.motion_preview_mode)
4747
with FormRow():
4848
sampler = create_gr_elem(d.sampler)
49+
scheduler = create_gr_elem(d.scheduler)
4950
steps = create_gr_elem(d.steps)
5051
with FormRow():
5152
W = create_gr_elem(d.W)

scripts/deforum_helpers/webui_sd_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def get_webui_sd_pipeline(args, root):
4040
p.batch_size = 1 # b.size 1 as this is DEFORUM :)
4141
p.seed = args.seed
4242
p.do_not_save_samples = True # Setting this to False will trigger webui's saving mechanism - and we will end up with duplicated files, and another folder within our destination folder - big no no.
43-
p.sampler_name = args.sampler
43+
p.scheduler = args.scheduler
4444
p.mask_blur = args.mask_overlay_blur
4545
p.extra_generation_params["Mask blur"] = args.mask_overlay_blur
4646
p.n_iter = 1

0 commit comments

Comments
 (0)