Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,9 @@
"LongCatImageEditPipeline",
"LongCatImagePipeline",
"LTX2ConditionPipeline",
"LTX2HDRPipeline",
"LTX2ImageToVideoPipeline",
"LTX2InContextPipeline",
"LTX2LatentUpsamplePipeline",
"LTX2Pipeline",
"LTXConditionPipeline",
Expand Down Expand Up @@ -1407,7 +1409,9 @@
LongCatImageEditPipeline,
LongCatImagePipeline,
LTX2ConditionPipeline,
LTX2HDRPipeline,
LTX2ImageToVideoPipeline,
LTX2InContextPipeline,
LTX2LatentUpsamplePipeline,
LTX2Pipeline,
LTXConditionPipeline,
Expand Down
15 changes: 13 additions & 2 deletions src/diffusers/models/transformers/transformer_ltx2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1343,6 +1343,7 @@ def forward(
perturbation_mask: torch.Tensor | None = None,
use_cross_timestep: bool = False,
attention_kwargs: dict[str, Any] | None = None,
video_self_attention_mask: torch.Tensor | None = None,
return_dict: bool = True,
) -> torch.Tensor:
"""
Expand Down Expand Up @@ -1408,6 +1409,11 @@ def forward(
`False` is the legacy LTX-2.0 behavior.
attention_kwargs (`dict[str, Any]`, *optional*):
Optional dict of keyword args to be passed to the attention processor.
video_self_attention_mask (`torch.Tensor`, *optional*):
Optional multiplicative self-attention mask of shape `(batch_size, num_video_tokens, num_video_tokens)`
applied to the video self-attention in each transformer block. Values in `[0, 1]` where `1` means full
attention and `0` means masked. Used e.g. by the IC-LoRA pipeline to control attention strength between
noisy tokens and appended reference tokens. Audio self-attention is not affected.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a dict-like structured output of type `AudioVisualModelOutput` or a tuple.

Expand All @@ -1430,6 +1436,11 @@ def forward(
audio_encoder_attention_mask = (1 - audio_encoder_attention_mask.to(audio_hidden_states.dtype)) * -10000.0
audio_encoder_attention_mask = audio_encoder_attention_mask.unsqueeze(1)

# Convert video_self_attention_mask from multiplicative mask ([0, 1]) to additive bias form (0 / -10000)
# matching the encoder_attention_mask convention above. Shape is preserved: (B, T_v, T_v).
if video_self_attention_mask is not None:
video_self_attention_mask = (1 - video_self_attention_mask.to(hidden_states.dtype)) * -10000.0

batch_size = hidden_states.size(0)

# 1. Prepare RoPE positional embeddings
Expand Down Expand Up @@ -1569,7 +1580,7 @@ def forward(
audio_cross_attn_rotary_emb,
encoder_attention_mask,
audio_encoder_attention_mask,
None, # self_attention_mask
video_self_attention_mask, # self_attention_mask (video-only)
None, # audio_self_attention_mask
None, # a2v_cross_attention_mask
None, # v2a_cross_attention_mask
Expand Down Expand Up @@ -1598,7 +1609,7 @@ def forward(
ca_audio_rotary_emb=audio_cross_attn_rotary_emb,
encoder_attention_mask=encoder_attention_mask,
audio_encoder_attention_mask=audio_encoder_attention_mask,
self_attention_mask=None,
self_attention_mask=video_self_attention_mask,
audio_self_attention_mask=None,
a2v_cross_attention_mask=None,
v2a_cross_attention_mask=None,
Expand Down
11 changes: 10 additions & 1 deletion src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,8 @@
_import_structure["ltx2"] = [
"LTX2Pipeline",
"LTX2ConditionPipeline",
"LTX2HDRPipeline",
"LTX2InContextPipeline",
"LTX2ImageToVideoPipeline",
"LTX2LatentUpsamplePipeline",
]
Expand Down Expand Up @@ -780,7 +782,14 @@
LTXLatentUpsamplePipeline,
LTXPipeline,
)
from .ltx2 import LTX2ConditionPipeline, LTX2ImageToVideoPipeline, LTX2LatentUpsamplePipeline, LTX2Pipeline
from .ltx2 import (
LTX2ConditionPipeline,
LTX2HDRPipeline,
LTX2ImageToVideoPipeline,
LTX2InContextPipeline,
LTX2LatentUpsamplePipeline,
LTX2Pipeline,
)
from .lucy import LucyEditPipeline
from .lumina import LuminaPipeline, LuminaText2ImgPipeline
from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline
Expand Down
6 changes: 6 additions & 0 deletions src/diffusers/pipelines/ltx2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,12 @@
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["connectors"] = ["LTX2TextConnectors"]
_import_structure["image_processor"] = ["LTX2VideoHDRProcessor"]
_import_structure["latent_upsampler"] = ["LTX2LatentUpsamplerModel"]
_import_structure["pipeline_ltx2"] = ["LTX2Pipeline"]
_import_structure["pipeline_ltx2_condition"] = ["LTX2ConditionPipeline"]
_import_structure["pipeline_ltx2_hdr_lora"] = ["LTX2HDRPipeline", "LTX2HDRReferenceCondition"]
_import_structure["pipeline_ltx2_ic_lora"] = ["LTX2InContextPipeline", "LTX2ReferenceCondition"]
_import_structure["pipeline_ltx2_image2video"] = ["LTX2ImageToVideoPipeline"]
_import_structure["pipeline_ltx2_latent_upsample"] = ["LTX2LatentUpsamplePipeline"]
_import_structure["vocoder"] = ["LTX2Vocoder", "LTX2VocoderWithBWE"]
Expand All @@ -39,9 +42,12 @@
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .connectors import LTX2TextConnectors
from .image_processor import LTX2VideoHDRProcessor
from .latent_upsampler import LTX2LatentUpsamplerModel
from .pipeline_ltx2 import LTX2Pipeline
from .pipeline_ltx2_condition import LTX2ConditionPipeline
from .pipeline_ltx2_hdr_lora import LTX2HDRPipeline, LTX2HDRReferenceCondition
from .pipeline_ltx2_ic_lora import LTX2InContextPipeline, LTX2ReferenceCondition
from .pipeline_ltx2_image2video import LTX2ImageToVideoPipeline
from .pipeline_ltx2_latent_upsample import LTX2LatentUpsamplePipeline
from .vocoder import LTX2Vocoder, LTX2VocoderWithBWE
Expand Down
Loading
Loading