2828from .load_images import load_img , prepare_mask , check_mask_for_errors
2929from .webui_sd_pipeline import get_webui_sd_pipeline
3030from .rich import console
31- from .defaults import get_samplers_list
31+ from .defaults import get_samplers_list , get_schedulers_list
3232from .prompt import check_is_number
3333from .opts_overrider import A1111OptionsOverrider
3434import cv2
@@ -70,14 +70,14 @@ def pairwise_repl(iterable):
7070 next (b , None )
7171 return zip (a , b )
7272
73- def generate (args , keys , anim_args , loop_args , controlnet_args , root , parseq_adapter , frame = 0 , sampler_name = None ):
73+ def generate (args , keys , anim_args , loop_args , controlnet_args , root , parseq_adapter , frame = 0 , sampler_name = None , scheduler_name = None ):
7474 if state .interrupted :
7575 return None
7676
7777 if args .reroll_blank_frames == 'ignore' :
78- return generate_inner (args , keys , anim_args , loop_args , controlnet_args , root , parseq_adapter , frame , sampler_name )
78+ return generate_inner (args , keys , anim_args , loop_args , controlnet_args , root , parseq_adapter , frame , sampler_name , scheduler_name )
7979
80- image , caught_vae_exception = generate_with_nans_check (args , keys , anim_args , loop_args , controlnet_args , root , parseq_adapter , frame , sampler_name )
80+ image , caught_vae_exception = generate_with_nans_check (args , keys , anim_args , loop_args , controlnet_args , root , parseq_adapter , frame , sampler_name , scheduler_name )
8181
8282 if caught_vae_exception or not image .getbbox ():
8383 patience = args .reroll_patience
@@ -86,7 +86,7 @@ def generate(args, keys, anim_args, loop_args, controlnet_args, root, parseq_ada
8686 while caught_vae_exception or not image .getbbox ():
8787 print ("Rerolling with +1 seed..." )
8888 args .seed += 1
89- image , caught_vae_exception = generate_with_nans_check (args , keys , anim_args , loop_args , controlnet_args , root , parseq_adapter , frame , sampler_name )
89+ image , caught_vae_exception = generate_with_nans_check (args , keys , anim_args , loop_args , controlnet_args , root , parseq_adapter , frame , sampler_name , scheduler_name = None )
9090 patience -= 1
9191 if patience == 0 :
9292 print ("Rerolling with +1 seed failed for 10 iterations! Try setting webui's precision to 'full' and if it fails, please report this to the devs! Interrupting..." )
@@ -100,12 +100,12 @@ def generate(args, keys, anim_args, loop_args, controlnet_args, root, parseq_ada
100100 return None
101101 return image
102102
103- def generate_with_nans_check (args , keys , anim_args , loop_args , controlnet_args , root , parseq_adapter , frame = 0 , sampler_name = None ):
103+ def generate_with_nans_check (args , keys , anim_args , loop_args , controlnet_args , root , parseq_adapter , frame = 0 , sampler_name = None , scheduler_name = None ):
104104 if cmd_opts .disable_nan_check :
105- image = generate_inner (args , keys , anim_args , loop_args , controlnet_args , root , parseq_adapter , frame , sampler_name )
105+ image = generate_inner (args , keys , anim_args , loop_args , controlnet_args , root , parseq_adapter , frame , sampler_name , scheduler_name )
106106 else :
107107 try :
108- image = generate_inner (args , keys , anim_args , loop_args , controlnet_args , root , parseq_adapter , frame , sampler_name )
108+ image = generate_inner (args , keys , anim_args , loop_args , controlnet_args , root , parseq_adapter , frame , sampler_name , scheduler_name )
109109 except Exception as e :
110110 if "A tensor with all NaNs was produced in VAE." in repr (e ):
111111 print (e )
@@ -114,7 +114,7 @@ def generate_with_nans_check(args, keys, anim_args, loop_args, controlnet_args,
114114 raise e
115115 return image , False
116116
117- def generate_inner (args , keys , anim_args , loop_args , controlnet_args , root , parseq_adapter , frame = 0 , sampler_name = None ):
117+ def generate_inner (args , keys , anim_args , loop_args , controlnet_args , root , parseq_adapter , frame = 0 , sampler_name = None , scheduler_name = None ):
118118 # Setup the pipeline
119119 p = get_webui_sd_pipeline (args , root )
120120 p .prompt , p .negative_prompt = split_weighted_subprompts (args .prompt , frame , anim_args .max_frames )
@@ -176,6 +176,13 @@ def generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, pars
176176 else :
177177 raise RuntimeError (f"Sampler name '{ sampler_name } ' is invalid. Please check the available sampler list in the 'Run' tab" )
178178
179+ available_schedulers = get_schedulers_list ()
180+ if scheduler_name is not None :
181+ if scheduler_name in available_schedulers .keys ():
182+ p .scheduler = available_schedulers [scheduler_name ]
183+ else :
184+ raise RuntimeError (f"Scheduler name '{ scheduler_name } ' is invalid. Please check the available scheduler list in the 'Run' tab" )
185+
179186 if args .checkpoint is not None :
180187 info = sd_models .get_closet_checkpoint_match (args .checkpoint )
181188 if info is None :
@@ -220,6 +227,7 @@ def generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, pars
220227 seed_resize_from_h = p .seed_resize_from_h ,
221228 seed_resize_from_w = p .seed_resize_from_w ,
222229 sampler_name = p .sampler_name ,
230+ scheduler = p .scheduler ,
223231 batch_size = p .batch_size ,
224232 n_iter = p .n_iter ,
225233 steps = p .steps ,
0 commit comments