29
29
from ..embeddings import (
30
30
CombinedTimestepGuidanceTextProjEmbeddings ,
31
31
CombinedTimestepTextProjEmbeddings ,
32
+ PixArtAlphaTextProjection ,
33
+ TimestepEmbedding ,
34
+ Timesteps ,
32
35
get_1d_rotary_pos_embed ,
33
36
)
34
37
from ..modeling_outputs import Transformer2DModelOutput
35
38
from ..modeling_utils import ModelMixin
36
- from ..normalization import AdaLayerNormContinuous , AdaLayerNormZero , AdaLayerNormZeroSingle
39
+ from ..normalization import AdaLayerNormContinuous , AdaLayerNormZero , AdaLayerNormZeroSingle , FP32LayerNorm
37
40
38
41
39
42
logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
@@ -173,6 +176,84 @@ def forward(
173
176
return gate_msa , gate_mlp
174
177
175
178
179
+ class HunyuanVideoTokenReplaceAdaLayerNormZero (nn .Module ):
180
+ def __init__ (self , embedding_dim : int , norm_type : str = "layer_norm" , bias : bool = True ):
181
+ super ().__init__ ()
182
+ self .silu = nn .SiLU ()
183
+ self .linear = nn .Linear (embedding_dim , 6 * embedding_dim , bias = bias )
184
+
185
+ if norm_type == "layer_norm" :
186
+ self .norm = nn .LayerNorm (embedding_dim , elementwise_affine = False , eps = 1e-6 )
187
+ elif norm_type == "fp32_layer_norm" :
188
+ self .norm = FP32LayerNorm (embedding_dim , elementwise_affine = False , bias = False )
189
+ else :
190
+ raise ValueError (
191
+ f"Unsupported `norm_type` ({ norm_type } ) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
192
+ )
193
+
194
+ def forward (
195
+ self ,
196
+ hidden_states : torch .Tensor ,
197
+ emb : torch .Tensor ,
198
+ token_replace_emb : torch .Tensor ,
199
+ first_frame_num_tokens : int ,
200
+ ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor ]:
201
+ emb = self .linear (self .silu (emb ))
202
+ token_replace_emb = self .linear (self .silu (token_replace_emb ))
203
+
204
+ shift_msa , scale_msa , gate_msa , shift_mlp , scale_mlp , gate_mlp = emb .chunk (6 , dim = 1 )
205
+ tr_shift_msa , tr_scale_msa , tr_gate_msa , tr_shift_mlp , tr_scale_mlp , tr_gate_mlp = token_replace_emb .chunk (
206
+ 6 , dim = 1
207
+ )
208
+
209
+ norm_hidden_states = self .norm (hidden_states )
210
+ hidden_states_zero = (
211
+ norm_hidden_states [:, :first_frame_num_tokens ] * (1 + tr_scale_msa [:, None ]) + tr_shift_msa [:, None ]
212
+ )
213
+ hidden_states_orig = (
214
+ norm_hidden_states [:, first_frame_num_tokens :] * (1 + scale_msa [:, None ]) + shift_msa [:, None ]
215
+ )
216
+ hidden_states = torch .cat ([hidden_states_zero , hidden_states_orig ], dim = 1 )
217
+
218
+ return (
219
+ hidden_states ,
220
+ gate_msa ,
221
+ shift_mlp ,
222
+ scale_mlp ,
223
+ gate_mlp ,
224
+ tr_gate_msa ,
225
+ tr_shift_mlp ,
226
+ tr_scale_mlp ,
227
+ tr_gate_mlp ,
228
+ )
229
+
230
+
231
+ class HunyuanVideoTimestepTextProjEmbeddings (nn .Module ):
232
+ def __init__ (self , embedding_dim : int , pooled_projection_dim : int , image_condition_type : Optional [str ] = None ):
233
+ super ().__init__ ()
234
+
235
+ self .image_condition_type = image_condition_type
236
+
237
+ self .time_proj = Timesteps (num_channels = 256 , flip_sin_to_cos = True , downscale_freq_shift = 0 )
238
+ self .timestep_embedder = TimestepEmbedding (in_channels = 256 , time_embed_dim = embedding_dim )
239
+ self .text_embedder = PixArtAlphaTextProjection (pooled_projection_dim , embedding_dim , act_fn = "silu" )
240
+
241
+ def forward (self , timestep : torch .Tensor , pooled_projection : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor ]:
242
+ timesteps_proj = self .time_proj (timestep )
243
+ timesteps_emb = self .timestep_embedder (timesteps_proj .to (dtype = pooled_projection .dtype )) # (N, D)
244
+ pooled_projections = self .text_embedder (pooled_projection )
245
+ conditioning = timesteps_emb + pooled_projections
246
+
247
+ token_replace_emb = None
248
+ if self .image_condition_type == "token_replace" :
249
+ token_replace_timestep = torch .zeros_like (timestep )
250
+ token_replace_proj = self .time_proj (token_replace_timestep )
251
+ token_replace_emb = self .timestep_embedder (token_replace_proj )
252
+ token_replace_emb = token_replace_emb + conditioning
253
+
254
+ return conditioning , token_replace_emb
255
+
256
+
176
257
class HunyuanVideoIndividualTokenRefinerBlock (nn .Module ):
177
258
def __init__ (
178
259
self ,
@@ -468,6 +549,7 @@ def forward(
468
549
temb : torch .Tensor ,
469
550
attention_mask : Optional [torch .Tensor ] = None ,
470
551
freqs_cis : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
552
+ ** kwargs ,
471
553
) -> Tuple [torch .Tensor , torch .Tensor ]:
472
554
# 1. Input normalization
473
555
norm_hidden_states , gate_msa , shift_mlp , scale_mlp , gate_mlp = self .norm1 (hidden_states , emb = temb )
@@ -503,6 +585,101 @@ def forward(
503
585
return hidden_states , encoder_hidden_states
504
586
505
587
588
+ class HunyuanVideoTokenReplaceTransformerBlock (nn .Module ):
589
+ def __init__ (
590
+ self ,
591
+ num_attention_heads : int ,
592
+ attention_head_dim : int ,
593
+ mlp_ratio : float ,
594
+ qk_norm : str = "rms_norm" ,
595
+ ) -> None :
596
+ super ().__init__ ()
597
+
598
+ hidden_size = num_attention_heads * attention_head_dim
599
+
600
+ self .norm1 = HunyuanVideoTokenReplaceAdaLayerNormZero (hidden_size , norm_type = "layer_norm" )
601
+ self .norm1_context = AdaLayerNormZero (hidden_size , norm_type = "layer_norm" )
602
+
603
+ self .attn = Attention (
604
+ query_dim = hidden_size ,
605
+ cross_attention_dim = None ,
606
+ added_kv_proj_dim = hidden_size ,
607
+ dim_head = attention_head_dim ,
608
+ heads = num_attention_heads ,
609
+ out_dim = hidden_size ,
610
+ context_pre_only = False ,
611
+ bias = True ,
612
+ processor = HunyuanVideoAttnProcessor2_0 (),
613
+ qk_norm = qk_norm ,
614
+ eps = 1e-6 ,
615
+ )
616
+
617
+ self .norm2 = nn .LayerNorm (hidden_size , elementwise_affine = False , eps = 1e-6 )
618
+ self .ff = FeedForward (hidden_size , mult = mlp_ratio , activation_fn = "gelu-approximate" )
619
+
620
+ self .norm2_context = nn .LayerNorm (hidden_size , elementwise_affine = False , eps = 1e-6 )
621
+ self .ff_context = FeedForward (hidden_size , mult = mlp_ratio , activation_fn = "gelu-approximate" )
622
+
623
+ def forward (
624
+ self ,
625
+ hidden_states : torch .Tensor ,
626
+ encoder_hidden_states : torch .Tensor ,
627
+ temb : torch .Tensor ,
628
+ attention_mask : Optional [torch .Tensor ] = None ,
629
+ freqs_cis : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
630
+ token_replace_emb : torch .Tensor = None ,
631
+ num_tokens : int = None ,
632
+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
633
+ # 1. Input normalization
634
+ (
635
+ norm_hidden_states ,
636
+ gate_msa ,
637
+ shift_mlp ,
638
+ scale_mlp ,
639
+ gate_mlp ,
640
+ tr_gate_msa ,
641
+ tr_shift_mlp ,
642
+ tr_scale_mlp ,
643
+ tr_gate_mlp ,
644
+ ) = self .norm1 (hidden_states , temb , token_replace_emb , num_tokens )
645
+ norm_encoder_hidden_states , c_gate_msa , c_shift_mlp , c_scale_mlp , c_gate_mlp = self .norm1_context (
646
+ encoder_hidden_states , emb = temb
647
+ )
648
+
649
+ # 2. Joint attention
650
+ attn_output , context_attn_output = self .attn (
651
+ hidden_states = norm_hidden_states ,
652
+ encoder_hidden_states = norm_encoder_hidden_states ,
653
+ attention_mask = attention_mask ,
654
+ image_rotary_emb = freqs_cis ,
655
+ )
656
+
657
+ # 3. Modulation and residual connection
658
+ hidden_states_zero = hidden_states [:, :num_tokens ] + attn_output [:, :num_tokens ] * tr_gate_msa
659
+ hidden_states_orig = hidden_states [:, num_tokens :] + attn_output [:, num_tokens :] * gate_msa
660
+ hidden_states = torch .cat ([hidden_states_zero , hidden_states_orig ], dim = 1 )
661
+ encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa .unsqueeze (1 )
662
+
663
+ norm_hidden_states = self .norm2 (hidden_states )
664
+ norm_encoder_hidden_states = self .norm2_context (encoder_hidden_states )
665
+
666
+ hidden_states_zero = norm_hidden_states [:, :num_tokens ] * (1 + tr_scale_mlp [:, None ]) + tr_shift_mlp [:, None ]
667
+ hidden_states_orig = norm_hidden_states [:, num_tokens :] * (1 + scale_mlp [:, None ]) + shift_mlp [:, None ]
668
+ hidden_states = torch .cat ([hidden_states_zero , hidden_states_orig ], dim = 1 )
669
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp [:, None ]) + c_shift_mlp [:, None ]
670
+
671
+ # 4. Feed-forward
672
+ ff_output = self .ff (norm_hidden_states )
673
+ context_ff_output = self .ff_context (norm_encoder_hidden_states )
674
+
675
+ hidden_states_zero = hidden_states [:, :num_tokens ] + ff_output [:, :num_tokens ] * tr_gate_mlp
676
+ hidden_states_orig = hidden_states [:, num_tokens :] + ff_output [:, num_tokens :] * gate_mlp
677
+ hidden_states = torch .cat ([hidden_states_zero , hidden_states_orig ], dim = 1 )
678
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp .unsqueeze (1 ) * context_ff_output
679
+
680
+ return hidden_states , encoder_hidden_states
681
+
682
+
506
683
class HunyuanVideoTransformer3DModel (ModelMixin , ConfigMixin , PeftAdapterMixin , FromOriginalModelMixin , CacheMixin ):
507
684
r"""
508
685
A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo).
@@ -570,8 +747,10 @@ def __init__(
570
747
pooled_projection_dim : int = 768 ,
571
748
rope_theta : float = 256.0 ,
572
749
rope_axes_dim : Tuple [int ] = (16 , 56 , 56 ),
750
+ image_condition_type : Optional [str ] = None ,
573
751
) -> None :
574
752
super ().__init__ ()
753
+ assert image_condition_type is None or image_condition_type in ["latent_concat" , "token_replace" ]
575
754
576
755
inner_dim = num_attention_heads * attention_head_dim
577
756
out_channels = out_channels or in_channels
@@ -585,20 +764,32 @@ def __init__(
585
764
if guidance_embeds :
586
765
self .time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings (inner_dim , pooled_projection_dim )
587
766
else :
588
- self .time_text_embed = CombinedTimestepTextProjEmbeddings (inner_dim , pooled_projection_dim )
767
+ self .time_text_embed = HunyuanVideoTimestepTextProjEmbeddings (
768
+ inner_dim , pooled_projection_dim , image_condition_type
769
+ )
589
770
590
771
# 2. RoPE
591
772
self .rope = HunyuanVideoRotaryPosEmbed (patch_size , patch_size_t , rope_axes_dim , rope_theta )
592
773
593
774
# 3. Dual stream transformer blocks
594
- self .transformer_blocks = nn .ModuleList (
595
- [
596
- HunyuanVideoTransformerBlock (
597
- num_attention_heads , attention_head_dim , mlp_ratio = mlp_ratio , qk_norm = qk_norm
598
- )
599
- for _ in range (num_layers )
600
- ]
601
- )
775
+ if image_condition_type == "token_replace" :
776
+ self .transformer_blocks = nn .ModuleList (
777
+ [
778
+ HunyuanVideoTokenReplaceTransformerBlock (
779
+ num_attention_heads , attention_head_dim , mlp_ratio = mlp_ratio , qk_norm = qk_norm
780
+ )
781
+ for _ in range (num_layers )
782
+ ]
783
+ )
784
+ else :
785
+ self .transformer_blocks = nn .ModuleList (
786
+ [
787
+ HunyuanVideoTransformerBlock (
788
+ num_attention_heads , attention_head_dim , mlp_ratio = mlp_ratio , qk_norm = qk_norm
789
+ )
790
+ for _ in range (num_layers )
791
+ ]
792
+ )
602
793
603
794
# 4. Single stream transformer blocks
604
795
self .single_transformer_blocks = nn .ModuleList (
@@ -707,6 +898,7 @@ def forward(
707
898
post_patch_num_frames = num_frames // p_t
708
899
post_patch_height = height // p
709
900
post_patch_width = width // p
901
+ first_frame_num_tokens = 1 * post_patch_height * post_patch_width
710
902
711
903
# 1. RoPE
712
904
image_rotary_emb = self .rope (hidden_states )
@@ -715,7 +907,7 @@ def forward(
715
907
if self .config .guidance_embeds :
716
908
temb = self .time_text_embed (timestep , guidance , pooled_projections )
717
909
else :
718
- temb = self .time_text_embed (timestep , pooled_projections )
910
+ temb , token_replace_emb = self .time_text_embed (timestep , pooled_projections )
719
911
720
912
hidden_states = self .x_embedder (hidden_states )
721
913
encoder_hidden_states = self .context_embedder (encoder_hidden_states , timestep , encoder_attention_mask )
@@ -746,6 +938,8 @@ def forward(
746
938
temb ,
747
939
attention_mask ,
748
940
image_rotary_emb ,
941
+ token_replace_emb ,
942
+ first_frame_num_tokens ,
749
943
)
750
944
751
945
for block in self .single_transformer_blocks :
@@ -761,7 +955,13 @@ def forward(
761
955
else :
762
956
for block in self .transformer_blocks :
763
957
hidden_states , encoder_hidden_states = block (
764
- hidden_states , encoder_hidden_states , temb , attention_mask , image_rotary_emb
958
+ hidden_states ,
959
+ encoder_hidden_states ,
960
+ temb ,
961
+ attention_mask ,
962
+ image_rotary_emb ,
963
+ token_replace_emb ,
964
+ first_frame_num_tokens ,
765
965
)
766
966
767
967
for block in self .single_transformer_blocks :
0 commit comments