Skip to content
2 changes: 1 addition & 1 deletion scripts/animatediff_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def set_p(self, p: StableDiffusionProcessing):
cn_unit.batch_mask_dir = self.mask_path

# find minimun control images in CN batch
cn_unit_batch_params = cn_unit.batch_images.split('\n')
cn_unit_batch_params = cn_unit.batch_images.split('\n') if cn_unit.batch_images is not None else []
if cn_unit.input_mode.name == 'BATCH':
cn_unit.animatediff_batch = True # for A1111 sd-webui-controlnet
if not any([cn_param.startswith("keyframe:") for cn_param in cn_unit_batch_params[1:]]):
Expand Down
26 changes: 20 additions & 6 deletions scripts/animatediff_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,28 @@ def get_controlnet_units(p: StableDiffusionProcessing):
cn_units = p.script_args[script.args_from:script.args_to]

if p.is_api and len(cn_units) > 0 and isinstance(cn_units[0], dict):
from scripts import external_code
from scripts.batch_hijack import InputMode
cn_units_dataclass = external_code.get_all_units_in_processing(p)
for cn_unit_dict, cn_unit_dataclass in zip(cn_units, cn_units_dataclass):
from scripts import external_code
from scripts.batch_hijack import InputMode
cn_units_dataclass = external_code.get_all_units_in_processing(p)
for cn_unit_dict, cn_unit_dataclass in zip(cn_units, cn_units_dataclass):
# NB: Unfortunately this setattr section is required because those attributes don't exist
# in the default ControlNetUnit class defined in sd-webui-controlnet library.
# So we have to use this hack to append extra batch processing related attributes to the object
# until sd-webui-controlnet makes an update.
setattr(cn_unit_dataclass, "input_mode", InputMode.SIMPLE)
setattr(cn_unit_dataclass, "batch_images", None)
setattr(cn_unit_dataclass, "batch_mask_dir", None)
setattr(cn_unit_dataclass, "batch_input_gallery", None)
setattr(cn_unit_dataclass, "batch_modifiers", [])
setattr(cn_unit_dataclass, "animatediff_batch", False)

if cn_unit_dataclass.image is None:
cn_unit_dataclass.input_mode = InputMode.BATCH
cn_unit_dataclass.batch_images = cn_unit_dict.get("batch_images", None)
p.script_args[script.args_from:script.args_to] = cn_units_dataclass
cn_unit_dataclass.batch_images = getattr(cn_unit_dict, "batch_images", None)
cn_unit_dataclass.animatediff_batch = True

p.script_args[script.args_from:script.args_to] = cn_units_dataclass
cn_units = cn_units_dataclass

return [x for x in cn_units if x.enabled] if not p.is_api else cn_units

Expand Down