Skip to content

Commit 72bb6c8

Browse files
committed
support sp
Signed-off-by: David Chen <530634352@qq.com>
1 parent 96b1075 commit 72bb6c8

File tree

3 files changed

+115
-19
lines changed

3 files changed

+115
-19
lines changed

examples/offline_inference/image_to_video/image_to_video.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import PIL.Image
3535
import torch
3636

37+
from vllm_omni.diffusion.data import DiffusionParallelConfig
3738
from vllm_omni.entrypoints.omni import Omni
3839
from vllm_omni.outputs import OmniRequestOutput
3940
from vllm_omni.utils.platform_utils import detect_device_type, is_npu
@@ -110,6 +111,18 @@ def parse_args() -> argparse.Namespace:
110111
"Default: None (no cache acceleration)."
111112
),
112113
)
114+
parser.add_argument(
115+
"--ulysses_degree",
116+
type=int,
117+
default=1,
118+
help="Number of GPUs used for ulysses sequence parallelism.",
119+
)
120+
parser.add_argument(
121+
"--ring_degree",
122+
type=int,
123+
default=1,
124+
help="Number of GPUs used for ring sequence parallelism.",
125+
)
113126
return parser.parse_args()
114127

115128

@@ -183,6 +196,11 @@ def main():
183196
"rel_l1_thresh": 0.2,
184197
}
185198

199+
parallel_config = DiffusionParallelConfig(
200+
ulysses_degree=args.ulysses_degree,
201+
ring_degree=args.ring_degree,
202+
)
203+
186204
# Check if profiling is requested via environment variable
187205
profiler_enabled = bool(os.getenv("VLLM_TORCH_PROFILER_DIR"))
188206

@@ -196,6 +214,7 @@ def main():
196214
model_class_name=model_class_name,
197215
cache_backend=args.cache_backend,
198216
cache_config=cache_config,
217+
parallel_config=parallel_config,
199218
)
200219

201220
if profiler_enabled:

vllm_omni/diffusion/models/ltx2/ltx2_transformer.py

Lines changed: 84 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import torch.nn as nn
2323
from diffusers.configuration_utils import ConfigMixin, register_to_config
2424
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
25-
from diffusers.models._modeling_parallel import ContextParallelInput, ContextParallelOutput
2625
from diffusers.models.attention import AttentionMixin, AttentionModuleMixin, FeedForward
2726
from diffusers.models.cache_utils import CacheMixin
2827
from diffusers.models.embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings, PixArtAlphaTextProjection
@@ -41,6 +40,8 @@
4140

4241
from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata
4342
from vllm_omni.diffusion.attention.layer import Attention
43+
from vllm_omni.diffusion.distributed.sp_plan import SequenceParallelInput, SequenceParallelOutput
44+
from vllm_omni.diffusion.forward_context import get_forward_context, is_forward_context_available
4445

4546
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
4647

@@ -204,10 +205,27 @@ def __call__(
204205
)
205206

206207
if attention_mask is not None:
207-
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
208-
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
209-
if attn.attn.attn_backend.get_name().upper() == "FLASH_ATTN":
210-
attention_mask = self._to_padding_mask(attention_mask)
208+
sp_enabled = False
209+
if is_forward_context_available():
210+
try:
211+
od_config = get_forward_context().omni_diffusion_config
212+
parallel_config = getattr(od_config, "parallel_config", None) if od_config is not None else None
213+
sp_enabled = getattr(parallel_config, "sequence_parallel_size", 1) > 1
214+
except Exception:
215+
sp_enabled = False
216+
217+
if sp_enabled:
218+
# In SP, Ulysses expects a 2D padding mask that matches query length.
219+
# For cross-attention, encoder sequence length != query length, so drop the mask.
220+
if encoder_hidden_states is not None and encoder_hidden_states.shape[1] != hidden_states.shape[1]:
221+
attention_mask = None
222+
else:
223+
attention_mask = self._to_padding_mask(attention_mask)
224+
else:
225+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
226+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
227+
if attn.attn.attn_backend.get_name().upper() == "FLASH_ATTN":
228+
attention_mask = self._to_padding_mask(attention_mask)
211229

212230
if is_self_attention:
213231
encoder_hidden_states = hidden_states
@@ -953,18 +971,66 @@ class LTX2VideoTransformer3DModel(
953971
_supports_gradient_checkpointing = True
954972
_skip_layerwise_casting_patterns = ["norm"]
955973
_repeated_blocks = ["LTX2VideoTransformerBlock"]
956-
_cp_plan = {
957-
"": {
958-
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
959-
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
960-
"encoder_attention_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
961-
},
962-
"rope": {
963-
0: ContextParallelInput(split_dim=1, expected_dims=3, split_output=True),
964-
1: ContextParallelInput(split_dim=1, expected_dims=3, split_output=True),
965-
},
966-
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
967-
}
974+
_sp_plan: dict[str, Any] | None = None
975+
976+
@staticmethod
977+
def _build_sp_plan(rope_type: str) -> dict[str, Any]:
978+
if rope_type == "split":
979+
# split RoPE returns (B, H, T, D/2) -> shard along T dim
980+
rope_expected_dims = 4
981+
rope_split_dim = 2
982+
else:
983+
# interleaved RoPE returns (B, T, D) -> shard along T dim
984+
rope_expected_dims = 3
985+
rope_split_dim = 1
986+
987+
return {
988+
"": {
989+
# Shard video/audio latents across sequence
990+
"hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3, split_output=False),
991+
"audio_hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3, split_output=False),
992+
# Shard prompt embeds across sequence
993+
"encoder_hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3, split_output=False),
994+
"audio_encoder_hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3, split_output=False),
995+
# Shard video timestep when provided as (B, seq_len)
996+
"timestep": SequenceParallelInput(split_dim=1, expected_dims=2, split_output=False),
997+
},
998+
"rope": {
999+
0: SequenceParallelInput(
1000+
split_dim=rope_split_dim, expected_dims=rope_expected_dims, split_output=True
1001+
),
1002+
1: SequenceParallelInput(
1003+
split_dim=rope_split_dim, expected_dims=rope_expected_dims, split_output=True
1004+
),
1005+
},
1006+
"audio_rope": {
1007+
0: SequenceParallelInput(
1008+
split_dim=rope_split_dim, expected_dims=rope_expected_dims, split_output=True
1009+
),
1010+
1: SequenceParallelInput(
1011+
split_dim=rope_split_dim, expected_dims=rope_expected_dims, split_output=True
1012+
),
1013+
},
1014+
"cross_attn_rope": {
1015+
0: SequenceParallelInput(
1016+
split_dim=rope_split_dim, expected_dims=rope_expected_dims, split_output=True
1017+
),
1018+
1: SequenceParallelInput(
1019+
split_dim=rope_split_dim, expected_dims=rope_expected_dims, split_output=True
1020+
),
1021+
},
1022+
"cross_attn_audio_rope": {
1023+
0: SequenceParallelInput(
1024+
split_dim=rope_split_dim, expected_dims=rope_expected_dims, split_output=True
1025+
),
1026+
1: SequenceParallelInput(
1027+
split_dim=rope_split_dim, expected_dims=rope_expected_dims, split_output=True
1028+
),
1029+
},
1030+
# Gather outputs before returning
1031+
"proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3),
1032+
"audio_proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3),
1033+
}
9681034

9691035
@register_to_config
9701036
def __init__(
@@ -1153,6 +1219,7 @@ def __init__(
11531219
self.audio_proj_out = nn.Linear(audio_inner_dim, audio_out_channels)
11541220

11551221
self.gradient_checkpointing = False
1222+
self._sp_plan = self._build_sp_plan(rope_type)
11561223

11571224
def forward(
11581225
self,

vllm_omni/diffusion/models/ltx2/pipeline_ltx2.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -539,11 +539,21 @@ def prepare_audio_latents(
539539
latents_per_second = float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio)
540540
latent_length = round(duration_s * latents_per_second)
541541

542+
latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
543+
544+
sp_size = getattr(self.od_config.parallel_config, "sequence_parallel_size", 1)
545+
if sp_size > 1 and latent_length < sp_size:
546+
pad_len = sp_size - latent_length
547+
if latents is not None:
548+
pad_shape = list(latents.shape)
549+
pad_shape[2] = pad_len
550+
padding = torch.zeros(pad_shape, dtype=latents.dtype, device=latents.device)
551+
latents = torch.cat([latents, padding], dim=2)
552+
latent_length = sp_size
553+
542554
if latents is not None:
543555
return latents.to(device=device, dtype=dtype), latent_length
544556

545-
latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
546-
547557
shape = (batch_size, num_channels_latents, latent_length, latent_mel_bins)
548558

549559
if isinstance(generator, list) and len(generator) != batch_size:

0 commit comments

Comments
 (0)