|
22 | 22 | import torch.nn as nn |
23 | 23 | from diffusers.configuration_utils import ConfigMixin, register_to_config |
24 | 24 | from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin |
25 | | -from diffusers.models._modeling_parallel import ContextParallelInput, ContextParallelOutput |
26 | 25 | from diffusers.models.attention import AttentionMixin, AttentionModuleMixin, FeedForward |
27 | 26 | from diffusers.models.cache_utils import CacheMixin |
28 | 27 | from diffusers.models.embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings, PixArtAlphaTextProjection |
|
41 | 40 |
|
42 | 41 | from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata |
43 | 42 | from vllm_omni.diffusion.attention.layer import Attention |
| 43 | +from vllm_omni.diffusion.distributed.sp_plan import SequenceParallelInput, SequenceParallelOutput |
| 44 | +from vllm_omni.diffusion.forward_context import get_forward_context, is_forward_context_available |
44 | 45 |
|
45 | 46 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
46 | 47 |
|
@@ -204,10 +205,27 @@ def __call__( |
204 | 205 | ) |
205 | 206 |
|
206 | 207 | if attention_mask is not None: |
207 | | - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
208 | | - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) |
209 | | - if attn.attn.attn_backend.get_name().upper() == "FLASH_ATTN": |
210 | | - attention_mask = self._to_padding_mask(attention_mask) |
| 208 | + sp_enabled = False |
| 209 | + if is_forward_context_available(): |
| 210 | + try: |
| 211 | + od_config = get_forward_context().omni_diffusion_config |
| 212 | + parallel_config = getattr(od_config, "parallel_config", None) if od_config is not None else None |
| 213 | + sp_enabled = getattr(parallel_config, "sequence_parallel_size", 1) > 1 |
| 214 | + except Exception: |
| 215 | + sp_enabled = False |
| 216 | + |
| 217 | + if sp_enabled: |
| 218 | + # In SP, Ulysses expects a 2D padding mask that matches query length. |
| 219 | + # For cross-attention, encoder sequence length != query length, so drop the mask. |
| 220 | + if encoder_hidden_states is not None and encoder_hidden_states.shape[1] != hidden_states.shape[1]: |
| 221 | + attention_mask = None |
| 222 | + else: |
| 223 | + attention_mask = self._to_padding_mask(attention_mask) |
| 224 | + else: |
| 225 | + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
| 226 | + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) |
| 227 | + if attn.attn.attn_backend.get_name().upper() == "FLASH_ATTN": |
| 228 | + attention_mask = self._to_padding_mask(attention_mask) |
211 | 229 |
|
212 | 230 | if is_self_attention: |
213 | 231 | encoder_hidden_states = hidden_states |
@@ -953,18 +971,66 @@ class LTX2VideoTransformer3DModel( |
953 | 971 | _supports_gradient_checkpointing = True |
954 | 972 | _skip_layerwise_casting_patterns = ["norm"] |
955 | 973 | _repeated_blocks = ["LTX2VideoTransformerBlock"] |
956 | | - _cp_plan = { |
957 | | - "": { |
958 | | - "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), |
959 | | - "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), |
960 | | - "encoder_attention_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False), |
961 | | - }, |
962 | | - "rope": { |
963 | | - 0: ContextParallelInput(split_dim=1, expected_dims=3, split_output=True), |
964 | | - 1: ContextParallelInput(split_dim=1, expected_dims=3, split_output=True), |
965 | | - }, |
966 | | - "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3), |
967 | | - } |
| 974 | + _sp_plan: dict[str, Any] | None = None |
| 975 | + |
| 976 | + @staticmethod |
| 977 | + def _build_sp_plan(rope_type: str) -> dict[str, Any]: |
| 978 | + if rope_type == "split": |
| 979 | + # split RoPE returns (B, H, T, D/2) -> shard along T dim |
| 980 | + rope_expected_dims = 4 |
| 981 | + rope_split_dim = 2 |
| 982 | + else: |
| 983 | + # interleaved RoPE returns (B, T, D) -> shard along T dim |
| 984 | + rope_expected_dims = 3 |
| 985 | + rope_split_dim = 1 |
| 986 | + |
| 987 | + return { |
| 988 | + "": { |
| 989 | + # Shard video/audio latents across sequence |
| 990 | + "hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3, split_output=False), |
| 991 | + "audio_hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3, split_output=False), |
| 992 | + # Shard prompt embeds across sequence |
| 993 | + "encoder_hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3, split_output=False), |
| 994 | + "audio_encoder_hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3, split_output=False), |
| 995 | + # Shard video timestep when provided as (B, seq_len) |
| 996 | + "timestep": SequenceParallelInput(split_dim=1, expected_dims=2, split_output=False), |
| 997 | + }, |
| 998 | + "rope": { |
| 999 | + 0: SequenceParallelInput( |
| 1000 | + split_dim=rope_split_dim, expected_dims=rope_expected_dims, split_output=True |
| 1001 | + ), |
| 1002 | + 1: SequenceParallelInput( |
| 1003 | + split_dim=rope_split_dim, expected_dims=rope_expected_dims, split_output=True |
| 1004 | + ), |
| 1005 | + }, |
| 1006 | + "audio_rope": { |
| 1007 | + 0: SequenceParallelInput( |
| 1008 | + split_dim=rope_split_dim, expected_dims=rope_expected_dims, split_output=True |
| 1009 | + ), |
| 1010 | + 1: SequenceParallelInput( |
| 1011 | + split_dim=rope_split_dim, expected_dims=rope_expected_dims, split_output=True |
| 1012 | + ), |
| 1013 | + }, |
| 1014 | + "cross_attn_rope": { |
| 1015 | + 0: SequenceParallelInput( |
| 1016 | + split_dim=rope_split_dim, expected_dims=rope_expected_dims, split_output=True |
| 1017 | + ), |
| 1018 | + 1: SequenceParallelInput( |
| 1019 | + split_dim=rope_split_dim, expected_dims=rope_expected_dims, split_output=True |
| 1020 | + ), |
| 1021 | + }, |
| 1022 | + "cross_attn_audio_rope": { |
| 1023 | + 0: SequenceParallelInput( |
| 1024 | + split_dim=rope_split_dim, expected_dims=rope_expected_dims, split_output=True |
| 1025 | + ), |
| 1026 | + 1: SequenceParallelInput( |
| 1027 | + split_dim=rope_split_dim, expected_dims=rope_expected_dims, split_output=True |
| 1028 | + ), |
| 1029 | + }, |
| 1030 | + # Gather outputs before returning |
| 1031 | + "proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3), |
| 1032 | + "audio_proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3), |
| 1033 | + } |
968 | 1034 |
|
969 | 1035 | @register_to_config |
970 | 1036 | def __init__( |
@@ -1153,6 +1219,7 @@ def __init__( |
1153 | 1219 | self.audio_proj_out = nn.Linear(audio_inner_dim, audio_out_channels) |
1154 | 1220 |
|
1155 | 1221 | self.gradient_checkpointing = False |
| 1222 | + self._sp_plan = self._build_sp_plan(rope_type) |
1156 | 1223 |
|
1157 | 1224 | def forward( |
1158 | 1225 | self, |
|
0 commit comments