Skip to content

[Model] Add Hunyuan Image3 AR Support#759

Draft
usberkeley wants to merge 4 commits intovllm-project:mainfrom
usberkeley:hunyuan-image3
Draft

[Model] Add Hunyuan Image3 AR Support#759
usberkeley wants to merge 4 commits intovllm-project:mainfrom
usberkeley:hunyuan-image3

Conversation

@usberkeley
Copy link

@usberkeley usberkeley commented Jan 13, 2026

Purpose

This PR adds support for the Hunyuan Image3 model to vLLM-Omni. Hunyuan Image3 is a multimodal image generation model developed by Tencent, supporting text-to-image generation tasks.

Test Plan

  1. Text input test
  • GPU: 8 x L40S (48GB)
  • TP: 8

Note: The default configuration in hunyuan_image_3_moe.yaml is tensor_parallel_size: 8.

from vllm_omni.entrypoints.omni import Omni

if __name__ == "__main__":
    omni = Omni(model="tencent/HunyuanImage-3.0")
    prompts = [
    {
        "prompt": "<|im_start|>system\nYou are Qwen.<|im_end|>\n<|im_start|>user\nExplain the system architecture for a scalable audio generation pipeline. Answer in 15 words.<|im_end|>\n<|im_start|>assistant\n",
        "modalities": ["text"]
    }
    ]
    omni_outputs = omni.generate(prompts)
    print(omni_outputs[0].request_output[0].outputs[0].text)
a68c8020-1416-4f58-98b1-73d7bbd61ee8
  1. Multimodal input test
    TODO

Test Result

TODO


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft.

BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)

@usberkeley usberkeley force-pushed the hunyuan-image3 branch 2 times, most recently from ed4d687 to bb011f2 Compare January 14, 2026 09:09
@usberkeley usberkeley marked this pull request as ready for review January 15, 2026 03:21
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: bb011f27c7

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

@hsliuustc0106
Copy link
Collaborator

please paste your test example command

@usberkeley
Copy link
Author

usberkeley commented Jan 15, 2026

please paste your test example command

Hi @hsliuustc0106

  1. Text input test
  • GPU: 8 x L40S (48GB)
  • TP: 8

Note: The default configuration in hunyuan_image_3_moe.yaml is tensor_parallel_size: 8.

from vllm_omni.entrypoints.omni import Omni

if __name__ == "__main__":
    omni = Omni(model="tencent/HunyuanImage-3.0")
    prompts = [
    {
        "prompt": "<|im_start|>system\nYou are Qwen.<|im_end|>\n<|im_start|>user\nExplain the system architecture for a scalable audio generation pipeline. Answer in 15 words.<|im_end|>\n<|im_start|>assistant\n",
        "modalities": ["text"]
    }
    ]
    omni_outputs = omni.generate(prompts)
    print(omni_outputs[0].request_output[0].outputs[0].text)

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds initial vLLM-Omni autoregressive (AR) integration for Tencent’s Hunyuan Image3 model, including model registration and a default stage config.

Changes:

  • Updates AR GPU runner postprocessing to use a shared multimodal-output extraction helper.
  • Registers HunyuanImage3ForCausalMM in the Omni model registry.
  • Introduces a new Hunyuan Image3 model implementation + utilities and a new stage config YAML.

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 12 comments.

Show a summary per file
File Description
vllm_omni/worker/gpu_ar_model_runner.py Switches to extract_multimodal_outputs for postprocessing model outputs.
vllm_omni/model_executor/stage_configs/hunyuan_image_3_moe.yaml Adds a default stage config for running Hunyuan Image3 with the AR worker/scheduler.
vllm_omni/model_executor/models/registry.py Registers the Hunyuan Image3 model architecture for lazy loading.
vllm_omni/model_executor/models/hunyuan_image3_0/hunyuan_image3_0_utils.py Adds Hunyuan-specific RoPE2D + image KV cache helper utilities.
vllm_omni/model_executor/models/hunyuan_image3_0/hunyuan_image3_0.py Adds the main Hunyuan Image3 model implementation (decoder, attention, MoE, weight loading).
vllm_omni/model_executor/models/hunyuan_image3_0/__init__.py Exposes the new model class for import.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 197 to 198
text_hidden_states, multimodal_outputs = self.extract_multimodal_outputs(hidden_states)
hidden_states = text_hidden_states
Copy link

Copilot AI Jan 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extract_multimodal_outputs is called unconditionally here, but on non-last pipeline-parallel ranks model_output can be an IntermediateTensors instance (the code even asserts that a few lines later). extract_multimodal_outputs does not handle IntermediateTensors, so this will raise ValueError before the early return for non-last PP ranks. Guard this call (e.g., skip extraction when isinstance(hidden_states, IntermediateTensors)) or move it into the get_pp_group().is_last_rank branch.

Suggested change
text_hidden_states, multimodal_outputs = self.extract_multimodal_outputs(hidden_states)
hidden_states = text_hidden_states
if isinstance(hidden_states, IntermediateTensors):
# On non-last pipeline-parallel ranks, hidden_states is an
# IntermediateTensors instance. extract_multimodal_outputs does
# not handle this type, so skip extraction here.
multimodal_outputs = None
else:
text_hidden_states, multimodal_outputs = self.extract_multimodal_outputs(
hidden_states
)
hidden_states = text_hidden_states

Copilot uses AI. Check for mistakes.
default 4097 (timestamp + 4096 image tokens).
"""
self.image_token_len: int = image_token_len
self.image_kv_cache: tuple[torch.Tensor, torch.Tensor] = None
Copy link

Copilot AI Jan 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ImageKVCacheManager initializes self.image_kv_cache, but the implementation actually uses self.image_kv_cache_map (which is never initialized in __init__). If __call__ is ever invoked with first_step=False before a first_step=True call, this will raise an AttributeError. Initialize a single cache attribute (e.g., self.image_kv_cache_map: tuple[Tensor, Tensor] | None = None) and use it consistently.

Suggested change
self.image_kv_cache: tuple[torch.Tensor, torch.Tensor] = None
self.image_kv_cache_map: Optional[tuple[torch.Tensor, torch.Tensor]] = None

Copilot uses AI. Check for mistakes.
Comment on lines 144 to 146
# 5. Restore original shape + convert to bfloat16
q = q.transpose(1, 2).reshape(hidden_states.shape[0], self.num_heads * self.head_dim).to(torch.bfloat16)
k = k.transpose(1, 2).reshape(hidden_states.shape[0], self.num_kv_heads * self.head_dim).to(torch.bfloat16)
Copy link

Copilot AI Jan 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This forces q/k to torch.bfloat16 unconditionally. If the model is running in float16 (or another dtype), this introduces extra casts and can cause dtype mismatches in attention/KV-cache code paths. Prefer preserving the input dtype (e.g., q.dtype/hidden_states.dtype) instead of hard-coding bfloat16.

Suggested change
# 5. Restore original shape + convert to bfloat16
q = q.transpose(1, 2).reshape(hidden_states.shape[0], self.num_heads * self.head_dim).to(torch.bfloat16)
k = k.transpose(1, 2).reshape(hidden_states.shape[0], self.num_kv_heads * self.head_dim).to(torch.bfloat16)
# 5. Restore original shape and cast back to the model dtype
target_dtype = hidden_states.dtype
q = q.transpose(1, 2).reshape(hidden_states.shape[0], self.num_heads * self.head_dim).to(target_dtype)
k = k.transpose(1, 2).reshape(hidden_states.shape[0], self.num_kv_heads * self.head_dim).to(target_dtype)

Copilot uses AI. Check for mistakes.
Comment on lines 706 to 708
hidden_states = hidden_states.contiguous()
torch.cuda.synchronize()

Copy link

Copilot AI Jan 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.cuda.synchronize() inside forward_block will severely degrade performance and can break overlap/streaming assumptions (and CUDA graph capture). Unless this is strictly for debugging, it should be removed or gated behind an explicit debug/profiling flag.

Copilot uses AI. Check for mistakes.
Comment on lines 903 to 904
# if tp_rank == 0:
# print(f"origin weight_name: {weight_name}, param_name: {param_name}, name: {name}")
Copy link

Copilot AI Jan 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment appears to contain commented-out code.

Copilot uses AI. Check for mistakes.
Comment on lines 913 to 914
# if tp_rank == 0:
# print(f"remapped weight_name: {weight_name}, offset: {offset}, den: {den}")
Copy link

Copilot AI Jan 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment appears to contain commented-out code.

Copilot uses AI. Check for mistakes.
Comment on lines 925 to 926
# if tp_rank == 0:
# print(f"name_mapped: {name_mapped}, found_num: {found_num}")
Copy link

Copilot AI Jan 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment appears to contain commented-out code.

Copilot uses AI. Check for mistakes.
attention_bias = getattr(config, "attention_bias", False) or getattr(config, "bias", False)
cla_factor = _get_cla_factor(config)
attention_type = (
AttentionType.ENCODER_DECODER if layer_id >= 0 and layer_id % cla_factor != 0 else AttentionType.DECODER
Copy link

Copilot AI Jan 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test is always true, because of this condition.

Suggested change
AttentionType.ENCODER_DECODER if layer_id >= 0 and layer_id % cla_factor != 0 else AttentionType.DECODER
AttentionType.ENCODER_DECODER if layer_id % cla_factor != 0 else AttentionType.DECODER

Copilot uses AI. Check for mistakes.
@@ -0,0 +1,1097 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import typing
Copy link

Copilot AI Jan 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Module 'typing' is imported with both 'import' and 'import from'.

Suggested change
import typing

Copilot uses AI. Check for mistakes.
@david6666666
Copy link
Collaborator

@usberkeley
Copy link
Author

any updated? Tencent has just released https://huggingface.co/tencent/HunyuanImage-3.0-Instruct https://huggingface.co/tencent/HunyuanImage-3.0-Instruct-Distil

Got it. we are working on image encoder and will follow up the update of new release

@usberkeley
Copy link
Author

Hi @princepride

When you have a moment, please review this code. thanks!

@princepride
Copy link
Contributor

@usberkeley Can you rebase your code first, we have changed some code in ar_model_runner.

@usberkeley usberkeley marked this pull request as draft February 2, 2026 10:14
@usberkeley usberkeley force-pushed the hunyuan-image3 branch 3 times, most recently from 71570e7 to b8d58b5 Compare February 4, 2026 03:17
@usberkeley usberkeley marked this pull request as ready for review February 4, 2026 03:19
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: b8d58b560e

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines 1953 to 1964
for vae_token_embed in vae_token_embeddings:
# 1. Timestep embedding
cond_timestep = torch.zeros((1,))
timestep_emb = self._timestep_encode(cond_timestep)
combined_embeddings += (timestep_emb,)

# 2. VAE image token embeddings
combined_embeddings += (vae_token_embed,)

# 3. ViT image embeddings
for vit_embed in vit_embeddings:
combined_embeddings += (vit_embed,)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve per-image embedding order for multi-image prompts

embed_multimodal appends all ViT embeddings after all VAE/timestep embeddings, but _get_prompt_updates constructs placeholders in per‑image order (timestep → VAE → ViT for each image). With multiple images, this misaligns embeddings so image N’s ViT tokens get the wrong vectors, breaking multimodal conditioning. The embedding concatenation should interleave each image’s ViT embeddings immediately after its VAE/timestep block to match the placeholder order.

Useful? React with 👍 / 👎.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we have multi image, your combine image tokens will looks like:[ts1][vae1][ts2][vae2]...[vit1][vit2]... , but I think it should be: [ts1][vae1][vit1][ts2][vae2][vit2]...

Comment on lines 2038 to 2045
# Keywords for image generation components that we skip
generation_keywords = [
"final_layer",
# "patch_embed",
# "timestep_emb",
"time_embed",
"time_embed_2",
"guidance_emb",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Load time_embed weights used in multimodal VAE encoding

load_weights skips parameters with names containing time_embed and time_embed_2, but time_embed is used in embed_multimodal to convert VAE latents into token embeddings. Skipping these weights leaves the timestep embedder randomly initialized, so multimodal image embeddings will be inconsistent with the checkpoint even when images are provided. If image embeddings are supported, time_embed should be loaded (or the embed path should avoid it).

Useful? React with 👍 / 👎.

@princepride
Copy link
Contributor

@usberkeley pre-commit failed, PTAL

Copy link
Contributor

@princepride princepride left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@usberkeley good job! just a little advice.

stage_type: llm # Use llm stage type to launch OmniLLM
runtime:
process: true # Run this stage in a separate process
devices: "0,1,2,3,4,5,6,7" # Visible devices for this stage (CUDA_VISIBLE_DEVICES/torch.cuda.set_device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need so many devices?

Comment on lines 1953 to 1964
for vae_token_embed in vae_token_embeddings:
# 1. Timestep embedding
cond_timestep = torch.zeros((1,))
timestep_emb = self._timestep_encode(cond_timestep)
combined_embeddings += (timestep_emb,)

# 2. VAE image token embeddings
combined_embeddings += (vae_token_embed,)

# 3. ViT image embeddings
for vit_embed in vit_embeddings:
combined_embeddings += (vit_embed,)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree

Comment on lines 1953 to 1964
for vae_token_embed in vae_token_embeddings:
# 1. Timestep embedding
cond_timestep = torch.zeros((1,))
timestep_emb = self._timestep_encode(cond_timestep)
combined_embeddings += (timestep_emb,)

# 2. VAE image token embeddings
combined_embeddings += (vae_token_embed,)

# 3. ViT image embeddings
for vit_embed in vit_embeddings:
combined_embeddings += (vit_embed,)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for vae_token_embed in vae_token_embeddings:
# 1. Timestep embedding
cond_timestep = torch.zeros((1,))
timestep_emb = self._timestep_encode(cond_timestep)
combined_embeddings += (timestep_emb,)
# 2. VAE image token embeddings
combined_embeddings += (vae_token_embed,)
# 3. ViT image embeddings
for vit_embed in vit_embeddings:
combined_embeddings += (vit_embed,)
for i, vae_token_embed in enumerate(vae_token_embeddings):
cond_timestep = torch.zeros((1,))
timestep_emb = self._timestep_encode(cond_timestep)
combined_embeddings += (timestep_emb,)
combined_embeddings += (vae_token_embed,)
if i < len(vit_embeddings):
combined_embeddings += (vit_embeddings[i],)

Comment on lines 1953 to 1964
for vae_token_embed in vae_token_embeddings:
# 1. Timestep embedding
cond_timestep = torch.zeros((1,))
timestep_emb = self._timestep_encode(cond_timestep)
combined_embeddings += (timestep_emb,)

# 2. VAE image token embeddings
combined_embeddings += (vae_token_embed,)

# 3. ViT image embeddings
for vit_embed in vit_embeddings:
combined_embeddings += (vit_embed,)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we have multi image, your combine image tokens will looks like:[ts1][vae1][ts2][vae2]...[vit1][vit2]... , but I think it should be: [ts1][vae1][vit1][ts2][vae2][vit2]...

Comment on lines 918 to 1061
class ImageKVCacheManager:
"""
Manages specialized caching and updating of KV-Cache for image tokens in multimodal models.
"""

def __init__(self, image_token_len: int = 4097):
"""
Args:
image_token_len: Number of tokens per image (including special placeholders),
default 4097 (timestamp + 4096 image tokens).
"""
self.image_token_len: int = image_token_len
self.image_kv_cache: tuple[torch.Tensor, torch.Tensor] = None

def _save_image_kv_caches(
self,
key: torch.Tensor,
value: torch.Tensor,
seq_len: int,
) -> None:
bs, q_len, num_kv_heads, head_dim = key.shape
assert q_len == seq_len, f"for first-step, {q_len} != {seq_len}"

key = key.reshape(-1, num_kv_heads, head_dim)
value = value.reshape(-1, num_kv_heads, head_dim)

cached_prompt_len = seq_len - self.image_token_len - 1
cached_key = [key[:cached_prompt_len], key[seq_len - 1: seq_len]]
cached_value = [value[:cached_prompt_len], value[seq_len - 1: seq_len]]

if bs > 1:
assert bs == 2, "for cfg case, bs must be 2"
cached_key.append(key[seq_len: seq_len + cached_prompt_len])
cached_key.append(key[-1:])

cached_value.append(value[seq_len: seq_len + cached_prompt_len])
cached_value.append(value[-1:])

cached_key = torch.cat(cached_key, dim=0)
cached_value = torch.cat(cached_value, dim=0)
self.image_kv_cache_map = (cached_key, cached_value)

def _update_image_kv_caches(
self,
key: torch.Tensor,
value: torch.Tensor,
seq_len: int,
) -> tuple[torch.Tensor, torch.Tensor]:
cached_key, cached_value = self.image_kv_cache_map
bs, q_len, num_kv_heads, head_dim = key.shape

cached_prompt_len = cached_key.shape[0] // bs - 1
assert (cached_prompt_len + 1) == (seq_len - q_len), f"{cached_prompt_len + 1} != {seq_len - q_len}"

key = key.reshape(-1, num_kv_heads, head_dim)
value = value.reshape(-1, num_kv_heads, head_dim)

new_key = [
cached_key[:cached_prompt_len],
key[:q_len],
cached_key[cached_prompt_len: cached_prompt_len + 1],
]
new_value = [
cached_value[:cached_prompt_len],
value[:q_len],
cached_value[cached_prompt_len: cached_prompt_len + 1],
]

if bs > 1:
assert bs == 2, "for cfg case, bs must be 2"
new_key.append(cached_key[cached_prompt_len + 1: cached_prompt_len + 1 + cached_prompt_len])
new_key.append(key[q_len:])
new_key.append(cached_key[-1:])

new_value.append(cached_value[cached_prompt_len + 1: cached_prompt_len + 1 + cached_prompt_len])
new_value.append(value[q_len:])
new_value.append(cached_value[-1:])

new_key = torch.cat(new_key, dim=0)
new_value = torch.cat(new_value, dim=0)
new_key = new_key.reshape(bs, seq_len, num_kv_heads, head_dim)
new_value = new_value.reshape(bs, seq_len, num_kv_heads, head_dim)

return new_key.contiguous(), new_value.contiguous()

def __call__(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: Optional["HunyuanImageAttentionMeta"], # 前向引用加引号
attention_mask: torch.Tensor | None = None,
) -> torch.Tensor:
assert attn_metadata is not None, "attn_metadata is required"
self.image_token_len = attn_metadata.num_image_tokens
first_step = attn_metadata.first_step

bs = len(attn_metadata.query_lens)
q_len = attn_metadata.query_lens[0]
seq_len = attn_metadata.seq_lens[0]
assert query.shape[0] == bs * q_len, f"{query.shape[0]} != {bs * q_len}"

head_num_per_rank = query.shape[1]
kv_head_num_per_rank = key.shape[1]
repeat_num = head_num_per_rank // kv_head_num_per_rank
head_dim = query.shape[2]

query = query.reshape(bs, q_len, head_num_per_rank, head_dim)
key = key.reshape(bs, q_len, kv_head_num_per_rank, head_dim)
value = value.reshape(bs, q_len, kv_head_num_per_rank, head_dim)

if first_step:
self.image_kv_cache_map = None
self._save_image_kv_caches(key, value, seq_len)
else:
key, value = self._update_image_kv_caches(key, value, seq_len)

query = query.transpose(1, 2).contiguous()
key = key.transpose(1, 2).contiguous()
value = value.transpose(1, 2).contiguous()

key = ImageKVCacheManager.repeat_kv(key, repeat_num)
value = ImageKVCacheManager.repeat_kv(value, repeat_num)

attention_mask = attention_mask.contiguous()

attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0)

attn_output = attn_output.transpose(1, 2).contiguous() # [bs, q_len, heads, head_dim]
attn_output = attn_output.reshape(bs * q_len, head_num_per_rank, head_dim)
return attn_output

@staticmethod
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
The hidden states go from (batch, num_key_value_heads, seqlen, head_dim)
to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought it should appeared in DiT part, right?

Comment on lines +1086 to +699
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs),
act_layer(),
nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use RowParallelLinear and ColumnParallelLinear

info=HunyuanImage3ProcessingInfo,
dummy_inputs=HunyuanImage3DummyInputsBuilder,
)
class HunyuanImage3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can combine HunyuanImage3ConditionalGenerationMixin and HunyuanImage3ForConditionalGeneration

Comment on lines 634 to 672
class ImageTensorInputs(TensorSchema):
"""
Dimensions:
- ni: Number of images
- seq_len: Sequence length (varies per image)
- dim: Feature dimension

Historical context:
- For Hunyuan Image 3.0, images are processed through VAE and ViT encoders
- image_tensors contains vae_image_tensors, vit_image_tensors, and vit_kwargs
"""

type: Literal["image_tensors"]

image_tensors: Annotated[
dict[str, Any],
TensorShape(),
]

class ImageEmbeddingInputs(TensorSchema):
"""
Dimensions:
- nf: Number of image features
- hs: Hidden size
- ni: Number of images

Historical context:
- image_embeds shape: (num_image_features, hidden_size)
- num_image_features varies based on the number and resolution of the
images.
- hidden_size must match the hidden size of language model backbone.
"""

type: Literal["image_embeds"]

image_embeds: Annotated[
torch.Tensor,
TensorShape("nf", "hs"),
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typically, we put these classes outside.

@princepride
Copy link
Contributor

@usberkeley I think as for AR model, it should have text as the output

@usberkeley usberkeley marked this pull request as draft February 4, 2026 09:48
@hsliuustc0106 hsliuustc0106 requested a review from Copilot February 5, 2026 15:31
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 4 out of 4 changed files in this pull request and generated 12 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

# ========== 临时代码结束 ==========
# print(f"Loading weight name: {name}, tp_rank: {tp_rank}", flush=True)
if contains_unexpected_keyword(name, unexpected_keywords):
print(f"Skipping unexpected weight name: {name}")
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using print() for logging is inconsistent with the rest of the codebase which uses the logger (initialized at line 81). This print statement should use logger.info() or logger.warning() instead to maintain consistency and allow for proper log level control. The logger is already imported and initialized as 'logger = init_logger(name)'.

Suggested change
print(f"Skipping unexpected weight name: {name}")
logger.warning("Skipping unexpected weight name: %s", name)

Copilot uses AI. Check for mistakes.
vae_weights = []
vision_model_weights = []
vision_aligner_weights = []
generation_weights = [] # For image generation components we don't use
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 'generation_weights' list is declared but never used. It was intended to collect weights for image generation components, but those weights are now directly skipped using the 'generation_keywords' filter at line 1913. This unused variable should be removed to clean up the code.

Suggested change
generation_weights = [] # For image generation components we don't use

Copilot uses AI. Check for mistakes.
Comment on lines 1223 to 1228
# ========== 临时代码开始:跳过中间层的权重加载 ==========
# TODO: 这是临时代码,用于调试目的,需要删除
if should_skip_layer_weight(name):
# 跳过中间层的权重,只加载第一层和最后一层
continue
# ========== 临时代码结束 ==========
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Temporary debugging code must be removed before merging. This code skips loading weights for intermediate layers (only loads first and last layer), which will cause the model to not function correctly in production. The comments explicitly state this is for debugging purposes and needs to be deleted.

Suggested change
# ========== 临时代码开始:跳过中间层的权重加载 ==========
# TODO: 这是临时代码,用于调试目的,需要删除
if should_skip_layer_weight(name):
# 跳过中间层的权重,只加载第一层和最后一层
continue
# ========== 临时代码结束 ==========

Copilot uses AI. Check for mistakes.
Comment on lines +1207 to +1212
"patch_embed",
"timestep_emb",
"time_embed",
"time_embed_2",
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's an inconsistency in weight loading logic between HunyuanImage3Model.load_weights and HunyuanImage3ForConditionalGeneration.load_weights. In HunyuanImage3Model (lines 1203-1213), 'patch_embed', 'timestep_emb', 'time_embed', and 'time_embed_2' are in the unexpected_keywords list and will be skipped. However, in HunyuanImage3ForConditionalGeneration.load_weights (lines 1902-1910), 'patch_embed' and 'timestep_emb' are commented out (not skipped), while 'time_embed' and 'time_embed_2' are still skipped. This creates confusion about which components should actually load weights. Since these components are initialized in init (lines 1456, 1472, 1487) and used in embed_multimodal (lines 1792, 1815), their weights should be loaded. The commented-out keywords suggest uncertainty about the correct behavior.

Suggested change
"patch_embed",
"timestep_emb",
"time_embed",
"time_embed_2",

Copilot uses AI. Check for mistakes.
Comment on lines +920 to +956
# Handle attention computation for different modes
if attn_meta is not None and isinstance(attn_meta, HunyuanImageAttentionMeta):
# Image generation mode (Text -> Image)
# Note: Currently not supported - this is reserved for future image generation functionality
assert False, (
"Image generation mode (Text -> Image) is not currently supported. "
"This model only supports Text + Image -> Text scenarios. "
"If image generation is needed, ImageKVCacheManager should be implemented."
)
# Future implementation would use: attn_output = self.image_attn(q, k, v, attn_meta, attention_mask=attention_mask)
else:
# Standard text generation mode (Text + Image -> Text)
attn_output = self.attn(q, k, v)
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assertion will unconditionally fail and prevent the code from running whenever HunyuanImageAttentionMeta is used. The assertion message indicates that "Image generation mode (Text -> Image) is not currently supported", but the code should not reach this point if that mode is truly unsupported. Consider removing the entire if-block (lines 921-929) if image generation is not supported, or remove the assertion if it will be implemented in the future. The assertion makes the code unreachable and creates dead code.

Suggested change
# Handle attention computation for different modes
if attn_meta is not None and isinstance(attn_meta, HunyuanImageAttentionMeta):
# Image generation mode (Text -> Image)
# Note: Currently not supported - this is reserved for future image generation functionality
assert False, (
"Image generation mode (Text -> Image) is not currently supported. "
"This model only supports Text + Image -> Text scenarios. "
"If image generation is needed, ImageKVCacheManager should be implemented."
)
# Future implementation would use: attn_output = self.image_attn(q, k, v, attn_meta, attention_mask=attention_mask)
else:
# Standard text generation mode (Text + Image -> Text)
attn_output = self.attn(q, k, v)
# Compute attention output (image-specific behavior is handled in the RoPE step above)
attn_output = self.attn(q, k, v)

Copilot uses AI. Check for mistakes.
Returns:
JointImageInfo containing VAE and ViT preprocessed image info
"""
from .hunyuan_image.image_processor import resize_and_crop
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The import statement references a non-existent module 'hunyuan_image.image_processor'. This import is expected to provide the 'resize_and_crop' function, but there is no 'hunyuan_image' subdirectory in the 'hunyuan_image3' folder. This will cause an ImportError when the preprocess method is called. The resize_and_crop function needs to be either implemented in this file or imported from the correct location.

Copilot uses AI. Check for mistakes.
Comment on lines +1331 to +1328
# if tp_rank == 0:
# print(f"origin weight_name: {weight_name}, param_name: {param_name}, name: {name}")
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment appears to contain commented-out code.

Copilot uses AI. Check for mistakes.
Comment on lines +1341 to +1338
# if tp_rank == 0:
# print(f"remapped weight_name: {weight_name}, offset: {offset}, den: {den}")
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment appears to contain commented-out code.

Copilot uses AI. Check for mistakes.
Comment on lines +1353 to +1350
# if tp_rank == 0:
# print(f"name_mapped: {name_mapped}, found_num: {found_num}")
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment appears to contain commented-out code.

Copilot uses AI. Check for mistakes.
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import math
import typing
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Module 'typing' is imported with both 'import' and 'import from'.

Copilot uses AI. Check for mistakes.
Signed-off-by: Bradley <bradley.b.pitt@gmail.com>
Signed-off-by: Bradley <bradley.b.pitt@gmail.com>
Signed-off-by: Bradley <bradley.b.pitt@gmail.com>
Signed-off-by: Bradley <bradley.b.pitt@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants