Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions scripts/animatediff.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from scripts.animatediff_settings import on_ui_settings
from scripts.animatediff_infotext import update_infotext, infotext_pasted
from scripts.animatediff_utils import get_animatediff_arg
from scripts.animatediff_i2ibatch import * # this is necessary for CN to find the function
from scripts.animatediff_i2ibatch import animatediff_i2i_init, animatediff_i2i_batch # Make functions available for CN
from scripts.animatediff_freeinit import AnimateDiffFreeInit

script_dir = scripts.basedir()
Expand All @@ -28,7 +28,7 @@
class AnimateDiffScript(scripts.Script):

def __init__(self):
self.hacked = False
self.module_injected_by_this_script_run = False
self.infotext_fields: List[Tuple[gr.components.IOComponent, str]] = []
self.paste_field_names: List[str] = []

Expand Down Expand Up @@ -68,10 +68,10 @@ def before_process(self, p: StableDiffusionProcessing, params: AnimateDiffProces
if params.freeinit_enable:
self.freeinit_hacker = AnimateDiffFreeInit(params)
self.freeinit_hacker.hack(p, params)
self.hacked = True
elif self.hacked:
self.module_injected_by_this_script_run = True
elif self.module_injected_by_this_script_run:
motion_module.restore(p.sd_model)
self.hacked = False
self.module_injected_by_this_script_run = False


def before_process_batch(self, p: StableDiffusionProcessing, params: AnimateDiffProcess, **kwargs):
Expand All @@ -93,7 +93,7 @@ def postprocess(self, p: StableDiffusionProcessing, res: Processed, params: Anim
if params.enable:
params.prompt_scheduler.save_infotext_txt(res)
motion_module.restore(p.sd_model)
self.hacked = False
self.module_injected_by_this_script_run = False
AnimateDiffOutput().output(p, res, params)
logger.info("AnimateDiff process end.")

Expand Down
2 changes: 1 addition & 1 deletion scripts/animatediff_i2ibatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def animatediff_i2i_batch(
try:
img = Image.open(image)
except UnidentifiedImageError as e:
print(e)
logger.error(f"Skipping image {image} due to UnidentifiedImageError: {e}")
continue
# Use the EXIF orientation of photos taken by smartphones.
img = ImageOps.exif_transpose(img)
Expand Down
28 changes: 26 additions & 2 deletions scripts/animatediff_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def inject(self, sd_model, model_name="mm_sd15_v3.safetensors"):
if self.mm.is_v2:
logger.info(f"Injecting motion module {model_name} into {sd_ver} UNet middle block.")
unet.middle_block.insert(-1, self.mm.mid_block.motion_modules[0])
# This section applies a global monkey patch to GroupNorm32.forward for specific motion modules
# that require reshaping of tensors before and after the GroupNorm operation.
elif self.mm.enable_gn_hack():
logger.info(f"Hacking {sd_ver} GroupNorm32 forward function.")
if self.mm.is_hotshot:
Expand All @@ -83,14 +85,21 @@ def inject(self, sd_model, model_name="mm_sd15_v3.safetensors"):
gn32_original_forward = self.gn32_original_forward

def groupnorm32_mm_forward(self, x):
# Reshape the tensor to isolate a dimension (assumed to be 'frames' or similar,
# with 'b' likely referring to a fixed batch factor of 2 for this hack)
# before the original GroupNorm.
x = rearrange(x, "(b f) c h w -> b c f h w", b=2)
x = gn32_original_forward(self, x)
# Reshape the tensor back to its expected format after the original GroupNorm.
x = rearrange(x, "b c f h w -> (b f) c h w", b=2)
return x

# This is a global modification of the GroupNorm32.forward method.
GroupNorm32.forward = groupnorm32_mm_forward

logger.info(f"Injecting motion module {model_name} into {sd_ver} UNet input blocks.")
# These indices are specific to the UNet architecture of SD1.5 and SDXL models.
# They might need adjustment if the underlying UNet structure changes significantly.
for mm_idx, unet_idx in enumerate([1, 2, 4, 5, 7, 8, 10, 11]):
if inject_sdxl and mm_idx >= 6:
break
Expand All @@ -99,6 +108,8 @@ def groupnorm32_mm_forward(self, x):
unet.input_blocks[unet_idx].append(mm_inject)

logger.info(f"Injecting motion module {model_name} into {sd_ver} UNet output blocks.")
# These indices are specific to the UNet architecture of SD1.5 and SDXL models.
# They might need adjustment if the underlying UNet structure changes significantly.
for unet_idx in range(12):
if inject_sdxl and unet_idx >= 9:
break
Expand Down Expand Up @@ -126,12 +137,23 @@ def restore(self, sd_model):
unet = sd_model.model.diffusion_model

logger.info(f"Removing motion module from {sd_ver} UNet input blocks.")
for unet_idx in [1, 2, 4, 5, 7, 8, 10, 11]:
if inject_sdxl and unet_idx >= 9:
# These indices are specific to the UNet architecture of SD1.5 and SDXL models.
# They might need adjustment if the underlying UNet structure changes significantly.
input_block_indices_to_process = []
original_input_indices = [1, 2, 4, 5, 7, 8, 10, 11]
for mm_idx, unet_idx in enumerate(original_input_indices):
if inject_sdxl and mm_idx >= 6:
break
input_block_indices_to_process.append(unet_idx)

# Iterate in reverse for popping if order matters, though for pop(-1) it might not.
# However, to be safe and clear, let's stick to the derived list.
for unet_idx in input_block_indices_to_process:
unet.input_blocks[unet_idx].pop(-1)

logger.info(f"Removing motion module from {sd_ver} UNet output blocks.")
# These indices are specific to the UNet architecture of SD1.5 and SDXL models.
# They might need adjustment if the underlying UNet structure changes significantly.
for unet_idx in range(12):
if inject_sdxl and unet_idx >= 9:
break
Expand All @@ -143,6 +165,8 @@ def restore(self, sd_model):
if self.mm.is_v2:
logger.info(f"Removing motion module from {sd_ver} UNet middle block.")
unet.middle_block.pop(-2)
# This section restores the original GroupNorm32.forward method,
# removing the global monkey patch applied during injection if it was enabled.
elif self.mm.enable_gn_hack():
logger.info(f"Restoring {sd_ver} GroupNorm32 forward function.")
if self.mm.is_hotshot:
Expand Down
94 changes: 94 additions & 0 deletions tests/test_animatediff_mm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import unittest
from unittest.mock import MagicMock, patch, call
import sys
import os

# Add scripts directory to sys.path to allow importing animatediff_mm
# This assumes the tests are run from the root of the repository
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'scripts')))

from animatediff_mm import AnimateDiffMM, MotionModuleType
# Note: May need to adjust motion_module import if MotionWrapper is directly used/mocked from its original location.
# from motion_module import MotionWrapper

# Mock global 'shared' and 'devices' objects from 'modules' that are used in animatediff_mm
# These would typically be part of the A1111/Forge environment
mock_shared = MagicMock()
mock_shared.cmd_opts = MagicMock()
mock_shared.cmd_opts.no_half = False
mock_shared.opts = MagicMock()
mock_shared.opts.data = {} # For animatediff_model_path

mock_devices = MagicMock()
mock_devices.device = 'cpu' # Mock device
mock_devices.cpu = 'cpu'
mock_devices.fp8 = False

# Patch 'modules.shared' and 'modules.devices' at the script level where animatediff_mm can access them
# sys.modules is used here to ensure the mocks are in place before animatediff_mm is potentially fully parsed.
sys.modules['modules.shared'] = mock_shared
sys.modules['modules.devices'] = mock_devices
sys.modules['modules.hashes'] = MagicMock()
sys.modules['modules.sd_models'] = MagicMock()
sys.modules['ldm.modules.diffusionmodules.util'] = MagicMock() # For GroupNorm32
sys.modules['sgm.modules.diffusionmodules.util'] = MagicMock() # For GroupNorm32 (SDXL)


class TestAnimateDiffMM(unittest.TestCase):

def setUp(self):
self.mm_instance = AnimateDiffMM()
# Reset class variable for injection state between tests
AnimateDiffMM.mm_injected = False
# Set script_dir, normally done externally
self.mm_instance.set_script_dir(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'scripts')))

# Mock MotionWrapper class that AnimateDiffMM instantiates
self.mock_mw_instance = MagicMock()
self.mock_mw_instance.mm_name = "initial_name"
self.mock_mw_instance.load_state_dict = MagicMock()
self.mock_mw_instance.to = MagicMock(return_value=self.mock_mw_instance)
self.mock_mw_instance.eval = MagicMock()
self.mock_mw_instance.half = MagicMock()
self.mock_mw_instance.modules = MagicMock(return_value=[])
self.mock_mw_instance.is_v2 = False
self.mock_mw_instance.is_xl = False
self.mock_mw_instance.enable_gn_hack = MagicMock(return_value=False)
# Add other necessary attributes for MotionWrapper mock

@patch('scripts.animatediff_mm.MotionWrapper') # Patch where it's looked up by AnimateDiffMM
@patch('scripts.animatediff_mm.os.path.isfile', return_value=True)
@patch('modules.sd_models.read_state_dict') # Patched at sys.modules level, direct use here
@patch('modules.hashes.sha256')
@patch('scripts.animatediff_mm.MotionModuleType.get_mm_type')
def test_load_model_success(self, mock_get_mm_type, mock_sha256, mock_read_state_dict, mock_isfile, MockMotionWrapper):
# Configure the mock MotionWrapper that is returned when MotionWrapper() is called
MockMotionWrapper.return_value = self.mock_mw_instance

mock_read_state_dict.return_value = {"test_key": "test_value"}
mock_sha256.return_value = "test_hash"
mock_get_mm_type.return_value = MotionModuleType.STANDARD

# Action
self.mm_instance.load("test_model.safetensors")

# Assertions
self.assertIsNotNone(self.mm_instance.mm)
MockMotionWrapper.assert_called_once_with(mm_name="test_model.safetensors", mm_hash="test_hash", mm_type=MotionModuleType.STANDARD)
self.mm_instance.mm.load_state_dict.assert_called_once_with({"test_key": "test_value"})
self.mm_instance.mm.to.assert_called_with('cpu') # from mock_devices.device
self.mm_instance.mm.eval.assert_called_once()
# self.mm_instance.mm.half.assert_called_once() # Depends on no_half

@patch('scripts.animatediff_mm.os.path.isfile', return_value=False)
def test_load_model_file_not_found(self, mock_isfile):
with self.assertRaises(RuntimeError) as context:
self.mm_instance.load("non_existent_model.safetensors")
self.assertIn("Please download models manually.", str(context.exception))

if __name__ == '__main__':
# Create a 'tests' directory if it doesn't exist
# This check is more for running the script directly; create_file_with_block handles dir creation.
if not os.path.exists(os.path.join(os.path.dirname(__file__))):
os.makedirs(os.path.join(os.path.dirname(__file__)))
unittest.main()