Skip to content

Commit 72c03f6

Browse files
committed
Merge branch 'main' into wan22_online
2 parents 12049f6 + 11a5a91 commit 72c03f6

File tree

2 files changed

+76
-14
lines changed

2 files changed

+76
-14
lines changed

vllm_omni/diffusion/distributed/sp_sharding.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,11 @@ def sp_shard(
6363
f"world_size ({world_size}) for sequence parallel sharding."
6464
)
6565

66+
if size < world_size:
67+
raise ValueError(
68+
f"Tensor size along dim {dim} ({size}) must be >= world_size ({world_size}). Tensor shape: {tensor.shape}"
69+
)
70+
6671
return tensor.chunk(world_size, dim=dim)[rank]
6772

6873

vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py

Lines changed: 71 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,47 @@ def forward(
296296
return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image
297297

298298

299+
class TimestepProjPrepare(nn.Module):
300+
"""Prepares timestep_proj for sequence parallel in TI2V models.
301+
302+
Encapsulates the unflatten operation for timestep_proj to enable _sp_plan sharding.
303+
"""
304+
305+
def forward(
306+
self,
307+
timestep_proj: torch.Tensor,
308+
ts_seq_len: int | None,
309+
) -> torch.Tensor:
310+
if ts_seq_len is not None:
311+
# TI2V mode: [batch, seq_len, 6, inner_dim]
312+
timestep_proj = timestep_proj.unflatten(2, (6, -1))
313+
else:
314+
# T2V mode: [batch, 6, inner_dim]
315+
timestep_proj = timestep_proj.unflatten(1, (6, -1))
316+
return timestep_proj
317+
318+
319+
class OutputScaleShiftPrepare(nn.Module):
320+
"""Prepares output scale/shift for SP sharding in TI2V models."""
321+
322+
def __init__(self, inner_dim: int):
323+
super().__init__()
324+
self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
325+
326+
def forward(self, temb: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
327+
if temb.ndim == 3:
328+
# TI2V: [B, seq, D] -> 3D outputs for SP sharding
329+
shift, scale = (self.scale_shift_table.unsqueeze(0).to(temb.device) + temb.unsqueeze(2)).chunk(2, dim=2)
330+
shift = shift.squeeze(2)
331+
scale = scale.squeeze(2)
332+
else:
333+
# T2V: [B, D] -> 2D outputs (skip SP sharding via expected_dims=3)
334+
shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1)
335+
shift = shift.squeeze(1)
336+
scale = scale.squeeze(1)
337+
return shift, scale
338+
339+
299340
class WanSelfAttention(nn.Module):
300341
"""
301342
Optimized self-attention module using vLLM layers.
@@ -679,6 +720,7 @@ class WanTransformer3DModel(nn.Module):
679720
#
680721
# The _sp_plan specifies sharding/gathering at module boundaries:
681722
# - rope: Split both RoPE outputs (freqs_cos, freqs_sin) via split_output=True
723+
# - timestep_proj_prepare: Split timestep_proj for TI2V models (4D tensor)
682724
# - blocks.0: Split hidden_states input at the first transformer block
683725
# - proj_out: Gather outputs after the final projection layer
684726
#
@@ -689,11 +731,22 @@ class WanTransformer3DModel(nn.Module):
689731
0: SequenceParallelInput(split_dim=1, expected_dims=4, split_output=True), # freqs_cos [1, seq, 1, dim]
690732
1: SequenceParallelInput(split_dim=1, expected_dims=4, split_output=True), # freqs_sin [1, seq, 1, dim]
691733
},
734+
# Shard timestep_proj for TI2V models (4D tensor: [batch, seq_len, 6, inner_dim])
735+
# This is only active when ts_seq_len is not None (TI2V mode)
736+
# Output is a single tensor, shard along dim=1 (sequence dimension)
737+
"timestep_proj_prepare": {
738+
0: SequenceParallelInput(split_dim=1, expected_dims=4, split_output=True), # [B, seq, 6, dim]
739+
},
692740
# Shard hidden_states at first transformer block input
693741
# (after patch_embedding + flatten + transpose)
694742
"blocks.0": {
695743
"hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3), # [B, seq, dim]
696744
},
745+
# Shard output scale/shift for TI2V (3D); T2V outputs 2D and skips sharding
746+
"output_scale_shift_prepare": {
747+
0: SequenceParallelInput(split_dim=1, expected_dims=3, split_output=True),
748+
1: SequenceParallelInput(split_dim=1, expected_dims=3, split_output=True),
749+
},
697750
# Gather at proj_out (final linear projection before unpatchify)
698751
"proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3),
699752
}
@@ -774,7 +827,10 @@ def __init__(
774827
# 4. Output norm & projection
775828
self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
776829
self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
777-
self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
830+
831+
# SP helper modules
832+
self.timestep_proj_prepare = TimestepProjPrepare()
833+
self.output_scale_shift_prepare = OutputScaleShiftPrepare(inner_dim)
778834

779835
@property
780836
def dtype(self) -> torch.dtype:
@@ -814,10 +870,10 @@ def forward(
814870
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
815871
timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=ts_seq_len
816872
)
817-
if ts_seq_len is not None:
818-
timestep_proj = timestep_proj.unflatten(2, (6, -1))
819-
else:
820-
timestep_proj = timestep_proj.unflatten(1, (6, -1))
873+
# Prepare timestep_proj via TimestepProjPrepare module
874+
# _sp_plan will shard timestep_proj via split_output=True (when ts_seq_len is not None)
875+
# This ensures timestep_proj sequence dimension matches sharded hidden_states
876+
timestep_proj = self.timestep_proj_prepare(timestep_proj, ts_seq_len)
821877

822878
if encoder_hidden_states_image is not None:
823879
encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
@@ -827,15 +883,12 @@ def forward(
827883
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
828884

829885
# Output norm, projection & unpatchify
830-
if temb.ndim == 3:
831-
shift, scale = (self.scale_shift_table.unsqueeze(0).to(temb.device) + temb.unsqueeze(2)).chunk(2, dim=2)
832-
shift = shift.squeeze(2)
833-
scale = scale.squeeze(2)
834-
else:
835-
shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1)
836-
886+
shift, scale = self.output_scale_shift_prepare(temb)
837887
shift = shift.to(hidden_states.device)
838888
scale = scale.to(hidden_states.device)
889+
if shift.ndim == 2: # T2V mode: unsqueeze for broadcasting
890+
shift = shift.unsqueeze(1)
891+
scale = scale.unsqueeze(1)
839892

840893
hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
841894
hidden_states = self.proj_out(hidden_states)
@@ -867,19 +920,23 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
867920
tp_rank = get_tensor_model_parallel_rank()
868921
tp_size = get_tensor_model_parallel_world_size()
869922
# Stacked params mapping for self-attention QKV fusion
870-
# Format: (param_name, shard_name, shard_id)
871-
# Note: Only fuse attn1 (self-attention), NOT attn2 (cross-attention)
872923
stacked_params_mapping = [
873924
# self-attention QKV fusion
874925
(".attn1.to_qkv", ".attn1.to_q", "q"),
875926
(".attn1.to_qkv", ".attn1.to_k", "k"),
876927
(".attn1.to_qkv", ".attn1.to_v", "v"),
877928
]
878929

930+
# Remap scale_shift_table to new module location
931+
weight_name_remapping = {
932+
"scale_shift_table": "output_scale_shift_prepare.scale_shift_table",
933+
}
934+
879935
params_dict = dict(self.named_parameters())
880936
loaded_params: set[str] = set()
881937

882938
for name, loaded_weight in weights:
939+
name = weight_name_remapping.get(name, name)
883940
original_name = name
884941
lookup_name = name
885942

0 commit comments

Comments
 (0)