@@ -296,6 +296,47 @@ def forward(
296296 return temb , timestep_proj , encoder_hidden_states , encoder_hidden_states_image
297297
298298
299+ class TimestepProjPrepare (nn .Module ):
300+ """Prepares timestep_proj for sequence parallel in TI2V models.
301+
302+ Encapsulates the unflatten operation for timestep_proj to enable _sp_plan sharding.
303+ """
304+
305+ def forward (
306+ self ,
307+ timestep_proj : torch .Tensor ,
308+ ts_seq_len : int | None ,
309+ ) -> torch .Tensor :
310+ if ts_seq_len is not None :
311+ # TI2V mode: [batch, seq_len, 6, inner_dim]
312+ timestep_proj = timestep_proj .unflatten (2 , (6 , - 1 ))
313+ else :
314+ # T2V mode: [batch, 6, inner_dim]
315+ timestep_proj = timestep_proj .unflatten (1 , (6 , - 1 ))
316+ return timestep_proj
317+
318+
319+ class OutputScaleShiftPrepare (nn .Module ):
320+ """Prepares output scale/shift for SP sharding in TI2V models."""
321+
322+ def __init__ (self , inner_dim : int ):
323+ super ().__init__ ()
324+ self .scale_shift_table = nn .Parameter (torch .randn (1 , 2 , inner_dim ) / inner_dim ** 0.5 )
325+
326+ def forward (self , temb : torch .Tensor ) -> tuple [torch .Tensor , torch .Tensor ]:
327+ if temb .ndim == 3 :
328+ # TI2V: [B, seq, D] -> 3D outputs for SP sharding
329+ shift , scale = (self .scale_shift_table .unsqueeze (0 ).to (temb .device ) + temb .unsqueeze (2 )).chunk (2 , dim = 2 )
330+ shift = shift .squeeze (2 )
331+ scale = scale .squeeze (2 )
332+ else :
333+ # T2V: [B, D] -> 2D outputs (skip SP sharding via expected_dims=3)
334+ shift , scale = (self .scale_shift_table .to (temb .device ) + temb .unsqueeze (1 )).chunk (2 , dim = 1 )
335+ shift = shift .squeeze (1 )
336+ scale = scale .squeeze (1 )
337+ return shift , scale
338+
339+
299340class WanSelfAttention (nn .Module ):
300341 """
301342 Optimized self-attention module using vLLM layers.
@@ -679,6 +720,7 @@ class WanTransformer3DModel(nn.Module):
679720 #
680721 # The _sp_plan specifies sharding/gathering at module boundaries:
681722 # - rope: Split both RoPE outputs (freqs_cos, freqs_sin) via split_output=True
723+ # - timestep_proj_prepare: Split timestep_proj for TI2V models (4D tensor)
682724 # - blocks.0: Split hidden_states input at the first transformer block
683725 # - proj_out: Gather outputs after the final projection layer
684726 #
@@ -689,11 +731,22 @@ class WanTransformer3DModel(nn.Module):
689731 0 : SequenceParallelInput (split_dim = 1 , expected_dims = 4 , split_output = True ), # freqs_cos [1, seq, 1, dim]
690732 1 : SequenceParallelInput (split_dim = 1 , expected_dims = 4 , split_output = True ), # freqs_sin [1, seq, 1, dim]
691733 },
734+ # Shard timestep_proj for TI2V models (4D tensor: [batch, seq_len, 6, inner_dim])
735+ # This is only active when ts_seq_len is not None (TI2V mode)
736+ # Output is a single tensor, shard along dim=1 (sequence dimension)
737+ "timestep_proj_prepare" : {
738+ 0 : SequenceParallelInput (split_dim = 1 , expected_dims = 4 , split_output = True ), # [B, seq, 6, dim]
739+ },
692740 # Shard hidden_states at first transformer block input
693741 # (after patch_embedding + flatten + transpose)
694742 "blocks.0" : {
695743 "hidden_states" : SequenceParallelInput (split_dim = 1 , expected_dims = 3 ), # [B, seq, dim]
696744 },
745+ # Shard output scale/shift for TI2V (3D); T2V outputs 2D and skips sharding
746+ "output_scale_shift_prepare" : {
747+ 0 : SequenceParallelInput (split_dim = 1 , expected_dims = 3 , split_output = True ),
748+ 1 : SequenceParallelInput (split_dim = 1 , expected_dims = 3 , split_output = True ),
749+ },
697750 # Gather at proj_out (final linear projection before unpatchify)
698751 "proj_out" : SequenceParallelOutput (gather_dim = 1 , expected_dims = 3 ),
699752 }
@@ -774,7 +827,10 @@ def __init__(
774827 # 4. Output norm & projection
775828 self .norm_out = FP32LayerNorm (inner_dim , eps , elementwise_affine = False )
776829 self .proj_out = nn .Linear (inner_dim , out_channels * math .prod (patch_size ))
777- self .scale_shift_table = nn .Parameter (torch .randn (1 , 2 , inner_dim ) / inner_dim ** 0.5 )
830+
831+ # SP helper modules
832+ self .timestep_proj_prepare = TimestepProjPrepare ()
833+ self .output_scale_shift_prepare = OutputScaleShiftPrepare (inner_dim )
778834
779835 @property
780836 def dtype (self ) -> torch .dtype :
@@ -814,10 +870,10 @@ def forward(
814870 temb , timestep_proj , encoder_hidden_states , encoder_hidden_states_image = self .condition_embedder (
815871 timestep , encoder_hidden_states , encoder_hidden_states_image , timestep_seq_len = ts_seq_len
816872 )
817- if ts_seq_len is not None :
818- timestep_proj = timestep_proj . unflatten ( 2 , ( 6 , - 1 ) )
819- else :
820- timestep_proj = timestep_proj . unflatten ( 1 , ( 6 , - 1 ) )
873+ # Prepare timestep_proj via TimestepProjPrepare module
874+ # _sp_plan will shard timestep_proj via split_output=True (when ts_seq_len is not None )
875+ # This ensures timestep_proj sequence dimension matches sharded hidden_states
876+ timestep_proj = self . timestep_proj_prepare ( timestep_proj , ts_seq_len )
821877
822878 if encoder_hidden_states_image is not None :
823879 encoder_hidden_states = torch .concat ([encoder_hidden_states_image , encoder_hidden_states ], dim = 1 )
@@ -827,15 +883,12 @@ def forward(
827883 hidden_states = block (hidden_states , encoder_hidden_states , timestep_proj , rotary_emb )
828884
829885 # Output norm, projection & unpatchify
830- if temb .ndim == 3 :
831- shift , scale = (self .scale_shift_table .unsqueeze (0 ).to (temb .device ) + temb .unsqueeze (2 )).chunk (2 , dim = 2 )
832- shift = shift .squeeze (2 )
833- scale = scale .squeeze (2 )
834- else :
835- shift , scale = (self .scale_shift_table .to (temb .device ) + temb .unsqueeze (1 )).chunk (2 , dim = 1 )
836-
886+ shift , scale = self .output_scale_shift_prepare (temb )
837887 shift = shift .to (hidden_states .device )
838888 scale = scale .to (hidden_states .device )
889+ if shift .ndim == 2 : # T2V mode: unsqueeze for broadcasting
890+ shift = shift .unsqueeze (1 )
891+ scale = scale .unsqueeze (1 )
839892
840893 hidden_states = (self .norm_out (hidden_states .float ()) * (1 + scale ) + shift ).type_as (hidden_states )
841894 hidden_states = self .proj_out (hidden_states )
@@ -867,19 +920,23 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
867920 tp_rank = get_tensor_model_parallel_rank ()
868921 tp_size = get_tensor_model_parallel_world_size ()
869922 # Stacked params mapping for self-attention QKV fusion
870- # Format: (param_name, shard_name, shard_id)
871- # Note: Only fuse attn1 (self-attention), NOT attn2 (cross-attention)
872923 stacked_params_mapping = [
873924 # self-attention QKV fusion
874925 (".attn1.to_qkv" , ".attn1.to_q" , "q" ),
875926 (".attn1.to_qkv" , ".attn1.to_k" , "k" ),
876927 (".attn1.to_qkv" , ".attn1.to_v" , "v" ),
877928 ]
878929
930+ # Remap scale_shift_table to new module location
931+ weight_name_remapping = {
932+ "scale_shift_table" : "output_scale_shift_prepare.scale_shift_table" ,
933+ }
934+
879935 params_dict = dict (self .named_parameters ())
880936 loaded_params : set [str ] = set ()
881937
882938 for name , loaded_weight in weights :
939+ name = weight_name_remapping .get (name , name )
883940 original_name = name
884941 lookup_name = name
885942
0 commit comments