From 1b0a6d4950220c05b8c1bd09e8e5767dffb82885 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Thu, 23 Apr 2026 11:51:47 +0200 Subject: [PATCH 01/14] LTX-2.X IC LoRA + HDR IC LoRA pipeline implementation draft from Claude --- src/diffusers/__init__.py | 4 + .../models/transformers/transformer_ltx2.py | 16 +- src/diffusers/pipelines/__init__.py | 11 +- src/diffusers/pipelines/ltx2/__init__.py | 6 + src/diffusers/pipelines/ltx2/export_utils.py | 180 ++ .../pipelines/ltx2/image_processor.py | 155 ++ .../pipelines/ltx2/pipeline_ltx2_condition.py | 465 +++- .../pipelines/ltx2/pipeline_ltx2_hdr_lora.py | 1440 +++++++++++ .../pipelines/ltx2/pipeline_ltx2_ic_lora.py | 2214 +++++++++++++++++ 9 files changed, 4369 insertions(+), 122 deletions(-) create mode 100644 src/diffusers/pipelines/ltx2/image_processor.py create mode 100644 src/diffusers/pipelines/ltx2/pipeline_ltx2_hdr_lora.py create mode 100644 src/diffusers/pipelines/ltx2/pipeline_ltx2_ic_lora.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 2cbfd6e29305..8027eb64be91 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -608,6 +608,8 @@ "LongCatImageEditPipeline", "LongCatImagePipeline", "LTX2ConditionPipeline", + "LTX2HDRLoraPipeline", + "LTX2ICLoraPipeline", "LTX2ImageToVideoPipeline", "LTX2LatentUpsamplePipeline", "LTX2Pipeline", @@ -1393,6 +1395,8 @@ LongCatImageEditPipeline, LongCatImagePipeline, LTX2ConditionPipeline, + LTX2HDRLoraPipeline, + LTX2ICLoraPipeline, LTX2ImageToVideoPipeline, LTX2LatentUpsamplePipeline, LTX2Pipeline, diff --git a/src/diffusers/models/transformers/transformer_ltx2.py b/src/diffusers/models/transformers/transformer_ltx2.py index a4915ccfb96a..efd75100ee67 100644 --- a/src/diffusers/models/transformers/transformer_ltx2.py +++ b/src/diffusers/models/transformers/transformer_ltx2.py @@ -1331,6 +1331,7 @@ def forward( audio_sigma: torch.Tensor | None = None, encoder_attention_mask: torch.Tensor | None = None, audio_encoder_attention_mask: torch.Tensor | None = None, + video_self_attention_mask: torch.Tensor | None = None, num_frames: int | None = None, height: int | None = None, width: int | None = None, @@ -1374,6 +1375,12 @@ def forward( Optional multiplicative text attention mask of shape `(batch_size, text_seq_len)`. audio_encoder_attention_mask (`torch.Tensor`, *optional*): Optional multiplicative text attention mask of shape `(batch_size, text_seq_len)` for audio modeling. + video_self_attention_mask (`torch.Tensor`, *optional*): + Optional multiplicative self-attention mask of shape `(batch_size, num_video_tokens, + num_video_tokens)` applied to the video self-attention in each transformer block. Values in `[0, 1]` + where `1` means full attention and `0` means masked. Used e.g. by the IC-LoRA pipeline to control + attention strength between noisy tokens and appended reference tokens. Audio self-attention is not + affected. num_frames (`int`, *optional*): The number of latent video frames. Used if calculating the video coordinates for RoPE. height (`int`, *optional*): @@ -1430,6 +1437,11 @@ def forward( audio_encoder_attention_mask = (1 - audio_encoder_attention_mask.to(audio_hidden_states.dtype)) * -10000.0 audio_encoder_attention_mask = audio_encoder_attention_mask.unsqueeze(1) + # Convert video_self_attention_mask from multiplicative mask ([0, 1]) to additive bias form (0 / -10000) + # matching the encoder_attention_mask convention above. Shape is preserved: (B, T_v, T_v). + if video_self_attention_mask is not None: + video_self_attention_mask = (1 - video_self_attention_mask.to(hidden_states.dtype)) * -10000.0 + batch_size = hidden_states.size(0) # 1. Prepare RoPE positional embeddings @@ -1569,7 +1581,7 @@ def forward( audio_cross_attn_rotary_emb, encoder_attention_mask, audio_encoder_attention_mask, - None, # self_attention_mask + video_self_attention_mask, # self_attention_mask (video-only) None, # audio_self_attention_mask None, # a2v_cross_attention_mask None, # v2a_cross_attention_mask @@ -1598,7 +1610,7 @@ def forward( ca_audio_rotary_emb=audio_cross_attn_rotary_emb, encoder_attention_mask=encoder_attention_mask, audio_encoder_attention_mask=audio_encoder_attention_mask, - self_attention_mask=None, + self_attention_mask=video_self_attention_mask, audio_self_attention_mask=None, a2v_cross_attention_mask=None, v2a_cross_attention_mask=None, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index ae1849a587e8..eaf985e7ac87 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -324,6 +324,8 @@ _import_structure["ltx2"] = [ "LTX2Pipeline", "LTX2ConditionPipeline", + "LTX2HDRLoraPipeline", + "LTX2ICLoraPipeline", "LTX2ImageToVideoPipeline", "LTX2LatentUpsamplePipeline", ] @@ -768,7 +770,14 @@ LTXLatentUpsamplePipeline, LTXPipeline, ) - from .ltx2 import LTX2ConditionPipeline, LTX2ImageToVideoPipeline, LTX2LatentUpsamplePipeline, LTX2Pipeline + from .ltx2 import ( + LTX2ConditionPipeline, + LTX2HDRLoraPipeline, + LTX2ICLoraPipeline, + LTX2ImageToVideoPipeline, + LTX2LatentUpsamplePipeline, + LTX2Pipeline, + ) from .lucy import LucyEditPipeline from .lumina import LuminaPipeline, LuminaText2ImgPipeline from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline diff --git a/src/diffusers/pipelines/ltx2/__init__.py b/src/diffusers/pipelines/ltx2/__init__.py index 7177faaf3486..3781f556acae 100644 --- a/src/diffusers/pipelines/ltx2/__init__.py +++ b/src/diffusers/pipelines/ltx2/__init__.py @@ -26,6 +26,9 @@ _import_structure["latent_upsampler"] = ["LTX2LatentUpsamplerModel"] _import_structure["pipeline_ltx2"] = ["LTX2Pipeline"] _import_structure["pipeline_ltx2_condition"] = ["LTX2ConditionPipeline"] + _import_structure["pipeline_ltx2_ic_lora"] = ["LTX2ICLoraPipeline", "LTX2ReferenceCondition"] + _import_structure["pipeline_ltx2_hdr_lora"] = ["LTX2HDRLoraPipeline", "LTX2HDRReferenceCondition"] + _import_structure["image_processor"] = ["LTX2VideoHDRProcessor"] _import_structure["pipeline_ltx2_image2video"] = ["LTX2ImageToVideoPipeline"] _import_structure["pipeline_ltx2_latent_upsample"] = ["LTX2LatentUpsamplePipeline"] _import_structure["vocoder"] = ["LTX2Vocoder", "LTX2VocoderWithBWE"] @@ -42,6 +45,9 @@ from .latent_upsampler import LTX2LatentUpsamplerModel from .pipeline_ltx2 import LTX2Pipeline from .pipeline_ltx2_condition import LTX2ConditionPipeline + from .image_processor import LTX2VideoHDRProcessor + from .pipeline_ltx2_hdr_lora import LTX2HDRLoraPipeline, LTX2HDRReferenceCondition + from .pipeline_ltx2_ic_lora import LTX2ICLoraPipeline, LTX2ReferenceCondition from .pipeline_ltx2_image2video import LTX2ImageToVideoPipeline from .pipeline_ltx2_latent_upsample import LTX2LatentUpsamplePipeline from .vocoder import LTX2Vocoder, LTX2VocoderWithBWE diff --git a/src/diffusers/pipelines/ltx2/export_utils.py b/src/diffusers/pipelines/ltx2/export_utils.py index f0287506b8db..6b7220d08eba 100644 --- a/src/diffusers/pipelines/ltx2/export_utils.py +++ b/src/diffusers/pipelines/ltx2/export_utils.py @@ -16,6 +16,7 @@ from collections.abc import Iterator from fractions import Fraction from itertools import chain +from pathlib import Path import numpy as np import PIL.Image @@ -189,3 +190,182 @@ def encode_video( _write_audio(container, audio_stream, audio, audio_sample_rate) container.close() + + +# --------------------------------------------------------------------------- +# HDR export helpers (used with LTX2HDRLoraPipeline). +# +# These mirror the reference CLI's `save_exr_tensor`, `_linear_to_srgb`, and +# `encode_exr_sequence_to_mp4` in `ltx_pipelines.utils.media_io`. +# --------------------------------------------------------------------------- + + +def save_exr_tensor( + tensor: torch.Tensor | np.ndarray, + file_path: str | Path, + half: bool = False, +) -> None: + r""" + Save a single linear-HDR frame tensor to an OpenEXR file. + + Args: + tensor (`torch.Tensor` or `np.ndarray`): + A float frame of shape `(H, W, C)` or `(C, H, W)` with linear HDR values in `[0, ∞)`. Channels are + assumed to be RGB. + file_path (`str` or `pathlib.Path`): + Output EXR path (e.g. `frame_00000.exr`). + half (`bool`, *optional*, defaults to `False`): + When `True`, writes the file as `float16` (HALF) with ZIP compression. `float16` tensors are always + saved as HALF regardless of this flag. + + The resulting EXR is tagged with Rec.709/sRGB chromaticities and `colorSpace=sRGB` to match the reference. + Requires [OpenImageIO](https://openimageio.readthedocs.io) with OpenEXR support: + `pip install OpenImageIO` (or `pip install oiio`). + """ + try: + import OpenImageIO + except ImportError as e: # pragma: no cover - optional dep + raise ImportError( + "`save_exr_tensor` requires `OpenImageIO`. Install with `pip install OpenImageIO` (with OpenEXR support)." + ) from e + + if isinstance(tensor, torch.Tensor): + use_half = half or tensor.dtype in (torch.float16, torch.half) + if tensor.dim() == 3 and tensor.shape[0] == 3: + tensor = tensor.permute(1, 2, 0) + arr = np.ascontiguousarray(tensor.detach().cpu().numpy().astype(np.float32)) + else: + use_half = half or tensor.dtype == np.float16 + if tensor.ndim == 3 and tensor.shape[0] == 3: + tensor = np.transpose(tensor, (1, 2, 0)) + arr = np.ascontiguousarray(tensor.astype(np.float32)) + + file_path = str(file_path) + h, w = arr.shape[:2] + fmt = OpenImageIO.HALF if use_half else OpenImageIO.FLOAT + spec = OpenImageIO.ImageSpec(w, h, 3, fmt) + spec.channelnames = ("R", "G", "B") + spec.attribute("compression", "zip") + spec.attribute( + "chromaticities", "float[8]", (0.64, 0.33, 0.30, 0.60, 0.15, 0.06, 0.3127, 0.3290) + ) + spec.attribute("colorSpace", "sRGB") + + out = OpenImageIO.ImageOutput.create(file_path) + if out is None: + raise RuntimeError( + f"Failed to create EXR writer for '{file_path}'. Ensure OpenImageIO is built with OpenEXR support." + ) + try: + if not out.open(file_path, spec): + raise RuntimeError(f"Failed to open EXR file '{file_path}': {out.geterror()}") + if not out.write_image(arr): + raise RuntimeError(f"Failed to write EXR image '{file_path}': {out.geterror()}") + finally: + out.close() + + +def _linear_to_srgb(x: np.ndarray) -> np.ndarray: + r""" + Apply the sRGB OETF (IEC 61966-2-1) to a linear image. Input values must be in `[0, 1]`; values outside are + clipped. + """ + x = np.clip(x, 0.0, 1.0) + return np.where(x <= 0.0031308, x * 12.92, 1.055 * np.power(x, 1.0 / 2.4) - 0.055) + + +def encode_exr_sequence_to_mp4( + exr_dir: str | Path, + output_mp4: str | Path, + frame_rate: float, + crf: int = 18, +) -> None: + r""" + Convert a linear-HDR EXR frame sequence into an sRGB-tonemapped H.264 `.mp4` preview. + + Each EXR frame is loaded, clipped to `[0, 1]`, passed through the sRGB OETF (no exposure/gain, EV=0), quantized + to 8-bit BGR, and fed into a libx264 stream at the supplied `frame_rate`. This mirrors the reference CLI's + `encode_exr_sequence_to_mp4`. + + Args: + exr_dir (`str` or `pathlib.Path`): + Directory containing `frame_*.exr` files (sorted lexicographically). + output_mp4 (`str` or `pathlib.Path`): + Output MP4 path. + frame_rate (`float`): + Frame rate for the output video. + crf (`int`, *optional*, defaults to `18`): + libx264 CRF quality factor. Lower values produce higher quality. + + Requires `opencv-python` (for EXR reading via `OPENCV_IO_ENABLE_OPENEXR`). + """ + import os + + os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" + try: + import cv2 + except ImportError as e: # pragma: no cover - optional dep + raise ImportError( + "`encode_exr_sequence_to_mp4` requires `opencv-python`. Install with `pip install opencv-python`." + ) from e + + exr_dir = Path(exr_dir) + exr_files = sorted(exr_dir.glob("frame_*.exr")) + if not exr_files: + raise FileNotFoundError(f"No EXR frames found in {exr_dir}") + + container = av.open(str(output_mp4), mode="w") + stream = container.add_stream("libx264", rate=Fraction(frame_rate).limit_denominator(1000)) + stream.pix_fmt = "yuv420p" + stream.options = {"crf": str(crf), "movflags": "+faststart"} + + try: + for i, exr_path in enumerate(exr_files): + hdr = cv2.imread(str(exr_path), cv2.IMREAD_UNCHANGED).astype(np.float32) + sdr = _linear_to_srgb(np.maximum(hdr, 0.0)) + bgr8 = (sdr * 255.0 + 0.5).astype(np.uint8) + + if i == 0: + stream.height = bgr8.shape[0] + stream.width = bgr8.shape[1] + + frame = av.VideoFrame.from_ndarray(bgr8, format="bgr24") + for packet in stream.encode(frame): + container.mux(packet) + + for packet in stream.encode(): + container.mux(packet) + finally: + container.close() + + +def save_hdr_video_frames_as_exr( + frames: torch.Tensor | np.ndarray, + exr_dir: str | Path, + *, + half: bool = False, +) -> list[Path]: + r""" + Save a batch of linear-HDR frames to a directory as `frame_{idx:05d}.exr` files. + + Args: + frames (`torch.Tensor` or `np.ndarray`): + HDR video tensor of shape `(F, H, W, C)` or `(F, C, H, W)` with linear HDR values in `[0, ∞)`. + exr_dir (`str` or `pathlib.Path`): + Output directory. Created if missing. + half (`bool`, *optional*, defaults to `False`): + Forwarded to [`save_exr_tensor`] — when `True`, writes EXR files as `float16`. + + Returns: + `list[pathlib.Path]`: Paths of the written EXR files, in frame order. + """ + exr_dir = Path(exr_dir) + exr_dir.mkdir(parents=True, exist_ok=True) + paths: list[Path] = [] + num_frames = frames.shape[0] + for i in range(num_frames): + frame = frames[i] + path = exr_dir / f"frame_{i:05d}.exr" + save_exr_tensor(frame, path, half=half) + paths.append(path) + return paths diff --git a/src/diffusers/pipelines/ltx2/image_processor.py b/src/diffusers/pipelines/ltx2/image_processor.py new file mode 100644 index 000000000000..2de3bb4998e9 --- /dev/null +++ b/src/diffusers/pipelines/ltx2/image_processor.py @@ -0,0 +1,155 @@ +# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn.functional as F + +from ...configuration_utils import register_to_config +from ...video_processor import VideoProcessor + + +class LTX2VideoHDRProcessor(VideoProcessor): + r""" + Video processor for the LTX-2 HDR IC-LoRA pipeline. + + Inherits standard video preprocessing from [`VideoProcessor`] and additionally supports: + + - `preprocess_reference_video_hdr`: aspect-ratio-preserving resize followed by reflect-padding to the target size. + For LDR (SDR Rec.709) reference videos, `LogC3.compress_ldr` is an identity clamp, so the numerical output is + equivalent to the standard [-1, 1] normalization used by [`VideoProcessor.preprocess_video`] — only the resize + strategy differs (reflect-pad vs center-crop). + + - `postprocess_hdr_video`: applies the LogC3 inverse transform to the VAE's decoded output, mapping `[0, 1]` → + linear HDR `[0, ∞)`. This is the caller-facing counterpart to `apply_hdr_decode_postprocess` in the reference + `ltx_core.hdr` module. + + Args: + vae_scale_factor (`int`, *optional*, defaults to `32`): + VAE (spatial) scale factor for the LTX-2 video VAE. + resample (`str`, *optional*, defaults to `"bilinear"`): + Resampling filter used by the base [`VaeImageProcessor`] for PIL/tensor resizing. + hdr_transform (`str`, *optional*, defaults to `"logc3"`): + HDR transform identifier. Only `"logc3"` (ARRI EI 800) is currently supported. + """ + + # LogC3 (ARRI EI 800) coefficients, ported from `ltx_core.hdr.LogC3`. + _LOGC3_A = 5.555556 + _LOGC3_B = 0.052272 + _LOGC3_C = 0.247190 + _LOGC3_D = 0.385537 + _LOGC3_E = 5.367655 + _LOGC3_F = 0.092809 + _LOGC3_CUT = 0.010591 + + @register_to_config + def __init__( + self, + vae_scale_factor: int = 32, + resample: str = "bilinear", + hdr_transform: str = "logc3", + ): + super().__init__( + do_resize=True, + vae_scale_factor=vae_scale_factor, + resample=resample, + ) + if hdr_transform != "logc3": + raise ValueError(f"Unsupported HDR transform {hdr_transform!r}. Only 'logc3' is supported.") + + @classmethod + def _logc3_decompress(cls, logc: torch.Tensor) -> torch.Tensor: + r"""Decompress LogC3 `[0, 1]` → linear HDR `[0, ∞)`.""" + logc = torch.clamp(logc, 0.0, 1.0) + cut_log = cls._LOGC3_E * cls._LOGC3_CUT + cls._LOGC3_F + lin_from_log = (torch.pow(10.0, (logc - cls._LOGC3_D) / cls._LOGC3_C) - cls._LOGC3_B) / cls._LOGC3_A + lin_from_lin = (logc - cls._LOGC3_F) / cls._LOGC3_E + return torch.where(logc >= cut_log, lin_from_log, lin_from_lin) + + @staticmethod + def _resize_and_reflect_pad_video(video: torch.Tensor, height: int, width: int) -> torch.Tensor: + r""" + Resize a video tensor preserving aspect ratio, then reflect-pad to the exact target dimensions. + + Mirrors `resize_and_reflect_pad` in the reference `ltx_pipelines.utils.media_io`. When the source is already + at least as large as the target in both dimensions, the interpolation step is skipped entirely. + + Args: + video (`torch.Tensor`): Input of shape `(B, C, F, H, W)`. + height (`int`), width (`int`): Target spatial dimensions. + + Returns: + `torch.Tensor`: Resized and padded video of shape `(B, C, F, height, width)`. + """ + b, c, f, src_h, src_w = video.shape + + if height >= src_h and width >= src_w: + new_h, new_w = src_h, src_w + else: + scale = min(height / src_h, width / src_w) + new_h = round(src_h * scale) + new_w = round(src_w * scale) + # (B, C, F, H, W) → (B, F, C, H, W) → (B*F, C, H, W) for 2D per-frame interpolation. + video = video.permute(0, 2, 1, 3, 4).reshape(b * f, c, src_h, src_w) + video = F.interpolate(video, size=(new_h, new_w), mode="bilinear", align_corners=False) + video = video.reshape(b, f, c, new_h, new_w).permute(0, 2, 1, 3, 4) + + pad_bottom = height - new_h + pad_right = width - new_w + if pad_bottom > 0 or pad_right > 0: + # `reflect` pad requires the pad amount to be strictly less than the corresponding input dim. + pad_mode = "reflect" if pad_bottom < new_h and pad_right < new_w else "replicate" + video = video.permute(0, 2, 1, 3, 4).reshape(b * f, c, new_h, new_w) + video = F.pad(video, (0, pad_right, 0, pad_bottom), mode=pad_mode) + video = video.reshape(b, f, c, height, width).permute(0, 2, 1, 3, 4) + + return video + + def preprocess_reference_video_hdr( + self, + video, + height: int, + width: int, + ) -> torch.Tensor: + r""" + Preprocess a reference (SDR) video for HDR IC-LoRA conditioning. + + Runs the input through the standard video preprocessing (normalization to `[-1, 1]`) without resizing, then + applies reflect-pad resize to the target dimensions. For LDR inputs this is numerically equivalent to + `load_video_conditioning_hdr` in the reference implementation (since `LogC3.compress_ldr` is an identity clamp + on `[0, 1]` inputs). + + Args: + video: Input accepted by `VideoProcessor.preprocess_video` (list of PIL images, 4D/5D tensor/array, etc.). + height (`int`), width (`int`): Target spatial dimensions. + + Returns: + `torch.Tensor`: Preprocessed video of shape `(B, C, F, height, width)` with values in `[-1, 1]`. + """ + video = self.preprocess_video(video, height=None, width=None) # (B, C, F, src_h, src_w) in [-1, 1] + video = self._resize_and_reflect_pad_video(video, height, width) + return video + + def postprocess_hdr_video(self, video: torch.Tensor) -> torch.Tensor: + r""" + Postprocess the VAE's decoded output to linear HDR. + + Args: + video (`torch.Tensor`): + VAE decoded output in VAE range `[-1, 1]`, shape `(B, C, F, H, W)`. + + Returns: + `torch.Tensor`: Linear HDR video `[0, ∞)`, shape `(B, C, F, H, W)`, dtype `float32`. + """ + video = (video.float() / 2.0 + 0.5).clamp(0.0, 1.0) + return self._logc3_decompress(video) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py index a80d011015cf..f0e8d035d47a 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py @@ -678,7 +678,7 @@ def preprocess_conditions( width: int = 768, num_frames: int = 121, device: torch.device | None = None, - ) -> tuple[list[torch.Tensor], list[float], list[int]]: + ) -> tuple[list[torch.Tensor], list[float], list[int], list[int]]: """ Preprocesses the condition images/videos to torch tensors. @@ -695,14 +695,16 @@ def preprocess_conditions( The device on which to put the preprocessed image/video tensors. Returns: - `Tuple[List[torch.Tensor], List[float], List[int]]`: - Returns a 3-tuple of lists of length `len(conditions)` as follows: + `Tuple[List[torch.Tensor], List[float], List[int], List[int]]`: + Returns a 4-tuple of lists of length `len(conditions)` as follows: 1. The first list is a list of preprocessed video tensors of shape [batch_size=1, num_channels, num_frames, height, width]. 2. The second list is a list of conditioning strengths. - 3. The third list is a list of indices in latent space to insert the corresponding condition. + 3. The third list is a list of latent-space indices for each condition. + 4. The fourth list is a list of (trimmed) pixel-space frame counts per condition. This is needed + for keyframe coord semantics (single-pixel-frame keyframes have a clamped temporal extent). """ - conditioning_frames, conditioning_strengths, conditioning_indices = [], [], [] + conditioning_frames, conditioning_strengths, conditioning_indices, conditioning_pixel_frames = [], [], [], [] if conditions is None: conditions = [] @@ -750,10 +752,11 @@ def preprocess_conditions( conditioning_frames.append(condition_pixels.to(dtype=self.vae.dtype, device=device)) conditioning_strengths.append(condition.strength) conditioning_indices.append(latent_start_idx) + conditioning_pixel_frames.append(truncated_cond_frames) - return conditioning_frames, conditioning_strengths, conditioning_indices + return conditioning_frames, conditioning_strengths, conditioning_indices, conditioning_pixel_frames - def apply_visual_conditioning( + def apply_first_frame_conditioning( self, latents: torch.Tensor, conditioning_mask: torch.Tensor, @@ -764,38 +767,107 @@ def apply_visual_conditioning( latent_width: int, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ - Applies visual conditioning frames to an initial latent. + Apply first-frame visual conditioning by overwriting tokens at the first-frame positions. + + Only conditions with `latent_idx == 0` are applied here (matching `VideoConditionByLatentIndex` in the + reference implementation). Conditions at non-zero latent indices are appended as separate keyframe tokens + via `prepare_keyframe_extras` (matching `VideoConditionByKeyframeIndex`) and are skipped here. Args: latents (`torch.Tensor`): Initial packed (patchified) latents of shape [batch_size, patch_seq_len, hidden_dim]. - conditioning_mask (`torch.Tensor`, *optional*): + conditioning_mask (`torch.Tensor`): Initial packed (patchified) conditioning mask of shape [batch_size, patch_seq_len, 1] with values in - [0, 1] where 0 means that the denoising model output will be fully used and 1 means that the condition - will be fully used (with intermediate values specifying a blend of the denoised and latent values). + [0, 1] where 0 means the denoising model output will be fully used and 1 means the condition will be + fully used. Returns: `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]`: Returns a 3-tuple of tensors where: - 1. The first element is the packed video latents (with unchanged shape [batch_size, patch_seq_len, - hidden_dim]) with the conditions applied - 2. The second element is the packed conditioning mask with conditioning strengths applied - 3. The third element holds the clean conditioning latents. + 1. The packed video latents with first-frame conditions applied. + 2. The packed conditioning mask with first-frame strengths applied. + 3. The clean conditioning latents at first-frame positions (zeros elsewhere). """ - # Latents-like tensor which holds the clean conditioning latents clean_latents = torch.zeros_like(latents) for cond, strength, latent_idx in zip(condition_latents, condition_strengths, condition_indices): + if latent_idx != 0: + # Non-first-frame conditions are handled as keyframe extras (appended tokens) instead. + continue num_cond_tokens = cond.size(1) start_token_idx = latent_idx * latent_height * latent_width end_token_idx = start_token_idx + num_cond_tokens - # Overwrite the portion of latents starting with start_token_idx with the condition latents[:, start_token_idx:end_token_idx] = cond conditioning_mask[:, start_token_idx:end_token_idx] = strength clean_latents[:, start_token_idx:end_token_idx] = cond return latents, conditioning_mask, clean_latents + def _prepare_keyframe_coords( + self, + keyframe_latent_num_frames: int, + keyframe_latent_height: int, + keyframe_latent_width: int, + pixel_frame_idx: int, + num_pixel_frames: int, + fps: float, + device: torch.device, + ) -> torch.Tensor: + """ + Compute positional coordinates for a keyframe condition being appended as extra tokens. + + Mirrors `VideoConditionByKeyframeIndex.apply_to` in the reference implementation: + - Latent coords scaled to pixel space *without* the causal fix (since non-zero-index keyframes don't need + the first-frame causal adjustment). + - Temporal axis offset by `pixel_frame_idx` (the pixel-space index at which the keyframe appears). + - For single-pixel-frame keyframes, the per-patch temporal extent is clamped to `[idx, idx + 1)` so the + keyframe occupies a single pixel timestep rather than the VAE-scaled range. + - Temporal coords divided by `fps` to produce seconds. + """ + patch_size = self.transformer_spatial_patch_size + patch_size_t = self.transformer_temporal_patch_size + scale_factors = ( + self.vae_temporal_compression_ratio, + self.vae_spatial_compression_ratio, + self.vae_spatial_compression_ratio, + ) + + grid_f = torch.arange( + start=0, end=keyframe_latent_num_frames, step=patch_size_t, dtype=torch.float32, device=device + ) + grid_h = torch.arange( + start=0, end=keyframe_latent_height, step=patch_size, dtype=torch.float32, device=device + ) + grid_w = torch.arange( + start=0, end=keyframe_latent_width, step=patch_size, dtype=torch.float32, device=device + ) + grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing="ij") + grid = torch.stack(grid, dim=0) + + patch_size_delta = torch.tensor((patch_size_t, patch_size, patch_size), dtype=grid.dtype, device=device) + patch_ends = grid + patch_size_delta.view(3, 1, 1, 1) + + latent_coords = torch.stack([grid, patch_ends], dim=-1) # [3, N_F, N_H, N_W, 2] + latent_coords = latent_coords.flatten(1, 3) # [3, num_patches, 2] + latent_coords = latent_coords.unsqueeze(0) # [1, 3, num_patches, 2] + + scale_tensor = torch.tensor(scale_factors, device=device, dtype=latent_coords.dtype) + broadcast_shape = [1] * latent_coords.ndim + broadcast_shape[1] = -1 + pixel_coords = latent_coords * scale_tensor.view(*broadcast_shape) + + # No causal fix: keyframe coords place the keyframe at `pixel_frame_idx` without the first-frame adjustment. + pixel_coords[:, 0, :, :] = pixel_coords[:, 0, :, :] + pixel_frame_idx + + if num_pixel_frames == 1: + # Single-pixel-frame keyframe: clamp temporal extent to [idx, idx + 1). + pixel_coords[:, 0, :, 1:] = pixel_coords[:, 0, :, :1] + 1 + + pixel_coords[:, 0, :, :] = pixel_coords[:, 0, :, :] / fps + + return pixel_coords + + def prepare_latents( self, conditions: LTX2VideoCondition | list[LTX2VideoCondition] | None = None, @@ -804,12 +876,28 @@ def prepare_latents( height: int = 512, width: int = 768, num_frames: int = 121, + frame_rate: float = 24.0, noise_scale: float = 1.0, dtype: torch.dtype | None = None, device: torch.device | None = None, generator: torch.Generator | None = None, latents: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None]: + """ + Prepare noisy video latents, applying frame conditions. + + First-frame conditions (`latent_idx == 0`) are applied by overwriting tokens at the first-frame positions + (`VideoConditionByLatentIndex` semantics). Non-first-frame conditions (`latent_idx > 0`) are packaged as + keyframe extras to be appended to the latent sequence during the transformer forward pass + (`VideoConditionByKeyframeIndex` semantics). + + Returns a 4-tuple: + - `latents`: packed noisy latents (with first-frame replacement applied if applicable). + - `conditioning_mask`: packed conditioning mask (non-zero only at first-frame positions). + - `clean_latents`: clean first-frame conditions at first-frame positions (zeros elsewhere). + - `keyframe_extras`: `(keyframe_latents, keyframe_coords, keyframe_denoise_factors)` for keyframe + conditions at non-zero latent indices, or `None` if there are none. + """ latent_height = height // self.vae_spatial_compression_ratio latent_width = width // self.vae_spatial_compression_ratio latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 @@ -847,41 +935,95 @@ def prepare_latents( ) generator = generator[0] - condition_frames, condition_strengths, condition_indices = self.preprocess_conditions( + condition_frames, condition_strengths, condition_indices, condition_pixel_frames = self.preprocess_conditions( conditions, height, width, num_frames, device=device ) - condition_latents = [] + # Encode each condition through the VAE. We keep both the 5D latent (for coord computation) and the packed + # 3D latent (for first-frame replacement or keyframe append). + condition_latents_5d = [] + condition_latents_packed = [] for condition_tensor in condition_frames: - condition_latent = retrieve_latents( + condition_latent_5d = retrieve_latents( self.vae.encode(condition_tensor), generator=generator, sample_mode="argmax" ) - condition_latent = self._normalize_latents( - condition_latent, self.vae.latents_mean, self.vae.latents_std + condition_latent_5d = self._normalize_latents( + condition_latent_5d, self.vae.latents_mean, self.vae.latents_std ).to(device=device, dtype=dtype) - condition_latent = self._pack_latents( - condition_latent, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + condition_latent_packed = self._pack_latents( + condition_latent_5d, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size ) - condition_latents.append(condition_latent) + condition_latents_5d.append(condition_latent_5d) + condition_latents_packed.append(condition_latent_packed) + # First-frame conditions (latent_idx == 0): replace tokens at the first-frame positions. # NOTE: following the I2V pipeline, we return a conditioning mask. The original LTX 2 code uses a denoising - # mask, which is the inverse of the conditioning mask (`denoise_mask = 1 - conditioning_mask`) - latents, conditioning_mask, clean_latents = self.apply_visual_conditioning( + # mask, which is the inverse of the conditioning mask (`denoise_mask = 1 - conditioning_mask`). + latents, conditioning_mask, clean_latents = self.apply_first_frame_conditioning( latents, conditioning_mask, - condition_latents, + condition_latents_packed, condition_strengths, condition_indices, latent_height=latent_height, latent_width=latent_width, ) + # Non-first-frame conditions (latent_idx > 0): append as keyframe extras with offset pixel coords. + frame_scale_factor = self.vae_temporal_compression_ratio + keyframe_tokens = [] + keyframe_coords = [] + keyframe_denoise_factors = [] + for cond_5d, cond_packed, strength, latent_idx, num_pixel_frames in zip( + condition_latents_5d, + condition_latents_packed, + condition_strengths, + condition_indices, + condition_pixel_frames, + ): + if latent_idx == 0: + continue + + _, _, kf_latent_frames, kf_latent_height, kf_latent_width = cond_5d.shape + # Pixel-space frame index at which the keyframe is placed. Matches the `start_idx` formula in + # `preprocess_conditions` used for trimming. + pixel_frame_idx = (latent_idx - 1) * frame_scale_factor + 1 + + coords = self._prepare_keyframe_coords( + keyframe_latent_num_frames=kf_latent_frames, + keyframe_latent_height=kf_latent_height, + keyframe_latent_width=kf_latent_width, + pixel_frame_idx=pixel_frame_idx, + num_pixel_frames=num_pixel_frames, + fps=frame_rate, + device=device, + ) + + num_tokens = cond_packed.shape[1] + denoise_factor = torch.full( + (1, num_tokens), 1.0 - strength, device=device, dtype=torch.float32 + ) + + keyframe_tokens.append(cond_packed) + keyframe_coords.append(coords) + keyframe_denoise_factors.append(denoise_factor) + + keyframe_extras: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None + if keyframe_tokens: + keyframe_extras = ( + torch.cat(keyframe_tokens, dim=1), + torch.cat(keyframe_coords, dim=2), + torch.cat(keyframe_denoise_factors, dim=1), + ) + else: + keyframe_extras = None + # Sample from the standard Gaussian prior (or an intermediate Gaussian distribution if noise_scale < 1.0). noise = randn_tensor(latents.shape, generator=generator, device=latents.device, dtype=latents.dtype) scaled_mask = (1.0 - conditioning_mask) * noise_scale # Add noise to the `latents` so that it is at the noise level specified by `noise_scale`. latents = noise * scaled_mask + latents * (1 - scaled_mask) - return latents, conditioning_mask, clean_latents + return latents, conditioning_mask, clean_latents, keyframe_extras def prepare_audio_latents( self, @@ -994,6 +1136,90 @@ def attention_kwargs(self): def interrupt(self): return self._interrupt + def _run_transformer( + self, + latent_model_input: torch.Tensor, + audio_latent_model_input: torch.Tensor, + video_timestep: torch.Tensor, + audio_timestep: torch.Tensor, + sigma: torch.Tensor, + video_coords: torch.Tensor, + audio_coords: torch.Tensor, + connector_prompt_embeds: torch.Tensor, + connector_audio_prompt_embeds: torch.Tensor, + connector_attention_mask: torch.Tensor, + latent_num_frames: int, + latent_height: int, + latent_width: int, + frame_rate: float, + audio_num_frames: int, + use_cross_timestep: bool, + attention_kwargs: dict[str, Any] | None, + cache_context: str, + extra_latents: torch.Tensor | None = None, + extra_coords: torch.Tensor | None = None, + extra_timestep: torch.Tensor | None = None, + isolate_modalities: bool = False, + spatio_temporal_guidance_blocks: list[int] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Run a single transformer forward pass, optionally concatenating extra tokens (keyframe conditions) to the + video hidden states. + + When `extra_latents` is provided, the extra tokens are concatenated to the video hidden states, video coords, + and video timesteps. After the transformer forward pass, the extra tokens are stripped from the video output + so only the noisy-token predictions are returned. + + Returns: + `(noise_pred_video, noise_pred_audio)` where `noise_pred_video` has the same sequence length as the input + `latent_model_input` (extras are stripped). + """ + video_seq_len = latent_model_input.shape[1] + + if extra_latents is not None: + batch_size = latent_model_input.shape[0] + extra_batch = extra_latents.to(latent_model_input.dtype).expand(batch_size, -1, -1) + combined_hidden = torch.cat([latent_model_input, extra_batch], dim=1) + + extra_coords_batch = extra_coords.expand(batch_size, -1, -1, -1) + combined_coords = torch.cat([video_coords, extra_coords_batch], dim=2) + + extra_ts_batch = extra_timestep.expand(batch_size, -1) + combined_timestep = torch.cat([video_timestep, extra_ts_batch], dim=1) + else: + combined_hidden = latent_model_input + combined_coords = video_coords + combined_timestep = video_timestep + + with self.transformer.cache_context(cache_context): + noise_pred_combined, noise_pred_audio = self.transformer( + hidden_states=combined_hidden, + audio_hidden_states=audio_latent_model_input, + encoder_hidden_states=connector_prompt_embeds, + audio_encoder_hidden_states=connector_audio_prompt_embeds, + timestep=combined_timestep, + audio_timestep=audio_timestep, + sigma=sigma, + encoder_attention_mask=connector_attention_mask, + audio_encoder_attention_mask=connector_attention_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + fps=frame_rate, + audio_num_frames=audio_num_frames, + video_coords=combined_coords, + audio_coords=audio_coords, + isolate_modalities=isolate_modalities, + spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks, + perturbation_mask=None, + use_cross_timestep=use_cross_timestep, + attention_kwargs=attention_kwargs, + return_dict=False, + ) + + noise_pred_video = noise_pred_combined[:, :video_seq_len] + return noise_pred_video, noise_pred_audio + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -1271,18 +1497,19 @@ def __call__( # video_sequence_length = latent_num_frames * latent_height * latent_width num_channels_latents = self.transformer.config.in_channels - latents, conditioning_mask, clean_latents = self.prepare_latents( - conditions, - batch_size * num_videos_per_prompt, - num_channels_latents, - height, - width, - num_frames, - noise_scale, - torch.float32, - device, - generator, - latents, + latents, conditioning_mask, clean_latents, keyframe_extras = self.prepare_latents( + conditions=conditions, + batch_size=batch_size * num_videos_per_prompt, + num_channels_latents=num_channels_latents, + height=height, + width=width, + num_frames=num_frames, + frame_rate=frame_rate, + noise_scale=noise_scale, + dtype=torch.float32, + device=device, + generator=generator, + latents=latents, ) if self.do_classifier_free_guidance: conditioning_mask = torch.cat([conditioning_mask, conditioning_mask]) @@ -1377,31 +1604,35 @@ def __call__( timestep = t.expand(latent_model_input.shape[0]) video_timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask.squeeze(-1)) - with self.transformer.cache_context("cond_uncond"): - noise_pred_video, noise_pred_audio = self.transformer( - hidden_states=latent_model_input, - audio_hidden_states=audio_latent_model_input, - encoder_hidden_states=connector_prompt_embeds, - audio_encoder_hidden_states=connector_audio_prompt_embeds, - timestep=video_timestep, - audio_timestep=timestep, - sigma=timestep, # Used by LTX-2.3 - encoder_attention_mask=connector_attention_mask, - audio_encoder_attention_mask=connector_attention_mask, - num_frames=latent_num_frames, - height=latent_height, - width=latent_width, - fps=frame_rate, - audio_num_frames=audio_num_frames, - video_coords=video_coords, - audio_coords=audio_coords, - isolate_modalities=False, - spatio_temporal_guidance_blocks=None, - perturbation_mask=None, - use_cross_timestep=use_cross_timestep, - attention_kwargs=attention_kwargs, - return_dict=False, - ) + # Per-token timestep for keyframe extras: sigma * (1 - strength), i.e. 0 for fully-clean keyframes. + extra_latents_in = extra_coords_in = extra_timestep_in = None + if keyframe_extras is not None: + extra_latents_in, extra_coords_in, keyframe_denoise_factors = keyframe_extras + extra_timestep_in = t * keyframe_denoise_factors + + noise_pred_video, noise_pred_audio = self._run_transformer( + latent_model_input=latent_model_input, + audio_latent_model_input=audio_latent_model_input, + video_timestep=video_timestep, + audio_timestep=timestep, + sigma=timestep, + video_coords=video_coords, + audio_coords=audio_coords, + connector_prompt_embeds=connector_prompt_embeds, + connector_audio_prompt_embeds=connector_audio_prompt_embeds, + connector_attention_mask=connector_attention_mask, + latent_num_frames=latent_num_frames, + latent_height=latent_height, + latent_width=latent_width, + frame_rate=frame_rate, + audio_num_frames=audio_num_frames, + use_cross_timestep=use_cross_timestep, + attention_kwargs=attention_kwargs, + cache_context="cond_uncond", + extra_latents=extra_latents_in, + extra_coords=extra_coords_in, + extra_timestep=extra_timestep_in, + ) noise_pred_video = noise_pred_video.float() noise_pred_audio = noise_pred_audio.float() @@ -1451,32 +1682,30 @@ def __call__( noise_pred_audio = self.convert_velocity_to_x0(audio_latents, noise_pred_audio, i, audio_scheduler) if self.do_spatio_temporal_guidance: - with self.transformer.cache_context("uncond_stg"): - noise_pred_video_uncond_stg, noise_pred_audio_uncond_stg = self.transformer( - hidden_states=latents.to(dtype=prompt_embeds.dtype), - audio_hidden_states=audio_latents.to(dtype=prompt_embeds.dtype), - encoder_hidden_states=video_prompt_embeds, - audio_encoder_hidden_states=audio_prompt_embeds, - timestep=video_timestep, - audio_timestep=timestep, - sigma=timestep, # Used by LTX-2.3 - encoder_attention_mask=prompt_attn_mask, - audio_encoder_attention_mask=prompt_attn_mask, - num_frames=latent_num_frames, - height=latent_height, - width=latent_width, - fps=frame_rate, - audio_num_frames=audio_num_frames, - video_coords=video_pos_ids, - audio_coords=audio_pos_ids, - isolate_modalities=False, - # Use STG at given blocks to perturb model - spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks, - perturbation_mask=None, - use_cross_timestep=use_cross_timestep, - attention_kwargs=attention_kwargs, - return_dict=False, - ) + noise_pred_video_uncond_stg, noise_pred_audio_uncond_stg = self._run_transformer( + latent_model_input=latents.to(dtype=prompt_embeds.dtype), + audio_latent_model_input=audio_latents.to(dtype=prompt_embeds.dtype), + video_timestep=video_timestep, + audio_timestep=timestep, + sigma=timestep, + video_coords=video_pos_ids, + audio_coords=audio_pos_ids, + connector_prompt_embeds=video_prompt_embeds, + connector_audio_prompt_embeds=audio_prompt_embeds, + connector_attention_mask=prompt_attn_mask, + latent_num_frames=latent_num_frames, + latent_height=latent_height, + latent_width=latent_width, + frame_rate=frame_rate, + audio_num_frames=audio_num_frames, + use_cross_timestep=use_cross_timestep, + attention_kwargs=attention_kwargs, + cache_context="uncond_stg", + extra_latents=extra_latents_in, + extra_coords=extra_coords_in, + extra_timestep=extra_timestep_in, + spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks, + ) noise_pred_video_uncond_stg = noise_pred_video_uncond_stg.float() noise_pred_audio_uncond_stg = noise_pred_audio_uncond_stg.float() noise_pred_video_uncond_stg = self.convert_velocity_to_x0( @@ -1492,32 +1721,30 @@ def __call__( video_stg_delta = audio_stg_delta = 0 if self.do_modality_isolation_guidance: - with self.transformer.cache_context("uncond_modality"): - noise_pred_video_uncond_modality, noise_pred_audio_uncond_modality = self.transformer( - hidden_states=latents.to(dtype=prompt_embeds.dtype), - audio_hidden_states=audio_latents.to(dtype=prompt_embeds.dtype), - encoder_hidden_states=video_prompt_embeds, - audio_encoder_hidden_states=audio_prompt_embeds, - timestep=video_timestep, - audio_timestep=timestep, - sigma=timestep, # Used by LTX-2.3 - encoder_attention_mask=prompt_attn_mask, - audio_encoder_attention_mask=prompt_attn_mask, - num_frames=latent_num_frames, - height=latent_height, - width=latent_width, - fps=frame_rate, - audio_num_frames=audio_num_frames, - video_coords=video_pos_ids, - audio_coords=audio_pos_ids, - # Turn off A2V and V2A cross attn to isolate video and audio modalities - isolate_modalities=True, - spatio_temporal_guidance_blocks=None, - perturbation_mask=None, - use_cross_timestep=use_cross_timestep, - attention_kwargs=attention_kwargs, - return_dict=False, - ) + noise_pred_video_uncond_modality, noise_pred_audio_uncond_modality = self._run_transformer( + latent_model_input=latents.to(dtype=prompt_embeds.dtype), + audio_latent_model_input=audio_latents.to(dtype=prompt_embeds.dtype), + video_timestep=video_timestep, + audio_timestep=timestep, + sigma=timestep, + video_coords=video_pos_ids, + audio_coords=audio_pos_ids, + connector_prompt_embeds=video_prompt_embeds, + connector_audio_prompt_embeds=audio_prompt_embeds, + connector_attention_mask=prompt_attn_mask, + latent_num_frames=latent_num_frames, + latent_height=latent_height, + latent_width=latent_width, + frame_rate=frame_rate, + audio_num_frames=audio_num_frames, + use_cross_timestep=use_cross_timestep, + attention_kwargs=attention_kwargs, + cache_context="uncond_modality", + extra_latents=extra_latents_in, + extra_coords=extra_coords_in, + extra_timestep=extra_timestep_in, + isolate_modalities=True, + ) noise_pred_video_uncond_modality = noise_pred_video_uncond_modality.float() noise_pred_audio_uncond_modality = noise_pred_audio_uncond_modality.float() noise_pred_video_uncond_modality = self.convert_velocity_to_x0( diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_hdr_lora.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_hdr_lora.py new file mode 100644 index 000000000000..886ac094233e --- /dev/null +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_hdr_lora.py @@ -0,0 +1,1440 @@ +# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import inspect +from dataclasses import dataclass +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import FromSingleFileMixin, LTX2LoraLoaderMixin +from ...models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video +from ...models.transformers import LTX2VideoTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .connectors import LTX2TextConnectors +from .image_processor import LTX2VideoHDRProcessor +from .pipeline_output import LTX2PipelineOutput +from .vocoder import LTX2Vocoder, LTX2VocoderWithBWE + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class LTX2HDRReferenceCondition: + r""" + A reference video condition for HDR IC-LoRA conditioning. + + The reference video is encoded into latent tokens and concatenated to the noisy latent sequence during + denoising, allowing the HDR IC-LoRA adapter to condition the generation on the reference video content. + + Matches the `(video_path, strength)` tuples consumed by the reference `HDRICLoraPipeline`'s + `video_conditioning` argument. + + Attributes: + frames (`PIL.Image.Image` or `List[PIL.Image.Image]` or `np.ndarray` or `torch.Tensor`): + The reference video frames. Accepts any type handled by `VideoProcessor.preprocess_video`. + strength (`float`, defaults to `1.0`): + Controls how "clean" the reference tokens appear to the model. A value of `1.0` means fully clean + (per-token timestep=0), `0.0` means fully noisy. + """ + + frames: PIL.Image.Image | list[PIL.Image.Image] | np.ndarray | torch.Tensor + strength: float = 1.0 + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import LTX2HDRLoraPipeline + >>> from diffusers.pipelines.ltx2.pipeline_ltx2_hdr_lora import LTX2HDRReferenceCondition + >>> from diffusers.pipelines.ltx2.utils import DISTILLED_SIGMA_VALUES + >>> from diffusers.pipelines.ltx2.export_utils import save_hdr_video_frames_as_exr + >>> from diffusers.utils import load_video + + >>> pipe = LTX2HDRLoraPipeline.from_pretrained( + ... "rootonchair/LTX-2-19b-distilled", torch_dtype=torch.bfloat16 + ... ) + >>> pipe.enable_sequential_cpu_offload(device="cuda") + >>> pipe.load_lora_weights("path/to/hdr_ic_lora.safetensors", adapter_name="hdr_lora") + >>> pipe.set_adapters("hdr_lora", 1.0) + + >>> reference_video = load_video("reference.mp4") + >>> ref_cond = LTX2HDRReferenceCondition(frames=reference_video, strength=1.0) + + >>> prompt = "A cinematic landscape at sunset" + >>> hdr_video = pipe( + ... prompt=prompt, + ... reference_conditions=[ref_cond], + ... width=768, + ... height=512, + ... num_frames=121, + ... frame_rate=24.0, + ... num_inference_steps=8, + ... sigmas=DISTILLED_SIGMA_VALUES, + ... guidance_scale=1.0, + ... output_type="pt", + ... return_dict=False, + ... )[0] + + >>> # `hdr_video` is a linear HDR tensor of shape (batch, frames, H, W, C). + >>> save_hdr_video_frames_as_exr(hdr_video[0], "hdr_output/") + ``` +""" + + +# Copied from diffusers.pipelines.ltx2.pipeline_ltx2_ic_lora.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class LTX2HDRLoraPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin): + r""" + Pipeline for HDR IC-LoRA video generation with reference video conditioning. + + This is a video-only HDR counterpart to [`LTX2ICLoraPipeline`]. The HDR IC-LoRA adapter (loaded as a standard + LoRA via `load_lora_weights`) conditions generation on a reference video, and the pipeline's postprocessing + applies the LogC3 inverse transform to produce linear HDR output in `[0, ∞)`. + + Compared to [`LTX2ICLoraPipeline`], the HDR pipeline drops: + + - Frame-level keyframe conditioning (the reference HDR pipeline does not support this). + - The `conditioning_attention_strength` / `conditioning_attention_mask` knobs. + - Audio output (video-only). The transformer's audio branch is still run since the diffusers transformer API + requires audio inputs, but the decoded audio is discarded and audio-specific guidance scales are fixed to + no-op values to avoid wasted compute. + + Two-stage inference is supported through separate calls to `__call__`: + + - **Stage 1**: generate video latents at target resolution with HDR IC-LoRA conditioning + (`output_type="latent"`). + - **Stage 2**: upsample via [`LTX2LatentUpsamplePipeline`] and refine with this same pipeline (or + [`LTX2Pipeline`]) by passing `latents=upsampled_latents`. The reference HDR stage-2 additionally supports + spatial/temporal tiling of the refinement pass — that optimization is not yet implemented here. + + Reference: https://github.com/Lightricks/LTX-2 + + Args: + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + Scheduler used in the denoising loop. + vae ([`AutoencoderKLLTX2Video`]): + Video VAE. + audio_vae ([`AutoencoderKLLTX2Audio`]): + Audio VAE. Required for transformer compatibility; its outputs are discarded. + text_encoder ([`transformers.Gemma3ForConditionalGeneration`]): + Text encoder. + tokenizer (`GemmaTokenizer` or `GemmaTokenizerFast`): + Tokenizer for the text encoder. + connectors ([`LTX2TextConnectors`]): + Text connector stack for the transformer. + transformer ([`LTX2VideoTransformer3DModel`]): + Transformer backbone. + vocoder ([`LTX2Vocoder`] or [`LTX2VocoderWithBWE`]): + Vocoder. Required for transformer compatibility; its outputs are discarded. + hdr_transform (`str`, *optional*, defaults to `"logc3"`): + HDR transform identifier applied during postprocessing. Currently only `"logc3"` is supported. + """ + + model_cpu_offload_seq = "text_encoder->connectors->transformer->vae->audio_vae->vocoder" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLLTX2Video, + audio_vae: AutoencoderKLLTX2Audio, + text_encoder: Gemma3ForConditionalGeneration, + tokenizer: GemmaTokenizer | GemmaTokenizerFast, + connectors: LTX2TextConnectors, + transformer: LTX2VideoTransformer3DModel, + vocoder: LTX2Vocoder | LTX2VocoderWithBWE, + hdr_transform: str = "logc3", + ): + super().__init__() + + self.register_modules( + vae=vae, + audio_vae=audio_vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + connectors=connectors, + transformer=transformer, + vocoder=vocoder, + scheduler=scheduler, + ) + + self.vae_spatial_compression_ratio = ( + self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32 + ) + self.vae_temporal_compression_ratio = ( + self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8 + ) + self.audio_vae_mel_compression_ratio = ( + self.audio_vae.mel_compression_ratio if getattr(self, "audio_vae", None) is not None else 4 + ) + self.audio_vae_temporal_compression_ratio = ( + self.audio_vae.temporal_compression_ratio if getattr(self, "audio_vae", None) is not None else 4 + ) + self.transformer_spatial_patch_size = ( + self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1 + ) + self.transformer_temporal_patch_size = ( + self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1 + ) + + self.audio_sampling_rate = ( + self.audio_vae.config.sample_rate if getattr(self, "audio_vae", None) is not None else 16000 + ) + self.audio_hop_length = ( + self.audio_vae.config.mel_hop_length if getattr(self, "audio_vae", None) is not None else 160 + ) + + self.hdr_video_processor = LTX2VideoHDRProcessor( + vae_scale_factor=self.vae_spatial_compression_ratio, + hdr_transform=hdr_transform, + ) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024 + ) + + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._get_gemma_prompt_embeds + def _get_gemma_prompt_embeds( + self, + prompt: str | list[str], + num_videos_per_prompt: int = 1, + max_sequence_length: int = 1024, + scale_factor: int = 8, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if getattr(self, "tokenizer", None) is not None: + self.tokenizer.padding_side = "left" + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + prompt = [p.strip() for p in prompt] + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + text_input_ids = text_input_ids.to(device) + prompt_attention_mask = prompt_attention_mask.to(device) + + text_encoder_outputs = self.text_encoder( + input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True + ) + text_encoder_hidden_states = text_encoder_outputs.hidden_states + text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1) + prompt_embeds = text_encoder_hidden_states.flatten(2, 3).to(dtype=dtype) + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) + + return prompt_embeds, prompt_attention_mask + + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + max_sequence_length: int = 1024, + scale_factor: int = 8, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + scale_factor=scale_factor, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + scale_factor=scale_factor, + device=device, + dtype=dtype, + ) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + def check_inputs( + self, + prompt, + height, + width, + callback_on_step_end_tensor_inputs=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + latents=None, + spatio_temporal_guidance_blocks=None, + stg_scale=None, + ): + if height % 32 != 0 or width % 32 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if latents is not None and latents.ndim != 5: + raise ValueError( + f"Only unpacked (5D) video latents of shape `[batch_size, latent_channels, latent_frames," + f" latent_height, latent_width] are supported, but got {latents.ndim} dims." + ) + + if (stg_scale is not None and stg_scale > 0.0) and not spatio_temporal_guidance_blocks: + raise ValueError( + "Spatio-Temporal Guidance (STG) is specified but no STG blocks are supplied. Please supply a list of" + " block indices at which to apply STG in `spatio_temporal_guidance_blocks`" + ) + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_latents + def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = latents.shape + post_patch_num_frames = num_frames // patch_size_t + post_patch_height = height // patch_size + post_patch_width = width // patch_size + latents = latents.reshape( + batch_size, + -1, + post_patch_num_frames, + patch_size_t, + post_patch_height, + patch_size, + post_patch_width, + patch_size, + ) + latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._unpack_latents + def _unpack_latents( + latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1 + ) -> torch.Tensor: + batch_size = latents.size(0) + latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size) + latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2_image2video.LTX2ImageToVideoPipeline._normalize_latents + def _normalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = (latents - latents_mean) * scaling_factor / latents_std + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._denormalize_latents + def _denormalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = latents * latents_std / scaling_factor + latents_mean + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._create_noised_state + def _create_noised_state( + latents: torch.Tensor, noise_scale: float | torch.Tensor, generator: torch.Generator | None = None + ): + noise = randn_tensor(latents.shape, generator=generator, device=latents.device, dtype=latents.dtype) + noised_latents = noise_scale * noise + (1 - noise_scale) * latents + return noised_latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._normalize_audio_latents + def _normalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor): + latents_mean = latents_mean.to(latents.device, latents.dtype) + latents_std = latents_std.to(latents.device, latents.dtype) + return (latents - latents_mean) / latents_std + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._denormalize_audio_latents + def _denormalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor): + latents_mean = latents_mean.to(latents.device, latents.dtype) + latents_std = latents_std.to(latents.device, latents.dtype) + return (latents * latents_std) + latents_mean + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_audio_latents + def _pack_audio_latents( + latents: torch.Tensor, patch_size: int | None = None, patch_size_t: int | None = None + ) -> torch.Tensor: + if patch_size is not None and patch_size_t is not None: + batch_size, num_channels, latent_length, latent_mel_bins = latents.shape + post_patch_latent_length = latent_length / patch_size_t + post_patch_mel_bins = latent_mel_bins / patch_size + latents = latents.reshape( + batch_size, -1, post_patch_latent_length, patch_size_t, post_patch_mel_bins, patch_size + ) + latents = latents.permute(0, 2, 4, 1, 3, 5).flatten(3, 5).flatten(1, 2) + else: + latents = latents.transpose(1, 2).flatten(2, 3) + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._unpack_audio_latents + def _unpack_audio_latents( + latents: torch.Tensor, + latent_length: int, + num_mel_bins: int, + patch_size: int | None = None, + patch_size_t: int | None = None, + ) -> torch.Tensor: + if patch_size is not None and patch_size_t is not None: + batch_size = latents.size(0) + latents = latents.reshape(batch_size, latent_length, num_mel_bins, -1, patch_size_t, patch_size) + latents = latents.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3) + else: + latents = latents.unflatten(2, (-1, num_mel_bins)).transpose(1, 2) + return latents + + def prepare_latents( + self, + batch_size: int = 1, + num_channels_latents: int = 128, + height: int = 512, + width: int = 768, + num_frames: int = 121, + noise_scale: float = 0.0, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + r""" + Prepare noisy video latents. Either allocates fresh noise (Stage 1) or noises supplied latents from a + previous stage (Stage 2 after [`LTX2LatentUpsamplePipeline`]). + """ + if latents is not None: + if latents.ndim == 5: + latents = self._normalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + if latents.ndim != 3: + raise ValueError( + f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is [batch_size," + f" num_seq, num_features]." + ) + latents = self._create_noised_state(latents, noise_scale, generator) + return latents.to(device=device, dtype=dtype) + + height = height // self.vae_spatial_compression_ratio + width = width // self.vae_spatial_compression_ratio + num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + + shape = (batch_size, num_channels_latents, num_frames, height, width) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + return latents + + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline.prepare_audio_latents + def prepare_audio_latents( + self, + batch_size: int = 1, + num_channels_latents: int = 8, + audio_latent_length: int = 1, + num_mel_bins: int = 64, + noise_scale: float = 0.0, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if latents is not None: + if latents.ndim == 4: + latents = self._pack_audio_latents(latents) + if latents.ndim != 3: + raise ValueError( + f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is [batch_size," + f" num_seq, num_features]." + ) + latents = self._normalize_audio_latents(latents, self.audio_vae.latents_mean, self.audio_vae.latents_std) + latents = self._create_noised_state(latents, noise_scale, generator) + return latents.to(device=device, dtype=dtype) + + latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio + + shape = (batch_size, num_channels_latents, audio_latent_length, latent_mel_bins) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_audio_latents(latents) + return latents + + def prepare_reference_latents( + self, + reference_conditions: list[LTX2HDRReferenceCondition], + height: int, + width: int, + num_frames: int, + reference_downscale_factor: int = 1, + frame_rate: float = 24.0, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + r""" + Encode reference videos with HDR preprocessing into packed latent tokens and compute positional coordinates. + + Each reference video is preprocessed via [`LTX2VideoHDRProcessor.preprocess_reference_video_hdr`] (reflect-pad + resize at the reference resolution), VAE-encoded, packed into tokens, and paired with positional coordinates + computed at the reference latent dimensions and scaled by `reference_downscale_factor`. + + Returns a 3-tuple `(reference_latents, reference_coords, reference_denoise_factors)` with the same shapes as + [`LTX2ICLoraPipeline.prepare_reference_latents`]. + """ + ref_height = height // reference_downscale_factor + ref_width = width // reference_downscale_factor + + if reference_downscale_factor != 1 and (height % reference_downscale_factor != 0 or width % reference_downscale_factor != 0): + raise ValueError( + f"Output dimensions ({height}x{width}) must be divisible by reference_downscale_factor " + f"({reference_downscale_factor})." + ) + + all_ref_latents = [] + all_ref_coords = [] + all_ref_denoise_factors = [] + + for ref_cond in reference_conditions: + if isinstance(ref_cond.frames, PIL.Image.Image): + video_like = [ref_cond.frames] + elif isinstance(ref_cond.frames, np.ndarray) and ref_cond.frames.ndim == 3: + video_like = np.expand_dims(ref_cond.frames, axis=0) + elif isinstance(ref_cond.frames, torch.Tensor) and ref_cond.frames.ndim == 3: + video_like = ref_cond.frames.unsqueeze(0) + else: + video_like = ref_cond.frames + + # HDR-specific preprocessing: reflect-pad resize (vs center-crop in the standard IC-LoRA pipeline). + # For LDR reference videos the numerical output of `preprocess_reference_video_hdr` is identical to the + # standard [-1, 1] normalization since LogC3's `compress_ldr` is an identity clamp. + ref_pixels = self.hdr_video_processor.preprocess_reference_video_hdr( + video_like, ref_height, ref_width + ) + ref_pixels = ref_pixels[:, :, :num_frames] + ref_pixels = ref_pixels.to(dtype=self.vae.dtype, device=device) + + ref_latent = retrieve_latents( + self.vae.encode(ref_pixels), generator=generator, sample_mode="argmax" + ) + ref_latent = self._normalize_latents( + ref_latent, self.vae.latents_mean, self.vae.latents_std + ).to(device=device, dtype=dtype) + + _, _, ref_latent_frames, ref_latent_height, ref_latent_width = ref_latent.shape + + ref_latent_packed = self._pack_latents( + ref_latent, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + + ref_coords = self.transformer.rope.prepare_video_coords( + batch_size=1, + num_frames=ref_latent_frames, + height=ref_latent_height, + width=ref_latent_width, + device=device, + fps=frame_rate, + ) + if reference_downscale_factor != 1: + ref_coords[:, 1, :, :] = ref_coords[:, 1, :, :] * reference_downscale_factor + ref_coords[:, 2, :, :] = ref_coords[:, 2, :, :] * reference_downscale_factor + + num_tokens = ref_latent_packed.shape[1] + denoise_factor = torch.full( + (1, num_tokens), 1.0 - ref_cond.strength, device=device, dtype=torch.float32 + ) + + all_ref_latents.append(ref_latent_packed) + all_ref_coords.append(ref_coords) + all_ref_denoise_factors.append(denoise_factor) + + reference_latents = torch.cat(all_ref_latents, dim=1) + reference_coords = torch.cat(all_ref_coords, dim=2) + reference_denoise_factors = torch.cat(all_ref_denoise_factors, dim=1) + + return reference_latents, reference_coords, reference_denoise_factors + + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2_condition.LTX2ConditionPipeline.convert_velocity_to_x0 + def convert_velocity_to_x0( + self, sample: torch.Tensor, denoised_output: torch.Tensor, step_idx: int, scheduler: Any | None = None + ) -> torch.Tensor: + if scheduler is None: + scheduler = self.scheduler + + sample_x0 = sample - denoised_output * scheduler.sigmas[step_idx] + return sample_x0 + + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2_condition.LTX2ConditionPipeline.convert_x0_to_velocity + def convert_x0_to_velocity( + self, sample: torch.Tensor, denoised_output: torch.Tensor, step_idx: int, scheduler: Any | None = None + ) -> torch.Tensor: + if scheduler is None: + scheduler = self.scheduler + + sample_v = (sample - denoised_output) / scheduler.sigmas[step_idx] + return sample_v + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def stg_scale(self): + return self._stg_scale + + @property + def modality_scale(self): + return self._modality_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def do_spatio_temporal_guidance(self): + return self._stg_scale > 0.0 + + @property + def do_modality_isolation_guidance(self): + return self._modality_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2_ic_lora.LTX2ICLoraPipeline._run_transformer + def _run_transformer( + self, + latent_model_input: torch.Tensor, + audio_latent_model_input: torch.Tensor, + video_timestep: torch.Tensor, + audio_timestep: torch.Tensor, + sigma: torch.Tensor, + video_coords: torch.Tensor, + audio_coords: torch.Tensor, + connector_prompt_embeds: torch.Tensor, + connector_audio_prompt_embeds: torch.Tensor, + connector_attention_mask: torch.Tensor, + latent_num_frames: int, + latent_height: int, + latent_width: int, + frame_rate: float, + audio_num_frames: int, + use_cross_timestep: bool, + attention_kwargs: dict[str, Any] | None, + cache_context: str, + extra_latents: torch.Tensor | None = None, + extra_coords: torch.Tensor | None = None, + extra_timestep: torch.Tensor | None = None, + video_self_attention_mask: torch.Tensor | None = None, + isolate_modalities: bool = False, + spatio_temporal_guidance_blocks: list[int] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + video_seq_len = latent_model_input.shape[1] + + if extra_latents is not None: + batch_size = latent_model_input.shape[0] + extra_batch = extra_latents.to(latent_model_input.dtype).expand(batch_size, -1, -1) + combined_hidden = torch.cat([latent_model_input, extra_batch], dim=1) + + extra_coords_batch = extra_coords.expand(batch_size, -1, -1, -1) + combined_coords = torch.cat([video_coords, extra_coords_batch], dim=2) + + extra_ts_batch = extra_timestep.expand(batch_size, -1) + combined_timestep = torch.cat([video_timestep, extra_ts_batch], dim=1) + else: + combined_hidden = latent_model_input + combined_coords = video_coords + combined_timestep = video_timestep + + if video_self_attention_mask is not None: + video_self_attention_mask = video_self_attention_mask.expand(combined_hidden.shape[0], -1, -1) + + with self.transformer.cache_context(cache_context): + noise_pred_combined, noise_pred_audio = self.transformer( + hidden_states=combined_hidden, + audio_hidden_states=audio_latent_model_input, + encoder_hidden_states=connector_prompt_embeds, + audio_encoder_hidden_states=connector_audio_prompt_embeds, + timestep=combined_timestep, + audio_timestep=audio_timestep, + sigma=sigma, + encoder_attention_mask=connector_attention_mask, + audio_encoder_attention_mask=connector_attention_mask, + video_self_attention_mask=video_self_attention_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + fps=frame_rate, + audio_num_frames=audio_num_frames, + video_coords=combined_coords, + audio_coords=audio_coords, + isolate_modalities=isolate_modalities, + spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks, + perturbation_mask=None, + use_cross_timestep=use_cross_timestep, + attention_kwargs=attention_kwargs, + return_dict=False, + ) + + noise_pred_video = noise_pred_combined[:, :video_seq_len] + return noise_pred_video, noise_pred_audio + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str | list[str] | None = None, + reference_conditions: LTX2HDRReferenceCondition | list[LTX2HDRReferenceCondition] | None = None, + reference_downscale_factor: int = 1, + height: int = 512, + width: int = 768, + num_frames: int = 121, + frame_rate: float = 24.0, + num_inference_steps: int = 8, + sigmas: list[float] | None = None, + timesteps: list[float] | None = None, + guidance_scale: float = 1.0, + stg_scale: float = 0.0, + modality_scale: float = 1.0, + guidance_rescale: float = 0.0, + spatio_temporal_guidance_blocks: list[int] | None = None, + noise_scale: float | None = None, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + decode_timestep: float | list[float] = 0.0, + decode_noise_scale: float | list[float] | None = None, + use_cross_timestep: bool = False, + output_type: str = "pt", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 1024, + ): + r""" + Run HDR IC-LoRA video generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt(s) to guide generation. Either `prompt` or `prompt_embeds` must be supplied. + negative_prompt (`str` or `List[str]`, *optional*): + The negative prompt(s). Ignored when `guidance_scale <= 1.0`. + reference_conditions (`LTX2HDRReferenceCondition` or `List[LTX2HDRReferenceCondition]`, *optional*): + Reference video conditions for HDR IC-LoRA conditioning. + reference_downscale_factor (`int`, *optional*, defaults to `1`): + Ratio between target and reference video resolutions. IC-LoRA models trained with downscaled + reference videos store this factor in their safetensors metadata. + height (`int`, *optional*, defaults to `512`): + Output video height in pixels. Must be divisible by 32. + width (`int`, *optional*, defaults to `768`): + Output video width in pixels. Must be divisible by 32. + num_frames (`int`, *optional*, defaults to `121`): + Number of frames to generate. Must satisfy `(n - 1) % 8 == 0`. + frame_rate (`float`, *optional*, defaults to `24.0`): + Output frame rate (used for temporal positional encoding). + num_inference_steps (`int`, *optional*, defaults to `8`): + Number of denoising steps. Default matches the distilled model schedule. + sigmas (`List[float]`, *optional*): + Custom sigma schedule. Overrides `num_inference_steps` when set. + timesteps (`List[float]`, *optional*): + Custom timesteps schedule. Overrides `num_inference_steps` when set. + guidance_scale (`float`, *optional*, defaults to `1.0`): + Classifier-Free Guidance scale for video. Default `1.0` disables CFG (matches the distilled model). + stg_scale (`float`, *optional*, defaults to `0.0`): + Spatio-Temporal Guidance scale for video. + modality_scale (`float`, *optional*, defaults to `1.0`): + Modality isolation guidance scale for video. + guidance_rescale (`float`, *optional*, defaults to `0.0`): + Video guidance rescale factor. + spatio_temporal_guidance_blocks (`list[int]`, *optional*): + Transformer block indices at which to apply STG. + noise_scale (`float`, *optional*): + Noise scale used when preparing the initial latents. Inferred from the sigma schedule when unset. + num_videos_per_prompt (`int`, *optional*, defaults to `1`): + Number of videos to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + Random generator(s) for reproducibility. + latents (`torch.Tensor`, *optional*): + Pre-generated video latents. Pass output from [`LTX2LatentUpsamplePipeline`] here for Stage 2. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Bypasses `prompt`/`tokenizer`/`text_encoder` if supplied. + prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for `prompt_embeds`. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for `negative_prompt_embeds`. + decode_timestep, decode_noise_scale: + VAE-decode timestep conditioning (only used by VAE configs with `timestep_conditioning=True`). + use_cross_timestep (`bool`, *optional*, defaults to `False`): + Whether to use cross-modality sigma for cross-attention modulation. + output_type (`str`, *optional*, defaults to `"pt"`): + One of `"pt"`, `"np"`, or `"latent"`. `"pt"` returns a linear HDR torch tensor in `[0, ∞)` of shape + `(batch_size, num_frames, height, width, channels)`; `"np"` returns the equivalent `float32` NumPy + array; `"latent"` returns the raw denoised latents (skip the HDR decode). + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return an [`LTX2PipelineOutput`] instead of a plain tuple. + attention_kwargs, callback_on_step_end, callback_on_step_end_tensor_inputs, max_sequence_length: + Standard hooks and arguments, same as [`LTX2ICLoraPipeline`]. + + Examples: + + Returns: + [`LTX2PipelineOutput`] or `tuple`. When `return_dict=False`, returns `(frames, None)` — the audio slot is + always `None` since this pipeline is video-only. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs + self.check_inputs( + prompt=prompt, + height=height, + width=width, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + latents=latents, + spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks, + stg_scale=stg_scale, + ) + + # Video-only guidance state. + self._guidance_scale = guidance_scale + self._stg_scale = stg_scale + self._modality_scale = modality_scale + self._guidance_rescale = guidance_rescale + + self._attention_kwargs = attention_kwargs + self._interrupt = False + self._current_timestep = None + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if reference_conditions is not None and not isinstance(reference_conditions, list): + reference_conditions = [reference_conditions] + + if noise_scale is None: + noise_scale = sigmas[0] if sigmas is not None else 1.0 + + device = self._execution_device + + # 3. Prepare text embeddings + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + device=device, + ) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + tokenizer_padding_side = "left" + if getattr(self, "tokenizer", None) is not None: + tokenizer_padding_side = getattr(self.tokenizer, "padding_side", "left") + connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = self.connectors( + prompt_embeds, prompt_attention_mask, padding_side=tokenizer_padding_side + ) + + # 4. Prepare video latents + latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + if latents is not None: + logger.info( + "Got pre-supplied latents of shape %s; `latent_num_frames`, `latent_height`, and `latent_width` will" + " be inferred.", + tuple(latents.shape), + ) + _, _, latent_num_frames, latent_height, latent_width = latents.shape + + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size=batch_size * num_videos_per_prompt, + num_channels_latents=num_channels_latents, + height=height, + width=width, + num_frames=num_frames, + noise_scale=noise_scale, + dtype=torch.float32, + device=device, + generator=generator, + latents=latents, + ) + + # 4b. Prepare reference extras for HDR IC-LoRA conditioning. + extra_latents = extra_coords = extra_denoise_factors = None + if reference_conditions is not None and len(reference_conditions) > 0: + extra_latents, extra_coords, extra_denoise_factors = self.prepare_reference_latents( + reference_conditions=reference_conditions, + height=height, + width=width, + num_frames=num_frames, + reference_downscale_factor=reference_downscale_factor, + frame_rate=frame_rate, + dtype=torch.float32, + device=device, + generator=generator, + ) + + # 5. Prepare audio latents. Audio is discarded at the end, but the transformer's audio branch still runs so + # we need well-formed audio inputs. Audio guidance is fixed so no extra audio-only forward passes fire. + duration_s = num_frames / frame_rate + audio_latents_per_second = ( + self.audio_sampling_rate / self.audio_hop_length / float(self.audio_vae_temporal_compression_ratio) + ) + audio_num_frames = round(duration_s * audio_latents_per_second) + + num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64 + latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio + num_channels_latents_audio = ( + self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8 + ) + audio_latents = self.prepare_audio_latents( + batch_size * num_videos_per_prompt, + num_channels_latents=num_channels_latents_audio, + audio_latent_length=audio_num_frames, + num_mel_bins=num_mel_bins, + noise_scale=noise_scale, + dtype=torch.float32, + device=device, + generator=generator, + latents=None, + ) + + # 6. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + mu = calculate_shift( + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_image_seq_len", 1024), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.95), + self.scheduler.config.get("max_shift", 2.05), + ) + audio_scheduler = copy.deepcopy(self.scheduler) + _, _ = retrieve_timesteps( + audio_scheduler, + num_inference_steps, + device, + timesteps, + sigmas=sigmas, + mu=mu, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 7. Prepare positional coordinates + video_coords = self.transformer.rope.prepare_video_coords( + latents.shape[0], latent_num_frames, latent_height, latent_width, latents.device, fps=frame_rate + ) + audio_coords = self.transformer.audio_rope.prepare_audio_coords( + audio_latents.shape[0], audio_num_frames, audio_latents.device + ) + if self.do_classifier_free_guidance: + video_coords = video_coords.repeat((2,) + (1,) * (video_coords.ndim - 1)) + audio_coords = audio_coords.repeat((2,) + (1,) * (audio_coords.ndim - 1)) + + # 8. Denoising loop + video_seq_len = latents.shape[1] + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = latent_model_input.to(prompt_embeds.dtype) + audio_latent_model_input = ( + torch.cat([audio_latents] * 2) if self.do_classifier_free_guidance else audio_latents + ) + audio_latent_model_input = audio_latent_model_input.to(prompt_embeds.dtype) + + timestep_scalar = t.expand(latent_model_input.shape[0]) + video_timestep = timestep_scalar.unsqueeze(-1).expand(-1, video_seq_len) + + extra_timestep_in = t * extra_denoise_factors if extra_denoise_factors is not None else None + + # --- Main forward pass (cond + uncond for CFG) --- + noise_pred_video, noise_pred_audio = self._run_transformer( + latent_model_input=latent_model_input, + audio_latent_model_input=audio_latent_model_input, + video_timestep=video_timestep, + audio_timestep=timestep_scalar, + sigma=timestep_scalar, + video_coords=video_coords, + audio_coords=audio_coords, + connector_prompt_embeds=connector_prompt_embeds, + connector_audio_prompt_embeds=connector_audio_prompt_embeds, + connector_attention_mask=connector_attention_mask, + latent_num_frames=latent_num_frames, + latent_height=latent_height, + latent_width=latent_width, + frame_rate=frame_rate, + audio_num_frames=audio_num_frames, + use_cross_timestep=use_cross_timestep, + attention_kwargs=attention_kwargs, + cache_context="cond_uncond", + extra_latents=extra_latents, + extra_coords=extra_coords, + extra_timestep=extra_timestep_in, + ) + noise_pred_video = noise_pred_video.float() + + if self.do_classifier_free_guidance: + noise_pred_video_uncond_text, noise_pred_video = noise_pred_video.chunk(2) + noise_pred_video = self.convert_velocity_to_x0(latents, noise_pred_video, i, self.scheduler) + noise_pred_video_uncond_text = self.convert_velocity_to_x0( + latents, noise_pred_video_uncond_text, i, self.scheduler + ) + video_cfg_delta = (self.guidance_scale - 1) * (noise_pred_video - noise_pred_video_uncond_text) + + if self.do_spatio_temporal_guidance or self.do_modality_isolation_guidance: + if i == 0: + video_prompt_embeds = connector_prompt_embeds.chunk(2, dim=0)[1] + audio_prompt_embeds = connector_audio_prompt_embeds.chunk(2, dim=0)[1] + prompt_attn_mask = connector_attention_mask.chunk(2, dim=0)[1] + video_pos_ids = video_coords.chunk(2, dim=0)[0] + audio_pos_ids = audio_coords.chunk(2, dim=0)[0] + timestep_scalar_single = timestep_scalar.chunk(2, dim=0)[0] + video_timestep_single = timestep_scalar_single.unsqueeze(-1).expand(-1, video_seq_len) + else: + video_cfg_delta = 0 + + video_prompt_embeds = connector_prompt_embeds + audio_prompt_embeds = connector_audio_prompt_embeds + prompt_attn_mask = connector_attention_mask + video_pos_ids = video_coords + audio_pos_ids = audio_coords + + timestep_scalar_single = timestep_scalar + video_timestep_single = video_timestep + + noise_pred_video = self.convert_velocity_to_x0(latents, noise_pred_video, i, self.scheduler) + + # --- STG forward pass (video only — audio output discarded) --- + if self.do_spatio_temporal_guidance: + noise_pred_video_uncond_stg, _ = self._run_transformer( + latent_model_input=latents.to(dtype=prompt_embeds.dtype), + audio_latent_model_input=audio_latents.to(dtype=prompt_embeds.dtype), + video_timestep=video_timestep_single, + audio_timestep=timestep_scalar_single, + sigma=timestep_scalar_single, + video_coords=video_pos_ids, + audio_coords=audio_pos_ids, + connector_prompt_embeds=video_prompt_embeds, + connector_audio_prompt_embeds=audio_prompt_embeds, + connector_attention_mask=prompt_attn_mask, + latent_num_frames=latent_num_frames, + latent_height=latent_height, + latent_width=latent_width, + frame_rate=frame_rate, + audio_num_frames=audio_num_frames, + use_cross_timestep=use_cross_timestep, + attention_kwargs=attention_kwargs, + cache_context="uncond_stg", + extra_latents=extra_latents, + extra_coords=extra_coords, + extra_timestep=extra_timestep_in, + spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks, + ) + noise_pred_video_uncond_stg = noise_pred_video_uncond_stg.float() + noise_pred_video_uncond_stg = self.convert_velocity_to_x0( + latents, noise_pred_video_uncond_stg, i, self.scheduler + ) + video_stg_delta = self.stg_scale * (noise_pred_video - noise_pred_video_uncond_stg) + else: + video_stg_delta = 0 + + # --- Modality isolation guidance forward pass --- + if self.do_modality_isolation_guidance: + noise_pred_video_uncond_mod, _ = self._run_transformer( + latent_model_input=latents.to(dtype=prompt_embeds.dtype), + audio_latent_model_input=audio_latents.to(dtype=prompt_embeds.dtype), + video_timestep=video_timestep_single, + audio_timestep=timestep_scalar_single, + sigma=timestep_scalar_single, + video_coords=video_pos_ids, + audio_coords=audio_pos_ids, + connector_prompt_embeds=video_prompt_embeds, + connector_audio_prompt_embeds=audio_prompt_embeds, + connector_attention_mask=prompt_attn_mask, + latent_num_frames=latent_num_frames, + latent_height=latent_height, + latent_width=latent_width, + frame_rate=frame_rate, + audio_num_frames=audio_num_frames, + use_cross_timestep=use_cross_timestep, + attention_kwargs=attention_kwargs, + cache_context="uncond_modality", + extra_latents=extra_latents, + extra_coords=extra_coords, + extra_timestep=extra_timestep_in, + isolate_modalities=True, + ) + noise_pred_video_uncond_mod = noise_pred_video_uncond_mod.float() + noise_pred_video_uncond_mod = self.convert_velocity_to_x0( + latents, noise_pred_video_uncond_mod, i, self.scheduler + ) + video_modality_delta = (self.modality_scale - 1) * ( + noise_pred_video - noise_pred_video_uncond_mod + ) + else: + video_modality_delta = 0 + + noise_pred_video_g = noise_pred_video + video_cfg_delta + video_stg_delta + video_modality_delta + + if self.guidance_rescale > 0: + noise_pred_video = rescale_noise_cfg( + noise_pred_video_g, noise_pred_video, guidance_rescale=self.guidance_rescale + ) + else: + noise_pred_video = noise_pred_video_g + + noise_pred_video = self.convert_x0_to_velocity(latents, noise_pred_video, i, self.scheduler) + + latents = self.scheduler.step(noise_pred_video, t, latents, return_dict=False)[0] + # Step the audio scheduler so its internal state stays in sync with the video scheduler (audio + # output is discarded at the end, but keeping schedulers aligned avoids surprising behavior if the + # scheduler writes internal indices during `.step()`). + _ = audio_scheduler.step( + torch.zeros_like(audio_latents), t, audio_latents, return_dict=False + )[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + # Silence unused-variable lints for `audio_latent_model_input` / `latent_mel_bins`. + del audio_latent_model_input, latent_mel_bins + + # 9. Decode + latents = self._unpack_latents( + latents, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + + if output_type == "latent": + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + video = latents + else: + latents = latents.to(prompt_embeds.dtype) + + if not self.vae.config.timestep_conditioning: + timestep = None + else: + noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype) + if not isinstance(decode_timestep, list): + decode_timestep = [decode_timestep] * batch_size + if decode_noise_scale is None: + decode_noise_scale = decode_timestep + elif not isinstance(decode_noise_scale, list): + decode_noise_scale = [decode_noise_scale] * batch_size + + timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype) + decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[ + :, None, None, None, None + ] + latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise + + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + latents = latents.to(self.vae.dtype) + + # VAE decode returns a video tensor in the VAE's native range ([-1, 1]). + decoded = self.vae.decode(latents, timestep, return_dict=False)[0] + # HDR postprocess: LogC3 decompress → linear HDR [0, ∞). Always float32 for HDR fidelity. + video = self.hdr_video_processor.postprocess_hdr_video(decoded) + + # Format output (batch, frames, H, W, channels). + video = video.permute(0, 2, 3, 4, 1).contiguous() + if output_type == "np": + video = video.cpu().numpy() + elif output_type != "pt": + raise ValueError( + f"Unsupported `output_type` {output_type!r} for HDR pipeline. Choose one of 'pt', 'np', or" + f" 'latent'." + ) + + # Audio is always None for this video-only pipeline. + self.maybe_free_model_hooks() + + if not return_dict: + return (video, None) + + return LTX2PipelineOutput(frames=video, audio=None) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_ic_lora.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_ic_lora.py new file mode 100644 index 000000000000..024af2f3b210 --- /dev/null +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_ic_lora.py @@ -0,0 +1,2214 @@ +# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import inspect +from dataclasses import dataclass +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import FromSingleFileMixin, LTX2LoraLoaderMixin +from ...models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video +from ...models.transformers import LTX2VideoTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .connectors import LTX2TextConnectors +from .pipeline_ltx2_condition import LTX2VideoCondition +from .pipeline_output import LTX2PipelineOutput +from .vocoder import LTX2Vocoder, LTX2VocoderWithBWE + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class LTX2ReferenceCondition: + """ + A reference video condition for IC-LoRA (In-Context LoRA) conditioning. + + The reference video is encoded into latent tokens and concatenated to the noisy latent sequence during denoising. + The transformer attends to these extra tokens, allowing the IC-LoRA adapter to condition the generation on the + reference video content (e.g. style, structure, depth, pose). + + Attributes: + frames (`PIL.Image.Image` or `List[PIL.Image.Image]` or `np.ndarray` or `torch.Tensor`): + The reference video frames. Accepts any type handled by `VideoProcessor.preprocess_video`. + strength (`float`, defaults to `1.0`): + Controls how "clean" the reference tokens appear to the model. A value of `1.0` means fully clean + (timestep=0 for reference tokens), `0.0` means fully noisy (same as denoising tokens). + """ + + frames: PIL.Image.Image | list[PIL.Image.Image] | np.ndarray | torch.Tensor + strength: float = 1.0 + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import LTX2ICLoraPipeline + >>> from diffusers.pipelines.ltx2.pipeline_ltx2_ic_lora import LTX2ReferenceCondition + >>> from diffusers.pipelines.ltx2.utils import DISTILLED_SIGMA_VALUES + >>> from diffusers.pipelines.ltx2.export_utils import encode_video + >>> from diffusers.utils import load_video + + >>> pipe = LTX2ICLoraPipeline.from_pretrained( + ... "rootonchair/LTX-2-19b-distilled", torch_dtype=torch.bfloat16 + ... ) + >>> pipe.enable_sequential_cpu_offload(device="cuda") + >>> pipe.load_lora_weights("path/to/ic_lora.safetensors", adapter_name="ic_lora") + >>> pipe.set_adapters("ic_lora", 1.0) + + >>> reference_video = load_video("reference.mp4") + >>> ref_cond = LTX2ReferenceCondition(frames=reference_video, strength=1.0) + + >>> prompt = "A flowing river in a forest" + >>> frame_rate = 24.0 + >>> video, audio = pipe( + ... prompt=prompt, + ... reference_conditions=[ref_cond], + ... width=768, + ... height=512, + ... num_frames=121, + ... frame_rate=frame_rate, + ... num_inference_steps=8, + ... sigmas=DISTILLED_SIGMA_VALUES, + ... guidance_scale=1.0, + ... output_type="np", + ... return_dict=False, + ... ) + + >>> encode_video( + ... video[0], + ... fps=frame_rate, + ... audio=audio[0].float().cpu(), + ... audio_sample_rate=pipe.vocoder.config.output_sampling_rate, + ... output_path="ic_lora_output.mp4", + ... ) + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class LTX2ICLoraPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin): + r""" + Pipeline for IC-LoRA (In-Context LoRA) video generation with reference video conditioning. + + IC-LoRA conditions the generation on a reference video by encoding it into latent tokens and concatenating them + to the noisy latent sequence during denoising. The IC-LoRA adapter (loaded as a standard LoRA) learns to use this + in-context reference to guide generation (e.g. for style transfer, depth-conditioned generation, etc.). + + This pipeline also supports frame-level conditioning via the `conditions` parameter (same as + [`LTX2ConditionPipeline`]), allowing both reference video and frame conditions to be used together. + + Two-stage inference is supported through separate calls to `__call__`: + - **Stage 1**: Generate at target resolution with IC-LoRA conditioning (`output_type="latent"`). + - **Stage 2**: Upsample via [`LTX2LatentUpsamplePipeline`], then refine with a distilled LoRA (no IC-LoRA + reference conditioning needed for Stage 2). + + Reference: https://github.com/Lightricks/LTX-Video + + Args: + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLLTX2Video`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + audio_vae ([`AutoencoderKLLTX2Audio`]): + Audio VAE to encode and decode audio spectrograms. + text_encoder ([`Gemma3ForConditionalGeneration`]): + Text encoder model. + tokenizer (`GemmaTokenizer` or `GemmaTokenizerFast`): + Tokenizer for the text encoder. + connectors ([`LTX2TextConnectors`]): + Text connector stack used to adapt text encoder hidden states for the video and audio branches. + transformer ([`LTX2VideoTransformer3DModel`]): + Conditional Transformer architecture to denoise the encoded video latents. + vocoder ([`LTX2Vocoder`] or [`LTX2VocoderWithBWE`]): + Vocoder to convert mel spectrograms to audio waveforms. + """ + + model_cpu_offload_seq = "text_encoder->connectors->transformer->vae->audio_vae->vocoder" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLLTX2Video, + audio_vae: AutoencoderKLLTX2Audio, + text_encoder: Gemma3ForConditionalGeneration, + tokenizer: GemmaTokenizer | GemmaTokenizerFast, + connectors: LTX2TextConnectors, + transformer: LTX2VideoTransformer3DModel, + vocoder: LTX2Vocoder | LTX2VocoderWithBWE, + ): + super().__init__() + + self.register_modules( + vae=vae, + audio_vae=audio_vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + connectors=connectors, + transformer=transformer, + vocoder=vocoder, + scheduler=scheduler, + ) + + self.vae_spatial_compression_ratio = ( + self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32 + ) + self.vae_temporal_compression_ratio = ( + self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8 + ) + self.audio_vae_mel_compression_ratio = ( + self.audio_vae.mel_compression_ratio if getattr(self, "audio_vae", None) is not None else 4 + ) + self.audio_vae_temporal_compression_ratio = ( + self.audio_vae.temporal_compression_ratio if getattr(self, "audio_vae", None) is not None else 4 + ) + self.transformer_spatial_patch_size = ( + self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1 + ) + self.transformer_temporal_patch_size = ( + self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1 + ) + + self.audio_sampling_rate = ( + self.audio_vae.config.sample_rate if getattr(self, "audio_vae", None) is not None else 16000 + ) + self.audio_hop_length = ( + self.audio_vae.config.mel_hop_length if getattr(self, "audio_vae", None) is not None else 160 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio, resample="bilinear") + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024 + ) + + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._get_gemma_prompt_embeds + def _get_gemma_prompt_embeds( + self, + prompt: str | list[str], + num_videos_per_prompt: int = 1, + max_sequence_length: int = 1024, + scale_factor: int = 8, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`str` or `torch.device`): + torch device to place the resulting embeddings on + dtype: (`torch.dtype`): + torch dtype to cast the prompt embeds to + max_sequence_length (`int`, defaults to 1024): Maximum sequence length to use for the prompt. + """ + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if getattr(self, "tokenizer", None) is not None: + # Gemma expects left padding for chat-style prompts + self.tokenizer.padding_side = "left" + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + prompt = [p.strip() for p in prompt] + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + text_input_ids = text_input_ids.to(device) + prompt_attention_mask = prompt_attention_mask.to(device) + + text_encoder_outputs = self.text_encoder( + input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True + ) + text_encoder_hidden_states = text_encoder_outputs.hidden_states + text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1) + prompt_embeds = text_encoder_hidden_states.flatten(2, 3).to(dtype=dtype) # Pack to 3D + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) + + return prompt_embeds, prompt_attention_mask + + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + max_sequence_length: int = 1024, + scale_factor: int = 8, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + scale_factor=scale_factor, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + scale_factor=scale_factor, + device=device, + dtype=dtype, + ) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + def check_inputs( + self, + prompt, + height, + width, + callback_on_step_end_tensor_inputs=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + latents=None, + audio_latents=None, + spatio_temporal_guidance_blocks=None, + stg_scale=None, + audio_stg_scale=None, + ): + if height % 32 != 0 or width % 32 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + if latents is not None and latents.ndim != 5: + raise ValueError( + f"Only unpacked (5D) video latents of shape `[batch_size, latent_channels, latent_frames," + f" latent_height, latent_width] are supported, but got {latents.ndim} dims. If you have packed (3D)" + f" latents, please unpack them (e.g. using the `_unpack_latents` method)." + ) + if audio_latents is not None and audio_latents.ndim != 4: + raise ValueError( + f"Only unpacked (4D) audio latents of shape `[batch_size, num_channels, audio_length, mel_bins] are" + f" supported, but got {audio_latents.ndim} dims. If you have packed (3D) latents, please unpack them" + f" (e.g. using the `_unpack_audio_latents` method)." + ) + + if ((stg_scale > 0.0) or (audio_stg_scale > 0.0)) and not spatio_temporal_guidance_blocks: + raise ValueError( + "Spatio-Temporal Guidance (STG) is specified but no STG blocks are supplied. Please supply a list of" + "block indices at which to apply STG in `spatio_temporal_guidance_blocks`" + ) + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_latents + def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: + # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p]. + # The patch dimensions are then permuted and collapsed into the channel dimension of shape: + # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor). + # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features + batch_size, num_channels, num_frames, height, width = latents.shape + post_patch_num_frames = num_frames // patch_size_t + post_patch_height = height // patch_size + post_patch_width = width // patch_size + latents = latents.reshape( + batch_size, + -1, + post_patch_num_frames, + patch_size_t, + post_patch_height, + patch_size, + post_patch_width, + patch_size, + ) + latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._unpack_latents + def _unpack_latents( + latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1 + ) -> torch.Tensor: + # Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions) + # are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of + # what happens in the `_pack_latents` method. + batch_size = latents.size(0) + latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size) + latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2_image2video.LTX2ImageToVideoPipeline._normalize_latents + def _normalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Normalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = (latents - latents_mean) * scaling_factor / latents_std + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._denormalize_latents + def _denormalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Denormalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = latents * latents_std / scaling_factor + latents_mean + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._create_noised_state + def _create_noised_state( + latents: torch.Tensor, noise_scale: float | torch.Tensor, generator: torch.Generator | None = None + ): + noise = randn_tensor(latents.shape, generator=generator, device=latents.device, dtype=latents.dtype) + noised_latents = noise_scale * noise + (1 - noise_scale) * latents + return noised_latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._normalize_audio_latents + def _normalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor): + latents_mean = latents_mean.to(latents.device, latents.dtype) + latents_std = latents_std.to(latents.device, latents.dtype) + return (latents - latents_mean) / latents_std + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._denormalize_audio_latents + def _denormalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor): + latents_mean = latents_mean.to(latents.device, latents.dtype) + latents_std = latents_std.to(latents.device, latents.dtype) + return (latents * latents_std) + latents_mean + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_audio_latents + def _pack_audio_latents( + latents: torch.Tensor, patch_size: int | None = None, patch_size_t: int | None = None + ) -> torch.Tensor: + # Audio latents shape: [B, C, L, M], where L is the latent audio length and M is the number of mel bins + if patch_size is not None and patch_size_t is not None: + batch_size, num_channels, latent_length, latent_mel_bins = latents.shape + post_patch_latent_length = latent_length / patch_size_t + post_patch_mel_bins = latent_mel_bins / patch_size + latents = latents.reshape( + batch_size, -1, post_patch_latent_length, patch_size_t, post_patch_mel_bins, patch_size + ) + latents = latents.permute(0, 2, 4, 1, 3, 5).flatten(3, 5).flatten(1, 2) + else: + latents = latents.transpose(1, 2).flatten(2, 3) # [B, C, L, M] --> [B, L, C * M] + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._unpack_audio_latents + def _unpack_audio_latents( + latents: torch.Tensor, + latent_length: int, + num_mel_bins: int, + patch_size: int | None = None, + patch_size_t: int | None = None, + ) -> torch.Tensor: + if patch_size is not None and patch_size_t is not None: + batch_size = latents.size(0) + latents = latents.reshape(batch_size, latent_length, num_mel_bins, -1, patch_size_t, patch_size) + latents = latents.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3) + else: + latents = latents.unflatten(2, (-1, num_mel_bins)).transpose(1, 2) + return latents + + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2_condition.LTX2ConditionPipeline.trim_conditioning_sequence + def trim_conditioning_sequence(self, start_frame: int, sequence_num_frames: int, target_num_frames: int) -> int: + """ + Trim a conditioning sequence to the allowed number of frames. + """ + scale_factor = self.vae_temporal_compression_ratio + num_frames = min(sequence_num_frames, target_num_frames - start_frame) + num_frames = (num_frames - 1) // scale_factor * scale_factor + 1 + return num_frames + + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2_condition.LTX2ConditionPipeline.preprocess_conditions + def preprocess_conditions( + self, + conditions: LTX2VideoCondition | list[LTX2VideoCondition] | None = None, + height: int = 512, + width: int = 768, + num_frames: int = 121, + device: torch.device | None = None, + ) -> tuple[list[torch.Tensor], list[float], list[int], list[int]]: + """ + Preprocesses the condition images/videos to torch tensors. + + Args: + conditions (`LTX2VideoCondition` or `List[LTX2VideoCondition]`, *optional*, defaults to `None`): + A list of image/video condition instances. + height (`int`, *optional*, defaults to `512`): + The desired height in pixels. + width (`int`, *optional*, defaults to `768`): + The desired width in pixels. + num_frames (`int`, *optional*, defaults to `121`): + The desired number of frames in the generated video. + device (`torch.device`, *optional*, defaults to `None`): + The device on which to put the preprocessed image/video tensors. + + Returns: + `Tuple[List[torch.Tensor], List[float], List[int], List[int]]`: + Returns a 4-tuple of lists of length `len(conditions)` as follows: + 1. The first list is a list of preprocessed video tensors of shape [batch_size=1, num_channels, + num_frames, height, width]. + 2. The second list is a list of conditioning strengths. + 3. The third list is a list of latent-space indices for each condition. + 4. The fourth list is a list of (trimmed) pixel-space frame counts per condition. This is needed + for keyframe coord semantics (single-pixel-frame keyframes have a clamped temporal extent). + """ + conditioning_frames, conditioning_strengths, conditioning_indices, conditioning_pixel_frames = [], [], [], [] + + if conditions is None: + conditions = [] + if isinstance(conditions, LTX2VideoCondition): + conditions = [conditions] + + frame_scale_factor = self.vae_temporal_compression_ratio + latent_num_frames = (num_frames - 1) // frame_scale_factor + 1 + for i, condition in enumerate(conditions): + if isinstance(condition.frames, PIL.Image.Image): + # Single image, convert to List[PIL.Image.Image] + video_like_cond = [condition.frames] + elif isinstance(condition.frames, np.ndarray) and condition.frames.ndim == 3: + # Image-like ndarray of shape (H, W, C), insert frame dim in first axis + video_like_cond = np.expand_dims(condition.frames, axis=0) + elif isinstance(condition.frames, torch.Tensor) and condition.frames.ndim == 3: + # Image-like tensor of shape (C, H, W), insert frame dim in first dim + video_like_cond = condition.frames.unsqueeze(0) + else: + # Treat all other as videos. Note that this means 4D ndarrays and tensors will be treated as videos of + # shape (F, H, W, C) and (F, C, H, W), respectively. + video_like_cond = condition.frames + condition_pixels = self.video_processor.preprocess_video( + video_like_cond, height, width, resize_mode="crop" + ) + + # Interpret the index as a latent index, following the original LTX-2 code. + latent_start_idx = condition.index + # Support negative latent indices (e.g. -1 for the last latent index) + if latent_start_idx < 0: + # latent_start_idx will be positive because latent_num_frames is positive + latent_start_idx = latent_start_idx % latent_num_frames + if latent_start_idx >= latent_num_frames: + logger.warning( + f"The starting latent index {latent_start_idx} of condition {i} is too big for the specified number" + f" of latent frames {latent_num_frames}. This condition will be skipped." + ) + continue + + cond_num_frames = condition_pixels.size(2) + start_idx = max((latent_start_idx - 1) * frame_scale_factor + 1, 0) + truncated_cond_frames = self.trim_conditioning_sequence(start_idx, cond_num_frames, num_frames) + condition_pixels = condition_pixels[:, :, :truncated_cond_frames] + + conditioning_frames.append(condition_pixels.to(dtype=self.vae.dtype, device=device)) + conditioning_strengths.append(condition.strength) + conditioning_indices.append(latent_start_idx) + conditioning_pixel_frames.append(truncated_cond_frames) + + return conditioning_frames, conditioning_strengths, conditioning_indices, conditioning_pixel_frames + + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2_condition.LTX2ConditionPipeline.apply_first_frame_conditioning + def apply_first_frame_conditioning( + self, + latents: torch.Tensor, + conditioning_mask: torch.Tensor, + condition_latents: list[torch.Tensor], + condition_strengths: list[float], + condition_indices: list[int], + latent_height: int, + latent_width: int, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Apply first-frame visual conditioning by overwriting tokens at the first-frame positions. + + Only conditions with `latent_idx == 0` are applied here (matching `VideoConditionByLatentIndex` in the + reference implementation). Conditions at non-zero latent indices are appended as separate keyframe tokens + via `prepare_keyframe_extras` (matching `VideoConditionByKeyframeIndex`) and are skipped here. + + Args: + latents (`torch.Tensor`): + Initial packed (patchified) latents of shape [batch_size, patch_seq_len, hidden_dim]. + conditioning_mask (`torch.Tensor`): + Initial packed (patchified) conditioning mask of shape [batch_size, patch_seq_len, 1] with values in + [0, 1] where 0 means the denoising model output will be fully used and 1 means the condition will be + fully used. + + Returns: + `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]`: + Returns a 3-tuple of tensors where: + 1. The packed video latents with first-frame conditions applied. + 2. The packed conditioning mask with first-frame strengths applied. + 3. The clean conditioning latents at first-frame positions (zeros elsewhere). + """ + clean_latents = torch.zeros_like(latents) + for cond, strength, latent_idx in zip(condition_latents, condition_strengths, condition_indices): + if latent_idx != 0: + # Non-first-frame conditions are handled as keyframe extras (appended tokens) instead. + continue + num_cond_tokens = cond.size(1) + start_token_idx = latent_idx * latent_height * latent_width + end_token_idx = start_token_idx + num_cond_tokens + + latents[:, start_token_idx:end_token_idx] = cond + conditioning_mask[:, start_token_idx:end_token_idx] = strength + clean_latents[:, start_token_idx:end_token_idx] = cond + + return latents, conditioning_mask, clean_latents + + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2_condition.LTX2ConditionPipeline._prepare_keyframe_coords + def _prepare_keyframe_coords( + self, + keyframe_latent_num_frames: int, + keyframe_latent_height: int, + keyframe_latent_width: int, + pixel_frame_idx: int, + num_pixel_frames: int, + fps: float, + device: torch.device, + ) -> torch.Tensor: + """ + Compute positional coordinates for a keyframe condition being appended as extra tokens. + + Mirrors `VideoConditionByKeyframeIndex.apply_to` in the reference implementation: + - Latent coords scaled to pixel space *without* the causal fix (since non-zero-index keyframes don't need + the first-frame causal adjustment). + - Temporal axis offset by `pixel_frame_idx` (the pixel-space index at which the keyframe appears). + - For single-pixel-frame keyframes, the per-patch temporal extent is clamped to `[idx, idx + 1)` so the + keyframe occupies a single pixel timestep rather than the VAE-scaled range. + - Temporal coords divided by `fps` to produce seconds. + """ + patch_size = self.transformer_spatial_patch_size + patch_size_t = self.transformer_temporal_patch_size + scale_factors = ( + self.vae_temporal_compression_ratio, + self.vae_spatial_compression_ratio, + self.vae_spatial_compression_ratio, + ) + + grid_f = torch.arange( + start=0, end=keyframe_latent_num_frames, step=patch_size_t, dtype=torch.float32, device=device + ) + grid_h = torch.arange( + start=0, end=keyframe_latent_height, step=patch_size, dtype=torch.float32, device=device + ) + grid_w = torch.arange( + start=0, end=keyframe_latent_width, step=patch_size, dtype=torch.float32, device=device + ) + grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing="ij") + grid = torch.stack(grid, dim=0) + + patch_size_delta = torch.tensor((patch_size_t, patch_size, patch_size), dtype=grid.dtype, device=device) + patch_ends = grid + patch_size_delta.view(3, 1, 1, 1) + + latent_coords = torch.stack([grid, patch_ends], dim=-1) # [3, N_F, N_H, N_W, 2] + latent_coords = latent_coords.flatten(1, 3) # [3, num_patches, 2] + latent_coords = latent_coords.unsqueeze(0) # [1, 3, num_patches, 2] + + scale_tensor = torch.tensor(scale_factors, device=device, dtype=latent_coords.dtype) + broadcast_shape = [1] * latent_coords.ndim + broadcast_shape[1] = -1 + pixel_coords = latent_coords * scale_tensor.view(*broadcast_shape) + + # No causal fix: keyframe coords place the keyframe at `pixel_frame_idx` without the first-frame adjustment. + pixel_coords[:, 0, :, :] = pixel_coords[:, 0, :, :] + pixel_frame_idx + + if num_pixel_frames == 1: + # Single-pixel-frame keyframe: clamp temporal extent to [idx, idx + 1). + pixel_coords[:, 0, :, 1:] = pixel_coords[:, 0, :, :1] + 1 + + pixel_coords[:, 0, :, :] = pixel_coords[:, 0, :, :] / fps + + return pixel_coords + + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2_condition.LTX2ConditionPipeline.prepare_latents + def prepare_latents( + self, + conditions: LTX2VideoCondition | list[LTX2VideoCondition] | None = None, + batch_size: int = 1, + num_channels_latents: int = 128, + height: int = 512, + width: int = 768, + num_frames: int = 121, + frame_rate: float = 24.0, + noise_scale: float = 1.0, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | None = None, + latents: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None]: + """ + Prepare noisy video latents, applying frame conditions. + + First-frame conditions (`latent_idx == 0`) are applied by overwriting tokens at the first-frame positions + (`VideoConditionByLatentIndex` semantics). Non-first-frame conditions (`latent_idx > 0`) are packaged as + keyframe extras to be appended to the latent sequence during the transformer forward pass + (`VideoConditionByKeyframeIndex` semantics). + + Returns a 4-tuple: + - `latents`: packed noisy latents (with first-frame replacement applied if applicable). + - `conditioning_mask`: packed conditioning mask (non-zero only at first-frame positions). + - `clean_latents`: clean first-frame conditions at first-frame positions (zeros elsewhere). + - `keyframe_extras`: `(keyframe_latents, keyframe_coords, keyframe_denoise_factors)` for keyframe + conditions at non-zero latent indices, or `None` if there are none. + """ + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + + shape = (batch_size, num_channels_latents, latent_num_frames, latent_height, latent_width) + mask_shape = (batch_size, 1, latent_num_frames, latent_height, latent_width) + + if latents is not None: + # Latents are expected to be unpacked (5D) with shape [B, F, C, H, W] + latents = self._normalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + else: + # NOTE: we set the initial latents to zeros rather a sample from the standard Gaussian prior because we + # will sample from the prior later once we have calculated the conditioning mask + latents = torch.zeros(shape, device=device, dtype=dtype) + + conditioning_mask = latents.new_zeros(mask_shape) + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + conditioning_mask = self._pack_latents( + conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) # [B, seq_len, 1] + + if latents.ndim != 3 or latents.shape[:2] != conditioning_mask.shape[:2]: + raise ValueError( + f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is {conditioning_mask.shape[:2] + (num_channels_latents,)}." + ) + + if isinstance(generator, list): + logger.warning( + f"{self.__class__.__name__} does not support using a list of generators. The first generator in the" + f" list will be used for all (pseudo-)random operations." + ) + generator = generator[0] + + condition_frames, condition_strengths, condition_indices, condition_pixel_frames = self.preprocess_conditions( + conditions, height, width, num_frames, device=device + ) + # Encode each condition through the VAE. We keep both the 5D latent (for coord computation) and the packed + # 3D latent (for first-frame replacement or keyframe append). + condition_latents_5d = [] + condition_latents_packed = [] + for condition_tensor in condition_frames: + condition_latent_5d = retrieve_latents( + self.vae.encode(condition_tensor), generator=generator, sample_mode="argmax" + ) + condition_latent_5d = self._normalize_latents( + condition_latent_5d, self.vae.latents_mean, self.vae.latents_std + ).to(device=device, dtype=dtype) + condition_latent_packed = self._pack_latents( + condition_latent_5d, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + condition_latents_5d.append(condition_latent_5d) + condition_latents_packed.append(condition_latent_packed) + + # First-frame conditions (latent_idx == 0): replace tokens at the first-frame positions. + # NOTE: following the I2V pipeline, we return a conditioning mask. The original LTX 2 code uses a denoising + # mask, which is the inverse of the conditioning mask (`denoise_mask = 1 - conditioning_mask`). + latents, conditioning_mask, clean_latents = self.apply_visual_conditioning( + latents, + conditioning_mask, + condition_latents_packed, + condition_strengths, + condition_indices, + latent_height=latent_height, + latent_width=latent_width, + ) + + # Non-first-frame conditions (latent_idx > 0): append as keyframe extras with offset pixel coords. + frame_scale_factor = self.vae_temporal_compression_ratio + keyframe_tokens = [] + keyframe_coords = [] + keyframe_denoise_factors = [] + for cond_5d, cond_packed, strength, latent_idx, num_pixel_frames in zip( + condition_latents_5d, + condition_latents_packed, + condition_strengths, + condition_indices, + condition_pixel_frames, + ): + if latent_idx == 0: + continue + + _, _, kf_latent_frames, kf_latent_height, kf_latent_width = cond_5d.shape + # Pixel-space frame index at which the keyframe is placed. Matches the `start_idx` formula in + # `preprocess_conditions` used for trimming. + pixel_frame_idx = (latent_idx - 1) * frame_scale_factor + 1 + + coords = self._prepare_keyframe_coords( + keyframe_latent_num_frames=kf_latent_frames, + keyframe_latent_height=kf_latent_height, + keyframe_latent_width=kf_latent_width, + pixel_frame_idx=pixel_frame_idx, + num_pixel_frames=num_pixel_frames, + fps=frame_rate, + device=device, + ) + + num_tokens = cond_packed.shape[1] + denoise_factor = torch.full( + (1, num_tokens), 1.0 - strength, device=device, dtype=torch.float32 + ) + + keyframe_tokens.append(cond_packed) + keyframe_coords.append(coords) + keyframe_denoise_factors.append(denoise_factor) + + keyframe_extras: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None + if keyframe_tokens: + keyframe_extras = ( + torch.cat(keyframe_tokens, dim=1), + torch.cat(keyframe_coords, dim=2), + torch.cat(keyframe_denoise_factors, dim=1), + ) + else: + keyframe_extras = None + + # Sample from the standard Gaussian prior (or an intermediate Gaussian distribution if noise_scale < 1.0). + noise = randn_tensor(latents.shape, generator=generator, device=latents.device, dtype=latents.dtype) + scaled_mask = (1.0 - conditioning_mask) * noise_scale + # Add noise to the `latents` so that it is at the noise level specified by `noise_scale`. + latents = noise * scaled_mask + latents * (1 - scaled_mask) + + return latents, conditioning_mask, clean_latents, keyframe_extras + + def prepare_audio_latents( + self, + batch_size: int = 1, + num_channels_latents: int = 8, + audio_latent_length: int = 1, + num_mel_bins: int = 64, + noise_scale: float = 0.0, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if latents is not None: + latents = self._pack_audio_latents(latents) + latents = self._normalize_audio_latents(latents, self.audio_vae.latents_mean, self.audio_vae.latents_std) + latents = self._create_noised_state(latents, noise_scale, generator) + return latents.to(device=device, dtype=dtype) + + latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio + + shape = (batch_size, num_channels_latents, audio_latent_length, latent_mel_bins) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_audio_latents(latents) + return latents + + def prepare_reference_latents( + self, + reference_conditions: list[LTX2ReferenceCondition], + height: int, + width: int, + num_frames: int, + reference_downscale_factor: int = 1, + frame_rate: float = 24.0, + conditioning_attention_strength: float = 1.0, + conditioning_attention_mask: torch.Tensor | None = None, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None]: + """ + Encode reference videos into packed latent tokens and compute their positional coordinates. + + Each reference video is independently encoded by the VAE, packed into tokens, and its positional coordinates + are computed with spatial scaling by `reference_downscale_factor` to match the target coordinate space. + + All reference tokens are concatenated into a single sequence to be appended to the noisy video latents + during denoising. When `conditioning_attention_strength < 1.0` or `conditioning_attention_mask` is provided, + a per-token cross-attention mask is also computed for each reference video (downsampled to the reference + video's latent dimensions) and returned so the pipeline can build a self-attention mask over the full video + sequence. + + Args: + reference_conditions (`list[LTX2ReferenceCondition]`): + The reference video conditions. + height (`int`): + Target video height in pixels (used to determine reference video preprocessing size with + `reference_downscale_factor`). + width (`int`): + Target video width in pixels. + num_frames (`int`): + Number of target video frames. + reference_downscale_factor (`int`, defaults to `1`): + Ratio between target and reference resolutions. A factor of 2 means the reference video is + preprocessed at half the target resolution. Spatial positional coordinates are scaled by this factor + to map reference tokens into the target coordinate space. + frame_rate (`float`, defaults to `24.0`): + Video frame rate (used for temporal coordinate computation). + conditioning_attention_strength (`float`, defaults to `1.0`): + Scalar in `[0, 1]` controlling how strongly reference tokens attend to noisy tokens (and vice versa) + in the self-attention mask. `1.0` means full attention (no masking), `0.0` means reference tokens + are effectively ignored by the noisy tokens. + conditioning_attention_mask (`torch.Tensor`, *optional*): + Optional pixel-space mask of shape `(1, 1, F_pix, H_pix, W_pix)` with values in `[0, 1]` that provides + spatially-varying attention strength. Downsampled to latent space per reference video and multiplied + by `conditioning_attention_strength`. + dtype (`torch.dtype`, *optional*): + Data type for the latents. + device (`torch.device`, *optional*): + Device for the latents. + generator (`torch.Generator`, *optional*): + Random generator for VAE encoding. + + Returns: + A 4-tuple of `(reference_latents, reference_coords, reference_denoise_factors, reference_cross_mask)`: + - `reference_latents`: `[1, total_ref_tokens, hidden_dim]` + - `reference_coords`: `[1, 3, total_ref_tokens, 2]` + - `reference_denoise_factors`: `[1, total_ref_tokens]` — per-token `(1 - strength)` factors + - `reference_cross_mask`: `[1, total_ref_tokens]` per-token noisy↔reference attention strengths in + `[0, 1]`, or `None` when `conditioning_attention_strength == 1.0` and no pixel-space mask is + provided (in which case attention is unmasked). + """ + ref_height = height // reference_downscale_factor + ref_width = width // reference_downscale_factor + + mask_needed = conditioning_attention_strength < 1.0 or conditioning_attention_mask is not None + + all_ref_latents = [] + all_ref_coords = [] + all_ref_denoise_factors = [] + all_ref_cross_masks = [] + + for ref_cond in reference_conditions: + # Preprocess reference video frames to the (possibly downscaled) resolution + if isinstance(ref_cond.frames, PIL.Image.Image): + video_like = [ref_cond.frames] + elif isinstance(ref_cond.frames, np.ndarray) and ref_cond.frames.ndim == 3: + video_like = np.expand_dims(ref_cond.frames, axis=0) + elif isinstance(ref_cond.frames, torch.Tensor) and ref_cond.frames.ndim == 3: + video_like = ref_cond.frames.unsqueeze(0) + else: + video_like = ref_cond.frames + + ref_pixels = self.video_processor.preprocess_video( + video_like, ref_height, ref_width, resize_mode="crop" + ) + # Trim to num_frames + ref_pixels = ref_pixels[:, :, :num_frames] + ref_pixels = ref_pixels.to(dtype=self.vae.dtype, device=device) + + # Encode through VAE + ref_latent = retrieve_latents( + self.vae.encode(ref_pixels), generator=generator, sample_mode="argmax" + ) + ref_latent = self._normalize_latents( + ref_latent, self.vae.latents_mean, self.vae.latents_std + ).to(device=device, dtype=dtype) + + # Get latent dimensions for coordinate computation + _, _, ref_latent_frames, ref_latent_height, ref_latent_width = ref_latent.shape + + # Pack into tokens + ref_latent_packed = self._pack_latents( + ref_latent, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + + # Compute positional coordinates for the reference tokens. We use the transformer's + # prepare_video_coords at the reference video's latent dimensions, then scale spatial coords + # by downscale_factor so they map to the target coordinate space. + ref_coords = self.transformer.rope.prepare_video_coords( + batch_size=1, + num_frames=ref_latent_frames, + height=ref_latent_height, + width=ref_latent_width, + device=device, + fps=frame_rate, + ) + if reference_downscale_factor != 1: + # Scale spatial coordinates (height=axis 1, width=axis 2) to match target space + ref_coords[:, 1, :, :] = ref_coords[:, 1, :, :] * reference_downscale_factor + ref_coords[:, 2, :, :] = ref_coords[:, 2, :, :] * reference_downscale_factor + + num_tokens = ref_latent_packed.shape[1] + denoise_factor = torch.full( + (1, num_tokens), 1.0 - ref_cond.strength, device=device, dtype=torch.float32 + ) + + all_ref_latents.append(ref_latent_packed) + all_ref_coords.append(ref_coords) + all_ref_denoise_factors.append(denoise_factor) + + if mask_needed: + # Per-reference cross-attention mask. Start from either a downsampled pixel-space mask or a full-1 + # tensor, then scale by conditioning_attention_strength. + if conditioning_attention_mask is not None: + ref_cross = self._downsample_mask_to_latent( + mask=conditioning_attention_mask, + latent_num_frames=ref_latent_frames, + latent_height=ref_latent_height, + latent_width=ref_latent_width, + ).to(device=device, dtype=torch.float32) + else: + ref_cross = torch.ones((1, num_tokens), device=device, dtype=torch.float32) + ref_cross = ref_cross * conditioning_attention_strength + all_ref_cross_masks.append(ref_cross) + + # Concatenate all reference tokens into a single sequence + reference_latents = torch.cat(all_ref_latents, dim=1) # [1, total_ref_tokens, D] + reference_coords = torch.cat(all_ref_coords, dim=2) # [1, 3, total_ref_tokens, 2] + reference_denoise_factors = torch.cat(all_ref_denoise_factors, dim=1) # [1, total_ref_tokens] + reference_cross_mask = torch.cat(all_ref_cross_masks, dim=1) if mask_needed else None + + return reference_latents, reference_coords, reference_denoise_factors, reference_cross_mask + + @staticmethod + def _downsample_mask_to_latent( + mask: torch.Tensor, + latent_num_frames: int, + latent_height: int, + latent_width: int, + ) -> torch.Tensor: + """ + Downsample a pixel-space attention mask to a flattened per-token latent-space mask. + + Mirrors `ICLoraPipeline._downsample_mask_to_latent` in the reference implementation: + - Spatial downsampling via `area` interpolation per frame. + - Causal temporal downsampling: the first frame is kept as-is (the VAE encodes the first frame + independently with temporal stride 1), remaining frames are downsampled by group-mean using factor + `(F_pix - 1) // (F_lat - 1)`. + - Flattened to token order `(F, H, W)` matching the patchifier. + + Args: + mask (`torch.Tensor`): + Pixel-space mask of shape `(B, 1, F_pix, H_pix, W_pix)` with values in `[0, 1]`. + latent_num_frames (`int`), latent_height (`int`), latent_width (`int`): + Target latent dimensions. + + Returns: + Flattened latent-space mask of shape `(B, latent_num_frames * latent_height * latent_width)`. + """ + if mask.ndim != 5 or mask.shape[1] != 1: + raise ValueError( + f"Expected `conditioning_attention_mask` of shape (B, 1, F, H, W), got {tuple(mask.shape)}." + ) + b, _, f_pix, _, _ = mask.shape + + # 1. Spatial downsampling (area interpolation per frame). + mask_2d = mask.reshape(b * f_pix, 1, mask.shape[-2], mask.shape[-1]) + spatial_down = torch.nn.functional.interpolate( + mask_2d, size=(latent_height, latent_width), mode="area" + ) + spatial_down = spatial_down.reshape(b, 1, f_pix, latent_height, latent_width) + + # 2. Causal temporal downsampling. + first_frame = spatial_down[:, :, :1, :, :] # (B, 1, 1, H_lat, W_lat) + if f_pix > 1 and latent_num_frames > 1: + t = (f_pix - 1) // (latent_num_frames - 1) + if (f_pix - 1) % (latent_num_frames - 1) != 0: + raise ValueError( + f"Pixel frames ({f_pix}) not compatible with latent frames ({latent_num_frames}): " + f"(f_pix - 1) must be divisible by (latent_num_frames - 1)." + ) + rest = spatial_down[:, :, 1:, :, :] + rest = rest.reshape(b, 1, latent_num_frames - 1, t, latent_height, latent_width).mean(dim=3) + latent_mask = torch.cat([first_frame, rest], dim=2) + else: + latent_mask = first_frame + + # 3. Flatten to token order (f, h, w). + return latent_mask.reshape(b, latent_num_frames * latent_height * latent_width) + + @staticmethod + def _build_video_self_attention_mask( + num_noisy_tokens: int, + extras_cross_masks: list[torch.Tensor], + device: torch.device, + dtype: torch.dtype = torch.float32, + ) -> torch.Tensor: + """ + Build the `(1, T_video, T_video)` self-attention mask over `noisy + extras` tokens, where `extras` is a + concatenation of one or more conditioning groups (e.g. keyframes, IC-LoRA references). + + Block structure (mirrors the reference `update_attention_mask` / `ConditioningItemAttentionStrengthWrapper`): + - noisy ↔ noisy: 1.0 (full attention) + - noisy ↔ group_i: `extras_cross_masks[i]` broadcast across the noisy-token axis + - group_i ↔ noisy: `extras_cross_masks[i]` broadcast across the noisy-token axis (symmetric) + - group_i ↔ group_i: 1.0 (tokens in a group fully attend to themselves) + - group_i ↔ group_j (i != j): 0.0 (different conditioning groups don't cross-attend) + + Args: + num_noisy_tokens (`int`): + Number of noisy video tokens. + extras_cross_masks (`list[torch.Tensor]`): + List of per-token cross-attention strengths, one per conditioning group. Each entry has shape + `(1, num_tokens_in_group)` with values in `[0, 1]`. Groups must appear in the same order as their + tokens in the extras block. + device, dtype: + Tensor device and dtype. + + Returns: + Multiplicative self-attention mask of shape `(1, num_noisy_tokens + sum(group_sizes), + num_noisy_tokens + sum(group_sizes))` with values in `[0, 1]`. + """ + total_extras = sum(m.shape[1] for m in extras_cross_masks) + total = num_noisy_tokens + total_extras + + # Initialize to 0 so that between-group blocks remain masked without explicit assignment. + attn_mask = torch.zeros((1, total, total), device=device, dtype=dtype) + attn_mask[:, :num_noisy_tokens, :num_noisy_tokens] = 1.0 # noisy ↔ noisy + + offset = num_noisy_tokens + for cross_mask in extras_cross_masks: + n = cross_mask.shape[1] + cross = cross_mask.to(device=device, dtype=dtype) + # noisy (rows) ↔ this group (cols) + attn_mask[:, :num_noisy_tokens, offset:offset + n] = cross.unsqueeze(1) + # this group (rows) ↔ noisy (cols) + attn_mask[:, offset:offset + n, :num_noisy_tokens] = cross.unsqueeze(2) + # this group ↔ this group (self-attention within the group) + attn_mask[:, offset:offset + n, offset:offset + n] = 1.0 + offset += n + return attn_mask + + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2_condition.LTX2ConditionPipeline.convert_velocity_to_x0 + def convert_velocity_to_x0( + self, sample: torch.Tensor, denoised_output: torch.Tensor, step_idx: int, scheduler: Any | None = None + ) -> torch.Tensor: + if scheduler is None: + scheduler = self.scheduler + + sample_x0 = sample - denoised_output * scheduler.sigmas[step_idx] + return sample_x0 + + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2_condition.LTX2ConditionPipeline.convert_x0_to_velocity + def convert_x0_to_velocity( + self, sample: torch.Tensor, denoised_output: torch.Tensor, step_idx: int, scheduler: Any | None = None + ) -> torch.Tensor: + if scheduler is None: + scheduler = self.scheduler + + sample_v = (sample - denoised_output) / scheduler.sigmas[step_idx] + return sample_v + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def stg_scale(self): + return self._stg_scale + + @property + def modality_scale(self): + return self._modality_scale + + @property + def audio_guidance_scale(self): + return self._audio_guidance_scale + + @property + def audio_guidance_rescale(self): + return self._audio_guidance_rescale + + @property + def audio_stg_scale(self): + return self._audio_stg_scale + + @property + def audio_modality_scale(self): + return self._audio_modality_scale + + @property + def do_classifier_free_guidance(self): + return (self._guidance_scale > 1.0) or (self._audio_guidance_scale > 1.0) + + @property + def do_spatio_temporal_guidance(self): + return (self._stg_scale > 0.0) or (self._audio_stg_scale > 0.0) + + @property + def do_modality_isolation_guidance(self): + return (self._modality_scale > 1.0) or (self._audio_modality_scale > 1.0) + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + def _run_transformer( + self, + latent_model_input: torch.Tensor, + audio_latent_model_input: torch.Tensor, + video_timestep: torch.Tensor, + audio_timestep: torch.Tensor, + sigma: torch.Tensor, + video_coords: torch.Tensor, + audio_coords: torch.Tensor, + connector_prompt_embeds: torch.Tensor, + connector_audio_prompt_embeds: torch.Tensor, + connector_attention_mask: torch.Tensor, + latent_num_frames: int, + latent_height: int, + latent_width: int, + frame_rate: float, + audio_num_frames: int, + use_cross_timestep: bool, + attention_kwargs: dict[str, Any] | None, + cache_context: str, + extra_latents: torch.Tensor | None = None, + extra_coords: torch.Tensor | None = None, + extra_timestep: torch.Tensor | None = None, + video_self_attention_mask: torch.Tensor | None = None, + isolate_modalities: bool = False, + spatio_temporal_guidance_blocks: list[int] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Run a single transformer forward pass, optionally concatenating extra tokens (keyframe/reference conditions) + to the video hidden states and/or applying a video self-attention mask. + + When `extra_latents` is provided, the extra tokens are concatenated to the video hidden states, video coords, + and video timesteps. After the transformer forward pass, the extra tokens are stripped from the video output + so only the noisy-token predictions are returned. + + When `video_self_attention_mask` is provided (shape `(1, T_video, T_video)`), it is expanded to the current + batch size and passed to the transformer to control attention between noisy and appended extra tokens. + + Returns: + `(noise_pred_video, noise_pred_audio)` where `noise_pred_video` has the same sequence length as the input + `latent_model_input` (extras are stripped). + """ + video_seq_len = latent_model_input.shape[1] + + if extra_latents is not None: + batch_size = latent_model_input.shape[0] + extra_batch = extra_latents.to(latent_model_input.dtype).expand(batch_size, -1, -1) + combined_hidden = torch.cat([latent_model_input, extra_batch], dim=1) + + extra_coords_batch = extra_coords.expand(batch_size, -1, -1, -1) + combined_coords = torch.cat([video_coords, extra_coords_batch], dim=2) + + extra_ts_batch = extra_timestep.expand(batch_size, -1) + combined_timestep = torch.cat([video_timestep, extra_ts_batch], dim=1) + else: + combined_hidden = latent_model_input + combined_coords = video_coords + combined_timestep = video_timestep + + if video_self_attention_mask is not None: + video_self_attention_mask = video_self_attention_mask.expand(combined_hidden.shape[0], -1, -1) + + with self.transformer.cache_context(cache_context): + noise_pred_combined, noise_pred_audio = self.transformer( + hidden_states=combined_hidden, + audio_hidden_states=audio_latent_model_input, + encoder_hidden_states=connector_prompt_embeds, + audio_encoder_hidden_states=connector_audio_prompt_embeds, + timestep=combined_timestep, + audio_timestep=audio_timestep, + sigma=sigma, + encoder_attention_mask=connector_attention_mask, + audio_encoder_attention_mask=connector_attention_mask, + video_self_attention_mask=video_self_attention_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + fps=frame_rate, + audio_num_frames=audio_num_frames, + video_coords=combined_coords, + audio_coords=audio_coords, + isolate_modalities=isolate_modalities, + spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks, + perturbation_mask=None, + use_cross_timestep=use_cross_timestep, + attention_kwargs=attention_kwargs, + return_dict=False, + ) + + noise_pred_video = noise_pred_combined[:, :video_seq_len] + return noise_pred_video, noise_pred_audio + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str | list[str] | None = None, + reference_conditions: LTX2ReferenceCondition | list[LTX2ReferenceCondition] | None = None, + conditions: LTX2VideoCondition | list[LTX2VideoCondition] | None = None, + reference_downscale_factor: int = 1, + conditioning_attention_strength: float = 1.0, + conditioning_attention_mask: torch.Tensor | None = None, + height: int = 512, + width: int = 768, + num_frames: int = 121, + frame_rate: float = 24.0, + num_inference_steps: int = 40, + sigmas: list[float] | None = None, + timesteps: list[float] | None = None, + guidance_scale: float = 4.0, + stg_scale: float = 0.0, + modality_scale: float = 1.0, + guidance_rescale: float = 0.0, + audio_guidance_scale: float | None = None, + audio_stg_scale: float | None = None, + audio_modality_scale: float | None = None, + audio_guidance_rescale: float | None = None, + spatio_temporal_guidance_blocks: list[int] | None = None, + noise_scale: float | None = None, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + audio_latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + decode_timestep: float | list[float] = 0.0, + decode_noise_scale: float | list[float] | None = None, + use_cross_timestep: bool = False, + output_type: str = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 1024, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide video generation. Ignored when not using guidance (i.e., ignored if + `guidance_scale` is less than `1`). + reference_conditions (`LTX2ReferenceCondition` or `List[LTX2ReferenceCondition]`, *optional*): + Reference video conditions for IC-LoRA conditioning. Each reference video is encoded into latent tokens + and concatenated to the noisy latent sequence during denoising, allowing the IC-LoRA adapter to + condition the generation on the reference video content. + conditions (`LTX2VideoCondition` or `List[LTX2VideoCondition]`, *optional*): + Frame-level conditioning (same as [`LTX2ConditionPipeline`]). Conditions are inserted at specific + latent positions and blended with the denoised output during each denoising step. + reference_downscale_factor (`int`, *optional*, defaults to `1`): + Ratio between target and reference video resolutions. IC-LoRA models trained with downscaled reference + videos store this factor in their safetensors metadata (`reference_downscale_factor` key). A factor of + `2` means the reference video is preprocessed at half the target resolution and spatial positional + coordinates are scaled accordingly. + conditioning_attention_strength (`float`, *optional*, defaults to `1.0`): + Scalar in `[0, 1]` controlling how strongly noisy tokens and appended reference tokens attend to each + other in the video self-attention. `1.0` = full attention (no masking, same as the base IC-LoRA + behavior). `0.0` = reference tokens are fully masked out of the noisy-token attention (and vice + versa). Only takes effect when `reference_conditions` is provided. + conditioning_attention_mask (`torch.Tensor`, *optional*): + Optional pixel-space spatial attention mask of shape `(1, 1, F_pix, H_pix, W_pix)` with values in + `[0, 1]` that provides per-region attention strength. The mask's spatial-temporal dimensions must + match the reference video's pixel dimensions. Downsampled to latent space using VAE scale factors + (with causal temporal handling for the first frame) and multiplied by + `conditioning_attention_strength` to form the final cross-attention mask between noisy and reference + tokens. Only takes effect when `reference_conditions` is provided. + height (`int`, *optional*, defaults to `512`): + The height in pixels of the generated video. + width (`int`, *optional*, defaults to `768`): + The width in pixels of the generated video. + num_frames (`int`, *optional*, defaults to `121`): + The number of video frames to generate. + frame_rate (`float`, *optional*, defaults to `24.0`): + The frames per second (FPS) of the generated video. + num_inference_steps (`int`, *optional*, defaults to 40): + The number of denoising steps. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process. + guidance_scale (`float`, *optional*, defaults to `4.0`): + Classifier-Free Guidance scale for video. + stg_scale (`float`, *optional*, defaults to `0.0`): + Spatio-Temporal Guidance scale for video. `0.0` disables STG. + modality_scale (`float`, *optional*, defaults to `1.0`): + Modality isolation guidance scale for video. `1.0` disables modality guidance. + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor for video. + audio_guidance_scale (`float`, *optional*, defaults to `None`): + CFG scale for audio. If `None`, defaults to `guidance_scale`. + audio_stg_scale (`float`, *optional*, defaults to `None`): + STG scale for audio. If `None`, defaults to `stg_scale`. + audio_modality_scale (`float`, *optional*, defaults to `None`): + Modality guidance scale for audio. If `None`, defaults to `modality_scale`. + audio_guidance_rescale (`float`, *optional*, defaults to `None`): + Guidance rescale for audio. If `None`, defaults to `guidance_rescale`. + spatio_temporal_guidance_blocks (`list[int]`, *optional*): + Transformer block indices at which to apply STG. + noise_scale (`float`, *optional*): + Noise scale for latent initialization. If not set, inferred from the sigma schedule. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + Random generator(s) for reproducibility. + latents (`torch.Tensor`, *optional*): + Pre-generated video latents (5D unpacked). + audio_latents (`torch.Tensor`, *optional*): + Pre-generated audio latents (4D unpacked). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. + prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + decode_timestep (`float`, defaults to `0.0`): + The timestep at which generated video is decoded. + decode_noise_scale (`float`, defaults to `None`): + Noise scale at decode time. + use_cross_timestep (`bool`, *optional*, defaults to `False`): + Whether to use cross-modality sigma for cross attention modulation. `True` for LTX-2.3+. + output_type (`str`, *optional*, defaults to `"pil"`): + Output format. Choose `"pil"`, `"np"`, or `"latent"`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`LTX2PipelineOutput`] or a plain tuple. + attention_kwargs (`dict`, *optional*): + Additional kwargs passed to the attention processor. + callback_on_step_end (`Callable`, *optional*): + A function called at the end of each denoising step. + callback_on_step_end_tensor_inputs (`List`, *optional*, defaults to `["latents"]`): + Tensor inputs for the callback function. + max_sequence_length (`int`, *optional*, defaults to `1024`): + Maximum sequence length for the text prompt. + + Examples: + + Returns: + [`LTX2PipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`LTX2PipelineOutput`] is returned, otherwise a `tuple` of + `(video, audio)` is returned. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + audio_guidance_scale = audio_guidance_scale or guidance_scale + audio_stg_scale = audio_stg_scale or stg_scale + audio_modality_scale = audio_modality_scale or modality_scale + audio_guidance_rescale = audio_guidance_rescale or guidance_rescale + + # 1. Check inputs + self.check_inputs( + prompt=prompt, + height=height, + width=width, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + latents=latents, + audio_latents=audio_latents, + spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks, + stg_scale=stg_scale, + audio_stg_scale=audio_stg_scale, + ) + + # Per-modality guidance scales + self._guidance_scale = guidance_scale + self._stg_scale = stg_scale + self._modality_scale = modality_scale + self._guidance_rescale = guidance_rescale + self._audio_guidance_scale = audio_guidance_scale + self._audio_stg_scale = audio_stg_scale + self._audio_modality_scale = audio_modality_scale + self._audio_guidance_rescale = audio_guidance_rescale + + self._attention_kwargs = attention_kwargs + self._interrupt = False + self._current_timestep = None + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if conditions is not None and not isinstance(conditions, list): + conditions = [conditions] + if reference_conditions is not None and not isinstance(reference_conditions, list): + reference_conditions = [reference_conditions] + + # Infer noise scale from sigma schedule if not provided + if noise_scale is None: + noise_scale = sigmas[0] if sigmas is not None else 1.0 + + device = self._execution_device + + # 3. Prepare text embeddings + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + device=device, + ) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + tokenizer_padding_side = "left" + if getattr(self, "tokenizer", None) is not None: + tokenizer_padding_side = getattr(self.tokenizer, "padding_side", "left") + connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = self.connectors( + prompt_embeds, prompt_attention_mask, padding_side=tokenizer_padding_side + ) + + # 4. Prepare latent variables + latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + if latents is not None: + logger.info( + "Got latents of shape [batch_size, latent_dim, latent_frames, latent_height, latent_width]," + " `latent_num_frames`, `latent_height`, `latent_width` will be inferred." + ) + _, _, latent_num_frames, latent_height, latent_width = latents.shape + + num_channels_latents = self.transformer.config.in_channels + latents, conditioning_mask, clean_latents, keyframe_extras = self.prepare_latents( + conditions=conditions, + batch_size=batch_size * num_videos_per_prompt, + num_channels_latents=num_channels_latents, + height=height, + width=width, + num_frames=num_frames, + frame_rate=frame_rate, + noise_scale=noise_scale, + dtype=torch.float32, + device=device, + generator=generator, + latents=latents, + ) + has_conditions = conditions is not None and len(conditions) > 0 + if self.do_classifier_free_guidance and has_conditions: + conditioning_mask = torch.cat([conditioning_mask, conditioning_mask]) + + # 4b. Prepare reference extras for IC-LoRA conditioning. Reference extras are packaged in the same format + # as keyframe extras (tokens, coords, per-token denoise factors) so both can be concatenated into a single + # block of extra tokens before being fed to the transformer. The reference path also produces an optional + # per-token cross-attention mask when `conditioning_attention_strength < 1.0` or + # `conditioning_attention_mask` is provided. + reference_extras: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None + reference_cross_mask: torch.Tensor | None = None + if reference_conditions is not None and len(reference_conditions) > 0: + ref_latents, ref_coords, ref_denoise, reference_cross_mask = self.prepare_reference_latents( + reference_conditions=reference_conditions, + height=height, + width=width, + num_frames=num_frames, + reference_downscale_factor=reference_downscale_factor, + frame_rate=frame_rate, + conditioning_attention_strength=conditioning_attention_strength, + conditioning_attention_mask=conditioning_attention_mask, + dtype=torch.float32, + device=device, + generator=generator, + ) + reference_extras = (ref_latents, ref_coords, ref_denoise) + + # Combine keyframe extras + reference extras into a single extras block. Keyframes come first (matching + # the reference implementation's ordering in `_create_conditionings`: image conditions first, reference + # video conditions appended last). + extras_parts = [e for e in (keyframe_extras, reference_extras) if e is not None] + if extras_parts: + extra_latents_all = torch.cat([e[0] for e in extras_parts], dim=1) + extra_coords_all = torch.cat([e[1] for e in extras_parts], dim=2) + extra_denoise_factors_all = torch.cat([e[2] for e in extras_parts], dim=1) + else: + extra_latents_all = extra_coords_all = extra_denoise_factors_all = None + + # Build the video self-attention mask over `noisy + extras` when any extras group needs non-trivial + # attention strength (currently: only IC-LoRA references). Keyframes are always included with full + # cross-attention (cross_mask=1.0) so the resulting block structure correctly isolates keyframes from + # references (different-group blocks are 0). When no reference mask is needed, we leave + # `video_self_attention_mask=None` so attention is fully unmasked. + video_self_attention_mask: torch.Tensor | None = None + if reference_cross_mask is not None: + extras_cross_masks: list[torch.Tensor] = [] + if keyframe_extras is not None: + num_kf_tokens = keyframe_extras[0].shape[1] + extras_cross_masks.append(torch.ones((1, num_kf_tokens), device=device, dtype=torch.float32)) + extras_cross_masks.append(reference_cross_mask) + video_self_attention_mask = self._build_video_self_attention_mask( + num_noisy_tokens=latents.shape[1], + extras_cross_masks=extras_cross_masks, + device=device, + ) + + # 5. Prepare audio latents + duration_s = num_frames / frame_rate + audio_latents_per_second = ( + self.audio_sampling_rate / self.audio_hop_length / float(self.audio_vae_temporal_compression_ratio) + ) + audio_num_frames = round(duration_s * audio_latents_per_second) + if audio_latents is not None: + logger.info( + "Got audio_latents of shape [batch_size, num_channels, audio_num_frames, mel_bins]," + " `audio_num_frames` will be inferred." + ) + _, _, audio_num_frames, _ = audio_latents.shape + + num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64 + latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio + num_channels_latents_audio = ( + self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8 + ) + audio_latents = self.prepare_audio_latents( + batch_size * num_videos_per_prompt, + num_channels_latents=num_channels_latents_audio, + audio_latent_length=audio_num_frames, + num_mel_bins=num_mel_bins, + noise_scale=noise_scale, + dtype=torch.float32, + device=device, + generator=generator, + latents=audio_latents, + ) + + # 6. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + mu = calculate_shift( + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_image_seq_len", 1024), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.95), + self.scheduler.config.get("max_shift", 2.05), + ) + audio_scheduler = copy.deepcopy(self.scheduler) + _, _ = retrieve_timesteps( + audio_scheduler, + num_inference_steps, + device, + timesteps, + sigmas=sigmas, + mu=mu, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 7. Prepare positional coordinates + video_coords = self.transformer.rope.prepare_video_coords( + latents.shape[0], latent_num_frames, latent_height, latent_width, latents.device, fps=frame_rate + ) + audio_coords = self.transformer.audio_rope.prepare_audio_coords( + audio_latents.shape[0], audio_num_frames, audio_latents.device + ) + if self.do_classifier_free_guidance: + video_coords = video_coords.repeat((2,) + (1,) * (video_coords.ndim - 1)) + audio_coords = audio_coords.repeat((2,) + (1,) * (audio_coords.ndim - 1)) + + # 8. Denoising loop + video_seq_len = latents.shape[1] + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = latent_model_input.to(prompt_embeds.dtype) + audio_latent_model_input = ( + torch.cat([audio_latents] * 2) if self.do_classifier_free_guidance else audio_latents + ) + audio_latent_model_input = audio_latent_model_input.to(prompt_embeds.dtype) + + timestep_scalar = t.expand(latent_model_input.shape[0]) + + # Per-token video timestep: conditioned positions (from frame conditions) get timestep 0, + # unconditioned positions get the current sigma. + if has_conditions: + video_timestep = timestep_scalar.unsqueeze(-1) * (1 - conditioning_mask.squeeze(-1)) + else: + video_timestep = timestep_scalar.unsqueeze(-1).expand(-1, video_seq_len) + + # Per-token timestep for the combined extras block (keyframes + references). Each extra token's + # timestep is sigma * denoise_factor, where denoise_factor = 1 - strength (0 for fully-clean tokens). + extra_timestep_in = ( + t * extra_denoise_factors_all if extra_denoise_factors_all is not None else None + ) + + # --- Main transformer forward pass (conditional + unconditional for CFG) --- + noise_pred_video, noise_pred_audio = self._run_transformer( + latent_model_input=latent_model_input, + audio_latent_model_input=audio_latent_model_input, + video_timestep=video_timestep, + audio_timestep=timestep_scalar, + sigma=timestep_scalar, + video_coords=video_coords, + audio_coords=audio_coords, + connector_prompt_embeds=connector_prompt_embeds, + connector_audio_prompt_embeds=connector_audio_prompt_embeds, + connector_attention_mask=connector_attention_mask, + latent_num_frames=latent_num_frames, + latent_height=latent_height, + latent_width=latent_width, + frame_rate=frame_rate, + audio_num_frames=audio_num_frames, + use_cross_timestep=use_cross_timestep, + attention_kwargs=attention_kwargs, + cache_context="cond_uncond", + extra_latents=extra_latents_all, + extra_coords=extra_coords_all, + extra_timestep=extra_timestep_in, + video_self_attention_mask=video_self_attention_mask, + ) + noise_pred_video = noise_pred_video.float() + noise_pred_audio = noise_pred_audio.float() + + if self.do_classifier_free_guidance: + noise_pred_video_uncond_text, noise_pred_video = noise_pred_video.chunk(2) + noise_pred_video = self.convert_velocity_to_x0(latents, noise_pred_video, i, self.scheduler) + noise_pred_video_uncond_text = self.convert_velocity_to_x0( + latents, noise_pred_video_uncond_text, i, self.scheduler + ) + video_cfg_delta = (self.guidance_scale - 1) * (noise_pred_video - noise_pred_video_uncond_text) + + noise_pred_audio_uncond_text, noise_pred_audio = noise_pred_audio.chunk(2) + noise_pred_audio = self.convert_velocity_to_x0(audio_latents, noise_pred_audio, i, audio_scheduler) + noise_pred_audio_uncond_text = self.convert_velocity_to_x0( + audio_latents, noise_pred_audio_uncond_text, i, audio_scheduler + ) + audio_cfg_delta = (self.audio_guidance_scale - 1) * ( + noise_pred_audio - noise_pred_audio_uncond_text + ) + + if self.do_spatio_temporal_guidance or self.do_modality_isolation_guidance: + if i == 0: + video_prompt_embeds = connector_prompt_embeds.chunk(2, dim=0)[1] + audio_prompt_embeds = connector_audio_prompt_embeds.chunk(2, dim=0)[1] + prompt_attn_mask = connector_attention_mask.chunk(2, dim=0)[1] + + video_pos_ids = video_coords.chunk(2, dim=0)[0] + audio_pos_ids = audio_coords.chunk(2, dim=0)[0] + + timestep_scalar_single = timestep_scalar.chunk(2, dim=0)[0] + if has_conditions: + video_timestep_single = video_timestep.chunk(2, dim=0)[0] + else: + video_timestep_single = timestep_scalar_single.unsqueeze(-1).expand(-1, video_seq_len) + else: + video_cfg_delta = audio_cfg_delta = 0 + + video_prompt_embeds = connector_prompt_embeds + audio_prompt_embeds = connector_audio_prompt_embeds + prompt_attn_mask = connector_attention_mask + + video_pos_ids = video_coords + audio_pos_ids = audio_coords + + timestep_scalar_single = timestep_scalar + if has_conditions: + video_timestep_single = video_timestep + else: + video_timestep_single = timestep_scalar.unsqueeze(-1).expand(-1, video_seq_len) + + noise_pred_video = self.convert_velocity_to_x0(latents, noise_pred_video, i, self.scheduler) + noise_pred_audio = self.convert_velocity_to_x0(audio_latents, noise_pred_audio, i, audio_scheduler) + + # --- STG forward pass --- + if self.do_spatio_temporal_guidance: + noise_pred_video_uncond_stg, noise_pred_audio_uncond_stg = self._run_transformer( + latent_model_input=latents.to(dtype=prompt_embeds.dtype), + audio_latent_model_input=audio_latents.to(dtype=prompt_embeds.dtype), + video_timestep=video_timestep_single, + audio_timestep=timestep_scalar_single, + sigma=timestep_scalar_single, + video_coords=video_pos_ids, + audio_coords=audio_pos_ids, + connector_prompt_embeds=video_prompt_embeds, + connector_audio_prompt_embeds=audio_prompt_embeds, + connector_attention_mask=prompt_attn_mask, + latent_num_frames=latent_num_frames, + latent_height=latent_height, + latent_width=latent_width, + frame_rate=frame_rate, + audio_num_frames=audio_num_frames, + use_cross_timestep=use_cross_timestep, + attention_kwargs=attention_kwargs, + cache_context="uncond_stg", + extra_latents=extra_latents_all, + extra_coords=extra_coords_all, + extra_timestep=extra_timestep_in, + video_self_attention_mask=video_self_attention_mask, + spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks, + ) + noise_pred_video_uncond_stg = noise_pred_video_uncond_stg.float() + noise_pred_audio_uncond_stg = noise_pred_audio_uncond_stg.float() + noise_pred_video_uncond_stg = self.convert_velocity_to_x0( + latents, noise_pred_video_uncond_stg, i, self.scheduler + ) + noise_pred_audio_uncond_stg = self.convert_velocity_to_x0( + audio_latents, noise_pred_audio_uncond_stg, i, audio_scheduler + ) + + video_stg_delta = self.stg_scale * (noise_pred_video - noise_pred_video_uncond_stg) + audio_stg_delta = self.audio_stg_scale * (noise_pred_audio - noise_pred_audio_uncond_stg) + else: + video_stg_delta = audio_stg_delta = 0 + + # --- Modality isolation guidance forward pass --- + if self.do_modality_isolation_guidance: + noise_pred_video_uncond_mod, noise_pred_audio_uncond_mod = self._run_transformer( + latent_model_input=latents.to(dtype=prompt_embeds.dtype), + audio_latent_model_input=audio_latents.to(dtype=prompt_embeds.dtype), + video_timestep=video_timestep_single, + audio_timestep=timestep_scalar_single, + sigma=timestep_scalar_single, + video_coords=video_pos_ids, + audio_coords=audio_pos_ids, + connector_prompt_embeds=video_prompt_embeds, + connector_audio_prompt_embeds=audio_prompt_embeds, + connector_attention_mask=prompt_attn_mask, + latent_num_frames=latent_num_frames, + latent_height=latent_height, + latent_width=latent_width, + frame_rate=frame_rate, + audio_num_frames=audio_num_frames, + use_cross_timestep=use_cross_timestep, + attention_kwargs=attention_kwargs, + cache_context="uncond_modality", + extra_latents=extra_latents_all, + extra_coords=extra_coords_all, + extra_timestep=extra_timestep_in, + video_self_attention_mask=video_self_attention_mask, + isolate_modalities=True, + ) + noise_pred_video_uncond_mod = noise_pred_video_uncond_mod.float() + noise_pred_audio_uncond_mod = noise_pred_audio_uncond_mod.float() + noise_pred_video_uncond_mod = self.convert_velocity_to_x0( + latents, noise_pred_video_uncond_mod, i, self.scheduler + ) + noise_pred_audio_uncond_mod = self.convert_velocity_to_x0( + audio_latents, noise_pred_audio_uncond_mod, i, audio_scheduler + ) + + video_modality_delta = (self.modality_scale - 1) * ( + noise_pred_video - noise_pred_video_uncond_mod + ) + audio_modality_delta = (self.audio_modality_scale - 1) * ( + noise_pred_audio - noise_pred_audio_uncond_mod + ) + else: + video_modality_delta = audio_modality_delta = 0 + + # Apply all guidance terms + noise_pred_video_g = noise_pred_video + video_cfg_delta + video_stg_delta + video_modality_delta + noise_pred_audio_g = noise_pred_audio + audio_cfg_delta + audio_stg_delta + audio_modality_delta + + # Apply guidance rescaling + if self.guidance_rescale > 0: + noise_pred_video = rescale_noise_cfg( + noise_pred_video_g, noise_pred_video, guidance_rescale=self.guidance_rescale + ) + else: + noise_pred_video = noise_pred_video_g + + if self.audio_guidance_rescale > 0: + noise_pred_audio = rescale_noise_cfg( + noise_pred_audio_g, noise_pred_audio, guidance_rescale=self.audio_guidance_rescale + ) + else: + noise_pred_audio = noise_pred_audio_g + + # Apply frame conditioning mask: blend denoised x0 with clean condition latents + if has_conditions: + bsz = noise_pred_video.size(0) + denoised_sample_cond = ( + noise_pred_video * (1 - conditioning_mask[:bsz]) + + clean_latents.float() * conditioning_mask[:bsz] + ).to(noise_pred_video.dtype) + noise_pred_video = denoised_sample_cond + + # Convert back to velocity for scheduler + noise_pred_video = self.convert_x0_to_velocity(latents, noise_pred_video, i, self.scheduler) + noise_pred_audio = self.convert_x0_to_velocity(audio_latents, noise_pred_audio, i, audio_scheduler) + + # Compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred_video, t, latents, return_dict=False)[0] + audio_latents = audio_scheduler.step(noise_pred_audio, t, audio_latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + # 9. Decode + latents = self._unpack_latents( + latents, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + + audio_latents = self._denormalize_audio_latents( + audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std + ) + audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins) + + if output_type == "latent": + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + video = latents + audio = audio_latents + else: + latents = latents.to(prompt_embeds.dtype) + + if not self.vae.config.timestep_conditioning: + timestep = None + else: + noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype) + if not isinstance(decode_timestep, list): + decode_timestep = [decode_timestep] * batch_size + if decode_noise_scale is None: + decode_noise_scale = decode_timestep + elif not isinstance(decode_noise_scale, list): + decode_noise_scale = [decode_noise_scale] * batch_size + + timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype) + decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[ + :, None, None, None, None + ] + latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise + + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + + latents = latents.to(self.vae.dtype) + video = self.vae.decode(latents, timestep, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + + audio_latents = audio_latents.to(self.audio_vae.dtype) + generated_mel_spectrograms = self.audio_vae.decode(audio_latents, return_dict=False)[0] + audio = self.vocoder(generated_mel_spectrograms) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video, audio) + + return LTX2PipelineOutput(frames=video, audio=audio) From f18211d702c1956c9fd9c2f50c1f56a318b7c230 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sat, 25 Apr 2026 09:10:18 +0200 Subject: [PATCH 02/14] Refactor HDR export to accept custom tone-mapping functions --- src/diffusers/pipelines/ltx2/export_utils.py | 53 +++++++++++++++----- 1 file changed, 41 insertions(+), 12 deletions(-) diff --git a/src/diffusers/pipelines/ltx2/export_utils.py b/src/diffusers/pipelines/ltx2/export_utils.py index 6b7220d08eba..fb5a48fd4e44 100644 --- a/src/diffusers/pipelines/ltx2/export_utils.py +++ b/src/diffusers/pipelines/ltx2/export_utils.py @@ -17,6 +17,7 @@ from fractions import Fraction from itertools import chain from pathlib import Path +from typing import Callable import numpy as np import PIL.Image @@ -265,12 +266,20 @@ def save_exr_tensor( out.close() -def _linear_to_srgb(x: np.ndarray) -> np.ndarray: +def simple_tone_map(x: np.ndarray) -> np.ndarray: r""" - Apply the sRGB OETF (IEC 61966-2-1) to a linear image. Input values must be in `[0, 1]`; values outside are - clipped. + Applies a very simple tone-mapping function on (scene-referred) linear light which simply clips values above `1.0` + to `1.0`. This is what the original LTX-2.X code does, but you probably want to do some non-trivial tone-mapping + to make the sample look better. + """ + return np.clip(x, 0.0, 1.0) + + +def linear_to_srgb(x: np.ndarray) -> np.ndarray: + r""" + Apply the sRGB (Rec.709) transfer function (OETF; IEC 61966-2-1) to a linear light image. Input values must be in + `[0, 1]`. """ - x = np.clip(x, 0.0, 1.0) return np.where(x <= 0.0031308, x * 12.92, 1.055 * np.power(x, 1.0 / 2.4) - 0.055) @@ -278,14 +287,15 @@ def encode_exr_sequence_to_mp4( exr_dir: str | Path, output_mp4: str | Path, frame_rate: float, + tone_mapping_fn: Callable[[np.ndarray], np.ndarray] | None = None, + tone_map_in_rgb: bool = False, crf: int = 18, ) -> None: r""" Convert a linear-HDR EXR frame sequence into an sRGB-tonemapped H.264 `.mp4` preview. Each EXR frame is loaded, clipped to `[0, 1]`, passed through the sRGB OETF (no exposure/gain, EV=0), quantized - to 8-bit BGR, and fed into a libx264 stream at the supplied `frame_rate`. This mirrors the reference CLI's - `encode_exr_sequence_to_mp4`. + to 8-bit, and fed into a libx264 stream at the supplied `frame_rate`. Args: exr_dir (`str` or `pathlib.Path`): @@ -294,14 +304,26 @@ def encode_exr_sequence_to_mp4( Output MP4 path. frame_rate (`float`): Frame rate for the output video. + tone_mapping_fn (`Callable[[np.ndarray], np.ndarray]`, *optional*, defaults to `None`): + An optional tone mapping function which takes a float32 NumPy array of shape `(H, W, 3)` containing + linear HDR values in `[0, ∞)` and returns tone-mapped linear values in `[0, 1]`. The sRGB transfer + function (OETF) is applied afterwards — do **not** pre-apply gamma inside this function. If `None`, + defaults to [`simple_tone_map`], which clips values above `1.0`. The channel ordering of the input + array is controlled by `tone_map_in_rgb`: BGR by default (matching `opencv-python` conventions), or + RGB when `tone_map_in_rgb=True` (matching `colour-science` and most other libraries). + tone_map_in_rgb (`bool`, *optional*, defaults to `False`): + When `True`, each EXR frame is converted from BGR to RGB before being passed to `tone_mapping_fn`, + and the output frame is tagged as `rgb24`. Use this when `tone_mapping_fn` expects RGB input (e.g. + operators from `colour-science`). When `False` (default), frames are passed as BGR, which is the + native format for `opencv-python` tone mappers (e.g. `cv2.createTonemapReinhard().process`). crf (`int`, *optional*, defaults to `18`): libx264 CRF quality factor. Lower values produce higher quality. Requires `opencv-python` (for EXR reading via `OPENCV_IO_ENABLE_OPENEXR`). """ import os - os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" + try: import cv2 except ImportError as e: # pragma: no cover - optional dep @@ -319,17 +341,24 @@ def encode_exr_sequence_to_mp4( stream.pix_fmt = "yuv420p" stream.options = {"crf": str(crf), "movflags": "+faststart"} + pix_fmt = "rgb24" if tone_map_in_rgb else "bgr24" + if tone_mapping_fn is None: + tone_mapping_fn = simple_tone_map + try: for i, exr_path in enumerate(exr_files): hdr = cv2.imread(str(exr_path), cv2.IMREAD_UNCHANGED).astype(np.float32) - sdr = _linear_to_srgb(np.maximum(hdr, 0.0)) - bgr8 = (sdr * 255.0 + 0.5).astype(np.uint8) + if tone_map_in_rgb: + hdr = hdr[..., ::-1] + hdr_mapped = tone_mapping_fn(hdr) + sdr = linear_to_srgb(np.maximum(hdr_mapped, 0.0)) + out8 = (sdr * 255.0 + 0.5).astype(np.uint8) if i == 0: - stream.height = bgr8.shape[0] - stream.width = bgr8.shape[1] + stream.height = out8.shape[0] + stream.width = out8.shape[1] - frame = av.VideoFrame.from_ndarray(bgr8, format="bgr24") + frame = av.VideoFrame.from_ndarray(out8, format=pix_fmt) for packet in stream.encode(frame): container.mux(packet) From f5f9656bfcf3b32bf198cfc9982ba6abc315cb90 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 27 Apr 2026 06:05:27 +0200 Subject: [PATCH 03/14] Apply parity fixes + refactor + allow HDR LoRA pipeline to accept connector embeddings --- .../pipelines/ltx2/pipeline_ltx2_condition.py | 383 ++++----- .../pipelines/ltx2/pipeline_ltx2_hdr_lora.py | 681 ++++++++++------ .../pipelines/ltx2/pipeline_ltx2_ic_lora.py | 759 ++++++++++-------- .../dummy_torch_and_transformers_objects.py | 30 + 4 files changed, 1035 insertions(+), 818 deletions(-) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py index f0e8d035d47a..56ca0add74a9 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py @@ -14,6 +14,7 @@ import copy import inspect +import math from dataclasses import dataclass from typing import Any, Callable @@ -714,22 +715,50 @@ def preprocess_conditions( frame_scale_factor = self.vae_temporal_compression_ratio latent_num_frames = (num_frames - 1) // frame_scale_factor + 1 for i, condition in enumerate(conditions): + # Create a channels-last video-like array of shape (F, H, W, C) in preparation for resizing. if isinstance(condition.frames, PIL.Image.Image): - # Single image, convert to List[PIL.Image.Image] - video_like_cond = [condition.frames] - elif isinstance(condition.frames, np.ndarray) and condition.frames.ndim == 3: - # Image-like ndarray of shape (H, W, C), insert frame dim in first axis - video_like_cond = np.expand_dims(condition.frames, axis=0) - elif isinstance(condition.frames, torch.Tensor) and condition.frames.ndim == 3: - # Image-like tensor of shape (C, H, W), insert frame dim in first dim - video_like_cond = condition.frames.unsqueeze(0) + arr = np.array(condition.frames.convert("RGB"))[None] # (1, H, W, 3) + elif isinstance(condition.frames, list) and all( + isinstance(f, PIL.Image.Image) for f in condition.frames + ): + arr = np.stack([np.array(f.convert("RGB")) for f in condition.frames]) # (F, H, W, 3) + elif isinstance(condition.frames, np.ndarray): + arr = condition.frames if condition.frames.ndim == 4 else condition.frames[None] + elif isinstance(condition.frames, torch.Tensor): + t = condition.frames if condition.frames.ndim == 4 else condition.frames.unsqueeze(0) + # Reference layout for video tensors is (F, C, H, W); convert to (F, H, W, C) for the + # resize logic, which expects channels-last. + arr = t.detach().cpu().permute(0, 2, 3, 1).numpy() else: - # Treat all other as videos. Note that this means 4D ndarrays and tensors will be treated as videos of - # shape (F, H, W, C) and (F, C, H, W), respectively. - video_like_cond = condition.frames - condition_pixels = self.video_processor.preprocess_video( - video_like_cond, height, width, resize_mode="crop" + raise TypeError( + f"Unsupported `frames` type for condition {i}: {type(condition.frames)}" + ) + + src_h, src_w = arr.shape[1], arr.shape[2] + num_cond_frames = arr.shape[0] + # Convert the NumPy array to a channels-first tensor of shape (1, C, F, H, W) + pixels = torch.from_numpy(np.ascontiguousarray(arr)).to(torch.float32) + pixels = pixels.permute(3, 0, 1, 2).unsqueeze(0).to(device) # (1, C, F, H, W) + + # Resize so the longer side fills the target, then center-crop to exact (height, width). + scale = max(height / src_h, width / src_w) + new_h = math.ceil(src_h * scale) + new_w = math.ceil(src_w * scale) + # Flatten (B, C, F, H, W) → (B*F, C, H, W) for the per-frame interpolation + pixels = pixels.permute(0, 2, 1, 3, 4).reshape(num_cond_frames, 3, src_h, src_w) + # NOTE: we avoid using VideoProcessor.preprocess_video here because it uses PIL.Image.resize under the + # hood, which will apply an anti-aliasing pre-filter when downsampling. The original LTX-2.X code simply + # uses F.interpolate, which is reproduced here. + pixels = torch.nn.functional.interpolate( + pixels, size=(new_h, new_w), mode="bilinear", align_corners=False ) + top = (new_h - height) // 2 + left = (new_w - width) // 2 + pixels = pixels[:, :, top : top + height, left : left + width] + pixels = pixels.reshape(1, num_cond_frames, 3, height, width).permute(0, 2, 1, 3, 4) + + # Map [0, 255] → [-1, 1] (VAE input convention). + condition_pixels = pixels / 127.5 - 1.0 # Interpret the index as a latent index, following the original LTX-2 code. latent_start_idx = condition.index @@ -887,16 +916,19 @@ def prepare_latents( Prepare noisy video latents, applying frame conditions. First-frame conditions (`latent_idx == 0`) are applied by overwriting tokens at the first-frame positions - (`VideoConditionByLatentIndex` semantics). Non-first-frame conditions (`latent_idx > 0`) are packaged as - keyframe extras to be appended to the latent sequence during the transformer forward pass - (`VideoConditionByKeyframeIndex` semantics). + (`VideoConditionByLatentIndex` semantics). Non-first-frame conditions (`latent_idx > 0`) are concatenated + onto the main latent sequence with per-token `conditioning_mask = strength` + (`VideoConditionByKeyframeIndex` semantics) — the denoising loop's existing timestep formula + `t * (1 - conditioning_mask)` and post-process blend + `denoised * (1 - conditioning_mask) + clean * conditioning_mask` then drive them across steps. Returns a 4-tuple: - - `latents`: packed noisy latents (with first-frame replacement applied if applicable). - - `conditioning_mask`: packed conditioning mask (non-zero only at first-frame positions). - - `clean_latents`: clean first-frame conditions at first-frame positions (zeros elsewhere). - - `keyframe_extras`: `(keyframe_latents, keyframe_coords, keyframe_denoise_factors)` for keyframe - conditions at non-zero latent indices, or `None` if there are none. + - `latents`: packed noisy latents (base tokens + any keyframe tokens cat'd onto the sequence dim). + - `conditioning_mask`: packed conditioning mask with values in `[0, 1]` — `1` at first-frame positions, + `strength` at keyframe positions, `0` elsewhere. + - `clean_latents`: clean condition values at conditioned positions (zeros elsewhere); same shape as `latents`. + - `keyframe_coords`: `[B, 3, num_keyframe_patches, 2]` positional coordinates to append to `video_coords`, + or `None` if there are no non-first-frame conditions. """ latent_height = height // self.vae_spatial_compression_ratio latent_width = width // self.vae_spatial_compression_ratio @@ -968,11 +1000,12 @@ def prepare_latents( latent_width=latent_width, ) - # Non-first-frame conditions (latent_idx > 0): append as keyframe extras with offset pixel coords. + # Non-first-frame ("keyframe") conditions (latent_idx > 0): append as extra latent tokens to the noisy latent. + # Each condition gets a all-`strength` conditioning mask and pos ids, which are also appended to those of the + # noisy latent. At each denoising step i, the keyframe conditions get an effective noise level of + # (1 - conditioning_strength) * sigma_i. frame_scale_factor = self.vae_temporal_compression_ratio - keyframe_tokens = [] - keyframe_coords = [] - keyframe_denoise_factors = [] + kf_tokens_list, kf_coords_list, kf_mask_list, kf_clean_list = [], [], [], [] for cond_5d, cond_packed, strength, latent_idx, num_pixel_frames in zip( condition_latents_5d, condition_latents_packed, @@ -984,8 +1017,6 @@ def prepare_latents( continue _, _, kf_latent_frames, kf_latent_height, kf_latent_width = cond_5d.shape - # Pixel-space frame index at which the keyframe is placed. Matches the `start_idx` formula in - # `preprocess_conditions` used for trimming. pixel_frame_idx = (latent_idx - 1) * frame_scale_factor + 1 coords = self._prepare_keyframe_coords( @@ -999,31 +1030,33 @@ def prepare_latents( ) num_tokens = cond_packed.shape[1] - denoise_factor = torch.full( - (1, num_tokens), 1.0 - strength, device=device, dtype=torch.float32 - ) - - keyframe_tokens.append(cond_packed) - keyframe_coords.append(coords) - keyframe_denoise_factors.append(denoise_factor) - - keyframe_extras: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None - if keyframe_tokens: - keyframe_extras = ( - torch.cat(keyframe_tokens, dim=1), - torch.cat(keyframe_coords, dim=2), - torch.cat(keyframe_denoise_factors, dim=1), + kf_mask = torch.full( + (cond_packed.shape[0], num_tokens, 1), + float(strength), + device=device, + dtype=conditioning_mask.dtype, ) - else: - keyframe_extras = None - # Sample from the standard Gaussian prior (or an intermediate Gaussian distribution if noise_scale < 1.0). + kf_tokens_list.append(cond_packed) + kf_clean_list.append(cond_packed) + kf_mask_list.append(kf_mask) + kf_coords_list.append(coords) + + if kf_tokens_list: + keyframe_coords = torch.cat(kf_coords_list, dim=2) + latents = torch.cat([latents, torch.cat(kf_tokens_list, dim=1)], dim=1) + conditioning_mask = torch.cat([conditioning_mask, torch.cat(kf_mask_list, dim=1)], dim=1) + clean_latents = torch.cat([clean_latents, torch.cat(kf_clean_list, dim=1)], dim=1) + + # The conditioning_mask values have the following semantics: + # - mask=0: fully noise tokens (e.g. noisy latents) + # - mask=1: keep fully clean (e.g. I2V first-frame condition, conditions with strength=1) + # - mask in (0, 1): use intermediate noise level mask * sigma_i (noise_scale == sigma_0) noise = randn_tensor(latents.shape, generator=generator, device=latents.device, dtype=latents.dtype) - scaled_mask = (1.0 - conditioning_mask) * noise_scale - # Add noise to the `latents` so that it is at the noise level specified by `noise_scale`. + scaled_mask = (1.0 - conditioning_mask) * noise_scale # noise to initial noise level `noise_scale` latents = noise * scaled_mask + latents * (1 - scaled_mask) - return latents, conditioning_mask, clean_latents, keyframe_extras + return latents, conditioning_mask, clean_latents, keyframe_coords def prepare_audio_latents( self, @@ -1046,16 +1079,15 @@ def prepare_audio_latents( latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio - shape = (batch_size, num_channels_latents, audio_latent_length, latent_mel_bins) - if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = self._pack_audio_latents(latents) + # Sample in packed shape (B, L, C * M), following the original LTX-2.X code + packed_shape = (batch_size, audio_latent_length, num_channels_latents * latent_mel_bins) + latents = randn_tensor(packed_shape, generator=generator, device=device, dtype=dtype) return latents def convert_velocity_to_x0( @@ -1136,90 +1168,6 @@ def attention_kwargs(self): def interrupt(self): return self._interrupt - def _run_transformer( - self, - latent_model_input: torch.Tensor, - audio_latent_model_input: torch.Tensor, - video_timestep: torch.Tensor, - audio_timestep: torch.Tensor, - sigma: torch.Tensor, - video_coords: torch.Tensor, - audio_coords: torch.Tensor, - connector_prompt_embeds: torch.Tensor, - connector_audio_prompt_embeds: torch.Tensor, - connector_attention_mask: torch.Tensor, - latent_num_frames: int, - latent_height: int, - latent_width: int, - frame_rate: float, - audio_num_frames: int, - use_cross_timestep: bool, - attention_kwargs: dict[str, Any] | None, - cache_context: str, - extra_latents: torch.Tensor | None = None, - extra_coords: torch.Tensor | None = None, - extra_timestep: torch.Tensor | None = None, - isolate_modalities: bool = False, - spatio_temporal_guidance_blocks: list[int] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Run a single transformer forward pass, optionally concatenating extra tokens (keyframe conditions) to the - video hidden states. - - When `extra_latents` is provided, the extra tokens are concatenated to the video hidden states, video coords, - and video timesteps. After the transformer forward pass, the extra tokens are stripped from the video output - so only the noisy-token predictions are returned. - - Returns: - `(noise_pred_video, noise_pred_audio)` where `noise_pred_video` has the same sequence length as the input - `latent_model_input` (extras are stripped). - """ - video_seq_len = latent_model_input.shape[1] - - if extra_latents is not None: - batch_size = latent_model_input.shape[0] - extra_batch = extra_latents.to(latent_model_input.dtype).expand(batch_size, -1, -1) - combined_hidden = torch.cat([latent_model_input, extra_batch], dim=1) - - extra_coords_batch = extra_coords.expand(batch_size, -1, -1, -1) - combined_coords = torch.cat([video_coords, extra_coords_batch], dim=2) - - extra_ts_batch = extra_timestep.expand(batch_size, -1) - combined_timestep = torch.cat([video_timestep, extra_ts_batch], dim=1) - else: - combined_hidden = latent_model_input - combined_coords = video_coords - combined_timestep = video_timestep - - with self.transformer.cache_context(cache_context): - noise_pred_combined, noise_pred_audio = self.transformer( - hidden_states=combined_hidden, - audio_hidden_states=audio_latent_model_input, - encoder_hidden_states=connector_prompt_embeds, - audio_encoder_hidden_states=connector_audio_prompt_embeds, - timestep=combined_timestep, - audio_timestep=audio_timestep, - sigma=sigma, - encoder_attention_mask=connector_attention_mask, - audio_encoder_attention_mask=connector_attention_mask, - num_frames=latent_num_frames, - height=latent_height, - width=latent_width, - fps=frame_rate, - audio_num_frames=audio_num_frames, - video_coords=combined_coords, - audio_coords=audio_coords, - isolate_modalities=isolate_modalities, - spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks, - perturbation_mask=None, - use_cross_timestep=use_cross_timestep, - attention_kwargs=attention_kwargs, - return_dict=False, - ) - - noise_pred_video = noise_pred_combined[:, :video_seq_len] - return noise_pred_video, noise_pred_audio - @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -1497,7 +1445,7 @@ def __call__( # video_sequence_length = latent_num_frames * latent_height * latent_width num_channels_latents = self.transformer.config.in_channels - latents, conditioning_mask, clean_latents, keyframe_extras = self.prepare_latents( + latents, conditioning_mask, clean_latents, keyframe_coords = self.prepare_latents( conditions=conditions, batch_size=batch_size * num_videos_per_prompt, num_channels_latents=num_channels_latents, @@ -1581,6 +1529,8 @@ def __call__( audio_coords = self.transformer.audio_rope.prepare_audio_coords( audio_latents.shape[0], audio_num_frames, audio_latents.device ) + if keyframe_coords is not None: + video_coords = torch.cat([video_coords, keyframe_coords], dim=2) # Duplicate the positional ids as well if using CFG if self.do_classifier_free_guidance: video_coords = video_coords.repeat((2,) + (1,) * (video_coords.ndim - 1)) # Repeat twice in batch dim @@ -1604,35 +1554,31 @@ def __call__( timestep = t.expand(latent_model_input.shape[0]) video_timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask.squeeze(-1)) - # Per-token timestep for keyframe extras: sigma * (1 - strength), i.e. 0 for fully-clean keyframes. - extra_latents_in = extra_coords_in = extra_timestep_in = None - if keyframe_extras is not None: - extra_latents_in, extra_coords_in, keyframe_denoise_factors = keyframe_extras - extra_timestep_in = t * keyframe_denoise_factors - - noise_pred_video, noise_pred_audio = self._run_transformer( - latent_model_input=latent_model_input, - audio_latent_model_input=audio_latent_model_input, - video_timestep=video_timestep, - audio_timestep=timestep, - sigma=timestep, - video_coords=video_coords, - audio_coords=audio_coords, - connector_prompt_embeds=connector_prompt_embeds, - connector_audio_prompt_embeds=connector_audio_prompt_embeds, - connector_attention_mask=connector_attention_mask, - latent_num_frames=latent_num_frames, - latent_height=latent_height, - latent_width=latent_width, - frame_rate=frame_rate, - audio_num_frames=audio_num_frames, - use_cross_timestep=use_cross_timestep, - attention_kwargs=attention_kwargs, - cache_context="cond_uncond", - extra_latents=extra_latents_in, - extra_coords=extra_coords_in, - extra_timestep=extra_timestep_in, - ) + with self.transformer.cache_context("cond_uncond"): + noise_pred_video, noise_pred_audio = self.transformer( + hidden_states=latent_model_input, + audio_hidden_states=audio_latent_model_input, + encoder_hidden_states=connector_prompt_embeds, + audio_encoder_hidden_states=connector_audio_prompt_embeds, + timestep=video_timestep, + audio_timestep=timestep, + sigma=timestep, # Used by LTX-2.3 + encoder_attention_mask=connector_attention_mask, + audio_encoder_attention_mask=connector_attention_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + fps=frame_rate, + audio_num_frames=audio_num_frames, + video_coords=video_coords, + audio_coords=audio_coords, + isolate_modalities=False, + spatio_temporal_guidance_blocks=None, + perturbation_mask=None, + use_cross_timestep=use_cross_timestep, + attention_kwargs=attention_kwargs, + return_dict=False, + ) noise_pred_video = noise_pred_video.float() noise_pred_audio = noise_pred_audio.float() @@ -1682,30 +1628,32 @@ def __call__( noise_pred_audio = self.convert_velocity_to_x0(audio_latents, noise_pred_audio, i, audio_scheduler) if self.do_spatio_temporal_guidance: - noise_pred_video_uncond_stg, noise_pred_audio_uncond_stg = self._run_transformer( - latent_model_input=latents.to(dtype=prompt_embeds.dtype), - audio_latent_model_input=audio_latents.to(dtype=prompt_embeds.dtype), - video_timestep=video_timestep, - audio_timestep=timestep, - sigma=timestep, - video_coords=video_pos_ids, - audio_coords=audio_pos_ids, - connector_prompt_embeds=video_prompt_embeds, - connector_audio_prompt_embeds=audio_prompt_embeds, - connector_attention_mask=prompt_attn_mask, - latent_num_frames=latent_num_frames, - latent_height=latent_height, - latent_width=latent_width, - frame_rate=frame_rate, - audio_num_frames=audio_num_frames, - use_cross_timestep=use_cross_timestep, - attention_kwargs=attention_kwargs, - cache_context="uncond_stg", - extra_latents=extra_latents_in, - extra_coords=extra_coords_in, - extra_timestep=extra_timestep_in, - spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks, - ) + with self.transformer.cache_context("uncond_stg"): + noise_pred_video_uncond_stg, noise_pred_audio_uncond_stg = self.transformer( + hidden_states=latents.to(dtype=prompt_embeds.dtype), + audio_hidden_states=audio_latents.to(dtype=prompt_embeds.dtype), + encoder_hidden_states=video_prompt_embeds, + audio_encoder_hidden_states=audio_prompt_embeds, + timestep=video_timestep, + audio_timestep=timestep, + sigma=timestep, # Used by LTX-2.3 + encoder_attention_mask=prompt_attn_mask, + audio_encoder_attention_mask=prompt_attn_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + fps=frame_rate, + audio_num_frames=audio_num_frames, + video_coords=video_pos_ids, + audio_coords=audio_pos_ids, + isolate_modalities=False, + # Use STG at given blocks to perturb model + spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks, + perturbation_mask=None, + use_cross_timestep=use_cross_timestep, + attention_kwargs=attention_kwargs, + return_dict=False, + ) noise_pred_video_uncond_stg = noise_pred_video_uncond_stg.float() noise_pred_audio_uncond_stg = noise_pred_audio_uncond_stg.float() noise_pred_video_uncond_stg = self.convert_velocity_to_x0( @@ -1721,30 +1669,32 @@ def __call__( video_stg_delta = audio_stg_delta = 0 if self.do_modality_isolation_guidance: - noise_pred_video_uncond_modality, noise_pred_audio_uncond_modality = self._run_transformer( - latent_model_input=latents.to(dtype=prompt_embeds.dtype), - audio_latent_model_input=audio_latents.to(dtype=prompt_embeds.dtype), - video_timestep=video_timestep, - audio_timestep=timestep, - sigma=timestep, - video_coords=video_pos_ids, - audio_coords=audio_pos_ids, - connector_prompt_embeds=video_prompt_embeds, - connector_audio_prompt_embeds=audio_prompt_embeds, - connector_attention_mask=prompt_attn_mask, - latent_num_frames=latent_num_frames, - latent_height=latent_height, - latent_width=latent_width, - frame_rate=frame_rate, - audio_num_frames=audio_num_frames, - use_cross_timestep=use_cross_timestep, - attention_kwargs=attention_kwargs, - cache_context="uncond_modality", - extra_latents=extra_latents_in, - extra_coords=extra_coords_in, - extra_timestep=extra_timestep_in, - isolate_modalities=True, - ) + with self.transformer.cache_context("uncond_modality"): + noise_pred_video_uncond_modality, noise_pred_audio_uncond_modality = self.transformer( + hidden_states=latents.to(dtype=prompt_embeds.dtype), + audio_hidden_states=audio_latents.to(dtype=prompt_embeds.dtype), + encoder_hidden_states=video_prompt_embeds, + audio_encoder_hidden_states=audio_prompt_embeds, + timestep=video_timestep, + audio_timestep=timestep, + sigma=timestep, # Used by LTX-2.3 + encoder_attention_mask=prompt_attn_mask, + audio_encoder_attention_mask=prompt_attn_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + fps=frame_rate, + audio_num_frames=audio_num_frames, + video_coords=video_pos_ids, + audio_coords=audio_pos_ids, + # Turn off A2V and V2A cross attn to isolate video and audio modalities + isolate_modalities=True, + spatio_temporal_guidance_blocks=None, + perturbation_mask=None, + use_cross_timestep=use_cross_timestep, + attention_kwargs=attention_kwargs, + return_dict=False, + ) noise_pred_video_uncond_modality = noise_pred_video_uncond_modality.float() noise_pred_audio_uncond_modality = noise_pred_audio_uncond_modality.float() noise_pred_video_uncond_modality = self.convert_velocity_to_x0( @@ -1790,7 +1740,8 @@ def __call__( # NOTE: this operation should be applied in sample (x0) space and not velocity space (which is the # space the denoising model outputs are in) denoised_sample_cond = ( - noise_pred_video * (1 - conditioning_mask[:bsz]) + clean_latents.float() * conditioning_mask[:bsz] + noise_pred_video * (1 - conditioning_mask[:bsz]) + + clean_latents * conditioning_mask[:bsz] ).to(noise_pred_video.dtype) # Convert the denoised (x0) sample back to a velocity for the scheduler @@ -1820,6 +1771,10 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() + # Remove any appended keyframe (non-first-frame) condition tokens from the final latent + base_token_count = latent_num_frames * latent_height * latent_width + latents = latents[:, :base_token_count] + latents = self._unpack_latents( latents, latent_num_frames, diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_hdr_lora.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_hdr_lora.py index 886ac094233e..80d1fecd7243 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_hdr_lora.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_hdr_lora.py @@ -150,6 +150,25 @@ def retrieve_timesteps( r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") @@ -182,11 +201,26 @@ def retrieve_timesteps( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): r""" - Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg return noise_cfg @@ -311,6 +345,18 @@ def _get_gemma_prompt_embeds( device: torch.device | None = None, dtype: torch.dtype | None = None, ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`str` or `torch.device`): + torch device to place the resulting embeddings on + dtype: (`torch.dtype`): + torch dtype to cast the prompt embeds to + max_sequence_length (`int`, defaults to 1024): Maximum sequence length to use for the prompt. + """ device = device or self._execution_device dtype = dtype or self.text_encoder.dtype @@ -318,6 +364,7 @@ def _get_gemma_prompt_embeds( batch_size = len(prompt) if getattr(self, "tokenizer", None) is not None: + # Gemma expects left padding for chat-style prompts self.tokenizer.padding_side = "left" if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token @@ -341,8 +388,9 @@ def _get_gemma_prompt_embeds( ) text_encoder_hidden_states = text_encoder_outputs.hidden_states text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1) - prompt_embeds = text_encoder_hidden_states.flatten(2, 3).to(dtype=dtype) + prompt_embeds = text_encoder_hidden_states.flatten(2, 3).to(dtype=dtype) # Pack to 3D + # duplicate text embeddings for each generation per prompt, using mps friendly method _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) @@ -368,6 +416,32 @@ def encode_prompt( device: torch.device | None = None, dtype: torch.dtype | None = None, ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt @@ -423,6 +497,8 @@ def check_inputs( negative_prompt_embeds=None, prompt_attention_mask=None, negative_prompt_attention_mask=None, + connector_video_embeds=None, + connector_audio_embeds=None, latents=None, spatio_temporal_guidance_blocks=None, stg_scale=None, @@ -443,9 +519,10 @@ def check_inputs( " only forward one of the two." ) elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) + if connector_video_embeds is None or connector_audio_embeds is None: + raise ValueError( + "Provide a `prompt`, `prompt_embeds` or `connector_video_embeds` and `connector_audio_embeds`" + ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") @@ -470,6 +547,10 @@ def check_inputs( @staticmethod # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_latents def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: + # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p]. + # The patch dimensions are then permuted and collapsed into the channel dimension of shape: + # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor). + # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features batch_size, num_channels, num_frames, height, width = latents.shape post_patch_num_frames = num_frames // patch_size_t post_patch_height = height // patch_size @@ -492,6 +573,9 @@ def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int def _unpack_latents( latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1 ) -> torch.Tensor: + # Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions) + # are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of + # what happens in the `_pack_latents` method. batch_size = latents.size(0) latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size) latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) @@ -502,6 +586,7 @@ def _unpack_latents( def _normalize_latents( latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 ) -> torch.Tensor: + # Normalize latents across the channel dimension [B, C, F, H, W] latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) latents = (latents - latents_mean) * scaling_factor / latents_std @@ -512,6 +597,7 @@ def _normalize_latents( def _denormalize_latents( latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 ) -> torch.Tensor: + # Denormalize latents across the channel dimension [B, C, F, H, W] latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) latents = latents * latents_std / scaling_factor + latents_mean @@ -545,7 +631,10 @@ def _denormalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor def _pack_audio_latents( latents: torch.Tensor, patch_size: int | None = None, patch_size_t: int | None = None ) -> torch.Tensor: + # Audio latents shape: [B, C, L, M], where L is the latent audio length and M is the number of mel bins if patch_size is not None and patch_size_t is not None: + # Packs the latents into a patch sequence of shape [B, L // p_t * M // p, C * p_t * p] (a ndim=3 tnesor). + # dim=1 is the effective audio sequence length and dim=2 is the effective audio input feature size. batch_size, num_channels, latent_length, latent_mel_bins = latents.shape post_patch_latent_length = latent_length / patch_size_t post_patch_mel_bins = latent_mel_bins / patch_size @@ -554,7 +643,9 @@ def _pack_audio_latents( ) latents = latents.permute(0, 2, 4, 1, 3, 5).flatten(3, 5).flatten(1, 2) else: - latents = latents.transpose(1, 2).flatten(2, 3) + # Packs the latents into a patch sequence of shape [B, L, C * M]. This implicitly assumes a (mel) + # patch_size of M (all mel bins constitutes a single patch) and a patch_size_t of 1. + latents = latents.transpose(1, 2).flatten(2, 3) # [B, C, L, M] --> [B, L, C * M] return latents @staticmethod @@ -566,31 +657,66 @@ def _unpack_audio_latents( patch_size: int | None = None, patch_size_t: int | None = None, ) -> torch.Tensor: + # Unpacks an audio patch sequence of shape [B, S, D] into a latent spectrogram tensor of shape [B, C, L, M], + # where L is the latent audio length and M is the number of mel bins. if patch_size is not None and patch_size_t is not None: batch_size = latents.size(0) latents = latents.reshape(batch_size, latent_length, num_mel_bins, -1, patch_size_t, patch_size) latents = latents.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3) else: + # Assume [B, S, D] = [B, L, C * M], which implies that patch_size = M and patch_size_t = 1. latents = latents.unflatten(2, (-1, num_mel_bins)).transpose(1, 2) return latents def prepare_latents( self, + reference_conditions: list[LTX2HDRReferenceCondition] | None = None, + reference_downscale_factor: int = 1, batch_size: int = 1, num_channels_latents: int = 128, height: int = 512, width: int = 768, num_frames: int = 121, + frame_rate: float = 24.0, noise_scale: float = 0.0, dtype: torch.dtype | None = None, device: torch.device | None = None, generator: torch.Generator | None = None, latents: torch.Tensor | None = None, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None, int, torch.Tensor | None]: r""" - Prepare noisy video latents. Either allocates fresh noise (Stage 1) or noises supplied latents from a - previous stage (Stage 2 after [`LTX2LatentUpsamplePipeline`]). + Prepare noisy video latents, applying HDR IC-LoRA reference-video conditioning. + + Builds a packed latent sequence in the order `[base | reference]`: + - Base: either fresh noise (Stage 1, `latents=None`) or pre-existing upsampled latents (Stage 2). + - Reference: HDR-encoded reference-video tokens appended with per-token `conditioning_mask = strength`, + following the same pattern as [`LTX2ICLoraPipeline.prepare_latents`]. (HDR LoRA does not currently + take per-frame `conditions`, so there is no first-frame / keyframe block in between.) + + Returns a 6-tuple matching [`LTX2ICLoraPipeline.prepare_latents`]: + - `latents`: packed noisy latents `(B, base + n_ref, C)`. + - `conditioning_mask`: `(B, seq_len, 1)` with `strength` at reference positions, `0` elsewhere. + - `clean_latents`: clean reference values at reference positions (zeros elsewhere); same shape as + `latents`. + - `appended_coords`: `[1, 3, n_ref, 2]` reference coordinates to concat onto `video_coords`, or + `None` when no reference conditions are provided. + - `num_ref_tokens`: count of reference tokens at the END of `latents`. + - `ref_cross_mask`: always `None` for HDR LoRA (no cross-attention masking support). """ + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + + if isinstance(generator, list): + if len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective" + f" batch size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + generator = generator[0] + + # Build the base noisy latents at the maximum sigma (zeros for Stage 1 fresh noise; normalized provided latents + # for Stage 2). The noise mixing at the bottom converts these into the right partial-denoise state. if latents is not None: if latents.ndim == 5: latents = self._normalize_latents( @@ -604,33 +730,74 @@ def prepare_latents( f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is [batch_size," f" num_seq, num_features]." ) - latents = self._create_noised_state(latents, noise_scale, generator) - return latents.to(device=device, dtype=dtype) - - height = height // self.vae_spatial_compression_ratio - width = width // self.vae_spatial_compression_ratio - num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + else: + shape = (batch_size, num_channels_latents, latent_num_frames, latent_height, latent_width) + latents = torch.zeros(shape, device=device, dtype=dtype) + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + latents = latents.to(device=device, dtype=dtype) - shape = (batch_size, num_channels_latents, num_frames, height, width) + # Build conditioning_mask and clean_latents over the base token sequence (zeros — base is unconditioned). + base_seq_len = latents.shape[1] + conditioning_mask = torch.zeros((batch_size, base_seq_len, 1), device=device, dtype=dtype) + clean_latents = torch.zeros_like(latents) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." + # Append reference tokens (if any) as a contiguous block at the end of the sequence with per-token + # `conditioning_mask = strength` and `clean_latents = encoded_ref`. + ref_coords: torch.Tensor | None = None + num_ref_tokens = 0 + if reference_conditions is not None and len(reference_conditions) > 0: + ref_latents_packed, ref_coords, _ = self._encode_reference_conditions( + reference_conditions=reference_conditions, + num_frames=num_frames, + height=height, + width=width, + reference_downscale_factor=reference_downscale_factor, + frame_rate=frame_rate, + dtype=dtype, + device=device, + generator=generator, ) + num_ref_tokens = ref_latents_packed.shape[1] + + # All reference videos preprocess to the same shape, so split tokens evenly across conditions. + n_per_ref = num_ref_tokens // len(reference_conditions) + ref_mask_chunks = [ + torch.full( + (batch_size, n_per_ref, 1), + float(ref_cond.strength), + device=device, + dtype=conditioning_mask.dtype, + ) + for ref_cond in reference_conditions + ] + ref_mask_full = torch.cat(ref_mask_chunks, dim=1) + + ref_latents_packed_b = ref_latents_packed.expand(batch_size, -1, -1) + latents = torch.cat([latents, ref_latents_packed_b], dim=1) + conditioning_mask = torch.cat([conditioning_mask, ref_mask_full], dim=1) + clean_latents = torch.cat([clean_latents, ref_latents_packed_b], dim=1) + + # HDR LoRA has no keyframe conditions, so the only appended tokens are reference tokens. + appended_coords = ref_coords + + # The conditioning_mask values have the following semantics: + # - mask=0: fully noise tokens (e.g. noisy latents) + # - mask=1: keep fully clean (e.g. I2V first-frame condition, conditions with strength=1) + # - mask in (0, 1): use intermediate noise level mask * sigma_i (noise_scale == sigma_0) + noise = randn_tensor(latents.shape, generator=generator, device=latents.device, dtype=latents.dtype) + scaled_mask = (1.0 - conditioning_mask) * noise_scale # noise to initial noise level `noise_scale` + latents = noise * scaled_mask + latents * (1 - scaled_mask) - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = self._pack_latents( - latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size - ) - return latents + return latents, conditioning_mask, clean_latents, appended_coords, num_ref_tokens, None - # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline.prepare_audio_latents + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2_condition.LTX2ConditionPipeline.prepare_audio_latents def prepare_audio_latents( self, batch_size: int = 1, num_channels_latents: int = 8, - audio_latent_length: int = 1, + audio_latent_length: int = 1, # 1 is just a dummy value num_mel_bins: int = 64, noise_scale: float = 0.0, dtype: torch.dtype | None = None, @@ -639,32 +806,26 @@ def prepare_audio_latents( latents: torch.Tensor | None = None, ) -> torch.Tensor: if latents is not None: - if latents.ndim == 4: - latents = self._pack_audio_latents(latents) - if latents.ndim != 3: - raise ValueError( - f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is [batch_size," - f" num_seq, num_features]." - ) + # latents expected to be unpacked (4D) with shape [B, C, L, M] + latents = self._pack_audio_latents(latents) latents = self._normalize_audio_latents(latents, self.audio_vae.latents_mean, self.audio_vae.latents_std) latents = self._create_noised_state(latents, noise_scale, generator) return latents.to(device=device, dtype=dtype) latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio - shape = (batch_size, num_channels_latents, audio_latent_length, latent_mel_bins) - if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = self._pack_audio_latents(latents) + # Sample in packed shape (B, L, C * M), following the original LTX-2.X code + packed_shape = (batch_size, audio_latent_length, num_channels_latents * latent_mel_bins) + latents = randn_tensor(packed_shape, generator=generator, device=device, dtype=dtype) return latents - def prepare_reference_latents( + def _encode_reference_conditions( self, reference_conditions: list[LTX2HDRReferenceCondition], height: int, @@ -675,16 +836,12 @@ def prepare_reference_latents( dtype: torch.dtype | None = None, device: torch.device | None = None, generator: torch.Generator | None = None, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - r""" - Encode reference videos with HDR preprocessing into packed latent tokens and compute positional coordinates. - - Each reference video is preprocessed via [`LTX2VideoHDRProcessor.preprocess_reference_video_hdr`] (reflect-pad - resize at the reference resolution), VAE-encoded, packed into tokens, and paired with positional coordinates - computed at the reference latent dimensions and scaled by `reference_downscale_factor`. + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: + """Encode HDR IC-LoRA reference videos into `(reference_latents, reference_coords, reference_cross_mask)`. - Returns a 3-tuple `(reference_latents, reference_coords, reference_denoise_factors)` with the same shapes as - [`LTX2ICLoraPipeline.prepare_reference_latents`]. + Shared encoding core used by both `prepare_latents` (which folds reference tokens into the main noisy + sequence) and the back-compat shim `prepare_reference_latents`. HDR LoRA does not currently support + cross-attention masking for reference tokens, so the third return is always `None`. """ ref_height = height // reference_downscale_factor ref_width = width // reference_downscale_factor @@ -697,7 +854,6 @@ def prepare_reference_latents( all_ref_latents = [] all_ref_coords = [] - all_ref_denoise_factors = [] for ref_cond in reference_conditions: if isinstance(ref_cond.frames, PIL.Image.Image): @@ -743,19 +899,69 @@ def prepare_reference_latents( ref_coords[:, 1, :, :] = ref_coords[:, 1, :, :] * reference_downscale_factor ref_coords[:, 2, :, :] = ref_coords[:, 2, :, :] * reference_downscale_factor - num_tokens = ref_latent_packed.shape[1] - denoise_factor = torch.full( - (1, num_tokens), 1.0 - ref_cond.strength, device=device, dtype=torch.float32 - ) - all_ref_latents.append(ref_latent_packed) all_ref_coords.append(ref_coords) - all_ref_denoise_factors.append(denoise_factor) reference_latents = torch.cat(all_ref_latents, dim=1) reference_coords = torch.cat(all_ref_coords, dim=2) - reference_denoise_factors = torch.cat(all_ref_denoise_factors, dim=1) + return reference_latents, reference_coords, None + + def prepare_reference_latents( + self, + reference_conditions: list[LTX2HDRReferenceCondition], + height: int, + width: int, + num_frames: int, + reference_downscale_factor: int = 1, + frame_rate: float = 24.0, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + r""" + Encode reference videos with HDR preprocessing into packed latent tokens and compute positional coordinates. + + Each reference video is preprocessed via [`LTX2VideoHDRProcessor.preprocess_reference_video_hdr`] (reflect-pad + resize at the reference resolution), VAE-encoded, packed into tokens, and paired with positional coordinates + computed at the reference latent dimensions and scaled by `reference_downscale_factor`. + + NOTE: As of the HDR LoRA reference-token refactor, this method is a back-compat shim — the canonical + encoding helper is `_encode_reference_conditions` and reference tokens are folded into the main noisy + sequence by `prepare_latents`. This method exists for callers that want the standalone encoding output + (e.g. for downstream parity instrumentation). The `reference_denoise_factors` it returns are derivable + as `1 - strength` per token; in the integrated path the equivalent information lives in + `conditioning_mask` produced by `prepare_latents`. + + Returns a 3-tuple `(reference_latents, reference_coords, reference_denoise_factors)` with the same shapes as + [`LTX2ICLoraPipeline.prepare_reference_latents`]. + """ + reference_latents, reference_coords, _ = self._encode_reference_conditions( + reference_conditions=reference_conditions, + height=height, + width=width, + num_frames=num_frames, + reference_downscale_factor=reference_downscale_factor, + frame_rate=frame_rate, + dtype=dtype, + device=device, + generator=generator, + ) + + # Materialize per-token denoise factors for callers that still expect the 3-tuple. Each ref video has + # `1 - strength` for all of its tokens; we rebuild this from the per-video token counts. All ref videos + # preprocess to the same shape, so total token count divides equally across them. + n_total = reference_latents.shape[1] + n_per_ref = n_total // max(len(reference_conditions), 1) + denoise_chunks = [ + torch.full( + (1, n_per_ref), 1.0 - ref_cond.strength, device=reference_latents.device, dtype=torch.float32 + ) + for ref_cond in reference_conditions + ] + reference_denoise_factors = ( + torch.cat(denoise_chunks, dim=1) if denoise_chunks else reference_latents.new_zeros((1, 0)) + ) return reference_latents, reference_coords, reference_denoise_factors # Copied from diffusers.pipelines.ltx2.pipeline_ltx2_condition.LTX2ConditionPipeline.convert_velocity_to_x0 @@ -822,84 +1028,6 @@ def attention_kwargs(self): def interrupt(self): return self._interrupt - # Copied from diffusers.pipelines.ltx2.pipeline_ltx2_ic_lora.LTX2ICLoraPipeline._run_transformer - def _run_transformer( - self, - latent_model_input: torch.Tensor, - audio_latent_model_input: torch.Tensor, - video_timestep: torch.Tensor, - audio_timestep: torch.Tensor, - sigma: torch.Tensor, - video_coords: torch.Tensor, - audio_coords: torch.Tensor, - connector_prompt_embeds: torch.Tensor, - connector_audio_prompt_embeds: torch.Tensor, - connector_attention_mask: torch.Tensor, - latent_num_frames: int, - latent_height: int, - latent_width: int, - frame_rate: float, - audio_num_frames: int, - use_cross_timestep: bool, - attention_kwargs: dict[str, Any] | None, - cache_context: str, - extra_latents: torch.Tensor | None = None, - extra_coords: torch.Tensor | None = None, - extra_timestep: torch.Tensor | None = None, - video_self_attention_mask: torch.Tensor | None = None, - isolate_modalities: bool = False, - spatio_temporal_guidance_blocks: list[int] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - video_seq_len = latent_model_input.shape[1] - - if extra_latents is not None: - batch_size = latent_model_input.shape[0] - extra_batch = extra_latents.to(latent_model_input.dtype).expand(batch_size, -1, -1) - combined_hidden = torch.cat([latent_model_input, extra_batch], dim=1) - - extra_coords_batch = extra_coords.expand(batch_size, -1, -1, -1) - combined_coords = torch.cat([video_coords, extra_coords_batch], dim=2) - - extra_ts_batch = extra_timestep.expand(batch_size, -1) - combined_timestep = torch.cat([video_timestep, extra_ts_batch], dim=1) - else: - combined_hidden = latent_model_input - combined_coords = video_coords - combined_timestep = video_timestep - - if video_self_attention_mask is not None: - video_self_attention_mask = video_self_attention_mask.expand(combined_hidden.shape[0], -1, -1) - - with self.transformer.cache_context(cache_context): - noise_pred_combined, noise_pred_audio = self.transformer( - hidden_states=combined_hidden, - audio_hidden_states=audio_latent_model_input, - encoder_hidden_states=connector_prompt_embeds, - audio_encoder_hidden_states=connector_audio_prompt_embeds, - timestep=combined_timestep, - audio_timestep=audio_timestep, - sigma=sigma, - encoder_attention_mask=connector_attention_mask, - audio_encoder_attention_mask=connector_attention_mask, - video_self_attention_mask=video_self_attention_mask, - num_frames=latent_num_frames, - height=latent_height, - width=latent_width, - fps=frame_rate, - audio_num_frames=audio_num_frames, - video_coords=combined_coords, - audio_coords=audio_coords, - isolate_modalities=isolate_modalities, - spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks, - perturbation_mask=None, - use_cross_timestep=use_cross_timestep, - attention_kwargs=attention_kwargs, - return_dict=False, - ) - - noise_pred_video = noise_pred_combined[:, :video_seq_len] - return noise_pred_video, noise_pred_audio - @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -928,6 +1056,8 @@ def __call__( prompt_attention_mask: torch.Tensor | None = None, negative_prompt_embeds: torch.Tensor | None = None, negative_prompt_attention_mask: torch.Tensor | None = None, + connector_video_embeds: torch.Tensor | None = None, + connector_audio_embeds: torch.Tensor | None = None, decode_timestep: float | list[float] = 0.0, decode_noise_scale: float | list[float] | None = None, use_cross_timestep: bool = False, @@ -991,6 +1121,12 @@ def __call__( Pre-generated negative text embeddings. negative_prompt_attention_mask (`torch.Tensor`, *optional*): Attention mask for `negative_prompt_embeds`. + connector_video_embeds (`torch.Tensor`, *optional*): + Optional pre-computed connector outputs for the video modality. Used by the HDR LoRA pipeline; if + supplied, will override any `prompt`/`prompt_embeds`. + connector_audio_embeds (`torch.Tensor`, *optional*): + Optional pre-computed connector outputs for the audio modality. Used by the HDR LoRA pipeline; if + supplied, will override any `prompt`/`prompt_embeds`. decode_timestep, decode_noise_scale: VAE-decode timestep conditioning (only used by VAE configs with `timestep_conditioning=True`). use_cross_timestep (`bool`, *optional*, defaults to `False`): @@ -1024,6 +1160,8 @@ def __call__( negative_prompt_embeds=negative_prompt_embeds, prompt_attention_mask=prompt_attention_mask, negative_prompt_attention_mask=negative_prompt_attention_mask, + connector_video_embeds=connector_video_embeds, + connector_audio_embeds=connector_audio_embeds, latents=latents, spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks, stg_scale=stg_scale, @@ -1044,8 +1182,10 @@ def __call__( batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) - else: + elif prompt_embeds is not None: batch_size = prompt_embeds.shape[0] + else: + batch_size = connector_video_embeds.shape[0] if reference_conditions is not None and not isinstance(reference_conditions, list): reference_conditions = [reference_conditions] @@ -1056,33 +1196,38 @@ def __call__( device = self._execution_device # 3. Prepare text embeddings - ( - prompt_embeds, - prompt_attention_mask, - negative_prompt_embeds, - negative_prompt_attention_mask, - ) = self.encode_prompt( - prompt=prompt, - negative_prompt=negative_prompt, - do_classifier_free_guidance=self.do_classifier_free_guidance, - num_videos_per_prompt=num_videos_per_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - prompt_attention_mask=prompt_attention_mask, - negative_prompt_attention_mask=negative_prompt_attention_mask, - max_sequence_length=max_sequence_length, - device=device, - ) - if self.do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) - - tokenizer_padding_side = "left" - if getattr(self, "tokenizer", None) is not None: - tokenizer_padding_side = getattr(self.tokenizer, "padding_side", "left") - connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = self.connectors( - prompt_embeds, prompt_attention_mask, padding_side=tokenizer_padding_side - ) + if connector_video_embeds is None or connector_audio_embeds is None: + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + device=device, + ) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + tokenizer_padding_side = "left" + if getattr(self, "tokenizer", None) is not None: + tokenizer_padding_side = getattr(self.tokenizer, "padding_side", "left") + connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = self.connectors( + prompt_embeds, prompt_attention_mask, padding_side=tokenizer_padding_side + ) + else: + connector_prompt_embeds = connector_video_embeds.to(device=device, dtype=self.transformer.dtype) + connector_audio_prompt_embeds = connector_audio_embeds.to(device=device, dtype=self.transformer.dtype) + connector_attention_mask = None # 4. Prepare video latents latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 @@ -1097,33 +1242,26 @@ def __call__( _, _, latent_num_frames, latent_height, latent_width = latents.shape num_channels_latents = self.transformer.config.in_channels - latents = self.prepare_latents( + latents, conditioning_mask, clean_latents, appended_coords, num_ref_tokens, _ = self.prepare_latents( + reference_conditions=reference_conditions, + reference_downscale_factor=reference_downscale_factor, batch_size=batch_size * num_videos_per_prompt, num_channels_latents=num_channels_latents, height=height, width=width, num_frames=num_frames, + frame_rate=frame_rate, noise_scale=noise_scale, dtype=torch.float32, device=device, generator=generator, latents=latents, ) - - # 4b. Prepare reference extras for HDR IC-LoRA conditioning. - extra_latents = extra_coords = extra_denoise_factors = None - if reference_conditions is not None and len(reference_conditions) > 0: - extra_latents, extra_coords, extra_denoise_factors = self.prepare_reference_latents( - reference_conditions=reference_conditions, - height=height, - width=width, - num_frames=num_frames, - reference_downscale_factor=reference_downscale_factor, - frame_rate=frame_rate, - dtype=torch.float32, - device=device, - generator=generator, - ) + # Track the base (non-reference) token count so we can trim the appended reference tokens off + # `latents` before unpack/decode at the end. + base_token_count = latents.shape[1] - num_ref_tokens + if self.do_classifier_free_guidance and num_ref_tokens > 0: + conditioning_mask = torch.cat([conditioning_mask, conditioning_mask]) # 5. Prepare audio latents. Audio is discarded at the end, but the transformer's audio branch still runs so # we need well-formed audio inputs. Audio guidance is fixed so no extra audio-only forward passes fire. @@ -1183,6 +1321,8 @@ def __call__( video_coords = self.transformer.rope.prepare_video_coords( latents.shape[0], latent_num_frames, latent_height, latent_width, latents.device, fps=frame_rate ) + if appended_coords is not None: + video_coords = torch.cat([video_coords, appended_coords], dim=2) audio_coords = self.transformer.audio_rope.prepare_audio_coords( audio_latents.shape[0], audio_num_frames, audio_latents.device ) @@ -1201,41 +1341,45 @@ def __call__( self._current_timestep = t latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents - latent_model_input = latent_model_input.to(prompt_embeds.dtype) + latent_model_input = latent_model_input.to(connector_prompt_embeds.dtype) audio_latent_model_input = ( torch.cat([audio_latents] * 2) if self.do_classifier_free_guidance else audio_latents ) - audio_latent_model_input = audio_latent_model_input.to(prompt_embeds.dtype) + audio_latent_model_input = audio_latent_model_input.to(connector_prompt_embeds.dtype) timestep_scalar = t.expand(latent_model_input.shape[0]) - video_timestep = timestep_scalar.unsqueeze(-1).expand(-1, video_seq_len) - - extra_timestep_in = t * extra_denoise_factors if extra_denoise_factors is not None else None + if num_ref_tokens > 0: + video_timestep = timestep_scalar.unsqueeze(-1) * (1 - conditioning_mask.squeeze(-1)) + else: + video_timestep = timestep_scalar.unsqueeze(-1).expand(-1, video_seq_len) # --- Main forward pass (cond + uncond for CFG) --- - noise_pred_video, noise_pred_audio = self._run_transformer( - latent_model_input=latent_model_input, - audio_latent_model_input=audio_latent_model_input, - video_timestep=video_timestep, - audio_timestep=timestep_scalar, - sigma=timestep_scalar, - video_coords=video_coords, - audio_coords=audio_coords, - connector_prompt_embeds=connector_prompt_embeds, - connector_audio_prompt_embeds=connector_audio_prompt_embeds, - connector_attention_mask=connector_attention_mask, - latent_num_frames=latent_num_frames, - latent_height=latent_height, - latent_width=latent_width, - frame_rate=frame_rate, - audio_num_frames=audio_num_frames, - use_cross_timestep=use_cross_timestep, - attention_kwargs=attention_kwargs, - cache_context="cond_uncond", - extra_latents=extra_latents, - extra_coords=extra_coords, - extra_timestep=extra_timestep_in, - ) + with self.transformer.cache_context("cond_uncond"): + noise_pred_video, noise_pred_audio = self.transformer( + hidden_states=latent_model_input, + audio_hidden_states=audio_latent_model_input, + encoder_hidden_states=connector_prompt_embeds, + audio_encoder_hidden_states=connector_audio_prompt_embeds, + timestep=video_timestep, + audio_timestep=timestep_scalar, + sigma=timestep_scalar, # Used by LTX-2.3 + encoder_attention_mask=connector_attention_mask, + audio_encoder_attention_mask=connector_attention_mask, + video_self_attention_mask=None, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + fps=frame_rate, + audio_num_frames=audio_num_frames, + video_coords=video_coords, + audio_coords=audio_coords, + isolate_modalities=False, + spatio_temporal_guidance_blocks=None, + perturbation_mask=None, + use_cross_timestep=use_cross_timestep, + attention_kwargs=attention_kwargs, + return_dict=False, + ) noise_pred_video = noise_pred_video.float() if self.do_classifier_free_guidance: @@ -1254,7 +1398,10 @@ def __call__( video_pos_ids = video_coords.chunk(2, dim=0)[0] audio_pos_ids = audio_coords.chunk(2, dim=0)[0] timestep_scalar_single = timestep_scalar.chunk(2, dim=0)[0] - video_timestep_single = timestep_scalar_single.unsqueeze(-1).expand(-1, video_seq_len) + if num_ref_tokens > 0: + video_timestep_single = video_timestep.chunk(2, dim=0)[0] + else: + video_timestep_single = timestep_scalar_single.unsqueeze(-1).expand(-1, video_seq_len) else: video_cfg_delta = 0 @@ -1265,36 +1412,42 @@ def __call__( audio_pos_ids = audio_coords timestep_scalar_single = timestep_scalar - video_timestep_single = video_timestep + if num_ref_tokens > 0: + video_timestep_single = video_timestep + else: + video_timestep_single = timestep_scalar.unsqueeze(-1).expand(-1, video_seq_len) noise_pred_video = self.convert_velocity_to_x0(latents, noise_pred_video, i, self.scheduler) # --- STG forward pass (video only — audio output discarded) --- if self.do_spatio_temporal_guidance: - noise_pred_video_uncond_stg, _ = self._run_transformer( - latent_model_input=latents.to(dtype=prompt_embeds.dtype), - audio_latent_model_input=audio_latents.to(dtype=prompt_embeds.dtype), - video_timestep=video_timestep_single, - audio_timestep=timestep_scalar_single, - sigma=timestep_scalar_single, - video_coords=video_pos_ids, - audio_coords=audio_pos_ids, - connector_prompt_embeds=video_prompt_embeds, - connector_audio_prompt_embeds=audio_prompt_embeds, - connector_attention_mask=prompt_attn_mask, - latent_num_frames=latent_num_frames, - latent_height=latent_height, - latent_width=latent_width, - frame_rate=frame_rate, - audio_num_frames=audio_num_frames, - use_cross_timestep=use_cross_timestep, - attention_kwargs=attention_kwargs, - cache_context="uncond_stg", - extra_latents=extra_latents, - extra_coords=extra_coords, - extra_timestep=extra_timestep_in, - spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks, - ) + with self.transformer.cache_context("uncond_stg"): + noise_pred_video_uncond_stg, noise_pred_audio_uncond_stg = self.transformer( + hidden_states=latents.to(dtype=connector_prompt_embeds.dtype), + audio_hidden_states=audio_latents.to(dtype=connector_prompt_embeds.dtype), + encoder_hidden_states=video_prompt_embeds, + audio_encoder_hidden_states=audio_prompt_embeds, + timestep=video_timestep_single, + audio_timestep=timestep_scalar_single, + sigma=timestep_scalar_single, # Used by LTX-2.3 + encoder_attention_mask=prompt_attn_mask, + audio_encoder_attention_mask=prompt_attn_mask, + video_self_attention_mask=None, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + fps=frame_rate, + audio_num_frames=audio_num_frames, + video_coords=video_pos_ids, + audio_coords=audio_pos_ids, + isolate_modalities=False, + # Use STG at given blocks to perturb model + spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks, + perturbation_mask=None, + use_cross_timestep=use_cross_timestep, + attention_kwargs=attention_kwargs, + return_dict=False, + ) noise_pred_video_uncond_stg = noise_pred_video_uncond_stg.float() noise_pred_video_uncond_stg = self.convert_velocity_to_x0( latents, noise_pred_video_uncond_stg, i, self.scheduler @@ -1305,30 +1458,33 @@ def __call__( # --- Modality isolation guidance forward pass --- if self.do_modality_isolation_guidance: - noise_pred_video_uncond_mod, _ = self._run_transformer( - latent_model_input=latents.to(dtype=prompt_embeds.dtype), - audio_latent_model_input=audio_latents.to(dtype=prompt_embeds.dtype), - video_timestep=video_timestep_single, - audio_timestep=timestep_scalar_single, - sigma=timestep_scalar_single, - video_coords=video_pos_ids, - audio_coords=audio_pos_ids, - connector_prompt_embeds=video_prompt_embeds, - connector_audio_prompt_embeds=audio_prompt_embeds, - connector_attention_mask=prompt_attn_mask, - latent_num_frames=latent_num_frames, - latent_height=latent_height, - latent_width=latent_width, - frame_rate=frame_rate, - audio_num_frames=audio_num_frames, - use_cross_timestep=use_cross_timestep, - attention_kwargs=attention_kwargs, - cache_context="uncond_modality", - extra_latents=extra_latents, - extra_coords=extra_coords, - extra_timestep=extra_timestep_in, - isolate_modalities=True, - ) + with self.transformer.cache_context("uncond_modality"): + noise_pred_video_uncond_mod, noise_pred_audio_uncond_mod = self.transformer( + hidden_states=latents.to(dtype=connector_prompt_embeds.dtype), + audio_hidden_states=audio_latents.to(dtype=connector_prompt_embeds.dtype), + encoder_hidden_states=video_prompt_embeds, + audio_encoder_hidden_states=audio_prompt_embeds, + timestep=video_timestep_single, + audio_timestep=timestep_scalar_single, + sigma=timestep_scalar_single, # Used by LTX-2.3 + encoder_attention_mask=prompt_attn_mask, + audio_encoder_attention_mask=prompt_attn_mask, + video_self_attention_mask=None, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + fps=frame_rate, + audio_num_frames=audio_num_frames, + video_coords=video_pos_ids, + audio_coords=audio_pos_ids, + # Turn off A2V and V2A cross attn to isolate video and audio modalities + isolate_modalities=True, + spatio_temporal_guidance_blocks=None, + perturbation_mask=None, + use_cross_timestep=use_cross_timestep, + attention_kwargs=attention_kwargs, + return_dict=False, + ) noise_pred_video_uncond_mod = noise_pred_video_uncond_mod.float() noise_pred_video_uncond_mod = self.convert_velocity_to_x0( latents, noise_pred_video_uncond_mod, i, self.scheduler @@ -1348,6 +1504,15 @@ def __call__( else: noise_pred_video = noise_pred_video_g + # Apply the conditioning mask to apply the reference conditions at the specified strength. + if num_ref_tokens > 0: + bsz = noise_pred_video.size(0) + denoised_sample_cond = ( + noise_pred_video * (1 - conditioning_mask[:bsz]) + + clean_latents.float() * conditioning_mask[:bsz] + ).to(noise_pred_video.dtype) + noise_pred_video = denoised_sample_cond + noise_pred_video = self.convert_x0_to_velocity(latents, noise_pred_video, i, self.scheduler) latents = self.scheduler.step(noise_pred_video, t, latents, return_dict=False)[0] @@ -1377,6 +1542,8 @@ def __call__( del audio_latent_model_input, latent_mel_bins # 9. Decode + # Trim any appended reference tokens from the latents to recover the generated video only. + latents = latents[:, :base_token_count] latents = self._unpack_latents( latents, latent_num_frames, @@ -1392,7 +1559,7 @@ def __call__( ) video = latents else: - latents = latents.to(prompt_embeds.dtype) + latents = latents.to(connector_prompt_embeds.dtype) if not self.vae.config.timestep_conditioning: timestep = None diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_ic_lora.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_ic_lora.py index 024af2f3b210..01ad07babf77 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_ic_lora.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_ic_lora.py @@ -14,6 +14,7 @@ import copy import inspect +import math from dataclasses import dataclass from typing import Any, Callable @@ -644,6 +645,8 @@ def _pack_audio_latents( ) -> torch.Tensor: # Audio latents shape: [B, C, L, M], where L is the latent audio length and M is the number of mel bins if patch_size is not None and patch_size_t is not None: + # Packs the latents into a patch sequence of shape [B, L // p_t * M // p, C * p_t * p] (a ndim=3 tnesor). + # dim=1 is the effective audio sequence length and dim=2 is the effective audio input feature size. batch_size, num_channels, latent_length, latent_mel_bins = latents.shape post_patch_latent_length = latent_length / patch_size_t post_patch_mel_bins = latent_mel_bins / patch_size @@ -652,6 +655,8 @@ def _pack_audio_latents( ) latents = latents.permute(0, 2, 4, 1, 3, 5).flatten(3, 5).flatten(1, 2) else: + # Packs the latents into a patch sequence of shape [B, L, C * M]. This implicitly assumes a (mel) + # patch_size of M (all mel bins constitutes a single patch) and a patch_size_t of 1. latents = latents.transpose(1, 2).flatten(2, 3) # [B, C, L, M] --> [B, L, C * M] return latents @@ -664,11 +669,14 @@ def _unpack_audio_latents( patch_size: int | None = None, patch_size_t: int | None = None, ) -> torch.Tensor: + # Unpacks an audio patch sequence of shape [B, S, D] into a latent spectrogram tensor of shape [B, C, L, M], + # where L is the latent audio length and M is the number of mel bins. if patch_size is not None and patch_size_t is not None: batch_size = latents.size(0) latents = latents.reshape(batch_size, latent_length, num_mel_bins, -1, patch_size_t, patch_size) latents = latents.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3) else: + # Assume [B, S, D] = [B, L, C * M], which implies that patch_size = M and patch_size_t = 1. latents = latents.unflatten(2, (-1, num_mel_bins)).transpose(1, 2) return latents @@ -676,9 +684,17 @@ def _unpack_audio_latents( def trim_conditioning_sequence(self, start_frame: int, sequence_num_frames: int, target_num_frames: int) -> int: """ Trim a conditioning sequence to the allowed number of frames. + + Args: + start_frame (int): The target frame number of the first frame in the sequence. + sequence_num_frames (int): The number of frames in the sequence. + target_num_frames (int): The target number of frames in the generated video. + Returns: + int: updated sequence length """ scale_factor = self.vae_temporal_compression_ratio num_frames = min(sequence_num_frames, target_num_frames - start_frame) + # Trim down to a multiple of temporal_scale_factor frames plus 1 num_frames = (num_frames - 1) // scale_factor * scale_factor + 1 return num_frames @@ -726,22 +742,50 @@ def preprocess_conditions( frame_scale_factor = self.vae_temporal_compression_ratio latent_num_frames = (num_frames - 1) // frame_scale_factor + 1 for i, condition in enumerate(conditions): + # Create a channels-last video-like array of shape (F, H, W, C) in preparation for resizing. if isinstance(condition.frames, PIL.Image.Image): - # Single image, convert to List[PIL.Image.Image] - video_like_cond = [condition.frames] - elif isinstance(condition.frames, np.ndarray) and condition.frames.ndim == 3: - # Image-like ndarray of shape (H, W, C), insert frame dim in first axis - video_like_cond = np.expand_dims(condition.frames, axis=0) - elif isinstance(condition.frames, torch.Tensor) and condition.frames.ndim == 3: - # Image-like tensor of shape (C, H, W), insert frame dim in first dim - video_like_cond = condition.frames.unsqueeze(0) + arr = np.array(condition.frames.convert("RGB"))[None] # (1, H, W, 3) + elif isinstance(condition.frames, list) and all( + isinstance(f, PIL.Image.Image) for f in condition.frames + ): + arr = np.stack([np.array(f.convert("RGB")) for f in condition.frames]) # (F, H, W, 3) + elif isinstance(condition.frames, np.ndarray): + arr = condition.frames if condition.frames.ndim == 4 else condition.frames[None] + elif isinstance(condition.frames, torch.Tensor): + t = condition.frames if condition.frames.ndim == 4 else condition.frames.unsqueeze(0) + # Reference layout for video tensors is (F, C, H, W); convert to (F, H, W, C) for the + # resize logic, which expects channels-last. + arr = t.detach().cpu().permute(0, 2, 3, 1).numpy() else: - # Treat all other as videos. Note that this means 4D ndarrays and tensors will be treated as videos of - # shape (F, H, W, C) and (F, C, H, W), respectively. - video_like_cond = condition.frames - condition_pixels = self.video_processor.preprocess_video( - video_like_cond, height, width, resize_mode="crop" + raise TypeError( + f"Unsupported `frames` type for condition {i}: {type(condition.frames)}" + ) + + src_h, src_w = arr.shape[1], arr.shape[2] + num_cond_frames = arr.shape[0] + # Convert the NumPy array to a channels-first tensor of shape (1, C, F, H, W) + pixels = torch.from_numpy(np.ascontiguousarray(arr)).to(torch.float32) + pixels = pixels.permute(3, 0, 1, 2).unsqueeze(0).to(device) # (1, C, F, H, W) + + # Resize so the longer side fills the target, then center-crop to exact (height, width). + scale = max(height / src_h, width / src_w) + new_h = math.ceil(src_h * scale) + new_w = math.ceil(src_w * scale) + # Flatten (B, C, F, H, W) → (B*F, C, H, W) for the per-frame interpolation + pixels = pixels.permute(0, 2, 1, 3, 4).reshape(num_cond_frames, 3, src_h, src_w) + # NOTE: we avoid using VideoProcessor.preprocess_video here because it uses PIL.Image.resize under the + # hood, which will apply an anti-aliasing pre-filter when downsampling. The original LTX-2.X code simply + # uses F.interpolate, which is reproduced here. + pixels = torch.nn.functional.interpolate( + pixels, size=(new_h, new_w), mode="bilinear", align_corners=False ) + top = (new_h - height) // 2 + left = (new_w - width) // 2 + pixels = pixels[:, :, top : top + height, left : left + width] + pixels = pixels.reshape(1, num_cond_frames, 3, height, width).permute(0, 2, 1, 3, 4) + + # Map [0, 255] → [-1, 1] (VAE input convention). + condition_pixels = pixels / 127.5 - 1.0 # Interpret the index as a latent index, following the original LTX-2 code. latent_start_idx = condition.index @@ -881,10 +925,13 @@ def _prepare_keyframe_coords( return pixel_coords - # Copied from diffusers.pipelines.ltx2.pipeline_ltx2_condition.LTX2ConditionPipeline.prepare_latents def prepare_latents( self, conditions: LTX2VideoCondition | list[LTX2VideoCondition] | None = None, + reference_conditions: list[LTX2ReferenceCondition] | None = None, + reference_downscale_factor: int = 1, + conditioning_attention_strength: float = 1.0, + conditioning_attention_mask: torch.Tensor | None = None, batch_size: int = 1, num_channels_latents: int = 128, height: int = 512, @@ -896,21 +943,39 @@ def prepare_latents( device: torch.device | None = None, generator: torch.Generator | None = None, latents: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None, int, torch.Tensor | None]: """ - Prepare noisy video latents, applying frame conditions. - - First-frame conditions (`latent_idx == 0`) are applied by overwriting tokens at the first-frame positions - (`VideoConditionByLatentIndex` semantics). Non-first-frame conditions (`latent_idx > 0`) are packaged as - keyframe extras to be appended to the latent sequence during the transformer forward pass - (`VideoConditionByKeyframeIndex` semantics). - - Returns a 4-tuple: - - `latents`: packed noisy latents (with first-frame replacement applied if applicable). - - `conditioning_mask`: packed conditioning mask (non-zero only at first-frame positions). - - `clean_latents`: clean first-frame conditions at first-frame positions (zeros elsewhere). - - `keyframe_extras`: `(keyframe_latents, keyframe_coords, keyframe_denoise_factors)` for keyframe - conditions at non-zero latent indices, or `None` if there are none. + Prepare noisy video latents, applying frame and reference-video conditioning. + + Conditioning sources are unified into a single packed sequence in the order + `[base | keyframe | reference]`: + + - First-frame conditions (`conditions` with `latent_idx == 0`) overwrite tokens at the first-frame positions + (`VideoConditionByLatentIndex` semantics). + - Non-first-frame conditions (`conditions` with `latent_idx > 0`) are concatenated onto the main latent + sequence with per-token `conditioning_mask = strength` + (`VideoConditionByKeyframeIndex` semantics). + - IC-LoRA `reference_conditions` (if any) are encoded by the VAE and appended after the keyframes with + per-token `conditioning_mask = strength` (matching the reference repo's + `VideoConditionByReferenceLatent` semantics). + + For all appended tokens the noise mixing below blends them to noise level `(1 - strength) * sigma_max`, + and the existing per-token timestep formula `t * (1 - conditioning_mask)` and the post-process blend + `denoised * (1 - cond_mask) + clean * cond_mask` drive them through the loop. + + Returns a 6-tuple: + - `latents`: packed noisy latents `(B, base + n_keyframe + n_ref, C)`. + - `conditioning_mask`: `(B, seq_len, 1)` with values in `[0, 1]` — `1` at first-frame positions, + `strength` at keyframe / reference positions, `0` elsewhere. + - `clean_latents`: clean condition values at conditioned positions (zeros elsewhere); same shape as + `latents`. + - `appended_coords`: `[1, 3, n_keyframe + n_ref, 2]` positional coordinates to concat onto + `video_coords`, or `None` if no keyframe/reference conditions are provided. + - `num_ref_tokens`: count of reference tokens at the END of `latents` (used by the call site to + build the unified self-attention mask). + - `ref_cross_mask`: `[1, num_ref_tokens]` per-reference-token cross-attention strengths in `[0, 1]`, + or `None` when `conditioning_attention_strength == 1.0` and no pixel-space mask is provided + (in which case attention is uniform). """ latent_height = height // self.vae_spatial_compression_ratio latent_width = width // self.vae_spatial_compression_ratio @@ -972,7 +1037,7 @@ def prepare_latents( # First-frame conditions (latent_idx == 0): replace tokens at the first-frame positions. # NOTE: following the I2V pipeline, we return a conditioning mask. The original LTX 2 code uses a denoising # mask, which is the inverse of the conditioning mask (`denoise_mask = 1 - conditioning_mask`). - latents, conditioning_mask, clean_latents = self.apply_visual_conditioning( + latents, conditioning_mask, clean_latents = self.apply_first_frame_conditioning( latents, conditioning_mask, condition_latents_packed, @@ -982,11 +1047,12 @@ def prepare_latents( latent_width=latent_width, ) - # Non-first-frame conditions (latent_idx > 0): append as keyframe extras with offset pixel coords. + # Non-first-frame ("keyframe") conditions (latent_idx > 0): append as extra latent tokens to the noisy latent. + # Each condition gets a all-`strength` conditioning mask and pos ids, which are also appended to those of the + # noisy latent. At each denoising step i, the keyframe conditions get an effective noise level of + # (1 - conditioning_strength) * sigma_i. frame_scale_factor = self.vae_temporal_compression_ratio - keyframe_tokens = [] - keyframe_coords = [] - keyframe_denoise_factors = [] + kf_tokens_list, kf_coords_list, kf_mask_list, kf_clean_list = [], [], [], [] for cond_5d, cond_packed, strength, latent_idx, num_pixel_frames in zip( condition_latents_5d, condition_latents_packed, @@ -998,8 +1064,6 @@ def prepare_latents( continue _, _, kf_latent_frames, kf_latent_height, kf_latent_width = cond_5d.shape - # Pixel-space frame index at which the keyframe is placed. Matches the `start_idx` formula in - # `preprocess_conditions` used for trimming. pixel_frame_idx = (latent_idx - 1) * frame_scale_factor + 1 coords = self._prepare_keyframe_coords( @@ -1013,37 +1077,95 @@ def prepare_latents( ) num_tokens = cond_packed.shape[1] - denoise_factor = torch.full( - (1, num_tokens), 1.0 - strength, device=device, dtype=torch.float32 + kf_mask = torch.full( + (cond_packed.shape[0], num_tokens, 1), + float(strength), + device=device, + dtype=conditioning_mask.dtype, ) - keyframe_tokens.append(cond_packed) - keyframe_coords.append(coords) - keyframe_denoise_factors.append(denoise_factor) - - keyframe_extras: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None - if keyframe_tokens: - keyframe_extras = ( - torch.cat(keyframe_tokens, dim=1), - torch.cat(keyframe_coords, dim=2), - torch.cat(keyframe_denoise_factors, dim=1), + kf_tokens_list.append(cond_packed) + kf_clean_list.append(cond_packed) + kf_mask_list.append(kf_mask) + kf_coords_list.append(coords) + + if kf_tokens_list: + keyframe_coords = torch.cat(kf_coords_list, dim=2) + latents = torch.cat([latents, torch.cat(kf_tokens_list, dim=1)], dim=1) + conditioning_mask = torch.cat([conditioning_mask, torch.cat(kf_mask_list, dim=1)], dim=1) + clean_latents = torch.cat([clean_latents, torch.cat(kf_clean_list, dim=1)], dim=1) + + # IC-LoRA reference-video conditions: encode each reference video, then append it to the main packed + # sequence with per-token `conditioning_mask = strength`. This is the same architectural pattern as + # for non-first-frame conditions above, but we need to keep keyframe and reference conditions separate + # for attention masking. + ref_cross_mask: torch.Tensor | None = None + ref_coords: torch.Tensor | None = None + num_ref_tokens = 0 + if reference_conditions is not None and len(reference_conditions) > 0: + ref_latents_packed, ref_coords, ref_cross_mask = self._encode_reference_conditions( + reference_conditions=reference_conditions, + num_frames=num_frames, + height=height, + width=width, + reference_downscale_factor=reference_downscale_factor, + frame_rate=frame_rate, + conditioning_attention_strength=conditioning_attention_strength, + conditioning_attention_mask=conditioning_attention_mask, + dtype=dtype, + device=device, + generator=generator, ) + num_ref_tokens = ref_latents_packed.shape[1] + + # All reference videos preprocess to the same (ref_height, ref_width, num_frames), so their packed + # token counts are identical. Split `num_ref_tokens` evenly across the conditions and materialize + # the per-token strength mask in `reference_conditions` order, matching the layout the encoder + # emitted. + n_per_ref = num_ref_tokens // len(reference_conditions) + ref_mask_chunks = [ + torch.full( + (batch_size, n_per_ref, 1), + float(ref_cond.strength), + device=device, + dtype=conditioning_mask.dtype, + ) + for ref_cond in reference_conditions + ] + ref_mask_full = torch.cat(ref_mask_chunks, dim=1) + + ref_latents_packed_b = ref_latents_packed.expand(batch_size, -1, -1) + latents = torch.cat([latents, ref_latents_packed_b], dim=1) + conditioning_mask = torch.cat([conditioning_mask, ref_mask_full], dim=1) + clean_latents = torch.cat([clean_latents, ref_latents_packed_b], dim=1) + + # Combine keyframe + reference appended-coords into a single block to concat onto `video_coords` at + # the call site. + if keyframe_coords is not None and ref_coords is not None: + appended_coords = torch.cat([keyframe_coords, ref_coords], dim=2) + elif keyframe_coords is not None: + appended_coords = keyframe_coords + elif ref_coords is not None: + appended_coords = ref_coords else: - keyframe_extras = None + appended_coords = None - # Sample from the standard Gaussian prior (or an intermediate Gaussian distribution if noise_scale < 1.0). + # The conditioning_mask values have the following semantics: + # - mask=0: fully noise tokens (e.g. noisy latents) + # - mask=1: keep fully clean (e.g. I2V first-frame condition, conditions with strength=1) + # - mask in (0, 1): use intermediate noise level mask * sigma_i (noise_scale == sigma_0) noise = randn_tensor(latents.shape, generator=generator, device=latents.device, dtype=latents.dtype) - scaled_mask = (1.0 - conditioning_mask) * noise_scale - # Add noise to the `latents` so that it is at the noise level specified by `noise_scale`. + scaled_mask = (1.0 - conditioning_mask) * noise_scale # noise to initial noise level `noise_scale` latents = noise * scaled_mask + latents * (1 - scaled_mask) - return latents, conditioning_mask, clean_latents, keyframe_extras + return latents, conditioning_mask, clean_latents, appended_coords, num_ref_tokens, ref_cross_mask + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2_condition.LTX2ConditionPipeline.prepare_audio_latents def prepare_audio_latents( self, batch_size: int = 1, num_channels_latents: int = 8, - audio_latent_length: int = 1, + audio_latent_length: int = 1, # 1 is just a dummy value num_mel_bins: int = 64, noise_scale: float = 0.0, dtype: torch.dtype | None = None, @@ -1052,6 +1174,7 @@ def prepare_audio_latents( latents: torch.Tensor | None = None, ) -> torch.Tensor: if latents is not None: + # latents expected to be unpacked (4D) with shape [B, C, L, M] latents = self._pack_audio_latents(latents) latents = self._normalize_audio_latents(latents, self.audio_vae.latents_mean, self.audio_vae.latents_std) latents = self._create_noised_state(latents, noise_scale, generator) @@ -1059,19 +1182,18 @@ def prepare_audio_latents( latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio - shape = (batch_size, num_channels_latents, audio_latent_length, latent_mel_bins) - if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = self._pack_audio_latents(latents) + # Sample in packed shape (B, L, C * M), following the original LTX-2.X code + packed_shape = (batch_size, audio_latent_length, num_channels_latents * latent_mel_bins) + latents = randn_tensor(packed_shape, generator=generator, device=device, dtype=dtype) return latents - def prepare_reference_latents( + def _encode_reference_conditions( self, reference_conditions: list[LTX2ReferenceCondition], height: int, @@ -1084,58 +1206,12 @@ def prepare_reference_latents( dtype: torch.dtype | None = None, device: torch.device | None = None, generator: torch.Generator | None = None, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None]: - """ - Encode reference videos into packed latent tokens and compute their positional coordinates. - - Each reference video is independently encoded by the VAE, packed into tokens, and its positional coordinates - are computed with spatial scaling by `reference_downscale_factor` to match the target coordinate space. - - All reference tokens are concatenated into a single sequence to be appended to the noisy video latents - during denoising. When `conditioning_attention_strength < 1.0` or `conditioning_attention_mask` is provided, - a per-token cross-attention mask is also computed for each reference video (downsampled to the reference - video's latent dimensions) and returned so the pipeline can build a self-attention mask over the full video - sequence. - - Args: - reference_conditions (`list[LTX2ReferenceCondition]`): - The reference video conditions. - height (`int`): - Target video height in pixels (used to determine reference video preprocessing size with - `reference_downscale_factor`). - width (`int`): - Target video width in pixels. - num_frames (`int`): - Number of target video frames. - reference_downscale_factor (`int`, defaults to `1`): - Ratio between target and reference resolutions. A factor of 2 means the reference video is - preprocessed at half the target resolution. Spatial positional coordinates are scaled by this factor - to map reference tokens into the target coordinate space. - frame_rate (`float`, defaults to `24.0`): - Video frame rate (used for temporal coordinate computation). - conditioning_attention_strength (`float`, defaults to `1.0`): - Scalar in `[0, 1]` controlling how strongly reference tokens attend to noisy tokens (and vice versa) - in the self-attention mask. `1.0` means full attention (no masking), `0.0` means reference tokens - are effectively ignored by the noisy tokens. - conditioning_attention_mask (`torch.Tensor`, *optional*): - Optional pixel-space mask of shape `(1, 1, F_pix, H_pix, W_pix)` with values in `[0, 1]` that provides - spatially-varying attention strength. Downsampled to latent space per reference video and multiplied - by `conditioning_attention_strength`. - dtype (`torch.dtype`, *optional*): - Data type for the latents. - device (`torch.device`, *optional*): - Device for the latents. - generator (`torch.Generator`, *optional*): - Random generator for VAE encoding. + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: + """Encode IC-LoRA reference videos into `(reference_latents, reference_coords, reference_cross_mask)`. - Returns: - A 4-tuple of `(reference_latents, reference_coords, reference_denoise_factors, reference_cross_mask)`: - - `reference_latents`: `[1, total_ref_tokens, hidden_dim]` - - `reference_coords`: `[1, 3, total_ref_tokens, 2]` - - `reference_denoise_factors`: `[1, total_ref_tokens]` — per-token `(1 - strength)` factors - - `reference_cross_mask`: `[1, total_ref_tokens]` per-token noisy↔reference attention strengths in - `[0, 1]`, or `None` when `conditioning_attention_strength == 1.0` and no pixel-space mask is - provided (in which case attention is unmasked). + This is the shared encoding core used by both `prepare_latents` (which folds reference tokens into the + main noisy sequence) and the back-compat shim `prepare_reference_latents` (which exposes the legacy + 4-tuple output). See `prepare_reference_latents` for parameter documentation. """ ref_height = height // reference_downscale_factor ref_width = width // reference_downscale_factor @@ -1144,7 +1220,6 @@ def prepare_reference_latents( all_ref_latents = [] all_ref_coords = [] - all_ref_denoise_factors = [] all_ref_cross_masks = [] for ref_cond in reference_conditions: @@ -1198,13 +1273,9 @@ def prepare_reference_latents( ref_coords[:, 2, :, :] = ref_coords[:, 2, :, :] * reference_downscale_factor num_tokens = ref_latent_packed.shape[1] - denoise_factor = torch.full( - (1, num_tokens), 1.0 - ref_cond.strength, device=device, dtype=torch.float32 - ) all_ref_latents.append(ref_latent_packed) all_ref_coords.append(ref_coords) - all_ref_denoise_factors.append(denoise_factor) if mask_needed: # Per-reference cross-attention mask. Start from either a downsampled pixel-space mask or a full-1 @@ -1224,9 +1295,116 @@ def prepare_reference_latents( # Concatenate all reference tokens into a single sequence reference_latents = torch.cat(all_ref_latents, dim=1) # [1, total_ref_tokens, D] reference_coords = torch.cat(all_ref_coords, dim=2) # [1, 3, total_ref_tokens, 2] - reference_denoise_factors = torch.cat(all_ref_denoise_factors, dim=1) # [1, total_ref_tokens] reference_cross_mask = torch.cat(all_ref_cross_masks, dim=1) if mask_needed else None + return reference_latents, reference_coords, reference_cross_mask + + def prepare_reference_latents( + self, + reference_conditions: list[LTX2ReferenceCondition], + height: int, + width: int, + num_frames: int, + reference_downscale_factor: int = 1, + frame_rate: float = 24.0, + conditioning_attention_strength: float = 1.0, + conditioning_attention_mask: torch.Tensor | None = None, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None]: + """ + Encode reference videos into packed latent tokens and compute their positional coordinates. + + Each reference video is independently encoded by the VAE, packed into tokens, and its positional coordinates + are computed with spatial scaling by `reference_downscale_factor` to match the target coordinate space. + + All reference tokens are concatenated into a single sequence. When `conditioning_attention_strength < 1.0` + or `conditioning_attention_mask` is provided, a per-token cross-attention mask is also computed for each + reference video (downsampled to the reference video's latent dimensions) and returned so callers can build + a self-attention mask over the full video sequence. + + NOTE: As of the IC-LoRA reference-token refactor, this method is a back-compat shim — the canonical encoding + helper is `_encode_reference_conditions` and reference tokens are folded into the main noisy sequence by + `prepare_latents`. This method exists for callers that want the standalone encoding output (e.g. for + downstream parity instrumentation). The `reference_denoise_factors` it returns are derivable as + `1 - strength` per token; in the integrated path the equivalent information lives in + `conditioning_mask` produced by `prepare_latents`. + + Args: + reference_conditions (`list[LTX2ReferenceCondition]`): + The reference video conditions. + height (`int`): + Target video height in pixels (used to determine reference video preprocessing size with + `reference_downscale_factor`). + width (`int`): + Target video width in pixels. + num_frames (`int`): + Number of target video frames. + reference_downscale_factor (`int`, defaults to `1`): + Ratio between target and reference resolutions. A factor of 2 means the reference video is + preprocessed at half the target resolution. Spatial positional coordinates are scaled by this factor + to map reference tokens into the target coordinate space. + frame_rate (`float`, defaults to `24.0`): + Video frame rate (used for temporal coordinate computation). + conditioning_attention_strength (`float`, defaults to `1.0`): + Scalar in `[0, 1]` controlling how strongly reference tokens attend to noisy tokens (and vice versa) + in the self-attention mask. `1.0` means full attention (no masking), `0.0` means reference tokens + are effectively ignored by the noisy tokens. + conditioning_attention_mask (`torch.Tensor`, *optional*): + Optional pixel-space mask of shape `(1, 1, F_pix, H_pix, W_pix)` with values in `[0, 1]` that provides + spatially-varying attention strength. Downsampled to latent space per reference video and multiplied + by `conditioning_attention_strength`. + dtype (`torch.dtype`, *optional*): + Data type for the latents. + device (`torch.device`, *optional*): + Device for the latents. + generator (`torch.Generator`, *optional*): + Random generator for VAE encoding. + + Returns: + A 4-tuple of `(reference_latents, reference_coords, reference_denoise_factors, reference_cross_mask)`: + - `reference_latents`: `[1, total_ref_tokens, hidden_dim]` + - `reference_coords`: `[1, 3, total_ref_tokens, 2]` + - `reference_denoise_factors`: `[1, total_ref_tokens]` — per-token `(1 - strength)` factors + - `reference_cross_mask`: `[1, total_ref_tokens]` per-token noisy↔reference attention strengths in + `[0, 1]`, or `None` when `conditioning_attention_strength == 1.0` and no pixel-space mask is + provided (in which case attention is unmasked). + """ + reference_latents, reference_coords, reference_cross_mask = self._encode_reference_conditions( + reference_conditions=reference_conditions, + height=height, + width=width, + num_frames=num_frames, + reference_downscale_factor=reference_downscale_factor, + frame_rate=frame_rate, + conditioning_attention_strength=conditioning_attention_strength, + conditioning_attention_mask=conditioning_attention_mask, + dtype=dtype, + device=device, + generator=generator, + ) + + # Materialize per-token denoise factors for callers that still expect the 4-tuple. Each ref video has + # `1 - strength` for all of its tokens; we rebuild this from the per-video token counts which we can + # back out from `reference_latents.shape[1]` and the input `reference_conditions` order. + ref_denoise_chunks: list[torch.Tensor] = [] + idx = 0 + # Walk the encoded ref tokens video-by-video. Each ref's token count is fixed by the ref video's latent + # shape, which equals (num_frames -> ref_latent_frames) * ref_latent_h * ref_latent_w. Computing it here + # would duplicate the encoding math; instead we rely on the shape match across all refs being identical + # (same `num_frames`, same downscaled height/width) so we can split equally. + n_total = reference_latents.shape[1] + n_per_ref = n_total // max(len(reference_conditions), 1) + for ref_cond in reference_conditions: + ref_denoise_chunks.append( + torch.full( + (1, n_per_ref), 1.0 - ref_cond.strength, device=reference_latents.device, dtype=torch.float32 + ) + ) + idx += n_per_ref + reference_denoise_factors = torch.cat(ref_denoise_chunks, dim=1) if ref_denoise_chunks else reference_latents.new_zeros((1, 0)) + return reference_latents, reference_coords, reference_denoise_factors, reference_cross_mask @staticmethod @@ -1418,98 +1596,6 @@ def attention_kwargs(self): def interrupt(self): return self._interrupt - def _run_transformer( - self, - latent_model_input: torch.Tensor, - audio_latent_model_input: torch.Tensor, - video_timestep: torch.Tensor, - audio_timestep: torch.Tensor, - sigma: torch.Tensor, - video_coords: torch.Tensor, - audio_coords: torch.Tensor, - connector_prompt_embeds: torch.Tensor, - connector_audio_prompt_embeds: torch.Tensor, - connector_attention_mask: torch.Tensor, - latent_num_frames: int, - latent_height: int, - latent_width: int, - frame_rate: float, - audio_num_frames: int, - use_cross_timestep: bool, - attention_kwargs: dict[str, Any] | None, - cache_context: str, - extra_latents: torch.Tensor | None = None, - extra_coords: torch.Tensor | None = None, - extra_timestep: torch.Tensor | None = None, - video_self_attention_mask: torch.Tensor | None = None, - isolate_modalities: bool = False, - spatio_temporal_guidance_blocks: list[int] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Run a single transformer forward pass, optionally concatenating extra tokens (keyframe/reference conditions) - to the video hidden states and/or applying a video self-attention mask. - - When `extra_latents` is provided, the extra tokens are concatenated to the video hidden states, video coords, - and video timesteps. After the transformer forward pass, the extra tokens are stripped from the video output - so only the noisy-token predictions are returned. - - When `video_self_attention_mask` is provided (shape `(1, T_video, T_video)`), it is expanded to the current - batch size and passed to the transformer to control attention between noisy and appended extra tokens. - - Returns: - `(noise_pred_video, noise_pred_audio)` where `noise_pred_video` has the same sequence length as the input - `latent_model_input` (extras are stripped). - """ - video_seq_len = latent_model_input.shape[1] - - if extra_latents is not None: - batch_size = latent_model_input.shape[0] - extra_batch = extra_latents.to(latent_model_input.dtype).expand(batch_size, -1, -1) - combined_hidden = torch.cat([latent_model_input, extra_batch], dim=1) - - extra_coords_batch = extra_coords.expand(batch_size, -1, -1, -1) - combined_coords = torch.cat([video_coords, extra_coords_batch], dim=2) - - extra_ts_batch = extra_timestep.expand(batch_size, -1) - combined_timestep = torch.cat([video_timestep, extra_ts_batch], dim=1) - else: - combined_hidden = latent_model_input - combined_coords = video_coords - combined_timestep = video_timestep - - if video_self_attention_mask is not None: - video_self_attention_mask = video_self_attention_mask.expand(combined_hidden.shape[0], -1, -1) - - with self.transformer.cache_context(cache_context): - noise_pred_combined, noise_pred_audio = self.transformer( - hidden_states=combined_hidden, - audio_hidden_states=audio_latent_model_input, - encoder_hidden_states=connector_prompt_embeds, - audio_encoder_hidden_states=connector_audio_prompt_embeds, - timestep=combined_timestep, - audio_timestep=audio_timestep, - sigma=sigma, - encoder_attention_mask=connector_attention_mask, - audio_encoder_attention_mask=connector_attention_mask, - video_self_attention_mask=video_self_attention_mask, - num_frames=latent_num_frames, - height=latent_height, - width=latent_width, - fps=frame_rate, - audio_num_frames=audio_num_frames, - video_coords=combined_coords, - audio_coords=audio_coords, - isolate_modalities=isolate_modalities, - spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks, - perturbation_mask=None, - use_cross_timestep=use_cross_timestep, - attention_kwargs=attention_kwargs, - return_dict=False, - ) - - noise_pred_video = noise_pred_combined[:, :video_seq_len] - return noise_pred_video, noise_pred_audio - @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -1765,73 +1851,44 @@ def __call__( _, _, latent_num_frames, latent_height, latent_width = latents.shape num_channels_latents = self.transformer.config.in_channels - latents, conditioning_mask, clean_latents, keyframe_extras = self.prepare_latents( - conditions=conditions, - batch_size=batch_size * num_videos_per_prompt, - num_channels_latents=num_channels_latents, - height=height, - width=width, - num_frames=num_frames, - frame_rate=frame_rate, - noise_scale=noise_scale, - dtype=torch.float32, - device=device, - generator=generator, - latents=latents, - ) - has_conditions = conditions is not None and len(conditions) > 0 - if self.do_classifier_free_guidance and has_conditions: - conditioning_mask = torch.cat([conditioning_mask, conditioning_mask]) - - # 4b. Prepare reference extras for IC-LoRA conditioning. Reference extras are packaged in the same format - # as keyframe extras (tokens, coords, per-token denoise factors) so both can be concatenated into a single - # block of extra tokens before being fed to the transformer. The reference path also produces an optional - # per-token cross-attention mask when `conditioning_attention_strength < 1.0` or - # `conditioning_attention_mask` is provided. - reference_extras: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None - reference_cross_mask: torch.Tensor | None = None - if reference_conditions is not None and len(reference_conditions) > 0: - ref_latents, ref_coords, ref_denoise, reference_cross_mask = self.prepare_reference_latents( + latents, conditioning_mask, clean_latents, appended_coords, num_ref_tokens, ref_cross_mask = ( + self.prepare_latents( + conditions=conditions, reference_conditions=reference_conditions, + reference_downscale_factor=reference_downscale_factor, + conditioning_attention_strength=conditioning_attention_strength, + conditioning_attention_mask=conditioning_attention_mask, + batch_size=batch_size * num_videos_per_prompt, + num_channels_latents=num_channels_latents, height=height, width=width, num_frames=num_frames, - reference_downscale_factor=reference_downscale_factor, frame_rate=frame_rate, - conditioning_attention_strength=conditioning_attention_strength, - conditioning_attention_mask=conditioning_attention_mask, + noise_scale=noise_scale, dtype=torch.float32, device=device, generator=generator, + latents=latents, ) - reference_extras = (ref_latents, ref_coords, ref_denoise) - - # Combine keyframe extras + reference extras into a single extras block. Keyframes come first (matching - # the reference implementation's ordering in `_create_conditionings`: image conditions first, reference - # video conditions appended last). - extras_parts = [e for e in (keyframe_extras, reference_extras) if e is not None] - if extras_parts: - extra_latents_all = torch.cat([e[0] for e in extras_parts], dim=1) - extra_coords_all = torch.cat([e[1] for e in extras_parts], dim=2) - extra_denoise_factors_all = torch.cat([e[2] for e in extras_parts], dim=1) - else: - extra_latents_all = extra_coords_all = extra_denoise_factors_all = None + ) + # Track the base token count in the generated video, excluding any appended keyframe and reference-video + # condition tokens. + base_token_count = latents.shape[1] - (appended_coords.shape[2] if appended_coords is not None else 0) + + has_conditions = conditions is not None and len(conditions) > 0 + has_appended_tokens = appended_coords is not None + if self.do_classifier_free_guidance and (has_conditions or num_ref_tokens > 0): + conditioning_mask = torch.cat([conditioning_mask, conditioning_mask]) - # Build the video self-attention mask over `noisy + extras` when any extras group needs non-trivial - # attention strength (currently: only IC-LoRA references). Keyframes are always included with full - # cross-attention (cross_mask=1.0) so the resulting block structure correctly isolates keyframes from - # references (different-group blocks are 0). When no reference mask is needed, we leave - # `video_self_attention_mask=None` so attention is fully unmasked. + # Build a video self-attention mask over three groups: (1) the noisy latents (2) keyframe conditions, if any + # and (3) reference conditions, if any. Tokens are attend to each other across groups as follows: + # - TODO video_self_attention_mask: torch.Tensor | None = None - if reference_cross_mask is not None: - extras_cross_masks: list[torch.Tensor] = [] - if keyframe_extras is not None: - num_kf_tokens = keyframe_extras[0].shape[1] - extras_cross_masks.append(torch.ones((1, num_kf_tokens), device=device, dtype=torch.float32)) - extras_cross_masks.append(reference_cross_mask) + if ref_cross_mask is not None: + num_noisy_tokens = latents.shape[1] - num_ref_tokens video_self_attention_mask = self._build_video_self_attention_mask( - num_noisy_tokens=latents.shape[1], - extras_cross_masks=extras_cross_masks, + num_noisy_tokens=num_noisy_tokens, + extras_cross_masks=[ref_cross_mask], device=device, ) @@ -1898,6 +1955,8 @@ def __call__( video_coords = self.transformer.rope.prepare_video_coords( latents.shape[0], latent_num_frames, latent_height, latent_width, latents.device, fps=frame_rate ) + if appended_coords is not None: + video_coords = torch.cat([video_coords, appended_coords], dim=2) audio_coords = self.transformer.audio_rope.prepare_audio_coords( audio_latents.shape[0], audio_num_frames, audio_latents.device ) @@ -1924,44 +1983,40 @@ def __call__( timestep_scalar = t.expand(latent_model_input.shape[0]) - # Per-token video timestep: conditioned positions (from frame conditions) get timestep 0, - # unconditioned positions get the current sigma. - if has_conditions: + if has_conditions or num_ref_tokens > 0: video_timestep = timestep_scalar.unsqueeze(-1) * (1 - conditioning_mask.squeeze(-1)) else: video_timestep = timestep_scalar.unsqueeze(-1).expand(-1, video_seq_len) - # Per-token timestep for the combined extras block (keyframes + references). Each extra token's - # timestep is sigma * denoise_factor, where denoise_factor = 1 - strength (0 for fully-clean tokens). - extra_timestep_in = ( - t * extra_denoise_factors_all if extra_denoise_factors_all is not None else None - ) - # --- Main transformer forward pass (conditional + unconditional for CFG) --- - noise_pred_video, noise_pred_audio = self._run_transformer( - latent_model_input=latent_model_input, - audio_latent_model_input=audio_latent_model_input, - video_timestep=video_timestep, - audio_timestep=timestep_scalar, - sigma=timestep_scalar, - video_coords=video_coords, - audio_coords=audio_coords, - connector_prompt_embeds=connector_prompt_embeds, - connector_audio_prompt_embeds=connector_audio_prompt_embeds, - connector_attention_mask=connector_attention_mask, - latent_num_frames=latent_num_frames, - latent_height=latent_height, - latent_width=latent_width, - frame_rate=frame_rate, - audio_num_frames=audio_num_frames, - use_cross_timestep=use_cross_timestep, - attention_kwargs=attention_kwargs, - cache_context="cond_uncond", - extra_latents=extra_latents_all, - extra_coords=extra_coords_all, - extra_timestep=extra_timestep_in, - video_self_attention_mask=video_self_attention_mask, - ) + if video_self_attention_mask is not None: + video_self_attention_mask = video_self_attention_mask.expand(latent_model_input.shape[0], -1, -1) + with self.transformer.cache_context("cond_uncond"): + noise_pred_video, noise_pred_audio = self.transformer( + hidden_states=latent_model_input, + audio_hidden_states=audio_latent_model_input, + encoder_hidden_states=connector_prompt_embeds, + audio_encoder_hidden_states=connector_audio_prompt_embeds, + timestep=video_timestep, + audio_timestep=timestep_scalar, + sigma=timestep_scalar, # Used by LTX-2.3 + encoder_attention_mask=connector_attention_mask, + audio_encoder_attention_mask=connector_attention_mask, + video_self_attention_mask=video_self_attention_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + fps=frame_rate, + audio_num_frames=audio_num_frames, + video_coords=video_coords, + audio_coords=audio_coords, + isolate_modalities=False, + spatio_temporal_guidance_blocks=None, + perturbation_mask=None, + use_cross_timestep=use_cross_timestep, + attention_kwargs=attention_kwargs, + return_dict=False, + ) noise_pred_video = noise_pred_video.float() noise_pred_audio = noise_pred_audio.float() @@ -1992,7 +2047,7 @@ def __call__( audio_pos_ids = audio_coords.chunk(2, dim=0)[0] timestep_scalar_single = timestep_scalar.chunk(2, dim=0)[0] - if has_conditions: + if has_conditions or num_ref_tokens > 0: video_timestep_single = video_timestep.chunk(2, dim=0)[0] else: video_timestep_single = timestep_scalar_single.unsqueeze(-1).expand(-1, video_seq_len) @@ -2007,7 +2062,7 @@ def __call__( audio_pos_ids = audio_coords timestep_scalar_single = timestep_scalar - if has_conditions: + if has_conditions or num_ref_tokens > 0: video_timestep_single = video_timestep else: video_timestep_single = timestep_scalar.unsqueeze(-1).expand(-1, video_seq_len) @@ -2017,31 +2072,35 @@ def __call__( # --- STG forward pass --- if self.do_spatio_temporal_guidance: - noise_pred_video_uncond_stg, noise_pred_audio_uncond_stg = self._run_transformer( - latent_model_input=latents.to(dtype=prompt_embeds.dtype), - audio_latent_model_input=audio_latents.to(dtype=prompt_embeds.dtype), - video_timestep=video_timestep_single, - audio_timestep=timestep_scalar_single, - sigma=timestep_scalar_single, - video_coords=video_pos_ids, - audio_coords=audio_pos_ids, - connector_prompt_embeds=video_prompt_embeds, - connector_audio_prompt_embeds=audio_prompt_embeds, - connector_attention_mask=prompt_attn_mask, - latent_num_frames=latent_num_frames, - latent_height=latent_height, - latent_width=latent_width, - frame_rate=frame_rate, - audio_num_frames=audio_num_frames, - use_cross_timestep=use_cross_timestep, - attention_kwargs=attention_kwargs, - cache_context="uncond_stg", - extra_latents=extra_latents_all, - extra_coords=extra_coords_all, - extra_timestep=extra_timestep_in, - video_self_attention_mask=video_self_attention_mask, - spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks, - ) + if video_self_attention_mask is not None: + video_self_attention_mask = video_self_attention_mask.expand(latents.shape[0], -1, -1) + with self.transformer.cache_context("uncond_stg"): + noise_pred_video_uncond_stg, noise_pred_audio_uncond_stg = self.transformer( + hidden_states=latents.to(dtype=prompt_embeds.dtype), + audio_hidden_states=audio_latents.to(dtype=prompt_embeds.dtype), + encoder_hidden_states=video_prompt_embeds, + audio_encoder_hidden_states=audio_prompt_embeds, + timestep=video_timestep_single, + audio_timestep=timestep_scalar_single, + sigma=timestep_scalar_single, # Used by LTX-2.3 + encoder_attention_mask=prompt_attn_mask, + audio_encoder_attention_mask=prompt_attn_mask, + video_self_attention_mask=video_self_attention_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + fps=frame_rate, + audio_num_frames=audio_num_frames, + video_coords=video_pos_ids, + audio_coords=audio_pos_ids, + isolate_modalities=False, + # Use STG at given blocks to perturb model + spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks, + perturbation_mask=None, + use_cross_timestep=use_cross_timestep, + attention_kwargs=attention_kwargs, + return_dict=False, + ) noise_pred_video_uncond_stg = noise_pred_video_uncond_stg.float() noise_pred_audio_uncond_stg = noise_pred_audio_uncond_stg.float() noise_pred_video_uncond_stg = self.convert_velocity_to_x0( @@ -2058,31 +2117,35 @@ def __call__( # --- Modality isolation guidance forward pass --- if self.do_modality_isolation_guidance: - noise_pred_video_uncond_mod, noise_pred_audio_uncond_mod = self._run_transformer( - latent_model_input=latents.to(dtype=prompt_embeds.dtype), - audio_latent_model_input=audio_latents.to(dtype=prompt_embeds.dtype), - video_timestep=video_timestep_single, - audio_timestep=timestep_scalar_single, - sigma=timestep_scalar_single, - video_coords=video_pos_ids, - audio_coords=audio_pos_ids, - connector_prompt_embeds=video_prompt_embeds, - connector_audio_prompt_embeds=audio_prompt_embeds, - connector_attention_mask=prompt_attn_mask, - latent_num_frames=latent_num_frames, - latent_height=latent_height, - latent_width=latent_width, - frame_rate=frame_rate, - audio_num_frames=audio_num_frames, - use_cross_timestep=use_cross_timestep, - attention_kwargs=attention_kwargs, - cache_context="uncond_modality", - extra_latents=extra_latents_all, - extra_coords=extra_coords_all, - extra_timestep=extra_timestep_in, - video_self_attention_mask=video_self_attention_mask, - isolate_modalities=True, - ) + if video_self_attention_mask is not None: + video_self_attention_mask = video_self_attention_mask.expand(latents.shape[0], -1, -1) + with self.transformer.cache_context("uncond_modality"): + noise_pred_video_uncond_mod, noise_pred_audio_uncond_mod = self.transformer( + hidden_states=latents.to(dtype=prompt_embeds.dtype), + audio_hidden_states=audio_latents.to(dtype=prompt_embeds.dtype), + encoder_hidden_states=video_prompt_embeds, + audio_encoder_hidden_states=audio_prompt_embeds, + timestep=video_timestep_single, + audio_timestep=timestep_scalar_single, + sigma=timestep_scalar_single, # Used by LTX-2.3 + encoder_attention_mask=prompt_attn_mask, + audio_encoder_attention_mask=prompt_attn_mask, + video_self_attention_mask=video_self_attention_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + fps=frame_rate, + audio_num_frames=audio_num_frames, + video_coords=video_pos_ids, + audio_coords=audio_pos_ids, + # Turn off A2V and V2A cross attn to isolate video and audio modalities + isolate_modalities=True, + spatio_temporal_guidance_blocks=None, + perturbation_mask=None, + use_cross_timestep=use_cross_timestep, + attention_kwargs=attention_kwargs, + return_dict=False, + ) noise_pred_video_uncond_mod = noise_pred_video_uncond_mod.float() noise_pred_audio_uncond_mod = noise_pred_audio_uncond_mod.float() noise_pred_video_uncond_mod = self.convert_velocity_to_x0( @@ -2153,6 +2216,8 @@ def __call__( xm.mark_step() # 9. Decode + # Trim any appended keyframe or reference tokens from the latents to recover the generated video only. + latents = latents[:, :base_token_count] latents = self._unpack_latents( latents, latent_num_frames, diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index c95c56789e37..c41ddc24f38c 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -2402,6 +2402,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class LTX2HDRLoraPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class LTX2ICLoraPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class LTX2ImageToVideoPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From 29be81d1b12ef002521675ff6ce2a9cbe2318969 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 27 Apr 2026 06:12:36 +0200 Subject: [PATCH 04/14] Change LTX2ConditionPipeline default __call__ parameters to match the suggested params for the LTX-2.3 model --- .../pipelines/ltx2/pipeline_ltx2_condition.py | 22 +++++++++---------- .../pipelines/ltx2/pipeline_ltx2_ic_lora.py | 22 +++++++++---------- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py index 56ca0add74a9..441c69995bd3 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py @@ -1179,18 +1179,18 @@ def __call__( width: int = 768, num_frames: int = 121, frame_rate: float = 24.0, - num_inference_steps: int = 40, + num_inference_steps: int = 30, sigmas: list[float] | None = None, timesteps: list[float] | None = None, - guidance_scale: float = 4.0, - stg_scale: float = 0.0, - modality_scale: float = 1.0, - guidance_rescale: float = 0.0, - audio_guidance_scale: float | None = None, - audio_stg_scale: float | None = None, - audio_modality_scale: float | None = None, - audio_guidance_rescale: float | None = None, - spatio_temporal_guidance_blocks: list[int] | None = None, + guidance_scale: float = 3.0, + stg_scale: float = 1.0, + modality_scale: float = 3.0, + guidance_rescale: float = 0.7, + audio_guidance_scale: float | None = 7.0, + audio_stg_scale: float | None = 1.0, + audio_modality_scale: float | None = 3.0, + audio_guidance_rescale: float | None = 0.7, + spatio_temporal_guidance_blocks: list[int] | None = [28], noise_scale: float | None = None, num_videos_per_prompt: int | None = 1, generator: torch.Generator | list[torch.Generator] | None = None, @@ -1202,7 +1202,7 @@ def __call__( negative_prompt_attention_mask: torch.Tensor | None = None, decode_timestep: float | list[float] = 0.0, decode_noise_scale: float | list[float] | None = None, - use_cross_timestep: bool = False, + use_cross_timestep: bool = True, output_type: str = "pil", return_dict: bool = True, attention_kwargs: dict[str, Any] | None = None, diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_ic_lora.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_ic_lora.py index 01ad07babf77..4c3348c4fa7d 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_ic_lora.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_ic_lora.py @@ -1611,18 +1611,18 @@ def __call__( width: int = 768, num_frames: int = 121, frame_rate: float = 24.0, - num_inference_steps: int = 40, + num_inference_steps: int = 30, sigmas: list[float] | None = None, timesteps: list[float] | None = None, - guidance_scale: float = 4.0, - stg_scale: float = 0.0, - modality_scale: float = 1.0, - guidance_rescale: float = 0.0, - audio_guidance_scale: float | None = None, - audio_stg_scale: float | None = None, - audio_modality_scale: float | None = None, - audio_guidance_rescale: float | None = None, - spatio_temporal_guidance_blocks: list[int] | None = None, + guidance_scale: float = 3.0, + stg_scale: float = 1.0, + modality_scale: float = 3.0, + guidance_rescale: float = 0.7, + audio_guidance_scale: float | None = 7.0, + audio_stg_scale: float | None = 1.0, + audio_modality_scale: float | None = 3.0, + audio_guidance_rescale: float | None = 0.7, + spatio_temporal_guidance_blocks: list[int] | None = [28], noise_scale: float | None = None, num_videos_per_prompt: int | None = 1, generator: torch.Generator | list[torch.Generator] | None = None, @@ -1634,7 +1634,7 @@ def __call__( negative_prompt_attention_mask: torch.Tensor | None = None, decode_timestep: float | list[float] = 0.0, decode_noise_scale: float | list[float] | None = None, - use_cross_timestep: bool = False, + use_cross_timestep: bool = True, output_type: str = "pil", return_dict: bool = True, attention_kwargs: dict[str, Any] | None = None, From 5800b0cfde05f195d2508bab61cf3168cf1457a9 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 27 Apr 2026 06:55:39 +0200 Subject: [PATCH 05/14] Improve IC LoRA example and fix some LTX2ICLoraPipeline bugs --- .../pipelines/ltx2/pipeline_ltx2_ic_lora.py | 25 ++++++++++++------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_ic_lora.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_ic_lora.py index 4c3348c4fa7d..9252227893ef 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_ic_lora.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_ic_lora.py @@ -75,32 +75,37 @@ class LTX2ReferenceCondition: >>> import torch >>> from diffusers import LTX2ICLoraPipeline >>> from diffusers.pipelines.ltx2.pipeline_ltx2_ic_lora import LTX2ReferenceCondition - >>> from diffusers.pipelines.ltx2.utils import DISTILLED_SIGMA_VALUES >>> from diffusers.pipelines.ltx2.export_utils import encode_video + >>> from diffusers.pipelines.ltx2.utils import DEFAULT_NEGATIVE_PROMPT >>> from diffusers.utils import load_video >>> pipe = LTX2ICLoraPipeline.from_pretrained( - ... "rootonchair/LTX-2-19b-distilled", torch_dtype=torch.bfloat16 + ... "dg845/LTX-2.3-Diffusers", torch_dtype=torch.bfloat16 ... ) >>> pipe.enable_sequential_cpu_offload(device="cuda") - >>> pipe.load_lora_weights("path/to/ic_lora.safetensors", adapter_name="ic_lora") + >>> pipe.load_lora_weights( + >>> "Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In", + >>> adapter_name="ic_lora", + >>> weight_name="ltx-2-19b-lora-camera-control-dolly-in.safetensors", + >>> ) >>> pipe.set_adapters("ic_lora", 1.0) - >>> reference_video = load_video("reference.mp4") - >>> ref_cond = LTX2ReferenceCondition(frames=reference_video, strength=1.0) + >>> # If the IC LoRA uses reference conditions, you can specify them as follows: + >>> # reference_video = load_video("reference.mp4") + >>> # ref_cond = LTX2ReferenceCondition(frames=reference_video, strength=1.0) >>> prompt = "A flowing river in a forest" >>> frame_rate = 24.0 >>> video, audio = pipe( ... prompt=prompt, - ... reference_conditions=[ref_cond], + ... negative_prompt=DEFAULT_NEGATIVE_PROMPT, + ... # reference_conditions=[ref_cond], ... width=768, ... height=512, ... num_frames=121, ... frame_rate=frame_rate, - ... num_inference_steps=8, - ... sigmas=DISTILLED_SIGMA_VALUES, - ... guidance_scale=1.0, + ... num_inference_steps=30, + ... guidance_scale=3.0, ... output_type="np", ... return_dict=False, ... ) @@ -1094,6 +1099,8 @@ def prepare_latents( latents = torch.cat([latents, torch.cat(kf_tokens_list, dim=1)], dim=1) conditioning_mask = torch.cat([conditioning_mask, torch.cat(kf_mask_list, dim=1)], dim=1) clean_latents = torch.cat([clean_latents, torch.cat(kf_clean_list, dim=1)], dim=1) + else: + keyframe_coords = None # IC-LoRA reference-video conditions: encode each reference video, then append it to the main packed # sequence with per-token `conditioning_mask = strength`. This is the same architectural pattern as From a3a784d565b11c4d6346e5855e242269508135b0 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 27 Apr 2026 07:22:19 +0200 Subject: [PATCH 06/14] Improve HDR IC LoRA example --- .../pipelines/ltx2/pipeline_ltx2_hdr_lora.py | 26 ++++++++++++++----- 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_hdr_lora.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_hdr_lora.py index 80d1fecd7243..686d929f9505 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_hdr_lora.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_hdr_lora.py @@ -73,26 +73,36 @@ class LTX2HDRReferenceCondition: Examples: ```py >>> import torch + >>> from safetensors import safe_open >>> from diffusers import LTX2HDRLoraPipeline >>> from diffusers.pipelines.ltx2.pipeline_ltx2_hdr_lora import LTX2HDRReferenceCondition >>> from diffusers.pipelines.ltx2.utils import DISTILLED_SIGMA_VALUES - >>> from diffusers.pipelines.ltx2.export_utils import save_hdr_video_frames_as_exr + >>> from diffusers.pipelines.ltx2.export_utils import save_hdr_video_frames_as_exr, encode_exr_sequence_to_mp4 >>> from diffusers.utils import load_video >>> pipe = LTX2HDRLoraPipeline.from_pretrained( - ... "rootonchair/LTX-2-19b-distilled", torch_dtype=torch.bfloat16 + ... "dg845/LTX-2.3-Distilled-Diffusers", torch_dtype=torch.bfloat16 ... ) >>> pipe.enable_sequential_cpu_offload(device="cuda") - >>> pipe.load_lora_weights("path/to/hdr_ic_lora.safetensors", adapter_name="hdr_lora") + >>> pipe.load_lora_weights( + >>> "Lightricks/LTX-2.3-22b-IC-LoRA-HDR", + >>> adapter_name="hdr_lora", + >>> weight_name="ltx-2.3-22b-ic-lora-hdr-0.9.safetensors", + >>> ) >>> pipe.set_adapters("hdr_lora", 1.0) - >>> reference_video = load_video("reference.mp4") + >>> reference_video = load_video("/path/to/reference.mp4") >>> ref_cond = LTX2HDRReferenceCondition(frames=reference_video, strength=1.0) - >>> prompt = "A cinematic landscape at sunset" + >>> # Load pre-computed HDR LoRA connector embeddings. + >>> with safe_open("/path/to/connector/embeds.safetensors", framework="pt", device="cuda") as f: + >>> connector_video_embeds = f.get_tensor("video_context") + >>> connector_audio_embeds = f.get_tensor("audio_context") + >>> hdr_video = pipe( - ... prompt=prompt, ... reference_conditions=[ref_cond], + ... connector_video_embeds=connector_video_embeds, + ... connector_audio_embeds=connector_audio_embeds, ... width=768, ... height=512, ... num_frames=121, @@ -105,7 +115,11 @@ class LTX2HDRReferenceCondition: ... )[0] >>> # `hdr_video` is a linear HDR tensor of shape (batch, frames, H, W, C). + >>> # Save the HDR video as per-frame EXR files in the specified directory. >>> save_hdr_video_frames_as_exr(hdr_video[0], "hdr_output/") + >>> # You can convert these EXR files to .mp4 file as below. + >>> # A custom tone-mapper can be specified via the `tone_mapping_fn` argument. + >>> encode_exr_sequence_to_mp4("hdr_output/", "ltx2_hdr_lora_output.mp4", frame_rate=24.0) ``` """ From 4938f67efa144bc2efdcf6713daf8d8b3738a6b1 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 27 Apr 2026 07:23:51 +0200 Subject: [PATCH 07/14] Clean up the code a bit --- src/diffusers/pipelines/ltx2/export_utils.py | 14 ++++++-------- .../pipelines/ltx2/pipeline_ltx2_hdr_lora.py | 7 ------- .../pipelines/ltx2/pipeline_ltx2_ic_lora.py | 7 ------- 3 files changed, 6 insertions(+), 22 deletions(-) diff --git a/src/diffusers/pipelines/ltx2/export_utils.py b/src/diffusers/pipelines/ltx2/export_utils.py index fb5a48fd4e44..4b81d449e8cf 100644 --- a/src/diffusers/pipelines/ltx2/export_utils.py +++ b/src/diffusers/pipelines/ltx2/export_utils.py @@ -193,14 +193,8 @@ def encode_video( container.close() -# --------------------------------------------------------------------------- -# HDR export helpers (used with LTX2HDRLoraPipeline). -# -# These mirror the reference CLI's `save_exr_tensor`, `_linear_to_srgb`, and -# `encode_exr_sequence_to_mp4` in `ltx_pipelines.utils.media_io`. -# --------------------------------------------------------------------------- - - +# Adapted from ltx_pipelines.utils.media_io.save_exr_tensor +# https://github.com/Lightricks/LTX-2/blob/41d924371612b692c0fd1e4d9d94c3dfb3c02cb3/packages/ltx-pipelines/src/ltx_pipelines/utils/media_io.py#L609 def save_exr_tensor( tensor: torch.Tensor | np.ndarray, file_path: str | Path, @@ -275,6 +269,8 @@ def simple_tone_map(x: np.ndarray) -> np.ndarray: return np.clip(x, 0.0, 1.0) +# Adapted from ltx_pipelines.utils.medio_io._linear_to_srgb +# https://github.com/Lightricks/LTX-2/blob/41d924371612b692c0fd1e4d9d94c3dfb3c02cb3/packages/ltx-pipelines/src/ltx_pipelines/utils/media_io.py#L644 def linear_to_srgb(x: np.ndarray) -> np.ndarray: r""" Apply the sRGB (Rec.709) transfer function (OETF; IEC 61966-2-1) to a linear light image. Input values must be in @@ -283,6 +279,8 @@ def linear_to_srgb(x: np.ndarray) -> np.ndarray: return np.where(x <= 0.0031308, x * 12.92, 1.055 * np.power(x, 1.0 / 2.4) - 0.055) +# Adapted from ltx_pipelines.utils.medio_io.encode_exr_sequence_to_mp4 +# https://github.com/Lightricks/LTX-2/blob/41d924371612b692c0fd1e4d9d94c3dfb3c02cb3/packages/ltx-pipelines/src/ltx_pipelines/utils/media_io.py#L650 def encode_exr_sequence_to_mp4( exr_dir: str | Path, output_mp4: str | Path, diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_hdr_lora.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_hdr_lora.py index 686d929f9505..3c8b9436cd36 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_hdr_lora.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_hdr_lora.py @@ -940,13 +940,6 @@ def prepare_reference_latents( resize at the reference resolution), VAE-encoded, packed into tokens, and paired with positional coordinates computed at the reference latent dimensions and scaled by `reference_downscale_factor`. - NOTE: As of the HDR LoRA reference-token refactor, this method is a back-compat shim — the canonical - encoding helper is `_encode_reference_conditions` and reference tokens are folded into the main noisy - sequence by `prepare_latents`. This method exists for callers that want the standalone encoding output - (e.g. for downstream parity instrumentation). The `reference_denoise_factors` it returns are derivable - as `1 - strength` per token; in the integrated path the equivalent information lives in - `conditioning_mask` produced by `prepare_latents`. - Returns a 3-tuple `(reference_latents, reference_coords, reference_denoise_factors)` with the same shapes as [`LTX2ICLoraPipeline.prepare_reference_latents`]. """ diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_ic_lora.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_ic_lora.py index 9252227893ef..0e48b3ca958f 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_ic_lora.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_ic_lora.py @@ -1331,13 +1331,6 @@ def prepare_reference_latents( reference video (downsampled to the reference video's latent dimensions) and returned so callers can build a self-attention mask over the full video sequence. - NOTE: As of the IC-LoRA reference-token refactor, this method is a back-compat shim — the canonical encoding - helper is `_encode_reference_conditions` and reference tokens are folded into the main noisy sequence by - `prepare_latents`. This method exists for callers that want the standalone encoding output (e.g. for - downstream parity instrumentation). The `reference_denoise_factors` it returns are derivable as - `1 - strength` per token; in the integrated path the equivalent information lives in - `conditioning_mask` produced by `prepare_latents`. - Args: reference_conditions (`list[LTX2ReferenceCondition]`): The reference video conditions. From e5414c898ca14bc01f224d224c368f98f24a00b8 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 27 Apr 2026 07:26:15 +0200 Subject: [PATCH 08/14] make style and make quality --- .../models/transformers/transformer_ltx2.py | 9 +- src/diffusers/pipelines/ltx2/__init__.py | 6 +- src/diffusers/pipelines/ltx2/export_utils.py | 47 +++-- .../pipelines/ltx2/image_processor.py | 4 +- .../pipelines/ltx2/pipeline_ltx2_condition.py | 44 ++--- .../pipelines/ltx2/pipeline_ltx2_hdr_lora.py | 89 ++++----- .../pipelines/ltx2/pipeline_ltx2_ic_lora.py | 184 ++++++++---------- 7 files changed, 169 insertions(+), 214 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_ltx2.py b/src/diffusers/models/transformers/transformer_ltx2.py index efd75100ee67..5f8c1063cfa9 100644 --- a/src/diffusers/models/transformers/transformer_ltx2.py +++ b/src/diffusers/models/transformers/transformer_ltx2.py @@ -1376,11 +1376,10 @@ def forward( audio_encoder_attention_mask (`torch.Tensor`, *optional*): Optional multiplicative text attention mask of shape `(batch_size, text_seq_len)` for audio modeling. video_self_attention_mask (`torch.Tensor`, *optional*): - Optional multiplicative self-attention mask of shape `(batch_size, num_video_tokens, - num_video_tokens)` applied to the video self-attention in each transformer block. Values in `[0, 1]` - where `1` means full attention and `0` means masked. Used e.g. by the IC-LoRA pipeline to control - attention strength between noisy tokens and appended reference tokens. Audio self-attention is not - affected. + Optional multiplicative self-attention mask of shape `(batch_size, num_video_tokens, num_video_tokens)` + applied to the video self-attention in each transformer block. Values in `[0, 1]` where `1` means full + attention and `0` means masked. Used e.g. by the IC-LoRA pipeline to control attention strength between + noisy tokens and appended reference tokens. Audio self-attention is not affected. num_frames (`int`, *optional*): The number of latent video frames. Used if calculating the video coordinates for RoPE. height (`int`, *optional*): diff --git a/src/diffusers/pipelines/ltx2/__init__.py b/src/diffusers/pipelines/ltx2/__init__.py index 3781f556acae..87ab0827c246 100644 --- a/src/diffusers/pipelines/ltx2/__init__.py +++ b/src/diffusers/pipelines/ltx2/__init__.py @@ -23,12 +23,12 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["connectors"] = ["LTX2TextConnectors"] + _import_structure["image_processor"] = ["LTX2VideoHDRProcessor"] _import_structure["latent_upsampler"] = ["LTX2LatentUpsamplerModel"] _import_structure["pipeline_ltx2"] = ["LTX2Pipeline"] _import_structure["pipeline_ltx2_condition"] = ["LTX2ConditionPipeline"] - _import_structure["pipeline_ltx2_ic_lora"] = ["LTX2ICLoraPipeline", "LTX2ReferenceCondition"] _import_structure["pipeline_ltx2_hdr_lora"] = ["LTX2HDRLoraPipeline", "LTX2HDRReferenceCondition"] - _import_structure["image_processor"] = ["LTX2VideoHDRProcessor"] + _import_structure["pipeline_ltx2_ic_lora"] = ["LTX2ICLoraPipeline", "LTX2ReferenceCondition"] _import_structure["pipeline_ltx2_image2video"] = ["LTX2ImageToVideoPipeline"] _import_structure["pipeline_ltx2_latent_upsample"] = ["LTX2LatentUpsamplePipeline"] _import_structure["vocoder"] = ["LTX2Vocoder", "LTX2VocoderWithBWE"] @@ -42,10 +42,10 @@ from ...utils.dummy_torch_and_transformers_objects import * else: from .connectors import LTX2TextConnectors + from .image_processor import LTX2VideoHDRProcessor from .latent_upsampler import LTX2LatentUpsamplerModel from .pipeline_ltx2 import LTX2Pipeline from .pipeline_ltx2_condition import LTX2ConditionPipeline - from .image_processor import LTX2VideoHDRProcessor from .pipeline_ltx2_hdr_lora import LTX2HDRLoraPipeline, LTX2HDRReferenceCondition from .pipeline_ltx2_ic_lora import LTX2ICLoraPipeline, LTX2ReferenceCondition from .pipeline_ltx2_image2video import LTX2ImageToVideoPipeline diff --git a/src/diffusers/pipelines/ltx2/export_utils.py b/src/diffusers/pipelines/ltx2/export_utils.py index 4b81d449e8cf..87ebf1eab885 100644 --- a/src/diffusers/pipelines/ltx2/export_utils.py +++ b/src/diffusers/pipelines/ltx2/export_utils.py @@ -205,17 +205,17 @@ def save_exr_tensor( Args: tensor (`torch.Tensor` or `np.ndarray`): - A float frame of shape `(H, W, C)` or `(C, H, W)` with linear HDR values in `[0, ∞)`. Channels are - assumed to be RGB. + A float frame of shape `(H, W, C)` or `(C, H, W)` with linear HDR values in `[0, ∞)`. Channels are assumed + to be RGB. file_path (`str` or `pathlib.Path`): Output EXR path (e.g. `frame_00000.exr`). half (`bool`, *optional*, defaults to `False`): - When `True`, writes the file as `float16` (HALF) with ZIP compression. `float16` tensors are always - saved as HALF regardless of this flag. + When `True`, writes the file as `float16` (HALF) with ZIP compression. `float16` tensors are always saved + as HALF regardless of this flag. - The resulting EXR is tagged with Rec.709/sRGB chromaticities and `colorSpace=sRGB` to match the reference. - Requires [OpenImageIO](https://openimageio.readthedocs.io) with OpenEXR support: - `pip install OpenImageIO` (or `pip install oiio`). + The resulting EXR is tagged with Rec.709/sRGB chromaticities and `colorSpace=sRGB` to match the reference. Requires + [OpenImageIO](https://openimageio.readthedocs.io) with OpenEXR support: `pip install OpenImageIO` (or `pip install + oiio`). """ try: import OpenImageIO @@ -241,9 +241,7 @@ def save_exr_tensor( spec = OpenImageIO.ImageSpec(w, h, 3, fmt) spec.channelnames = ("R", "G", "B") spec.attribute("compression", "zip") - spec.attribute( - "chromaticities", "float[8]", (0.64, 0.33, 0.30, 0.60, 0.15, 0.06, 0.3127, 0.3290) - ) + spec.attribute("chromaticities", "float[8]", (0.64, 0.33, 0.30, 0.60, 0.15, 0.06, 0.3127, 0.3290)) spec.attribute("colorSpace", "sRGB") out = OpenImageIO.ImageOutput.create(file_path) @@ -263,8 +261,8 @@ def save_exr_tensor( def simple_tone_map(x: np.ndarray) -> np.ndarray: r""" Applies a very simple tone-mapping function on (scene-referred) linear light which simply clips values above `1.0` - to `1.0`. This is what the original LTX-2.X code does, but you probably want to do some non-trivial tone-mapping - to make the sample look better. + to `1.0`. This is what the original LTX-2.X code does, but you probably want to do some non-trivial tone-mapping to + make the sample look better. """ return np.clip(x, 0.0, 1.0) @@ -292,8 +290,8 @@ def encode_exr_sequence_to_mp4( r""" Convert a linear-HDR EXR frame sequence into an sRGB-tonemapped H.264 `.mp4` preview. - Each EXR frame is loaded, clipped to `[0, 1]`, passed through the sRGB OETF (no exposure/gain, EV=0), quantized - to 8-bit, and fed into a libx264 stream at the supplied `frame_rate`. + Each EXR frame is loaded, clipped to `[0, 1]`, passed through the sRGB OETF (no exposure/gain, EV=0), quantized to + 8-bit, and fed into a libx264 stream at the supplied `frame_rate`. Args: exr_dir (`str` or `pathlib.Path`): @@ -303,23 +301,24 @@ def encode_exr_sequence_to_mp4( frame_rate (`float`): Frame rate for the output video. tone_mapping_fn (`Callable[[np.ndarray], np.ndarray]`, *optional*, defaults to `None`): - An optional tone mapping function which takes a float32 NumPy array of shape `(H, W, 3)` containing - linear HDR values in `[0, ∞)` and returns tone-mapped linear values in `[0, 1]`. The sRGB transfer - function (OETF) is applied afterwards — do **not** pre-apply gamma inside this function. If `None`, - defaults to [`simple_tone_map`], which clips values above `1.0`. The channel ordering of the input - array is controlled by `tone_map_in_rgb`: BGR by default (matching `opencv-python` conventions), or - RGB when `tone_map_in_rgb=True` (matching `colour-science` and most other libraries). + An optional tone mapping function which takes a float32 NumPy array of shape `(H, W, 3)` containing linear + HDR values in `[0, ∞)` and returns tone-mapped linear values in `[0, 1]`. The sRGB transfer function (OETF) + is applied afterwards — do **not** pre-apply gamma inside this function. If `None`, defaults to + [`simple_tone_map`], which clips values above `1.0`. The channel ordering of the input array is controlled + by `tone_map_in_rgb`: BGR by default (matching `opencv-python` conventions), or RGB when + `tone_map_in_rgb=True` (matching `colour-science` and most other libraries). tone_map_in_rgb (`bool`, *optional*, defaults to `False`): - When `True`, each EXR frame is converted from BGR to RGB before being passed to `tone_mapping_fn`, - and the output frame is tagged as `rgb24`. Use this when `tone_mapping_fn` expects RGB input (e.g. - operators from `colour-science`). When `False` (default), frames are passed as BGR, which is the - native format for `opencv-python` tone mappers (e.g. `cv2.createTonemapReinhard().process`). + When `True`, each EXR frame is converted from BGR to RGB before being passed to `tone_mapping_fn`, and the + output frame is tagged as `rgb24`. Use this when `tone_mapping_fn` expects RGB input (e.g. operators from + `colour-science`). When `False` (default), frames are passed as BGR, which is the native format for + `opencv-python` tone mappers (e.g. `cv2.createTonemapReinhard().process`). crf (`int`, *optional*, defaults to `18`): libx264 CRF quality factor. Lower values produce higher quality. Requires `opencv-python` (for EXR reading via `OPENCV_IO_ENABLE_OPENEXR`). """ import os + os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" try: diff --git a/src/diffusers/pipelines/ltx2/image_processor.py b/src/diffusers/pipelines/ltx2/image_processor.py index 2de3bb4998e9..c9569529db27 100644 --- a/src/diffusers/pipelines/ltx2/image_processor.py +++ b/src/diffusers/pipelines/ltx2/image_processor.py @@ -81,8 +81,8 @@ def _resize_and_reflect_pad_video(video: torch.Tensor, height: int, width: int) r""" Resize a video tensor preserving aspect ratio, then reflect-pad to the exact target dimensions. - Mirrors `resize_and_reflect_pad` in the reference `ltx_pipelines.utils.media_io`. When the source is already - at least as large as the target in both dimensions, the interpolation step is skipped entirely. + Mirrors `resize_and_reflect_pad` in the reference `ltx_pipelines.utils.media_io`. When the source is already at + least as large as the target in both dimensions, the interpolation step is skipped entirely. Args: video (`torch.Tensor`): Input of shape `(B, C, F, H, W)`. diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py index 441c69995bd3..37ba28de4b88 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py @@ -718,9 +718,7 @@ def preprocess_conditions( # Create a channels-last video-like array of shape (F, H, W, C) in preparation for resizing. if isinstance(condition.frames, PIL.Image.Image): arr = np.array(condition.frames.convert("RGB"))[None] # (1, H, W, 3) - elif isinstance(condition.frames, list) and all( - isinstance(f, PIL.Image.Image) for f in condition.frames - ): + elif isinstance(condition.frames, list) and all(isinstance(f, PIL.Image.Image) for f in condition.frames): arr = np.stack([np.array(f.convert("RGB")) for f in condition.frames]) # (F, H, W, 3) elif isinstance(condition.frames, np.ndarray): arr = condition.frames if condition.frames.ndim == 4 else condition.frames[None] @@ -730,9 +728,7 @@ def preprocess_conditions( # resize logic, which expects channels-last. arr = t.detach().cpu().permute(0, 2, 3, 1).numpy() else: - raise TypeError( - f"Unsupported `frames` type for condition {i}: {type(condition.frames)}" - ) + raise TypeError(f"Unsupported `frames` type for condition {i}: {type(condition.frames)}") src_h, src_w = arr.shape[1], arr.shape[2] num_cond_frames = arr.shape[0] @@ -749,9 +745,7 @@ def preprocess_conditions( # NOTE: we avoid using VideoProcessor.preprocess_video here because it uses PIL.Image.resize under the # hood, which will apply an anti-aliasing pre-filter when downsampling. The original LTX-2.X code simply # uses F.interpolate, which is reproduced here. - pixels = torch.nn.functional.interpolate( - pixels, size=(new_h, new_w), mode="bilinear", align_corners=False - ) + pixels = torch.nn.functional.interpolate(pixels, size=(new_h, new_w), mode="bilinear", align_corners=False) top = (new_h - height) // 2 left = (new_w - width) // 2 pixels = pixels[:, :, top : top + height, left : left + width] @@ -799,8 +793,8 @@ def apply_first_frame_conditioning( Apply first-frame visual conditioning by overwriting tokens at the first-frame positions. Only conditions with `latent_idx == 0` are applied here (matching `VideoConditionByLatentIndex` in the - reference implementation). Conditions at non-zero latent indices are appended as separate keyframe tokens - via `prepare_keyframe_extras` (matching `VideoConditionByKeyframeIndex`) and are skipped here. + reference implementation). Conditions at non-zero latent indices are appended as separate keyframe tokens via + `prepare_keyframe_extras` (matching `VideoConditionByKeyframeIndex`) and are skipped here. Args: latents (`torch.Tensor`): @@ -846,8 +840,8 @@ def _prepare_keyframe_coords( Compute positional coordinates for a keyframe condition being appended as extra tokens. Mirrors `VideoConditionByKeyframeIndex.apply_to` in the reference implementation: - - Latent coords scaled to pixel space *without* the causal fix (since non-zero-index keyframes don't need - the first-frame causal adjustment). + - Latent coords scaled to pixel space *without* the causal fix (since non-zero-index keyframes don't need the + first-frame causal adjustment). - Temporal axis offset by `pixel_frame_idx` (the pixel-space index at which the keyframe appears). - For single-pixel-frame keyframes, the per-patch temporal extent is clamped to `[idx, idx + 1)` so the keyframe occupies a single pixel timestep rather than the VAE-scaled range. @@ -864,12 +858,8 @@ def _prepare_keyframe_coords( grid_f = torch.arange( start=0, end=keyframe_latent_num_frames, step=patch_size_t, dtype=torch.float32, device=device ) - grid_h = torch.arange( - start=0, end=keyframe_latent_height, step=patch_size, dtype=torch.float32, device=device - ) - grid_w = torch.arange( - start=0, end=keyframe_latent_width, step=patch_size, dtype=torch.float32, device=device - ) + grid_h = torch.arange(start=0, end=keyframe_latent_height, step=patch_size, dtype=torch.float32, device=device) + grid_w = torch.arange(start=0, end=keyframe_latent_width, step=patch_size, dtype=torch.float32, device=device) grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing="ij") grid = torch.stack(grid, dim=0) @@ -896,7 +886,6 @@ def _prepare_keyframe_coords( return pixel_coords - def prepare_latents( self, conditions: LTX2VideoCondition | list[LTX2VideoCondition] | None = None, @@ -916,17 +905,17 @@ def prepare_latents( Prepare noisy video latents, applying frame conditions. First-frame conditions (`latent_idx == 0`) are applied by overwriting tokens at the first-frame positions - (`VideoConditionByLatentIndex` semantics). Non-first-frame conditions (`latent_idx > 0`) are concatenated - onto the main latent sequence with per-token `conditioning_mask = strength` - (`VideoConditionByKeyframeIndex` semantics) — the denoising loop's existing timestep formula - `t * (1 - conditioning_mask)` and post-process blend - `denoised * (1 - conditioning_mask) + clean * conditioning_mask` then drive them across steps. + (`VideoConditionByLatentIndex` semantics). Non-first-frame conditions (`latent_idx > 0`) are concatenated onto + the main latent sequence with per-token `conditioning_mask = strength` (`VideoConditionByKeyframeIndex` + semantics) — the denoising loop's existing timestep formula `t * (1 - conditioning_mask)` and post-process + blend `denoised * (1 - conditioning_mask) + clean * conditioning_mask` then drive them across steps. Returns a 4-tuple: - `latents`: packed noisy latents (base tokens + any keyframe tokens cat'd onto the sequence dim). - `conditioning_mask`: packed conditioning mask with values in `[0, 1]` — `1` at first-frame positions, `strength` at keyframe positions, `0` elsewhere. - - `clean_latents`: clean condition values at conditioned positions (zeros elsewhere); same shape as `latents`. + - `clean_latents`: clean condition values at conditioned positions (zeros elsewhere); same shape as + `latents`. - `keyframe_coords`: `[B, 3, num_keyframe_patches, 2]` positional coordinates to append to `video_coords`, or `None` if there are no non-first-frame conditions. """ @@ -1740,8 +1729,7 @@ def __call__( # NOTE: this operation should be applied in sample (x0) space and not velocity space (which is the # space the denoising model outputs are in) denoised_sample_cond = ( - noise_pred_video * (1 - conditioning_mask[:bsz]) - + clean_latents * conditioning_mask[:bsz] + noise_pred_video * (1 - conditioning_mask[:bsz]) + clean_latents * conditioning_mask[:bsz] ).to(noise_pred_video.dtype) # Convert the denoised (x0) sample back to a velocity for the scheduler diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_hdr_lora.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_hdr_lora.py index 3c8b9436cd36..a4b249a42d46 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_hdr_lora.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_hdr_lora.py @@ -51,11 +51,11 @@ class LTX2HDRReferenceCondition: r""" A reference video condition for HDR IC-LoRA conditioning. - The reference video is encoded into latent tokens and concatenated to the noisy latent sequence during - denoising, allowing the HDR IC-LoRA adapter to condition the generation on the reference video content. + The reference video is encoded into latent tokens and concatenated to the noisy latent sequence during denoising, + allowing the HDR IC-LoRA adapter to condition the generation on the reference video content. - Matches the `(video_path, strength)` tuples consumed by the reference `HDRICLoraPipeline`'s - `video_conditioning` argument. + Matches the `(video_path, strength)` tuples consumed by the reference `HDRICLoraPipeline`'s `video_conditioning` + argument. Attributes: frames (`PIL.Image.Image` or `List[PIL.Image.Image]` or `np.ndarray` or `torch.Tensor`): @@ -80,15 +80,13 @@ class LTX2HDRReferenceCondition: >>> from diffusers.pipelines.ltx2.export_utils import save_hdr_video_frames_as_exr, encode_exr_sequence_to_mp4 >>> from diffusers.utils import load_video - >>> pipe = LTX2HDRLoraPipeline.from_pretrained( - ... "dg845/LTX-2.3-Distilled-Diffusers", torch_dtype=torch.bfloat16 - ... ) + >>> pipe = LTX2HDRLoraPipeline.from_pretrained("dg845/LTX-2.3-Distilled-Diffusers", torch_dtype=torch.bfloat16) >>> pipe.enable_sequential_cpu_offload(device="cuda") >>> pipe.load_lora_weights( - >>> "Lightricks/LTX-2.3-22b-IC-LoRA-HDR", - >>> adapter_name="hdr_lora", - >>> weight_name="ltx-2.3-22b-ic-lora-hdr-0.9.safetensors", - >>> ) + ... "Lightricks/LTX-2.3-22b-IC-LoRA-HDR", + ... adapter_name="hdr_lora", + ... weight_name="ltx-2.3-22b-ic-lora-hdr-0.9.safetensors", + ... ) >>> pipe.set_adapters("hdr_lora", 1.0) >>> reference_video = load_video("/path/to/reference.mp4") @@ -96,8 +94,8 @@ class LTX2HDRReferenceCondition: >>> # Load pre-computed HDR LoRA connector embeddings. >>> with safe_open("/path/to/connector/embeds.safetensors", framework="pt", device="cuda") as f: - >>> connector_video_embeds = f.get_tensor("video_context") - >>> connector_audio_embeds = f.get_tensor("audio_context") + ... connector_video_embeds = f.get_tensor("video_context") + ... connector_audio_embeds = f.get_tensor("audio_context") >>> hdr_video = pipe( ... reference_conditions=[ref_cond], @@ -243,25 +241,24 @@ class LTX2HDRLoraPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoader r""" Pipeline for HDR IC-LoRA video generation with reference video conditioning. - This is a video-only HDR counterpart to [`LTX2ICLoraPipeline`]. The HDR IC-LoRA adapter (loaded as a standard - LoRA via `load_lora_weights`) conditions generation on a reference video, and the pipeline's postprocessing - applies the LogC3 inverse transform to produce linear HDR output in `[0, ∞)`. + This is a video-only HDR counterpart to [`LTX2ICLoraPipeline`]. The HDR IC-LoRA adapter (loaded as a standard LoRA + via `load_lora_weights`) conditions generation on a reference video, and the pipeline's postprocessing applies the + LogC3 inverse transform to produce linear HDR output in `[0, ∞)`. Compared to [`LTX2ICLoraPipeline`], the HDR pipeline drops: - Frame-level keyframe conditioning (the reference HDR pipeline does not support this). - The `conditioning_attention_strength` / `conditioning_attention_mask` knobs. - Audio output (video-only). The transformer's audio branch is still run since the diffusers transformer API - requires audio inputs, but the decoded audio is discarded and audio-specific guidance scales are fixed to - no-op values to avoid wasted compute. + requires audio inputs, but the decoded audio is discarded and audio-specific guidance scales are fixed to no-op + values to avoid wasted compute. Two-stage inference is supported through separate calls to `__call__`: - - **Stage 1**: generate video latents at target resolution with HDR IC-LoRA conditioning - (`output_type="latent"`). - - **Stage 2**: upsample via [`LTX2LatentUpsamplePipeline`] and refine with this same pipeline (or - [`LTX2Pipeline`]) by passing `latents=upsampled_latents`. The reference HDR stage-2 additionally supports - spatial/temporal tiling of the refinement pass — that optimization is not yet implemented here. + - **Stage 1**: generate video latents at target resolution with HDR IC-LoRA conditioning (`output_type="latent"`). + - **Stage 2**: upsample via [`LTX2LatentUpsamplePipeline`] and refine with this same pipeline (or [`LTX2Pipeline`]) + by passing `latents=upsampled_latents`. The reference HDR stage-2 additionally supports spatial/temporal tiling + of the refinement pass — that optimization is not yet implemented here. Reference: https://github.com/Lightricks/LTX-2 @@ -704,16 +701,16 @@ def prepare_latents( Builds a packed latent sequence in the order `[base | reference]`: - Base: either fresh noise (Stage 1, `latents=None`) or pre-existing upsampled latents (Stage 2). - Reference: HDR-encoded reference-video tokens appended with per-token `conditioning_mask = strength`, - following the same pattern as [`LTX2ICLoraPipeline.prepare_latents`]. (HDR LoRA does not currently - take per-frame `conditions`, so there is no first-frame / keyframe block in between.) + following the same pattern as [`LTX2ICLoraPipeline.prepare_latents`]. (HDR LoRA does not currently take + per-frame `conditions`, so there is no first-frame / keyframe block in between.) Returns a 6-tuple matching [`LTX2ICLoraPipeline.prepare_latents`]: - `latents`: packed noisy latents `(B, base + n_ref, C)`. - `conditioning_mask`: `(B, seq_len, 1)` with `strength` at reference positions, `0` elsewhere. - `clean_latents`: clean reference values at reference positions (zeros elsewhere); same shape as `latents`. - - `appended_coords`: `[1, 3, n_ref, 2]` reference coordinates to concat onto `video_coords`, or - `None` when no reference conditions are provided. + - `appended_coords`: `[1, 3, n_ref, 2]` reference coordinates to concat onto `video_coords`, or `None` when + no reference conditions are provided. - `num_ref_tokens`: count of reference tokens at the END of `latents`. - `ref_cross_mask`: always `None` for HDR LoRA (no cross-attention masking support). """ @@ -853,14 +850,16 @@ def _encode_reference_conditions( ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: """Encode HDR IC-LoRA reference videos into `(reference_latents, reference_coords, reference_cross_mask)`. - Shared encoding core used by both `prepare_latents` (which folds reference tokens into the main noisy - sequence) and the back-compat shim `prepare_reference_latents`. HDR LoRA does not currently support - cross-attention masking for reference tokens, so the third return is always `None`. + Shared encoding core used by both `prepare_latents` (which folds reference tokens into the main noisy sequence) + and the back-compat shim `prepare_reference_latents`. HDR LoRA does not currently support cross-attention + masking for reference tokens, so the third return is always `None`. """ ref_height = height // reference_downscale_factor ref_width = width // reference_downscale_factor - if reference_downscale_factor != 1 and (height % reference_downscale_factor != 0 or width % reference_downscale_factor != 0): + if reference_downscale_factor != 1 and ( + height % reference_downscale_factor != 0 or width % reference_downscale_factor != 0 + ): raise ValueError( f"Output dimensions ({height}x{width}) must be divisible by reference_downscale_factor " f"({reference_downscale_factor})." @@ -882,18 +881,14 @@ def _encode_reference_conditions( # HDR-specific preprocessing: reflect-pad resize (vs center-crop in the standard IC-LoRA pipeline). # For LDR reference videos the numerical output of `preprocess_reference_video_hdr` is identical to the # standard [-1, 1] normalization since LogC3's `compress_ldr` is an identity clamp. - ref_pixels = self.hdr_video_processor.preprocess_reference_video_hdr( - video_like, ref_height, ref_width - ) + ref_pixels = self.hdr_video_processor.preprocess_reference_video_hdr(video_like, ref_height, ref_width) ref_pixels = ref_pixels[:, :, :num_frames] ref_pixels = ref_pixels.to(dtype=self.vae.dtype, device=device) - ref_latent = retrieve_latents( - self.vae.encode(ref_pixels), generator=generator, sample_mode="argmax" + ref_latent = retrieve_latents(self.vae.encode(ref_pixels), generator=generator, sample_mode="argmax") + ref_latent = self._normalize_latents(ref_latent, self.vae.latents_mean, self.vae.latents_std).to( + device=device, dtype=dtype ) - ref_latent = self._normalize_latents( - ref_latent, self.vae.latents_mean, self.vae.latents_std - ).to(device=device, dtype=dtype) _, _, ref_latent_frames, ref_latent_height, ref_latent_width = ref_latent.shape @@ -961,9 +956,7 @@ def prepare_reference_latents( n_total = reference_latents.shape[1] n_per_ref = n_total // max(len(reference_conditions), 1) denoise_chunks = [ - torch.full( - (1, n_per_ref), 1.0 - ref_cond.strength, device=reference_latents.device, dtype=torch.float32 - ) + torch.full((1, n_per_ref), 1.0 - ref_cond.strength, device=reference_latents.device, dtype=torch.float32) for ref_cond in reference_conditions ] reference_denoise_factors = ( @@ -1086,8 +1079,8 @@ def __call__( reference_conditions (`LTX2HDRReferenceCondition` or `List[LTX2HDRReferenceCondition]`, *optional*): Reference video conditions for HDR IC-LoRA conditioning. reference_downscale_factor (`int`, *optional*, defaults to `1`): - Ratio between target and reference video resolutions. IC-LoRA models trained with downscaled - reference videos store this factor in their safetensors metadata. + Ratio between target and reference video resolutions. IC-LoRA models trained with downscaled reference + videos store this factor in their safetensors metadata. height (`int`, *optional*, defaults to `512`): Output video height in pixels. Must be divisible by 32. width (`int`, *optional*, defaults to `768`): @@ -1496,9 +1489,7 @@ def __call__( noise_pred_video_uncond_mod = self.convert_velocity_to_x0( latents, noise_pred_video_uncond_mod, i, self.scheduler ) - video_modality_delta = (self.modality_scale - 1) * ( - noise_pred_video - noise_pred_video_uncond_mod - ) + video_modality_delta = (self.modality_scale - 1) * (noise_pred_video - noise_pred_video_uncond_mod) else: video_modality_delta = 0 @@ -1526,9 +1517,7 @@ def __call__( # Step the audio scheduler so its internal state stays in sync with the video scheduler (audio # output is discarded at the end, but keeping schedulers aligned avoids surprising behavior if the # scheduler writes internal indices during `.step()`). - _ = audio_scheduler.step( - torch.zeros_like(audio_latents), t, audio_latents, return_dict=False - )[0] + _ = audio_scheduler.step(torch.zeros_like(audio_latents), t, audio_latents, return_dict=False)[0] if callback_on_step_end is not None: callback_kwargs = {} diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_ic_lora.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_ic_lora.py index 0e48b3ca958f..8ad7616df133 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_ic_lora.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_ic_lora.py @@ -79,15 +79,13 @@ class LTX2ReferenceCondition: >>> from diffusers.pipelines.ltx2.utils import DEFAULT_NEGATIVE_PROMPT >>> from diffusers.utils import load_video - >>> pipe = LTX2ICLoraPipeline.from_pretrained( - ... "dg845/LTX-2.3-Diffusers", torch_dtype=torch.bfloat16 - ... ) + >>> pipe = LTX2ICLoraPipeline.from_pretrained("dg845/LTX-2.3-Diffusers", torch_dtype=torch.bfloat16) >>> pipe.enable_sequential_cpu_offload(device="cuda") >>> pipe.load_lora_weights( - >>> "Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In", - >>> adapter_name="ic_lora", - >>> weight_name="ltx-2-19b-lora-camera-control-dolly-in.safetensors", - >>> ) + ... "Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In", + ... adapter_name="ic_lora", + ... weight_name="ltx-2-19b-lora-camera-control-dolly-in.safetensors", + ... ) >>> pipe.set_adapters("ic_lora", 1.0) >>> # If the IC LoRA uses reference conditions, you can specify them as follows: @@ -240,8 +238,8 @@ class LTX2ICLoraPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderM r""" Pipeline for IC-LoRA (In-Context LoRA) video generation with reference video conditioning. - IC-LoRA conditions the generation on a reference video by encoding it into latent tokens and concatenating them - to the noisy latent sequence during denoising. The IC-LoRA adapter (loaded as a standard LoRA) learns to use this + IC-LoRA conditions the generation on a reference video by encoding it into latent tokens and concatenating them to + the noisy latent sequence during denoising. The IC-LoRA adapter (loaded as a standard LoRA) learns to use this in-context reference to guide generation (e.g. for style transfer, depth-conditioned generation, etc.). This pipeline also supports frame-level conditioning via the `conditions` parameter (same as @@ -249,8 +247,8 @@ class LTX2ICLoraPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderM Two-stage inference is supported through separate calls to `__call__`: - **Stage 1**: Generate at target resolution with IC-LoRA conditioning (`output_type="latent"`). - - **Stage 2**: Upsample via [`LTX2LatentUpsamplePipeline`], then refine with a distilled LoRA (no IC-LoRA - reference conditioning needed for Stage 2). + - **Stage 2**: Upsample via [`LTX2LatentUpsamplePipeline`], then refine with a distilled LoRA (no IC-LoRA reference + conditioning needed for Stage 2). Reference: https://github.com/Lightricks/LTX-Video @@ -750,9 +748,7 @@ def preprocess_conditions( # Create a channels-last video-like array of shape (F, H, W, C) in preparation for resizing. if isinstance(condition.frames, PIL.Image.Image): arr = np.array(condition.frames.convert("RGB"))[None] # (1, H, W, 3) - elif isinstance(condition.frames, list) and all( - isinstance(f, PIL.Image.Image) for f in condition.frames - ): + elif isinstance(condition.frames, list) and all(isinstance(f, PIL.Image.Image) for f in condition.frames): arr = np.stack([np.array(f.convert("RGB")) for f in condition.frames]) # (F, H, W, 3) elif isinstance(condition.frames, np.ndarray): arr = condition.frames if condition.frames.ndim == 4 else condition.frames[None] @@ -762,9 +758,7 @@ def preprocess_conditions( # resize logic, which expects channels-last. arr = t.detach().cpu().permute(0, 2, 3, 1).numpy() else: - raise TypeError( - f"Unsupported `frames` type for condition {i}: {type(condition.frames)}" - ) + raise TypeError(f"Unsupported `frames` type for condition {i}: {type(condition.frames)}") src_h, src_w = arr.shape[1], arr.shape[2] num_cond_frames = arr.shape[0] @@ -781,9 +775,7 @@ def preprocess_conditions( # NOTE: we avoid using VideoProcessor.preprocess_video here because it uses PIL.Image.resize under the # hood, which will apply an anti-aliasing pre-filter when downsampling. The original LTX-2.X code simply # uses F.interpolate, which is reproduced here. - pixels = torch.nn.functional.interpolate( - pixels, size=(new_h, new_w), mode="bilinear", align_corners=False - ) + pixels = torch.nn.functional.interpolate(pixels, size=(new_h, new_w), mode="bilinear", align_corners=False) top = (new_h - height) // 2 left = (new_w - width) // 2 pixels = pixels[:, :, top : top + height, left : left + width] @@ -832,8 +824,8 @@ def apply_first_frame_conditioning( Apply first-frame visual conditioning by overwriting tokens at the first-frame positions. Only conditions with `latent_idx == 0` are applied here (matching `VideoConditionByLatentIndex` in the - reference implementation). Conditions at non-zero latent indices are appended as separate keyframe tokens - via `prepare_keyframe_extras` (matching `VideoConditionByKeyframeIndex`) and are skipped here. + reference implementation). Conditions at non-zero latent indices are appended as separate keyframe tokens via + `prepare_keyframe_extras` (matching `VideoConditionByKeyframeIndex`) and are skipped here. Args: latents (`torch.Tensor`): @@ -880,8 +872,8 @@ def _prepare_keyframe_coords( Compute positional coordinates for a keyframe condition being appended as extra tokens. Mirrors `VideoConditionByKeyframeIndex.apply_to` in the reference implementation: - - Latent coords scaled to pixel space *without* the causal fix (since non-zero-index keyframes don't need - the first-frame causal adjustment). + - Latent coords scaled to pixel space *without* the causal fix (since non-zero-index keyframes don't need the + first-frame causal adjustment). - Temporal axis offset by `pixel_frame_idx` (the pixel-space index at which the keyframe appears). - For single-pixel-frame keyframes, the per-patch temporal extent is clamped to `[idx, idx + 1)` so the keyframe occupies a single pixel timestep rather than the VAE-scaled range. @@ -898,12 +890,8 @@ def _prepare_keyframe_coords( grid_f = torch.arange( start=0, end=keyframe_latent_num_frames, step=patch_size_t, dtype=torch.float32, device=device ) - grid_h = torch.arange( - start=0, end=keyframe_latent_height, step=patch_size, dtype=torch.float32, device=device - ) - grid_w = torch.arange( - start=0, end=keyframe_latent_width, step=patch_size, dtype=torch.float32, device=device - ) + grid_h = torch.arange(start=0, end=keyframe_latent_height, step=patch_size, dtype=torch.float32, device=device) + grid_w = torch.arange(start=0, end=keyframe_latent_width, step=patch_size, dtype=torch.float32, device=device) grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing="ij") grid = torch.stack(grid, dim=0) @@ -952,35 +940,33 @@ def prepare_latents( """ Prepare noisy video latents, applying frame and reference-video conditioning. - Conditioning sources are unified into a single packed sequence in the order - `[base | keyframe | reference]`: + Conditioning sources are unified into a single packed sequence in the order `[base | keyframe | reference]`: - First-frame conditions (`conditions` with `latent_idx == 0`) overwrite tokens at the first-frame positions (`VideoConditionByLatentIndex` semantics). - Non-first-frame conditions (`conditions` with `latent_idx > 0`) are concatenated onto the main latent - sequence with per-token `conditioning_mask = strength` - (`VideoConditionByKeyframeIndex` semantics). + sequence with per-token `conditioning_mask = strength` (`VideoConditionByKeyframeIndex` semantics). - IC-LoRA `reference_conditions` (if any) are encoded by the VAE and appended after the keyframes with - per-token `conditioning_mask = strength` (matching the reference repo's - `VideoConditionByReferenceLatent` semantics). + per-token `conditioning_mask = strength` (matching the reference repo's `VideoConditionByReferenceLatent` + semantics). - For all appended tokens the noise mixing below blends them to noise level `(1 - strength) * sigma_max`, - and the existing per-token timestep formula `t * (1 - conditioning_mask)` and the post-process blend - `denoised * (1 - cond_mask) + clean * cond_mask` drive them through the loop. + For all appended tokens the noise mixing below blends them to noise level `(1 - strength) * sigma_max`, and the + existing per-token timestep formula `t * (1 - conditioning_mask)` and the post-process blend `denoised * (1 - + cond_mask) + clean * cond_mask` drive them through the loop. Returns a 6-tuple: - `latents`: packed noisy latents `(B, base + n_keyframe + n_ref, C)`. - - `conditioning_mask`: `(B, seq_len, 1)` with values in `[0, 1]` — `1` at first-frame positions, - `strength` at keyframe / reference positions, `0` elsewhere. + - `conditioning_mask`: `(B, seq_len, 1)` with values in `[0, 1]` — `1` at first-frame positions, `strength` + at keyframe / reference positions, `0` elsewhere. - `clean_latents`: clean condition values at conditioned positions (zeros elsewhere); same shape as `latents`. - - `appended_coords`: `[1, 3, n_keyframe + n_ref, 2]` positional coordinates to concat onto - `video_coords`, or `None` if no keyframe/reference conditions are provided. - - `num_ref_tokens`: count of reference tokens at the END of `latents` (used by the call site to - build the unified self-attention mask). - - `ref_cross_mask`: `[1, num_ref_tokens]` per-reference-token cross-attention strengths in `[0, 1]`, - or `None` when `conditioning_attention_strength == 1.0` and no pixel-space mask is provided - (in which case attention is uniform). + - `appended_coords`: `[1, 3, n_keyframe + n_ref, 2]` positional coordinates to concat onto `video_coords`, + or `None` if no keyframe/reference conditions are provided. + - `num_ref_tokens`: count of reference tokens at the END of `latents` (used by the call site to build the + unified self-attention mask). + - `ref_cross_mask`: `[1, num_ref_tokens]` per-reference-token cross-attention strengths in `[0, 1]`, or + `None` when `conditioning_attention_strength == 1.0` and no pixel-space mask is provided (in which case + attention is uniform). """ latent_height = height // self.vae_spatial_compression_ratio latent_width = width // self.vae_spatial_compression_ratio @@ -1216,9 +1202,9 @@ def _encode_reference_conditions( ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: """Encode IC-LoRA reference videos into `(reference_latents, reference_coords, reference_cross_mask)`. - This is the shared encoding core used by both `prepare_latents` (which folds reference tokens into the - main noisy sequence) and the back-compat shim `prepare_reference_latents` (which exposes the legacy - 4-tuple output). See `prepare_reference_latents` for parameter documentation. + This is the shared encoding core used by both `prepare_latents` (which folds reference tokens into the main + noisy sequence) and the back-compat shim `prepare_reference_latents` (which exposes the legacy 4-tuple output). + See `prepare_reference_latents` for parameter documentation. """ ref_height = height // reference_downscale_factor ref_width = width // reference_downscale_factor @@ -1240,20 +1226,16 @@ def _encode_reference_conditions( else: video_like = ref_cond.frames - ref_pixels = self.video_processor.preprocess_video( - video_like, ref_height, ref_width, resize_mode="crop" - ) + ref_pixels = self.video_processor.preprocess_video(video_like, ref_height, ref_width, resize_mode="crop") # Trim to num_frames ref_pixels = ref_pixels[:, :, :num_frames] ref_pixels = ref_pixels.to(dtype=self.vae.dtype, device=device) # Encode through VAE - ref_latent = retrieve_latents( - self.vae.encode(ref_pixels), generator=generator, sample_mode="argmax" + ref_latent = retrieve_latents(self.vae.encode(ref_pixels), generator=generator, sample_mode="argmax") + ref_latent = self._normalize_latents(ref_latent, self.vae.latents_mean, self.vae.latents_std).to( + device=device, dtype=dtype ) - ref_latent = self._normalize_latents( - ref_latent, self.vae.latents_mean, self.vae.latents_std - ).to(device=device, dtype=dtype) # Get latent dimensions for coordinate computation _, _, ref_latent_frames, ref_latent_height, ref_latent_width = ref_latent.shape @@ -1326,10 +1308,10 @@ def prepare_reference_latents( Each reference video is independently encoded by the VAE, packed into tokens, and its positional coordinates are computed with spatial scaling by `reference_downscale_factor` to match the target coordinate space. - All reference tokens are concatenated into a single sequence. When `conditioning_attention_strength < 1.0` - or `conditioning_attention_mask` is provided, a per-token cross-attention mask is also computed for each - reference video (downsampled to the reference video's latent dimensions) and returned so callers can build - a self-attention mask over the full video sequence. + All reference tokens are concatenated into a single sequence. When `conditioning_attention_strength < 1.0` or + `conditioning_attention_mask` is provided, a per-token cross-attention mask is also computed for each reference + video (downsampled to the reference video's latent dimensions) and returned so callers can build a + self-attention mask over the full video sequence. Args: reference_conditions (`list[LTX2ReferenceCondition]`): @@ -1342,19 +1324,19 @@ def prepare_reference_latents( num_frames (`int`): Number of target video frames. reference_downscale_factor (`int`, defaults to `1`): - Ratio between target and reference resolutions. A factor of 2 means the reference video is - preprocessed at half the target resolution. Spatial positional coordinates are scaled by this factor - to map reference tokens into the target coordinate space. + Ratio between target and reference resolutions. A factor of 2 means the reference video is preprocessed + at half the target resolution. Spatial positional coordinates are scaled by this factor to map + reference tokens into the target coordinate space. frame_rate (`float`, defaults to `24.0`): Video frame rate (used for temporal coordinate computation). conditioning_attention_strength (`float`, defaults to `1.0`): - Scalar in `[0, 1]` controlling how strongly reference tokens attend to noisy tokens (and vice versa) - in the self-attention mask. `1.0` means full attention (no masking), `0.0` means reference tokens - are effectively ignored by the noisy tokens. + Scalar in `[0, 1]` controlling how strongly reference tokens attend to noisy tokens (and vice versa) in + the self-attention mask. `1.0` means full attention (no masking), `0.0` means reference tokens are + effectively ignored by the noisy tokens. conditioning_attention_mask (`torch.Tensor`, *optional*): Optional pixel-space mask of shape `(1, 1, F_pix, H_pix, W_pix)` with values in `[0, 1]` that provides - spatially-varying attention strength. Downsampled to latent space per reference video and multiplied - by `conditioning_attention_strength`. + spatially-varying attention strength. Downsampled to latent space per reference video and multiplied by + `conditioning_attention_strength`. dtype (`torch.dtype`, *optional*): Data type for the latents. device (`torch.device`, *optional*): @@ -1367,9 +1349,9 @@ def prepare_reference_latents( - `reference_latents`: `[1, total_ref_tokens, hidden_dim]` - `reference_coords`: `[1, 3, total_ref_tokens, 2]` - `reference_denoise_factors`: `[1, total_ref_tokens]` — per-token `(1 - strength)` factors - - `reference_cross_mask`: `[1, total_ref_tokens]` per-token noisy↔reference attention strengths in - `[0, 1]`, or `None` when `conditioning_attention_strength == 1.0` and no pixel-space mask is - provided (in which case attention is unmasked). + - `reference_cross_mask`: `[1, total_ref_tokens]` per-token noisy↔reference attention strengths in `[0, + 1]`, or `None` when `conditioning_attention_strength == 1.0` and no pixel-space mask is provided (in + which case attention is unmasked). """ reference_latents, reference_coords, reference_cross_mask = self._encode_reference_conditions( reference_conditions=reference_conditions, @@ -1403,7 +1385,9 @@ def prepare_reference_latents( ) ) idx += n_per_ref - reference_denoise_factors = torch.cat(ref_denoise_chunks, dim=1) if ref_denoise_chunks else reference_latents.new_zeros((1, 0)) + reference_denoise_factors = ( + torch.cat(ref_denoise_chunks, dim=1) if ref_denoise_chunks else reference_latents.new_zeros((1, 0)) + ) return reference_latents, reference_coords, reference_denoise_factors, reference_cross_mask @@ -1419,9 +1403,9 @@ def _downsample_mask_to_latent( Mirrors `ICLoraPipeline._downsample_mask_to_latent` in the reference implementation: - Spatial downsampling via `area` interpolation per frame. - - Causal temporal downsampling: the first frame is kept as-is (the VAE encodes the first frame - independently with temporal stride 1), remaining frames are downsampled by group-mean using factor - `(F_pix - 1) // (F_lat - 1)`. + - Causal temporal downsampling: the first frame is kept as-is (the VAE encodes the first frame independently + with temporal stride 1), remaining frames are downsampled by group-mean using factor `(F_pix - 1) // (F_lat - + 1)`. - Flattened to token order `(F, H, W)` matching the patchifier. Args: @@ -1441,9 +1425,7 @@ def _downsample_mask_to_latent( # 1. Spatial downsampling (area interpolation per frame). mask_2d = mask.reshape(b * f_pix, 1, mask.shape[-2], mask.shape[-1]) - spatial_down = torch.nn.functional.interpolate( - mask_2d, size=(latent_height, latent_width), mode="area" - ) + spatial_down = torch.nn.functional.interpolate(mask_2d, size=(latent_height, latent_width), mode="area") spatial_down = spatial_down.reshape(b, 1, f_pix, latent_height, latent_width) # 2. Causal temporal downsampling. @@ -1486,15 +1468,15 @@ def _build_video_self_attention_mask( num_noisy_tokens (`int`): Number of noisy video tokens. extras_cross_masks (`list[torch.Tensor]`): - List of per-token cross-attention strengths, one per conditioning group. Each entry has shape - `(1, num_tokens_in_group)` with values in `[0, 1]`. Groups must appear in the same order as their - tokens in the extras block. + List of per-token cross-attention strengths, one per conditioning group. Each entry has shape `(1, + num_tokens_in_group)` with values in `[0, 1]`. Groups must appear in the same order as their tokens in + the extras block. device, dtype: Tensor device and dtype. Returns: - Multiplicative self-attention mask of shape `(1, num_noisy_tokens + sum(group_sizes), - num_noisy_tokens + sum(group_sizes))` with values in `[0, 1]`. + Multiplicative self-attention mask of shape `(1, num_noisy_tokens + sum(group_sizes), num_noisy_tokens + + sum(group_sizes))` with values in `[0, 1]`. """ total_extras = sum(m.shape[1] for m in extras_cross_masks) total = num_noisy_tokens + total_extras @@ -1508,11 +1490,11 @@ def _build_video_self_attention_mask( n = cross_mask.shape[1] cross = cross_mask.to(device=device, dtype=dtype) # noisy (rows) ↔ this group (cols) - attn_mask[:, :num_noisy_tokens, offset:offset + n] = cross.unsqueeze(1) + attn_mask[:, :num_noisy_tokens, offset : offset + n] = cross.unsqueeze(1) # this group (rows) ↔ noisy (cols) - attn_mask[:, offset:offset + n, :num_noisy_tokens] = cross.unsqueeze(2) + attn_mask[:, offset : offset + n, :num_noisy_tokens] = cross.unsqueeze(2) # this group ↔ this group (self-attention within the group) - attn_mask[:, offset:offset + n, offset:offset + n] = 1.0 + attn_mask[:, offset : offset + n, offset : offset + n] = 1.0 offset += n return attn_mask @@ -1666,15 +1648,15 @@ def __call__( conditioning_attention_strength (`float`, *optional*, defaults to `1.0`): Scalar in `[0, 1]` controlling how strongly noisy tokens and appended reference tokens attend to each other in the video self-attention. `1.0` = full attention (no masking, same as the base IC-LoRA - behavior). `0.0` = reference tokens are fully masked out of the noisy-token attention (and vice - versa). Only takes effect when `reference_conditions` is provided. + behavior). `0.0` = reference tokens are fully masked out of the noisy-token attention (and vice versa). + Only takes effect when `reference_conditions` is provided. conditioning_attention_mask (`torch.Tensor`, *optional*): - Optional pixel-space spatial attention mask of shape `(1, 1, F_pix, H_pix, W_pix)` with values in - `[0, 1]` that provides per-region attention strength. The mask's spatial-temporal dimensions must - match the reference video's pixel dimensions. Downsampled to latent space using VAE scale factors - (with causal temporal handling for the first frame) and multiplied by - `conditioning_attention_strength` to form the final cross-attention mask between noisy and reference - tokens. Only takes effect when `reference_conditions` is provided. + Optional pixel-space spatial attention mask of shape `(1, 1, F_pix, H_pix, W_pix)` with values in `[0, + 1]` that provides per-region attention strength. The mask's spatial-temporal dimensions must match the + reference video's pixel dimensions. Downsampled to latent space using VAE scale factors (with causal + temporal handling for the first frame) and multiplied by `conditioning_attention_strength` to form the + final cross-attention mask between noisy and reference tokens. Only takes effect when + `reference_conditions` is provided. height (`int`, *optional*, defaults to `512`): The height in pixels of the generated video. width (`int`, *optional*, defaults to `768`): @@ -1748,8 +1730,8 @@ def __call__( Returns: [`LTX2PipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`LTX2PipelineOutput`] is returned, otherwise a `tuple` of - `(video, audio)` is returned. + If `return_dict` is `True`, [`LTX2PipelineOutput`] is returned, otherwise a `tuple` of `(video, audio)` + is returned. """ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): @@ -2155,9 +2137,7 @@ def __call__( audio_latents, noise_pred_audio_uncond_mod, i, audio_scheduler ) - video_modality_delta = (self.modality_scale - 1) * ( - noise_pred_video - noise_pred_video_uncond_mod - ) + video_modality_delta = (self.modality_scale - 1) * (noise_pred_video - noise_pred_video_uncond_mod) audio_modality_delta = (self.audio_modality_scale - 1) * ( noise_pred_audio - noise_pred_audio_uncond_mod ) From b1e8907865e79d3d7fbc48a49cf202337683dd10 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 28 Apr 2026 03:52:49 +0200 Subject: [PATCH 09/14] Rename LTX2ICLoraPipeline to LTX2InContextPipeline and LTX2HDRLoraPipeline to LTX2HDRPipeline --- src/diffusers/__init__.py | 8 +++---- src/diffusers/pipelines/__init__.py | 8 +++---- src/diffusers/pipelines/ltx2/__init__.py | 8 +++---- .../pipelines/ltx2/pipeline_ltx2_hdr_lora.py | 22 +++++++++---------- .../pipelines/ltx2/pipeline_ltx2_ic_lora.py | 6 ++--- .../dummy_torch_and_transformers_objects.py | 6 ++--- 6 files changed, 29 insertions(+), 29 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 8027eb64be91..9e37fc069e77 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -608,9 +608,9 @@ "LongCatImageEditPipeline", "LongCatImagePipeline", "LTX2ConditionPipeline", - "LTX2HDRLoraPipeline", - "LTX2ICLoraPipeline", + "LTX2HDRPipeline", "LTX2ImageToVideoPipeline", + "LTX2InContextPipeline", "LTX2LatentUpsamplePipeline", "LTX2Pipeline", "LTXConditionPipeline", @@ -1395,9 +1395,9 @@ LongCatImageEditPipeline, LongCatImagePipeline, LTX2ConditionPipeline, - LTX2HDRLoraPipeline, - LTX2ICLoraPipeline, + LTX2HDRPipeline, LTX2ImageToVideoPipeline, + LTX2InContextPipeline, LTX2LatentUpsamplePipeline, LTX2Pipeline, LTXConditionPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index eaf985e7ac87..bc5c291d89f8 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -324,8 +324,8 @@ _import_structure["ltx2"] = [ "LTX2Pipeline", "LTX2ConditionPipeline", - "LTX2HDRLoraPipeline", - "LTX2ICLoraPipeline", + "LTX2HDRPipeline", + "LTX2InContextPipeline", "LTX2ImageToVideoPipeline", "LTX2LatentUpsamplePipeline", ] @@ -772,9 +772,9 @@ ) from .ltx2 import ( LTX2ConditionPipeline, - LTX2HDRLoraPipeline, - LTX2ICLoraPipeline, + LTX2HDRPipeline, LTX2ImageToVideoPipeline, + LTX2InContextPipeline, LTX2LatentUpsamplePipeline, LTX2Pipeline, ) diff --git a/src/diffusers/pipelines/ltx2/__init__.py b/src/diffusers/pipelines/ltx2/__init__.py index 87ab0827c246..92da5c55003e 100644 --- a/src/diffusers/pipelines/ltx2/__init__.py +++ b/src/diffusers/pipelines/ltx2/__init__.py @@ -27,8 +27,8 @@ _import_structure["latent_upsampler"] = ["LTX2LatentUpsamplerModel"] _import_structure["pipeline_ltx2"] = ["LTX2Pipeline"] _import_structure["pipeline_ltx2_condition"] = ["LTX2ConditionPipeline"] - _import_structure["pipeline_ltx2_hdr_lora"] = ["LTX2HDRLoraPipeline", "LTX2HDRReferenceCondition"] - _import_structure["pipeline_ltx2_ic_lora"] = ["LTX2ICLoraPipeline", "LTX2ReferenceCondition"] + _import_structure["pipeline_ltx2_hdr_lora"] = ["LTX2HDRPipeline", "LTX2HDRReferenceCondition"] + _import_structure["pipeline_ltx2_ic_lora"] = ["LTX2InContextPipeline", "LTX2ReferenceCondition"] _import_structure["pipeline_ltx2_image2video"] = ["LTX2ImageToVideoPipeline"] _import_structure["pipeline_ltx2_latent_upsample"] = ["LTX2LatentUpsamplePipeline"] _import_structure["vocoder"] = ["LTX2Vocoder", "LTX2VocoderWithBWE"] @@ -46,8 +46,8 @@ from .latent_upsampler import LTX2LatentUpsamplerModel from .pipeline_ltx2 import LTX2Pipeline from .pipeline_ltx2_condition import LTX2ConditionPipeline - from .pipeline_ltx2_hdr_lora import LTX2HDRLoraPipeline, LTX2HDRReferenceCondition - from .pipeline_ltx2_ic_lora import LTX2ICLoraPipeline, LTX2ReferenceCondition + from .pipeline_ltx2_hdr_lora import LTX2HDRPipeline, LTX2HDRReferenceCondition + from .pipeline_ltx2_ic_lora import LTX2InContextPipeline, LTX2ReferenceCondition from .pipeline_ltx2_image2video import LTX2ImageToVideoPipeline from .pipeline_ltx2_latent_upsample import LTX2LatentUpsamplePipeline from .vocoder import LTX2Vocoder, LTX2VocoderWithBWE diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_hdr_lora.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_hdr_lora.py index a4b249a42d46..99ebde1d9515 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_hdr_lora.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_hdr_lora.py @@ -74,13 +74,13 @@ class LTX2HDRReferenceCondition: ```py >>> import torch >>> from safetensors import safe_open - >>> from diffusers import LTX2HDRLoraPipeline + >>> from diffusers import LTX2HDRPipeline >>> from diffusers.pipelines.ltx2.pipeline_ltx2_hdr_lora import LTX2HDRReferenceCondition >>> from diffusers.pipelines.ltx2.utils import DISTILLED_SIGMA_VALUES >>> from diffusers.pipelines.ltx2.export_utils import save_hdr_video_frames_as_exr, encode_exr_sequence_to_mp4 >>> from diffusers.utils import load_video - >>> pipe = LTX2HDRLoraPipeline.from_pretrained("dg845/LTX-2.3-Distilled-Diffusers", torch_dtype=torch.bfloat16) + >>> pipe = LTX2HDRPipeline.from_pretrained("dg845/LTX-2.3-Distilled-Diffusers", torch_dtype=torch.bfloat16) >>> pipe.enable_sequential_cpu_offload(device="cuda") >>> pipe.load_lora_weights( ... "Lightricks/LTX-2.3-22b-IC-LoRA-HDR", @@ -237,15 +237,15 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): return noise_cfg -class LTX2HDRLoraPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin): +class LTX2HDRPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin): r""" Pipeline for HDR IC-LoRA video generation with reference video conditioning. - This is a video-only HDR counterpart to [`LTX2ICLoraPipeline`]. The HDR IC-LoRA adapter (loaded as a standard LoRA - via `load_lora_weights`) conditions generation on a reference video, and the pipeline's postprocessing applies the - LogC3 inverse transform to produce linear HDR output in `[0, ∞)`. + This is a video-only HDR counterpart to [`LTX2InContextPipeline`]. The HDR IC-LoRA adapter (loaded as a standard + LoRA via `load_lora_weights`) conditions generation on a reference video, and the pipeline's postprocessing applies + the LogC3 inverse transform to produce linear HDR output in `[0, ∞)`. - Compared to [`LTX2ICLoraPipeline`], the HDR pipeline drops: + Compared to [`LTX2InContextPipeline`], the HDR pipeline drops: - Frame-level keyframe conditioning (the reference HDR pipeline does not support this). - The `conditioning_attention_strength` / `conditioning_attention_mask` knobs. @@ -701,10 +701,10 @@ def prepare_latents( Builds a packed latent sequence in the order `[base | reference]`: - Base: either fresh noise (Stage 1, `latents=None`) or pre-existing upsampled latents (Stage 2). - Reference: HDR-encoded reference-video tokens appended with per-token `conditioning_mask = strength`, - following the same pattern as [`LTX2ICLoraPipeline.prepare_latents`]. (HDR LoRA does not currently take + following the same pattern as [`LTX2InContextPipeline.prepare_latents`]. (HDR LoRA does not currently take per-frame `conditions`, so there is no first-frame / keyframe block in between.) - Returns a 6-tuple matching [`LTX2ICLoraPipeline.prepare_latents`]: + Returns a 6-tuple matching [`LTX2InContextPipeline.prepare_latents`]: - `latents`: packed noisy latents `(B, base + n_ref, C)`. - `conditioning_mask`: `(B, seq_len, 1)` with `strength` at reference positions, `0` elsewhere. - `clean_latents`: clean reference values at reference positions (zeros elsewhere); same shape as @@ -936,7 +936,7 @@ def prepare_reference_latents( computed at the reference latent dimensions and scaled by `reference_downscale_factor`. Returns a 3-tuple `(reference_latents, reference_coords, reference_denoise_factors)` with the same shapes as - [`LTX2ICLoraPipeline.prepare_reference_latents`]. + [`LTX2InContextPipeline.prepare_reference_latents`]. """ reference_latents, reference_coords, _ = self._encode_reference_conditions( reference_conditions=reference_conditions, @@ -1138,7 +1138,7 @@ def __call__( return_dict (`bool`, *optional*, defaults to `True`): Whether to return an [`LTX2PipelineOutput`] instead of a plain tuple. attention_kwargs, callback_on_step_end, callback_on_step_end_tensor_inputs, max_sequence_length: - Standard hooks and arguments, same as [`LTX2ICLoraPipeline`]. + Standard hooks and arguments, same as [`LTX2InContextPipeline`]. Examples: diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_ic_lora.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_ic_lora.py index 8ad7616df133..35b9e702fc5a 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_ic_lora.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_ic_lora.py @@ -73,13 +73,13 @@ class LTX2ReferenceCondition: Examples: ```py >>> import torch - >>> from diffusers import LTX2ICLoraPipeline + >>> from diffusers import LTX2InContextPipeline >>> from diffusers.pipelines.ltx2.pipeline_ltx2_ic_lora import LTX2ReferenceCondition >>> from diffusers.pipelines.ltx2.export_utils import encode_video >>> from diffusers.pipelines.ltx2.utils import DEFAULT_NEGATIVE_PROMPT >>> from diffusers.utils import load_video - >>> pipe = LTX2ICLoraPipeline.from_pretrained("dg845/LTX-2.3-Diffusers", torch_dtype=torch.bfloat16) + >>> pipe = LTX2InContextPipeline.from_pretrained("dg845/LTX-2.3-Diffusers", torch_dtype=torch.bfloat16) >>> pipe.enable_sequential_cpu_offload(device="cuda") >>> pipe.load_lora_weights( ... "Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In", @@ -234,7 +234,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): return noise_cfg -class LTX2ICLoraPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin): +class LTX2InContextPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin): r""" Pipeline for IC-LoRA (In-Context LoRA) video generation with reference video conditioning. diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index c41ddc24f38c..177d0cece833 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -2402,7 +2402,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class LTX2HDRLoraPipeline(metaclass=DummyObject): +class LTX2HDRPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): @@ -2417,7 +2417,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class LTX2ICLoraPipeline(metaclass=DummyObject): +class LTX2ImageToVideoPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): @@ -2432,7 +2432,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class LTX2ImageToVideoPipeline(metaclass=DummyObject): +class LTX2InContextPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): From 0be15f373c5ff4e4d4375def84aae42efd33fb23 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 28 Apr 2026 04:18:30 +0200 Subject: [PATCH 10/14] Improve LTX2InContextPipeline and LTX2HDRPipeline docstrings --- .../pipelines/ltx2/pipeline_ltx2_hdr_lora.py | 18 +++++++++--------- .../pipelines/ltx2/pipeline_ltx2_ic_lora.py | 13 ++++++------- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_hdr_lora.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_hdr_lora.py index 99ebde1d9515..68ca88e393fb 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_hdr_lora.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_hdr_lora.py @@ -239,19 +239,19 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): class LTX2HDRPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin): r""" - Pipeline for HDR IC-LoRA video generation with reference video conditioning. + Pipeline for LTX-2.X HDR video generation with reference video conditioning. - This is a video-only HDR counterpart to [`LTX2InContextPipeline`]. The HDR IC-LoRA adapter (loaded as a standard - LoRA via `load_lora_weights`) conditions generation on a reference video, and the pipeline's postprocessing applies - the LogC3 inverse transform to produce linear HDR output in `[0, ∞)`. + The pipeline accepts a reference SDR ("normal") video and generates a linear HDR output with values in `[0, ∞)` via + a LogC3 inverse transform which has the same content as the reference video. The motivating use case for this + pipeline is to support LTX-2.X HDR IC-LoRAs, but it should support any LTX-2.X-like model that operates on HDR + inputs as above. - Compared to [`LTX2InContextPipeline`], the HDR pipeline drops: + Compared to [`LTX2InContextPipeline`], the HDR pipeline has the following differences: - - Frame-level keyframe conditioning (the reference HDR pipeline does not support this). - - The `conditioning_attention_strength` / `conditioning_attention_mask` knobs. - - Audio output (video-only). The transformer's audio branch is still run since the diffusers transformer API + - Video-only (no audio output). The transformer's audio branch is still run since the diffusers transformer API requires audio inputs, but the decoded audio is discarded and audio-specific guidance scales are fixed to no-op values to avoid wasted compute. + - No frame-level keyframe conditioning (the reference HDR pipeline does not support this). Two-stage inference is supported through separate calls to `__call__`: @@ -260,7 +260,7 @@ class LTX2HDRPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixi by passing `latents=upsampled_latents`. The reference HDR stage-2 additionally supports spatial/temporal tiling of the refinement pass — that optimization is not yet implemented here. - Reference: https://github.com/Lightricks/LTX-2 + Reference: https://github.com/Lightricks/LTX-2 Paper: https://huggingface.co/papers/2604.11788 Args: scheduler ([`FlowMatchEulerDiscreteScheduler`]): diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_ic_lora.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_ic_lora.py index 35b9e702fc5a..5c7ca04f8376 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_ic_lora.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_ic_lora.py @@ -236,14 +236,13 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): class LTX2InContextPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin): r""" - Pipeline for IC-LoRA (In-Context LoRA) video generation with reference video conditioning. + Pipeline for LTX-2.X models with in-context (IC) conditioning. Also supports frame-level image conditions like + `LTX2ConditionPipeline`; both frame and reference conditions can be used together. - IC-LoRA conditions the generation on a reference video by encoding it into latent tokens and concatenating them to - the noisy latent sequence during denoising. The IC-LoRA adapter (loaded as a standard LoRA) learns to use this - in-context reference to guide generation (e.g. for style transfer, depth-conditioned generation, etc.). - - This pipeline also supports frame-level conditioning via the `conditions` parameter (same as - [`LTX2ConditionPipeline`]), allowing both reference video and frame conditions to be used together. + In-context conditioning works by conditioning the generation on a reference video by encoding it into latent tokens + and concatenating them to the noisy latent tokens during denoising. The motivating use case is to support LTX-2.X + IC LoRAs, which may use reference conditions (e.g. a pose video for pose control) to guide generation, but this + pipeline is designed to work with any LTX-2.X-like model trained with in-context reference conditions. Two-stage inference is supported through separate calls to `__call__`: - **Stage 1**: Generate at target resolution with IC-LoRA conditioning (`output_type="latent"`). From cd25a51a61cd3cf8bb5545d58ea8beca81d80d06 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 28 Apr 2026 04:28:19 +0200 Subject: [PATCH 11/14] Add export function to directly convert HDR tensors to .mp4 files --- src/diffusers/pipelines/ltx2/export_utils.py | 67 +++++++++++++++++++ .../pipelines/ltx2/pipeline_ltx2_hdr_lora.py | 11 ++- 2 files changed, 72 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/ltx2/export_utils.py b/src/diffusers/pipelines/ltx2/export_utils.py index 87ebf1eab885..64fc110737d2 100644 --- a/src/diffusers/pipelines/ltx2/export_utils.py +++ b/src/diffusers/pipelines/ltx2/export_utils.py @@ -395,3 +395,70 @@ def save_hdr_video_frames_as_exr( save_exr_tensor(frame, path, half=half) paths.append(path) return paths + + +def encode_hdr_tensor_to_mp4( + frames: torch.Tensor, + output_mp4: str | Path, + frame_rate: float, + tone_mapping_fn: Callable[[np.ndarray], np.ndarray] | None = None, + tone_map_in_rgb: bool = True, + crf: int = 18, +) -> None: + """ + Converts a linear HDR tensor (for example, as outputted by `LTX2HDRPipeline`) to a SDR `.mp4` file (specifically, a + sRGB-tonemapped H.264 `.mp4`). + + Args: + frames (`torch.Tensor`): + A linear HDR tensors with RGB values in `[0, ∞)` of shape `(F, H, W, 3)`. + output_mp4 (`str` or `pathlib.Path`): + Output MP4 path. + frame_rate (`float`): + Frame rate for the output video. + tone_mapping_fn (`Callable[[np.ndarray], np.ndarray]`, *optional*, defaults to `None`): + An optional tone mapping function which takes a float32 NumPy array of shape `(H, W, 3)` containing linear + HDR values in `[0, ∞)` and returns tone-mapped linear values in `[0, 1]`. The sRGB transfer function (OETF) + is applied afterwards — do **not** pre-apply gamma inside this function. If `None`, defaults to + [`simple_tone_map`], which clips values above `1.0`. The channel ordering of the input array is controlled + by `tone_map_in_rgb`: RGB by default (matching the `LTX2HDRPipeline` output), or BGR when + `tone_map_in_rgb=False`. This is the opposite default to `encode_exr_sequence_to_mp4`. + tone_map_in_rgb (`bool`, *optional*, defaults to `True`): + When `True` (default), frames are passed as RGB to `tone_mapping_fn`, and the output frame is tagged as + `rgb24`. Use this when `tone_mapping_fn` expects RGB input (e.g. operators from `colour-science`). When + `False`, the frames first have their channels flipped to BGR, which is the native format for + `opencv-python` tone mappers (e.g. `cv2.createTonemapReinhard().process`). Note that this is the opposite + default to `encode_exr_sequence_to_mp4`. + crf (`int`, *optional*, defaults to `18`): + libx264 CRF quality factor. Lower values produce higher quality. + """ + frames_np = frames.cpu().float().numpy() + + container = av.open(str(output_mp4), mode="w") + stream = container.add_stream("libx264", rate=Fraction(frame_rate).limit_denominator(1000)) + stream.pix_fmt = "yuv420p" + stream.options = {"crf": str(crf), "movflags": "+faststart"} + + pix_fmt = "rgb24" if tone_map_in_rgb else "bgr24" + if tone_mapping_fn is None: + tone_mapping_fn = simple_tone_map + + try: + for i, hdr in enumerate(frames_np): + if not tone_map_in_rgb: + hdr = hdr[..., ::-1] + hdr_mapped = tone_mapping_fn(hdr) + sdr = linear_to_srgb(np.maximum(hdr_mapped, 0.0)) + out8 = (sdr * 255.0 + 0.5).astype(np.uint8) + + if i == 0: + stream.height, stream.width = out8.shape[:2] + + frame = av.VideoFrame.from_ndarray(out8, format=pix_fmt) + for packet in stream.encode(frame): + container.mux(packet) + + for packet in stream.encode(): + container.mux(packet) + finally: + container.close() diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_hdr_lora.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_hdr_lora.py index 68ca88e393fb..1409cb7cd092 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_hdr_lora.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_hdr_lora.py @@ -77,7 +77,7 @@ class LTX2HDRReferenceCondition: >>> from diffusers import LTX2HDRPipeline >>> from diffusers.pipelines.ltx2.pipeline_ltx2_hdr_lora import LTX2HDRReferenceCondition >>> from diffusers.pipelines.ltx2.utils import DISTILLED_SIGMA_VALUES - >>> from diffusers.pipelines.ltx2.export_utils import save_hdr_video_frames_as_exr, encode_exr_sequence_to_mp4 + >>> from diffusers.pipelines.ltx2.export_utils import encode_hdr_tensor_to_mp4 >>> from diffusers.utils import load_video >>> pipe = LTX2HDRPipeline.from_pretrained("dg845/LTX-2.3-Distilled-Diffusers", torch_dtype=torch.bfloat16) @@ -97,6 +97,7 @@ class LTX2HDRReferenceCondition: ... connector_video_embeds = f.get_tensor("video_context") ... connector_audio_embeds = f.get_tensor("audio_context") + >>> # `hdr_video` is a linear HDR tensor of shape (batch, frames, H, W, C). >>> hdr_video = pipe( ... reference_conditions=[ref_cond], ... connector_video_embeds=connector_video_embeds, @@ -112,12 +113,10 @@ class LTX2HDRReferenceCondition: ... return_dict=False, ... )[0] - >>> # `hdr_video` is a linear HDR tensor of shape (batch, frames, H, W, C). - >>> # Save the HDR video as per-frame EXR files in the specified directory. - >>> save_hdr_video_frames_as_exr(hdr_video[0], "hdr_output/") - >>> # You can convert these EXR files to .mp4 file as below. + >>> # Convert the HDR video to a SDR sRGB-tonemapped `.mp4` video. + >>> # You can also save the output to EXR using `save_hdr_video_frames_as_exr`. >>> # A custom tone-mapper can be specified via the `tone_mapping_fn` argument. - >>> encode_exr_sequence_to_mp4("hdr_output/", "ltx2_hdr_lora_output.mp4", frame_rate=24.0) + >>> encode_hdr_tensor_to_mp4(hdr_video[0], "ltx2_hdr_lora_output.mp4", frame_rate=24.0) ``` """ From dac2e20cd3c4db28cc717b5bcc7390983911ce10 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 28 Apr 2026 05:28:18 +0200 Subject: [PATCH 12/14] Clean up the code/docstrings some more --- src/diffusers/pipelines/ltx2/image_processor.py | 7 +------ .../pipelines/ltx2/pipeline_ltx2_ic_lora.py | 12 +++--------- 2 files changed, 4 insertions(+), 15 deletions(-) diff --git a/src/diffusers/pipelines/ltx2/image_processor.py b/src/diffusers/pipelines/ltx2/image_processor.py index c9569529db27..a22a37fe6b85 100644 --- a/src/diffusers/pipelines/ltx2/image_processor.py +++ b/src/diffusers/pipelines/ltx2/image_processor.py @@ -29,10 +29,8 @@ class LTX2VideoHDRProcessor(VideoProcessor): For LDR (SDR Rec.709) reference videos, `LogC3.compress_ldr` is an identity clamp, so the numerical output is equivalent to the standard [-1, 1] normalization used by [`VideoProcessor.preprocess_video`] — only the resize strategy differs (reflect-pad vs center-crop). - - `postprocess_hdr_video`: applies the LogC3 inverse transform to the VAE's decoded output, mapping `[0, 1]` → - linear HDR `[0, ∞)`. This is the caller-facing counterpart to `apply_hdr_decode_postprocess` in the reference - `ltx_core.hdr` module. + linear HDR `[0, ∞)`. Args: vae_scale_factor (`int`, *optional*, defaults to `32`): @@ -81,9 +79,6 @@ def _resize_and_reflect_pad_video(video: torch.Tensor, height: int, width: int) r""" Resize a video tensor preserving aspect ratio, then reflect-pad to the exact target dimensions. - Mirrors `resize_and_reflect_pad` in the reference `ltx_pipelines.utils.media_io`. When the source is already at - least as large as the target in both dimensions, the interpolation step is skipped entirely. - Args: video (`torch.Tensor`): Input of shape `(B, C, F, H, W)`. height (`int`), width (`int`): Target spatial dimensions. diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_ic_lora.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_ic_lora.py index 5c7ca04f8376..4fb9686496c9 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_ic_lora.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_ic_lora.py @@ -870,7 +870,7 @@ def _prepare_keyframe_coords( """ Compute positional coordinates for a keyframe condition being appended as extra tokens. - Mirrors `VideoConditionByKeyframeIndex.apply_to` in the reference implementation: + Has the following behavior (based on the LTX-2.X original code): - Latent coords scaled to pixel space *without* the causal fix (since non-zero-index keyframes don't need the first-frame causal adjustment). - Temporal axis offset by `pixel_frame_idx` (the pixel-space index at which the keyframe appears). @@ -1398,14 +1398,8 @@ def _downsample_mask_to_latent( latent_width: int, ) -> torch.Tensor: """ - Downsample a pixel-space attention mask to a flattened per-token latent-space mask. - - Mirrors `ICLoraPipeline._downsample_mask_to_latent` in the reference implementation: - - Spatial downsampling via `area` interpolation per frame. - - Causal temporal downsampling: the first frame is kept as-is (the VAE encodes the first frame independently - with temporal stride 1), remaining frames are downsampled by group-mean using factor `(F_pix - 1) // (F_lat - - 1)`. - - Flattened to token order `(F, H, W)` matching the patchifier. + Downsample a pixel-space attention mask to a flattened per-token latent-space mask. Uses causal temporal + downsampling (the first frame is kept as-is). Args: mask (`torch.Tensor`): From 7fee0caec0b4c4eb36cd14a7a99982d11115cfea Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 28 Apr 2026 05:47:42 +0200 Subject: [PATCH 13/14] Move new video_self_attention_mask LTX-2.X transformer arg to end to preserved positional arg ordering --- .../models/transformers/transformer_ltx2.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_ltx2.py b/src/diffusers/models/transformers/transformer_ltx2.py index 5f8c1063cfa9..465408d94693 100644 --- a/src/diffusers/models/transformers/transformer_ltx2.py +++ b/src/diffusers/models/transformers/transformer_ltx2.py @@ -1331,7 +1331,6 @@ def forward( audio_sigma: torch.Tensor | None = None, encoder_attention_mask: torch.Tensor | None = None, audio_encoder_attention_mask: torch.Tensor | None = None, - video_self_attention_mask: torch.Tensor | None = None, num_frames: int | None = None, height: int | None = None, width: int | None = None, @@ -1344,6 +1343,7 @@ def forward( perturbation_mask: torch.Tensor | None = None, use_cross_timestep: bool = False, attention_kwargs: dict[str, Any] | None = None, + video_self_attention_mask: torch.Tensor | None = None, return_dict: bool = True, ) -> torch.Tensor: """ @@ -1375,11 +1375,6 @@ def forward( Optional multiplicative text attention mask of shape `(batch_size, text_seq_len)`. audio_encoder_attention_mask (`torch.Tensor`, *optional*): Optional multiplicative text attention mask of shape `(batch_size, text_seq_len)` for audio modeling. - video_self_attention_mask (`torch.Tensor`, *optional*): - Optional multiplicative self-attention mask of shape `(batch_size, num_video_tokens, num_video_tokens)` - applied to the video self-attention in each transformer block. Values in `[0, 1]` where `1` means full - attention and `0` means masked. Used e.g. by the IC-LoRA pipeline to control attention strength between - noisy tokens and appended reference tokens. Audio self-attention is not affected. num_frames (`int`, *optional*): The number of latent video frames. Used if calculating the video coordinates for RoPE. height (`int`, *optional*): @@ -1414,6 +1409,11 @@ def forward( `False` is the legacy LTX-2.0 behavior. attention_kwargs (`dict[str, Any]`, *optional*): Optional dict of keyword args to be passed to the attention processor. + video_self_attention_mask (`torch.Tensor`, *optional*): + Optional multiplicative self-attention mask of shape `(batch_size, num_video_tokens, num_video_tokens)` + applied to the video self-attention in each transformer block. Values in `[0, 1]` where `1` means full + attention and `0` means masked. Used e.g. by the IC-LoRA pipeline to control attention strength between + noisy tokens and appended reference tokens. Audio self-attention is not affected. return_dict (`bool`, *optional*, defaults to `True`): Whether to return a dict-like structured output of type `AudioVisualModelOutput` or a tuple. From 9460fbcaf6e926435d79265c4fc9a62e48447bb2 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 28 Apr 2026 05:57:11 +0200 Subject: [PATCH 14/14] make fix-copies --- src/diffusers/pipelines/ltx2/pipeline_ltx2_ic_lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_ic_lora.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_ic_lora.py index 4fb9686496c9..a4d573d8d265 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_ic_lora.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_ic_lora.py @@ -870,7 +870,7 @@ def _prepare_keyframe_coords( """ Compute positional coordinates for a keyframe condition being appended as extra tokens. - Has the following behavior (based on the LTX-2.X original code): + Mirrors `VideoConditionByKeyframeIndex.apply_to` in the reference implementation: - Latent coords scaled to pixel space *without* the causal fix (since non-zero-index keyframes don't need the first-frame causal adjustment). - Temporal axis offset by `pixel_frame_idx` (the pixel-space index at which the keyframe appears).