Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 14 additions & 31 deletions src/musubi_tuner/flux_kontext_generate_image.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,26 @@
import argparse
from importlib.util import find_spec
import random
import copy
import logging
import os
import random
import time
import copy
from typing import Tuple, Optional, List, Any, Dict
from importlib.util import find_spec
from typing import Any, Dict, List, Optional, Tuple

from einops import rearrange
import torch
from safetensors.torch import load_file, save_file
from einops import rearrange
from safetensors import safe_open
from safetensors.torch import load_file, save_file
from tqdm import tqdm

from musubi_tuner.flux import flux_utils
from musubi_tuner.flux import flux_models, flux_utils
from musubi_tuner.flux.flux_utils import load_flow_model
from musubi_tuner.flux import flux_models

lycoris_available = find_spec("lycoris") is not None

from musubi_tuner.hv_generate_video import get_time_flag, save_images_grid, synchronize_device
from musubi_tuner.merge_lora import merge_lora_weights
from musubi_tuner.networks import lora_flux
from musubi_tuner.utils.device_utils import clean_memory_on_device
from musubi_tuner.hv_generate_video import get_time_flag, save_images_grid, synchronize_device
from musubi_tuner.wan_generate_video import merge_lora_weights

import logging

lycoris_available = find_spec("lycoris") is not None
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

Expand Down Expand Up @@ -572,13 +568,9 @@ def generate(
args.exclude_patterns,
device,
args.lycoris,
args.save_merged_model,
save_merged_model=args.save_merged_model,
)

# if we only want to save the model, we can skip the rest
if args.save_merged_model:
return None, None

# optimize model: fp8 conversion, block swap etc.
optimize_model(model, args, device)

Expand Down Expand Up @@ -929,13 +921,8 @@ def process_batch_prompts(prompts_data: List[Dict], args: argparse.Namespace) ->
first_prompt_args.exclude_patterns,
device,
first_prompt_args.lycoris,
first_prompt_args.save_merged_model,
save_merged_model=first_prompt_args.save_merged_model,
)
if first_prompt_args.save_merged_model:
logger.info("Merged DiT model saved. Skipping generation.")
del dit_model
clean_memory_on_device(device)
return

logger.info("Optimizing DiT model...")
optimize_model(dit_model, first_prompt_args, device) # Handles device placement, fp8 etc.
Expand All @@ -959,7 +946,7 @@ def process_batch_prompts(prompts_data: List[Dict], args: argparse.Namespace) ->
prompt_args_item, gen_settings, shared_models_for_generate, current_image_data, current_text_data
)

if latent is None: # and prompt_args_item.save_merged_model: # Should be caught earlier
if latent is None:
continue

# Save latent if needed (using data from precomputed_image_data for H/W)
Expand Down Expand Up @@ -1177,14 +1164,10 @@ def main():
# For single mode, precomputed data is None, shared_models is None.
# generate will load all necessary models (VAE, Text/Image Encoders, DiT).
returned_vae, latent = generate(args, gen_settings)
# print(f"Generated latent shape: {latent.shape}")
# if args.save_merged_model:
# return

# Save latent and video
# returned_vae from generate will be used for decoding here.
save_output(args, returned_vae, latent[0], device)

logger.info("Done!")


Expand Down
101 changes: 40 additions & 61 deletions src/musubi_tuner/fpack_generate_video.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,37 @@
import argparse
import copy
import gc
from importlib.util import find_spec
import random
import logging
import os
import random
import re
import time
import copy
from typing import Tuple, Optional, List, Any, Dict
from importlib.util import find_spec
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import torch
from safetensors.torch import load_file, save_file
from safetensors import safe_open
from PIL import Image
import numpy as np
from safetensors import safe_open
from safetensors.torch import load_file, save_file
from tqdm import tqdm

from musubi_tuner.networks import lora_framepack
from musubi_tuner.hunyuan_model.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
from musubi_tuner.dataset import image_video_dataset
from musubi_tuner.frame_pack import hunyuan
from musubi_tuner.frame_pack.clip_vision import hf_clip_vision_encode
from musubi_tuner.frame_pack.framepack_utils import load_image_encoders, load_text_encoder1, load_text_encoder2, load_vae
from musubi_tuner.frame_pack.hunyuan_video_packed import load_packed_model
from musubi_tuner.frame_pack.hunyuan_video_packed_inference import HunyuanVideoTransformer3DModelPackedInference
from musubi_tuner.frame_pack.utils import crop_or_pad_yield_mask, soft_append_bcthw
from musubi_tuner.frame_pack.clip_vision import hf_clip_vision_encode
from musubi_tuner.frame_pack.k_diffusion_hunyuan import sample_hunyuan
from musubi_tuner.dataset import image_video_dataset
from musubi_tuner.frame_pack.utils import crop_or_pad_yield_mask, soft_append_bcthw
from musubi_tuner.hunyuan_model.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
from musubi_tuner.hv_generate_video import get_time_flag, save_images_grid, save_videos_grid, synchronize_device
from musubi_tuner.merge_lora import merge_lora_weights
from musubi_tuner.networks import lora_framepack
from musubi_tuner.utils.device_utils import clean_memory_on_device
from musubi_tuner.utils.lora_utils import filter_lora_state_dict

lycoris_available = find_spec("lycoris") is not None

from musubi_tuner.utils.device_utils import clean_memory_on_device
from musubi_tuner.hv_generate_video import get_time_flag, save_images_grid, save_videos_grid, synchronize_device
from musubi_tuner.wan_generate_video import merge_lora_weights
from musubi_tuner.frame_pack.framepack_utils import load_vae, load_text_encoder1, load_text_encoder2, load_image_encoders

import logging

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

Expand Down Expand Up @@ -459,6 +456,7 @@ def load_dit_model(args: argparse.Namespace, device: torch.device) -> HunyuanVid
for_inference=True,
lora_weights_list=lora_weights_list,
lora_multipliers=args.lora_multiplier,
save_merged_model=args.save_merged_model if not args.lycoris else None,
)

# apply RoPE scaling factor
Expand All @@ -471,36 +469,30 @@ def load_dit_model(args: argparse.Namespace, device: torch.device) -> HunyuanVid
# magcache
initialize_magcache(args, model)

if args.lycoris:
# merge LoRA weights statically
if args.lora_weight is not None and len(args.lora_weight) > 0:
# ugly hack to common merge_lora_weights function
merge_lora_weights(
lora_framepack,
model,
args.lora_weight,
args.lora_multiplier,
args.include_patterns,
args.exclude_patterns,
device,
lycoris=True,
save_merged_model=args.save_merged_model,
converter=convert_lora_for_framepack,
)

if args.fp8_scaled:
state_dict = model.state_dict() # bf16 state dict
if args.lycoris and args.lora_weight is not None and len(args.lora_weight) > 0:
# ugly hack to common merge_lora_weights function
merge_lora_weights(
lora_framepack,
model,
args.lora_weight,
args.lora_multiplier,
args.include_patterns,
args.exclude_patterns,
device,
lycoris=True,
save_merged_model=args.save_merged_model,
converter=convert_lora_for_framepack,
)

# if no blocks to swap, we can move the weights to GPU after optimization on GPU (omit redundant CPU->GPU copy)
move_to_device = args.blocks_to_swap == 0 # if blocks_to_swap > 0, we will keep the model on CPU
state_dict = model.fp8_optimization(state_dict, device, move_to_device, use_scaled_mm=False) # args.fp8_fast)
if args.lycoris and args.fp8_scaled:
state_dict = model.state_dict() # bf16 state dict

info = model.load_state_dict(state_dict, strict=True, assign=True)
logger.info(f"Loaded FP8 optimized weights: {info}")
# if no blocks to swap, we can move the weights to GPU after optimization on GPU (omit redundant CPU->GPU copy)
move_to_device = args.blocks_to_swap == 0 # if blocks_to_swap > 0, we will keep the model on CPU
state_dict = model.fp8_optimization(state_dict, device, move_to_device, use_scaled_mm=False) # args.fp8_fast)

# if we only want to save the model, we can skip the rest
if args.save_merged_model:
return model
info = model.load_state_dict(state_dict, strict=True, assign=True)
logger.info(f"Loaded FP8 optimized weights: {info}")

if not args.fp8_scaled:
# simple cast to dit_dtype
Expand Down Expand Up @@ -1195,9 +1187,6 @@ def generate(

if shared_models is None or "model" not in shared_models:
model = load_dit_model(args, device)
if args.save_merged_model:
# If we only want to save the model, we can skip the rest
return model, None

if shared_models is not None:
shared_models["model"] = model
Expand Down Expand Up @@ -1918,11 +1907,6 @@ def process_batch_prompts(prompts_data: List[Dict], args: argparse.Namespace) ->
first_prompt_args = all_prompt_args_list[0]

dit_model = load_dit_model(first_prompt_args, device) # Load directly to target device if possible
if first_prompt_args.save_merged_model:
logger.info("Merged DiT model saved. Skipping generation.")
del dit_model
clean_memory_on_device(device)
return

shared_models_for_generate = {"model": dit_model} # Pass DiT via shared_models

Expand All @@ -1943,7 +1927,7 @@ def process_batch_prompts(prompts_data: List[Dict], args: argparse.Namespace) ->
prompt_args_item, gen_settings, shared_models_for_generate, current_image_data, current_text_data
)

if latent is None and prompt_args_item.save_merged_model: # Should be caught earlier
if latent is None:
continue

# Save latent if needed (using data from precomputed_image_data for H/W)
Expand Down Expand Up @@ -2087,8 +2071,6 @@ def main():
# Parse arguments
args = parse_args()

assert (not args.save_merged_model) or (not args.fp8_scaled), "Save merged model is not compatible with fp8_scaled"

# Check if latents are provided
latents_mode = args.latent_path is not None and len(args.latent_path) > 0

Expand Down Expand Up @@ -2179,9 +2161,6 @@ def main():
# For single mode, precomputed data is None, shared_models is None.
# generate will load all necessary models (VAE, Text/Image Encoders, DiT).
returned_vae, latent = generate(args, gen_settings)
# print(f"Generated latent shape: {latent.shape}")
if args.save_merged_model:
return

# Save latent and video
# returned_vae from generate will be used for decoding here.
Expand Down
3 changes: 3 additions & 0 deletions src/musubi_tuner/frame_pack/hunyuan_video_packed.py
Original file line number Diff line number Diff line change
Expand Up @@ -2043,6 +2043,7 @@ def load_packed_model(
for_inference: bool = False,
lora_weights_list: Optional[Dict[str, torch.Tensor]] = None,
lora_multipliers: Optional[List[float]] = None,
save_merged_model: Optional[str] = None,
) -> HunyuanVideoTransformer3DModelPacked:
"""
Load a packed DiT model from a given path.
Expand All @@ -2058,6 +2059,7 @@ def load_packed_model(
for_inference (bool): Whether to create the model for inference.
lora_weights_list (Optional[Dict[str, torch.Tensor]]): List of state_dicts for LoRA weights.
lora_multipliers (Optional[List[float]]): List of multipliers for LoRA weights.
save_merged_model (Optional[str]]): Path to save the merged model. If None, the model will not be saved.

Returns:
HunyuanVideoTransformer3DModelPacked: The loaded DiT model.
Expand Down Expand Up @@ -2119,6 +2121,7 @@ def load_packed_model(
move_to_device=(loading_device == device),
target_keys=FP8_OPTIMIZATION_TARGET_KEYS,
exclude_keys=FP8_OPTIMIZATION_EXCLUDE_KEYS,
save_merged_model=save_merged_model,
)

if fp8_scaled:
Expand Down
3 changes: 2 additions & 1 deletion src/musubi_tuner/hv_generate_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import random
import os
import time
import sys
from typing import Union

import numpy as np
Expand Down Expand Up @@ -668,7 +669,7 @@ def main():
logger.info(f"Saving merged model to {args.save_merged_model}")
mem_eff_save_file(transformer.state_dict(), args.save_merged_model) # save_file needs a lot of memory
logger.info("Merged model saved")
return
sys.exit(0)

logger.info(f"Casting model to {dit_weight_dtype}")
transformer.to(dtype=dit_weight_dtype)
Expand Down
Loading