Skip to content

Conversation

kohya-ss
Copy link
Owner

@kohya-ss kohya-ss commented Sep 11, 2025

Inference and LoRA training function for HunyuanImage-2.1.

It also incorporates the fp8 scaling function and dynamic merging during LoRA inference from Musubi Tuner, and unified attention method.

The official weights can be used for DiT weights, and ComfyUI weights for all other weights.

The arguments for hunyuan_image_minimal_inference.py are almost the same as the inference script for Qwen-Image in Musubi Tuner.

Block swap is not supported yet, so it uses about 24GB VRAM even with --fp8_scaled.

@FurkanGozukara
Copy link

block swap would be nice. i could use it then. nice work

@kohya-ss
Copy link
Owner Author

kohya-ss commented Sep 11, 2025

hunyuan_image_minimal_inference.py now supports block swap. It runs with 8GB VRAM and 64GB main RAM by specifying --fp8_scaled --blocks_to_swap 37 options. --text_encoder_cpu is also recommended if VRAM is less than 16GB.

LoRA training is not working yet.

@kohya-ss
Copy link
Owner Author

Sample command:

python hunyuan_image_minimal_inference.py --dit path/to/hunyuanimage2.1.safetensors 
    --text_encoder path/to/qwen_2.5_vl_7b.safetensors --byt5 path/to/byt5_small_glyphxl_fp16.safetensors 
    --vae path/to/hunyuan_image_2.1_vae_fp16.safetensors --save_path path/to/save_path --infer_steps 50 
    --prompt "A cute, cartoon-style anthropomorphic penguin plush toy with fluffy fur, standing in a painting studio, wearing a red knitted scarf and a red beret with the word ""Tencent"" on it, holding a paintbrush with a focused expression as it paints an oil painting of the Mona Lisa, rendered in a photorealistic photographic style." 
    --seed 542017 --guidance_scale 3.25 --flow_shift 4 --image_size 2048 2048

@kohya-ss kohya-ss changed the title feat: initial commit for HunyuanImage-2.1 inference feat: support HunyuanImage-2.1 Sep 12, 2025
@FurkanGozukara
Copy link

@kohya-ss amazing work

do you think could it be possible to make block swapping count automatic based on available VRAM somehow?

that would make this repo way more stronger

@kohya-ss
Copy link
Owner Author

Sample training command:

accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 hunyuan_image_train_network.py 
    --pretrained_model_name_or_path path/to/hunyuanimage2.1.safetensors 
    --text_encoder path/to/qwen_2.5_vl_7b.safetensors --byt5 path/to/byt5_small_glyphxl_fp16.safetensors 
    --vae path/to/hunyuan_image_2.1_vae_fp16.safetensors --cache_latents_to_disk --save_model_as safetensors 
    --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 
    --max_train_epochs 8 --save_every_n_epochs 1 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 
    --network_module networks.lora_hunyuan_image --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-3 
    --network_train_unet_only --cache_text_encoder_outputs_to_disk --highvram 
    --dataset_config path/to/dataset_config.toml --logging_dir ./logs --log_prefix hi21-lora1- 
    --output_dir path/to/output_dir --output_name hunyuan-image-21-lora1 --timestep_sampling shift --discrete_flow_shift 5.0 
    --sample_prompts=path/to/prompt.txt --sample_every_n_epochs 1 --disable_mmap 

--xformers, --fp8_vl, --fp8_scaled and --blocks_to_swap are also available. --fp8_vl is recommended for GPUs less than 16GB VRAM. Max --blocks_to_swap is 37.

Approximate VRAM and main RAM usage for 2048x2048 resolution training (batch size of 1) with --fp8_scaled:

blocks_to_swap VRAM main RAM
- 40GB 40GB (peak)
9 24GB 40GB (peak)
24 16GB 48GB
32 12GB 64GB

The learning rate may be a little too high.
Inference becomes unstable at sizes other than 2048x2048, so a 2048x2048 dataset would be better.

@kohya-ss
Copy link
Owner Author

kohya-ss commented Sep 14, 2025

I added a conversion script to use the trained LoRA in ComfyUI. Please use it like this:

networks/convert_hunyuan_image_lora_to_comfy.safetensors sd-scripts-lora.py comfyui-lora.safetensors

@sdbds
Copy link
Contributor

sdbds commented Sep 16, 2025

1、The model download repository in the document is incorrect; it seems they have switched to the latest tencent/HunyuanImage-2.1
Instead of the old hunyuanDIT repository
2、A few days ago they released the fp8scale weights, which seems to be directly usable?

@kohya-ss
Copy link
Owner Author

1、The model download repository in the document is incorrect; it seems they have switched to the latest tencent/HunyuanImage-2.1
Instead of the old hunyuanDIT repository

Thank you, I gave instructions to Claude and fixied it.

2、A few days ago they released the fp8scale weights, which seems to be directly usable?

The weights seems to be per-tensor scaled, so it may be better to use per-channel dynamic scaling from bf16 weights.

@Sarania
Copy link

Sarania commented Sep 16, 2025

@kohya-ss amazing work

do you think could it be possible to make block swapping count automatic based on available VRAM somehow?

that would make this repo way more stronger

I've considered ways of doing this for blissful-tuner actually but it's not exactly trivial. You have a couple of options, you can directly calculate the amount of VRAM but there are a lot of possible variables that must be considered to know that, different for each model and possibly even differing by like Pytorch version, OS etc. Or you could maybe use a try except model where you try to load the model with N blocks swapped and if it works great and if not, try more until you find something that works. But that's messy IMO.

The truth is with these large models, once you're swapping any amount of blocks the speed diff between say 10 blocks swapped and 20 is not that big(caveat: once you go over about 90% blocks swapped, it tanks speed but that's very rarely necessary). I usually aim to hit about 90% of my VRAM so I have a little extra to avoid OOM. But the point is you don't have to get the number exact and it's okay to err on the safer side to avoid OOMing.

@Sarania
Copy link

Sarania commented Sep 16, 2025

A small suggestion I might make for HyImage is the option to run the LLM on CPU in FP32 as an alternative to fp8. It takes about 15 seconds per encode on my 4.2 Ghz, 8 core, 16 thread i7 6900k which is a ten year old CPU so it's definitely a feasible option for those who have the sysram! Unless you're encoding like... tons XD.

@kohya-ss
Copy link
Owner Author

kohya-ss commented Sep 20, 2025

The unified (multi-backend) attention interface has been added.

The results of a small speed and calculation verification test. From top to bottom, xformers (old), xformers (new), flash (new), torch (new), and torch with split_attn (new).

image

After this PR is merged, flash attention and sage attention during inference will be available in SDXL, FLUX.1, etc. in the future.

@kohya-ss
Copy link
Owner Author

--text_encoder_cpu option has been added to the training script.

@Sarania
Copy link

Sarania commented Sep 20, 2025

A couple of things I notice about the minimal inference, in hunyuan_image_utils you seem to have inherited the old defaults from the repo:

    apg_start_step_ocr: int = 75,
    apg_start_step_general: int = 10,

But it was updated, from hunyuan_image_pipeline:

        if self.cfg_mode == "APG_mode_0":
            self.cfg_guider = AdaptiveProjectedGuidance(guidance_scale=10.0, eta=0.0,
                                                        adaptive_projected_guidance_rescale=10.0,
                                                        adaptive_projected_guidance_momentum=-0.5)
            self.apg_start_step = 10
        elif self.cfg_mode == "MIX_mode_0":
            self.cfg_guider_ocr = AdaptiveProjectedGuidance(guidance_scale=10.0, eta=0.0,
                                                            adaptive_projected_guidance_rescale=10.0,
                                                            adaptive_projected_guidance_momentum=-0.5)
            self.apg_start_step_ocr = 38

            self.cfg_guider_general = AdaptiveProjectedGuidance(guidance_scale=10.0, eta=0.0,
                                                                adaptive_projected_guidance_rescale=10.0,
                                                                adaptive_projected_guidance_momentum=-0.5)
            self.apg_start_step_general = 5

        self.ocr_mask = []

which makes sense, the default of 75 for ocr start step makes no sense when default number of steps is 50 and I also inherited this issue into my workspace until I noticed the change! Also in my personal experimentation, when you switch from CFG to APG makes a large difference in the aesthetic and quality of the output, text, etc. I can't give you perfect numbers I think it might be different depending on what you're doing, so my suggestion would be to allow the user to optionally choose the step to switch.

Also, there doesn't seem to be any functional difference between cfg_guider_ocr and cfg_guider_general? I know it's like that in the source repo too but aren't they both just instances of the same APG class with the same params? Are both necessary or could we dispatch with one and simplify?

And lastly, when using regular CFG for a significant portion of steps, guidance rescaling definitely seems beneficial to avoid oversaturation/overexposure. And just for reference I'm not referring to the rescale that's part of APG but

def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
    r"""
    Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
    Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
    Flawed](https://arxiv.org/pdf/2305.08891.pdf).

    Args:
        noise_cfg (`torch.Tensor`):
            The predicted noise tensor for the guided diffusion process.
        noise_pred_text (`torch.Tensor`):
            The predicted noise tensor for the text-guided diffusion process.
        guidance_rescale (`float`, *optional*, defaults to 0.0):
            A rescale factor applied to the noise predictions.
    Returns:
        noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
    """
    std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
    std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
    # rescale the results from guidance (fixes overexposure)
    noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
    # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
    noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
    return noise_cfg

And I find a default guidance_rescale of 0.7 seems good, though it's only necessary when using raw CFG for more than 10 steps I'd say, so it should be optional(though it doesn't seem to hurt otherwise tbh)! Anyway these are my thoughts after having a look based on my own experimentation over the last week! Sorry if I missed anything that's already there! Also lastly a small note - I know right now the refiner is not part of this repo and may never be, which is fine as we discussed I think it's optional. But if you do ever mess with it, it's worth noting that the refiner vae DOES seem to support spatial tiling, unlike the base VAE! Just thought I'd share that too, cheers!

@Sarania
Copy link

Sarania commented Sep 20, 2025

I note that the inference script in this repo runs at about half the speed of my version of the source repo(which is just their repo plus i offload everything and added blockswap, really, plus fixed a few small issues) and engages my GPU less (~170 Watts versus 270) and I wonder if this is because they are using batched_cfg? I'm working on testing that theory but I'm having the issue that our embed and neg embed aren't the same size in dimension 1, guessing unpadded or something so I'm looking into that!

@kohya-ss
Copy link
Owner Author

A new option --vae_chunk_size has been added to the training and inference scripts. It reduces memory usage by chunking the Conv2d inside VAE. When specifying it, we recommend a value around 16, such as --vae_chunk_size 16.

As a result, the --vae_enable_tiling option has been removed. We apologize for the inconvenience.

@kohya-ss
Copy link
Owner Author

kohya-ss commented Sep 21, 2025

But it was updated, from hunyuan_image_pipeline:

Thank you for letting me know! It's annoying that the official script contains errors... apg_start_step and guidance_rescale can now be specified as options😀

Regarding cfg_guider_ocr and cfg_guider_general, it's true that the same thing is currently being used. I think it's possible to make them common, but I'll leave them as they are as I may want to change some parameters in the future.

I haven't checked refiner yet, but it's great that VAE supports tiling. It will be worth checking out the VAE implementation.

EDIT: --guidance_rescale_apg has been also added.

@kohya-ss
Copy link
Owner Author

I note that the inference script in this repo runs at about half the speed of my version of the source repo(which is just their repo plus i offload everything and added blockswap, really, plus fixed a few small issues) and engages my GPU less (~170 Watts versus 270) and I wonder if this is because they are using batched_cfg? I'm working on testing that theory but I'm having the issue that our embed and neg embed aren't the same size in dimension 1, guessing unpadded or something so I'm looking into that!

I don't know why it's efficient, but it may be due to changes in the attention mechanism.

Additionally, a "forward only" mode has been added to the block swap feature (which already existed in Musubi Tuner). This may improve the efficiency of block swapping during inference.

@kohya-ss kohya-ss marked this pull request as ready for review September 21, 2025 04:10
@kohya-ss kohya-ss requested a review from Copilot September 21, 2025 04:18
Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds support for HunyuanImage-2.1, including inference and LoRA training functionality. It incorporates fp8 scaling and dynamic merging during LoRA inference from Musubi Tuner, and implements a unified attention method. The implementation allows using official weights for DiT models and ComfyUI weights for other components.

  • Implements HunyuanImage-2.1 DiT model with multimodal double-stream and single-stream transformer blocks
  • Adds comprehensive LoRA training support with regex-based learning rates and dimensions
  • Integrates fp8 optimization for memory efficiency and performance improvements

Reviewed Changes

Copilot reviewed 23 out of 23 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
train_network.py Adds lazy loading, casting control methods, and model initialization improvements
networks/lora_hunyuan_image.py Complete LoRA implementation for HunyuanImage with qkv/mlp splitting
networks/lora_flux.py Refactors hardcoded dimensions to use class methods for better extensibility
library/hunyuan_image_*.py Core model implementation including VAE, text encoders, utilities, and modules
library/fp8_optimization_utils.py FP8 quantization system with block-wise and channel-wise modes
library/attention.py Unified attention interface supporting multiple backends
hunyuan_image_train_network.py Training script with sampling, caching, and optimization features
Comments suppressed due to low confidence (2)

library/lora_utils.py:1

  • Remove commented code that appears to be old implementation. If this is needed for reference, move it to a separate comment block with explanation.
import os

library/hunyuan_image_vae.py:1

  • This TODO comment should be updated to explain the specific conditions when float() conversion is needed and provide a plan for removing this workaround.
from typing import Optional, Tuple

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

@Sarania
Copy link

Sarania commented Sep 21, 2025

Regarding cfg_guider_ocr and cfg_guider_general, it's true that the same thing is currently being used. I think it's possible to make them common, but I'll leave them as they are as I may want to change some parameters in the future.

I thought about this as well, I experimented with using different values for the different guiders myself. It's not a huge burden to keep both anyway so that makes sense to me.

I don't know why it's efficient, but it may be due to changes in the attention mechanism.

I pulled the new changes but the speed isn't really any different. Could it have anything to do with that I'm using the split Q/K/V statedict that I saved after their initial processing in my own stuff versus the combined QKV base here? That bit is a little bit above my current skill level but it does stand as a major difference between the repos so I wondered. Edit: See below, it's the batched uncond. Other than that I wonder about the batching because my version of their repo ticks at about 6 sec per it regardless of whether I use CFG or not, I can literally hardcode it to off and it clearly affects the results but the speed is almost the same which is roughly equal to the speed of this repo without CFG i.e. a single model call. It's weird that batching would be that effective but maybe it's offsetting the swapping?

Thank you for letting me know! It's annoying that the official script contains errors...

I agree, it also had an issue where OCR guidance was negated when using a negative prompt because of calling the same encode_glyph for both which sets a bool self.ocr_mask on the class and since it's called in the order positive, negative... that bool is negated unless there's text in your negative too which why would there be? I almost suspect the text features were a last minute addition in response to Qwen Image, though that's complete speculation on my part. I can almost imagine some poor researcher getting yelled at by the higher ups "Qwen has text! We can't ship without text and we ship in less than a week! Make it happen!" Regardless of the small issues, this model is awesome and I get more excited about it every new generation I make, every new thing I try. I know I'm engaging a lot and typing up big posts but that's just my excitement showing, please don't let it overwhelm you 🙂 Edit: Oh and thank you for implementing my suggestions!

@Sarania
Copy link

Sarania commented Sep 21, 2025

Update: I was able to implement batched uncond by having GPT help me pad the Qwen outputs to match length and yeah it improved the speed by a LOT(6.8s/it w/CFG versus over 11 before! No CFG is 5.5s/it with my current settings for reference.) while costing very little VRAM relatively(about a gig when making 3072x2304). This is in bf16 with 26 blocks swapped. This is excellent because now I can make this script my base and start working on it rather than my hacked up version of the base repo that I've been using until now! This is the padding code, it's just GPTs though in honesty b/c I wasn't sure how XD:

def pad_prompt_embeds_and_mask(
    prompt_embeds: torch.Tensor,        # [B, L, C]
    encoder_attention_mask: torch.Tensor,  # [B, L]
    target_len: int
) -> Tuple[torch.Tensor, torch.Tensor]:
    import torch.nn.functional as F
    """Right-pad embeddings with zeros and mask with 0s to reach target_len."""
    B, L, C = prompt_embeds.shape
    pad_len = target_len - L
    if pad_len <= 0:
        return prompt_embeds, encoder_attention_mask
    # F.pad pads last dim first; we want to pad along sequence dim (dim=1):
    # pad format is (pad_last_dim_left, pad_last_dim_right, pad_seq_left, pad_seq_right, pad_batch_left, pad_batch_right)
    prompt_embeds = F.pad(prompt_embeds, (0, 0, 0, pad_len, 0, 0), value=0.0)
    encoder_attention_mask = F.pad(encoder_attention_mask, (0, pad_len, 0, 0), value=0)
    return prompt_embeds, encoder_attention_mask


def prepare_cfg_batch(
    pos_embeds: torch.Tensor, pos_mask: torch.Tensor,   # [1, Lp, C], [1, Lp]
    neg_embeds: torch.Tensor, neg_mask: torch.Tensor    # [1, Ln, C], [1, Ln]
) -> Tuple[torch.Tensor, torch.Tensor]:
    # 1) Find common length and right-pad both to it
    target_len = max(pos_embeds.shape[1], neg_embeds.shape[1])
    pos_embeds, pos_mask = pad_prompt_embeds_and_mask(pos_embeds, pos_mask, target_len)
    neg_embeds, neg_mask = pad_prompt_embeds_and_mask(neg_embeds, neg_mask, target_len)
    print(f"embed shape: {pos_embeds.shape}, mean: {pos_embeds.mean()}, std: {pos_embeds.std()}")
    print(f"negative_embed shape: {neg_embeds.shape}, mean: {neg_embeds.mean()}, std: {neg_embeds.std()}")
    # 2) Stack as a batch of 2: [neg, pos] (order is your choice; just be consistent downstream)
    prompt_embeds = torch.cat([neg_embeds, pos_embeds], dim=0)     # [2, L, C]
    encoder_attention_mask = torch.cat([neg_mask, pos_mask], dim=0)  # [2, L]

    return prompt_embeds, encoder_attention_mask

and then just something like(which I'm sure you know, just sharing mine for reference):

    # Denoising loop
    do_cfg = args.guidance_scale != 1.0
    batched_cfg = args.batched_cfg
    embed_in = embed
    mask_in = mask
    embed_byt5_in = embed_byt5
    mask_byt5_in = mask_byt5
    if do_cfg and batched_cfg:
        embed_in, mask_in = prepare_cfg_batch(embed, mask, negative_embed, negative_mask)
        embed_byt5_in = torch.cat([negative_embed_byt5, embed_byt5], dim=0)
        mask_byt5_in = torch.cat([negative_mask_byt5, mask_byt5], dim=0)

    autocast_enabled = args.fp8

    with tqdm(total=len(timesteps), desc="Denoising steps") as pbar:
        for i, t in enumerate(timesteps):
            latent_model_input = torch.cat([latents] * 2) if do_cfg and batched_cfg else latents
            t_expand = t.repeat(latent_model_input.shape[0])

            with torch.no_grad(), torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=autocast_enabled):
                noise_pred = model(latent_model_input, t_expand, embed_in, mask_in, embed_byt5_in, mask_byt5_in)

            if do_cfg:
                if not batched_cfg:
                    with torch.no_grad(), torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=autocast_enabled):
                        uncond_noise_pred = model(
                            latents, t_expand, negative_embed, negative_mask, negative_embed_byt5, negative_mask_byt5
                        )
                else:
                    uncond_noise_pred, noise_pred = noise_pred.chunk(2, dim=0)
                noise_pred = hunyuan_image_utils.apply_classifier_free_guidance(
                    noise_pred,
                    uncond_noise_pred,
                    ocr_mask[0],
                    args.guidance_scale,
                    i,
                    apg_start_step_ocr=args.apg_start_step_ocr,
                    apg_start_step_general=args.apg_start_step_general,
                    cfg_guider_ocr=cfg_guider_ocr,
                    cfg_guider_general=cfg_guider_general,
                    guidance_rescale=args.guidance_rescale,
                )

            # ensure latents dtype is consistent
            latents = hunyuan_image_utils.step(latents, noise_pred, sigmas, i).to(latents.dtype)

            pbar.update()

It seems to just make a really big diff here, unlike with other recent models so I definitely think adding it as an option would be good as it's not very heavy either and output is identical. Cheers and thanks for this script that will now become my base of operations!

Edit: Small laugh, I see at one point you added regular single quote support too... and then later found out why that might have been left out just like I did 😂 So many of us walk the same paths without realizing.

@kohya-ss
Copy link
Owner Author

It's great that batching cond and uncond improved the speed. It's certainly strange that there's almost no difference in speed, but that's possible if the bottleneck is somewhere else.

In this repository, I've kept cond and uncond separate to minimize memory usage.
I also feel that the implementation around OCR is quite ad hoc. Conversely, there may be room for research in the community.

When I looked at the official repository, the single quote had been fixed, but the APG step had not yet been fixed😂

@Sarania
Copy link

Sarania commented Sep 22, 2025

It's great that batching cond and uncond improved the speed. It's certainly strange that there's almost no difference in speed, but that's possible if the bottleneck is somewhere else.

In this repository, I've kept cond and uncond separate to minimize memory usage. I also feel that the implementation around OCR is quite ad hoc. Conversely, there may be room for research in the community.

When I looked at the official repository, the single quote had been fixed, but the APG step had not yet been fixed😂

Totally fair if you wanna keep it that way to minimize resources and keep it simple, I thought I'd share my results either way for others who may come across. My suspicion is the gain might be significantly less (relatively) when not using block swap. This is a BIG model in full BF16 and so I'm swapping 26 blocks. With no batched uncond, I see around 150Watts usage on my GPU. With batched, I see around 220. This makes sense since we're only doing one round of block swapping with these large blocks instead of two so I think that's where the biggest gain lies tbh. I tested the outputs and they are identical whether uncond is batched or not so the computation is ultimately the same and this is my theory of how it's that much faster 🙂

Edit: Yeah my guess is the 1.3s per iteration increase with batched uncond is probably what my single batch iteration speed would be if I wasn't swapping like 60% of the model. But since we've already paid the ~4.2s cost to swap the blocks once, when batched it only adds that much. 1.3s for cond, 1.3s for uncond, 4.2s for swap = 6.8s per it! Add another 4.2s to swap again when unbatched = 11s which is exactly what I see.

Edit2: Yeah that sounds like the same time I first looked at the repo. But I added single quotes back in and then learned why they were removed XD

@kohya-ss
Copy link
Owner Author

Thanks for the detailed explanation. That makes sense! I'm sure the bottleneck is transfer speed, not computation speed.

@kohya-ss
Copy link
Owner Author

We tested LoRA training on existing SDXL, FLUX.1, Chroma, and Lumina, and the loss trends seem to be roughly consistent.

@kohya-ss kohya-ss merged commit 121853c into sd3 Sep 23, 2025
3 checks passed
@kohya-ss kohya-ss deleted the feat-hunyuan-image-2.1-inference branch September 23, 2025 10:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants