Skip to content

Commit 5da0839

Browse files
committed
update
1 parent e7ffeae commit 5da0839

File tree

2 files changed

+258
-22
lines changed

2 files changed

+258
-22
lines changed

src/diffusers/models/transformers/transformer_hunyuan_video.py

+212-12
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,14 @@
2929
from ..embeddings import (
3030
CombinedTimestepGuidanceTextProjEmbeddings,
3131
CombinedTimestepTextProjEmbeddings,
32+
PixArtAlphaTextProjection,
33+
TimestepEmbedding,
34+
Timesteps,
3235
get_1d_rotary_pos_embed,
3336
)
3437
from ..modeling_outputs import Transformer2DModelOutput
3538
from ..modeling_utils import ModelMixin
36-
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
39+
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, FP32LayerNorm
3740

3841

3942
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -173,6 +176,84 @@ def forward(
173176
return gate_msa, gate_mlp
174177

175178

179+
class HunyuanVideoTokenReplaceAdaLayerNormZero(nn.Module):
180+
def __init__(self, embedding_dim: int, norm_type: str = "layer_norm", bias: bool = True):
181+
super().__init__()
182+
self.silu = nn.SiLU()
183+
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
184+
185+
if norm_type == "layer_norm":
186+
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
187+
elif norm_type == "fp32_layer_norm":
188+
self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False)
189+
else:
190+
raise ValueError(
191+
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
192+
)
193+
194+
def forward(
195+
self,
196+
hidden_states: torch.Tensor,
197+
emb: torch.Tensor,
198+
token_replace_emb: torch.Tensor,
199+
first_frame_num_tokens: int,
200+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
201+
emb = self.linear(self.silu(emb))
202+
token_replace_emb = self.linear(self.silu(token_replace_emb))
203+
204+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
205+
tr_shift_msa, tr_scale_msa, tr_gate_msa, tr_shift_mlp, tr_scale_mlp, tr_gate_mlp = token_replace_emb.chunk(
206+
6, dim=1
207+
)
208+
209+
norm_hidden_states = self.norm(hidden_states)
210+
hidden_states_zero = (
211+
norm_hidden_states[:, :first_frame_num_tokens] * (1 + tr_scale_msa[:, None]) + tr_shift_msa[:, None]
212+
)
213+
hidden_states_orig = (
214+
norm_hidden_states[:, first_frame_num_tokens:] * (1 + scale_msa[:, None]) + shift_msa[:, None]
215+
)
216+
hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
217+
218+
return (
219+
hidden_states,
220+
gate_msa,
221+
shift_mlp,
222+
scale_mlp,
223+
gate_mlp,
224+
tr_gate_msa,
225+
tr_shift_mlp,
226+
tr_scale_mlp,
227+
tr_gate_mlp,
228+
)
229+
230+
231+
class HunyuanVideoTimestepTextProjEmbeddings(nn.Module):
232+
def __init__(self, embedding_dim: int, pooled_projection_dim: int, image_condition_type: Optional[str] = None):
233+
super().__init__()
234+
235+
self.image_condition_type = image_condition_type
236+
237+
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
238+
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
239+
self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
240+
241+
def forward(self, timestep: torch.Tensor, pooled_projection: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
242+
timesteps_proj = self.time_proj(timestep)
243+
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
244+
pooled_projections = self.text_embedder(pooled_projection)
245+
conditioning = timesteps_emb + pooled_projections
246+
247+
token_replace_emb = None
248+
if self.image_condition_type == "token_replace":
249+
token_replace_timestep = torch.zeros_like(timestep)
250+
token_replace_proj = self.time_proj(token_replace_timestep)
251+
token_replace_emb = self.timestep_embedder(token_replace_proj)
252+
token_replace_emb = token_replace_emb + conditioning
253+
254+
return conditioning, token_replace_emb
255+
256+
176257
class HunyuanVideoIndividualTokenRefinerBlock(nn.Module):
177258
def __init__(
178259
self,
@@ -468,6 +549,7 @@ def forward(
468549
temb: torch.Tensor,
469550
attention_mask: Optional[torch.Tensor] = None,
470551
freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
552+
**kwargs,
471553
) -> Tuple[torch.Tensor, torch.Tensor]:
472554
# 1. Input normalization
473555
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
@@ -503,6 +585,101 @@ def forward(
503585
return hidden_states, encoder_hidden_states
504586

505587

588+
class HunyuanVideoTokenReplaceTransformerBlock(nn.Module):
589+
def __init__(
590+
self,
591+
num_attention_heads: int,
592+
attention_head_dim: int,
593+
mlp_ratio: float,
594+
qk_norm: str = "rms_norm",
595+
) -> None:
596+
super().__init__()
597+
598+
hidden_size = num_attention_heads * attention_head_dim
599+
600+
self.norm1 = HunyuanVideoTokenReplaceAdaLayerNormZero(hidden_size, norm_type="layer_norm")
601+
self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
602+
603+
self.attn = Attention(
604+
query_dim=hidden_size,
605+
cross_attention_dim=None,
606+
added_kv_proj_dim=hidden_size,
607+
dim_head=attention_head_dim,
608+
heads=num_attention_heads,
609+
out_dim=hidden_size,
610+
context_pre_only=False,
611+
bias=True,
612+
processor=HunyuanVideoAttnProcessor2_0(),
613+
qk_norm=qk_norm,
614+
eps=1e-6,
615+
)
616+
617+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
618+
self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
619+
620+
self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
621+
self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
622+
623+
def forward(
624+
self,
625+
hidden_states: torch.Tensor,
626+
encoder_hidden_states: torch.Tensor,
627+
temb: torch.Tensor,
628+
attention_mask: Optional[torch.Tensor] = None,
629+
freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
630+
token_replace_emb: torch.Tensor = None,
631+
num_tokens: int = None,
632+
) -> Tuple[torch.Tensor, torch.Tensor]:
633+
# 1. Input normalization
634+
(
635+
norm_hidden_states,
636+
gate_msa,
637+
shift_mlp,
638+
scale_mlp,
639+
gate_mlp,
640+
tr_gate_msa,
641+
tr_shift_mlp,
642+
tr_scale_mlp,
643+
tr_gate_mlp,
644+
) = self.norm1(hidden_states, temb, token_replace_emb, num_tokens)
645+
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
646+
encoder_hidden_states, emb=temb
647+
)
648+
649+
# 2. Joint attention
650+
attn_output, context_attn_output = self.attn(
651+
hidden_states=norm_hidden_states,
652+
encoder_hidden_states=norm_encoder_hidden_states,
653+
attention_mask=attention_mask,
654+
image_rotary_emb=freqs_cis,
655+
)
656+
657+
# 3. Modulation and residual connection
658+
hidden_states_zero = hidden_states[:, :num_tokens] + attn_output[:, :num_tokens] * tr_gate_msa
659+
hidden_states_orig = hidden_states[:, num_tokens:] + attn_output[:, num_tokens:] * gate_msa
660+
hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
661+
encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1)
662+
663+
norm_hidden_states = self.norm2(hidden_states)
664+
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
665+
666+
hidden_states_zero = norm_hidden_states[:, :num_tokens] * (1 + tr_scale_mlp[:, None]) + tr_shift_mlp[:, None]
667+
hidden_states_orig = norm_hidden_states[:, num_tokens:] * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
668+
hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
669+
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
670+
671+
# 4. Feed-forward
672+
ff_output = self.ff(norm_hidden_states)
673+
context_ff_output = self.ff_context(norm_encoder_hidden_states)
674+
675+
hidden_states_zero = hidden_states[:, :num_tokens] + ff_output[:, :num_tokens] * tr_gate_mlp
676+
hidden_states_orig = hidden_states[:, num_tokens:] + ff_output[:, num_tokens:] * gate_mlp
677+
hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
678+
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
679+
680+
return hidden_states, encoder_hidden_states
681+
682+
506683
class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
507684
r"""
508685
A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo).
@@ -570,8 +747,10 @@ def __init__(
570747
pooled_projection_dim: int = 768,
571748
rope_theta: float = 256.0,
572749
rope_axes_dim: Tuple[int] = (16, 56, 56),
750+
image_condition_type: Optional[str] = None,
573751
) -> None:
574752
super().__init__()
753+
assert image_condition_type is None or image_condition_type in ["latent_concat", "token_replace"]
575754

576755
inner_dim = num_attention_heads * attention_head_dim
577756
out_channels = out_channels or in_channels
@@ -585,20 +764,32 @@ def __init__(
585764
if guidance_embeds:
586765
self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim)
587766
else:
588-
self.time_text_embed = CombinedTimestepTextProjEmbeddings(inner_dim, pooled_projection_dim)
767+
self.time_text_embed = HunyuanVideoTimestepTextProjEmbeddings(
768+
inner_dim, pooled_projection_dim, image_condition_type
769+
)
589770

590771
# 2. RoPE
591772
self.rope = HunyuanVideoRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta)
592773

593774
# 3. Dual stream transformer blocks
594-
self.transformer_blocks = nn.ModuleList(
595-
[
596-
HunyuanVideoTransformerBlock(
597-
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
598-
)
599-
for _ in range(num_layers)
600-
]
601-
)
775+
if image_condition_type == "token_replace":
776+
self.transformer_blocks = nn.ModuleList(
777+
[
778+
HunyuanVideoTokenReplaceTransformerBlock(
779+
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
780+
)
781+
for _ in range(num_layers)
782+
]
783+
)
784+
else:
785+
self.transformer_blocks = nn.ModuleList(
786+
[
787+
HunyuanVideoTransformerBlock(
788+
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
789+
)
790+
for _ in range(num_layers)
791+
]
792+
)
602793

603794
# 4. Single stream transformer blocks
604795
self.single_transformer_blocks = nn.ModuleList(
@@ -707,6 +898,7 @@ def forward(
707898
post_patch_num_frames = num_frames // p_t
708899
post_patch_height = height // p
709900
post_patch_width = width // p
901+
first_frame_num_tokens = 1 * post_patch_height * post_patch_width
710902

711903
# 1. RoPE
712904
image_rotary_emb = self.rope(hidden_states)
@@ -715,7 +907,7 @@ def forward(
715907
if self.config.guidance_embeds:
716908
temb = self.time_text_embed(timestep, guidance, pooled_projections)
717909
else:
718-
temb = self.time_text_embed(timestep, pooled_projections)
910+
temb, token_replace_emb = self.time_text_embed(timestep, pooled_projections)
719911

720912
hidden_states = self.x_embedder(hidden_states)
721913
encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask)
@@ -746,6 +938,8 @@ def forward(
746938
temb,
747939
attention_mask,
748940
image_rotary_emb,
941+
token_replace_emb,
942+
first_frame_num_tokens,
749943
)
750944

751945
for block in self.single_transformer_blocks:
@@ -761,7 +955,13 @@ def forward(
761955
else:
762956
for block in self.transformer_blocks:
763957
hidden_states, encoder_hidden_states = block(
764-
hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
958+
hidden_states,
959+
encoder_hidden_states,
960+
temb,
961+
attention_mask,
962+
image_rotary_emb,
963+
token_replace_emb,
964+
first_frame_num_tokens,
765965
)
766966

767967
for block in self.single_transformer_blocks:

0 commit comments

Comments
 (0)