Skip to content

Commit e81069d

Browse files
committed
Clean up _decode_latents, fix LTX2 compat, add LTX2 batch support
- Remove unused batch_size param from _decode_latents in all pipelines (FLUX1, FLUX2, WAN T2V, WAN I2V) - Fix outdated docstrings to reflect always-batched output - Fix LTX2 postprocess_video_tensor call (remove_batch_dim removed) - Add batch generation support to LTX2 pipeline (batch=batch_size) - Add lightweight LTX2 batch unit tests (no model loading required) Signed-off-by: Kanghwan Jang <861393+karljang@users.noreply.github.com>
1 parent d3243b9 commit e81069d

File tree

6 files changed

+80
-19
lines changed

6 files changed

+80
-19
lines changed

tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -333,9 +333,7 @@ def forward_fn(
333333
# Decode
334334
logger.info("Decoding image...")
335335
decode_start = time.time()
336-
image = self.decode_latents(
337-
latents, lambda lat: self._decode_latents(lat, height, width, batch_size)
338-
)
336+
image = self.decode_latents(latents, lambda lat: self._decode_latents(lat, height, width))
339337

340338
if self.rank == 0:
341339
logger.info(f"Image decoded in {time.time() - decode_start:.2f}s")
@@ -524,19 +522,16 @@ def _prepare_latents(
524522

525523
return latents, latent_ids
526524

527-
def _decode_latents(
528-
self, latents: torch.Tensor, height: int, width: int, batch_size: int = 1
529-
) -> torch.Tensor:
525+
def _decode_latents(self, latents: torch.Tensor, height: int, width: int) -> torch.Tensor:
530526
"""Decode latents to image tensor.
531527
532528
Args:
533529
latents: Packed latents [B, seq, 64].
534530
height: Output image height.
535531
width: Output image width.
536-
batch_size: Number of images in batch.
537532
538533
Returns:
539-
Image tensor (H, W, C) for single image, (B, H, W, C) for batch.
534+
Image tensor (B, H, W, C).
540535
"""
541536
# Unpack latents: (batch, seq_len, channels) -> (batch, channels, h, w)
542537
latents = self._unpack_latents(latents, height, width)

tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -424,9 +424,7 @@ def forward_fn(
424424
# Decode
425425
logger.info("Decoding image...")
426426
decode_start = time.time()
427-
image = self.decode_latents(
428-
latents, lambda lat: self._decode_latents(lat, latent_ids, batch_size)
429-
)
427+
image = self.decode_latents(latents, lambda lat: self._decode_latents(lat, latent_ids))
430428

431429
if self.rank == 0:
432430
logger.info(f"Image decoded in {time.time() - decode_start:.2f}s")
@@ -659,17 +657,15 @@ def _decode_latents(
659657
self,
660658
latents: torch.Tensor,
661659
latent_ids: torch.Tensor,
662-
batch_size: int = 1,
663660
) -> torch.Tensor:
664661
"""Decode latents to image tensor.
665662
666663
Args:
667664
latents: Packed latents [B, seq, C].
668665
latent_ids: Position IDs [seq, 4].
669-
batch_size: Number of images in batch.
670666
671667
Returns:
672-
Image tensor (H, W, C) for single image, (B, H, W, C) for batch.
668+
Image tensor (B, H, W, C).
673669
"""
674670
# Unpack latents using position IDs
675671
latents = self._unpack_latents_with_ids(latents, latent_ids)

tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1402,7 +1402,7 @@ def decode_video_fn(vid_latents):
14021402
)
14031403
)
14041404
video = torch.cat(chunks, dim=2)
1405-
video = postprocess_video_tensor(video, remove_batch_dim=True)
1405+
video = postprocess_video_tensor(video)
14061406
return video
14071407

14081408
def decode_audio_fn(aud_latents):

tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,7 @@ def forward_fn(
483483
# Decode
484484
logger.info("Decoding video...")
485485
decode_start = time.time()
486-
video = self.decode_latents(latents, lambda lat: self._decode_latents(lat, batch_size))
486+
video = self.decode_latents(latents, self._decode_latents)
487487

488488
if self.rank == 0:
489489
logger.info(f"Video decoded in {time.time() - decode_start:.2f}s")
@@ -566,7 +566,7 @@ def _prepare_latents(
566566
return randn_tensor(shape, generator=generator, device=self.device, dtype=self.dtype)
567567

568568
@nvtx_range("_decode_latents", color="blue")
569-
def _decode_latents(self, latents: torch.Tensor, batch_size: int = 1) -> torch.Tensor:
569+
def _decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
570570
"""Decode latents to video tensor."""
571571
latents = latents.to(self.vae.dtype)
572572

tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -647,7 +647,7 @@ def forward_fn(
647647
# Decode
648648
logger.info("Decoding video...")
649649
decode_start = time.time()
650-
video = self.decode_latents(latents, lambda lat: self._decode_latents(lat, batch_size))
650+
video = self.decode_latents(latents, self._decode_latents)
651651

652652
if self.rank == 0:
653653
logger.info(f"Video decoded in {time.time() - decode_start:.2f}s")
@@ -828,7 +828,7 @@ def _prepare_latents(
828828

829829
return latents, condition
830830

831-
def _decode_latents(self, latents, batch_size=1):
831+
def _decode_latents(self, latents):
832832
"""Decode latents to video."""
833833
latents = latents.to(self.vae.dtype)
834834

tests/unittest/_torch/visual_gen/test_ltx2_pipeline.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,5 +413,75 @@ def test_attention_backend_comparison(self, ltx2_bf16_checkpoint_exists):
413413
torch.cuda.empty_cache()
414414

415415

416+
# ============================================================================
417+
# Batch Support Unit Tests (no model loading required)
418+
# ============================================================================
419+
420+
421+
class TestLTX2BatchSupport:
422+
"""Test batch support logic without loading the full pipeline."""
423+
424+
def test_video_pixel_shape_batch_propagation(self):
425+
"""VideoPixelShape(batch=N) propagates through VideoLatentShape."""
426+
from tensorrt_llm._torch.visual_gen.models.ltx2.ltx2_core.types import (
427+
VideoLatentShape,
428+
VideoPixelShape,
429+
)
430+
431+
for batch_size in [1, 2, 4]:
432+
pixel_shape = VideoPixelShape(
433+
batch=batch_size, frames=9, height=512, width=768, fps=24.0
434+
)
435+
video_shape = VideoLatentShape.from_pixel_shape(pixel_shape, latent_channels=128)
436+
assert video_shape.batch == batch_size
437+
torch_shape = video_shape.to_torch_shape()
438+
assert torch_shape[0] == batch_size
439+
440+
def test_prompt_normalization(self):
441+
"""forward() normalizes str prompt to List[str] and computes batch_size."""
442+
# Simulate the normalization logic from forward()
443+
for prompt_input, expected_batch in [
444+
("a cat", 1),
445+
(["a cat"], 1),
446+
(["a cat", "a dog"], 2),
447+
]:
448+
prompt = prompt_input
449+
if isinstance(prompt, str):
450+
prompt = [prompt]
451+
assert len(prompt) == expected_batch
452+
453+
def test_negative_prompt_expansion(self):
454+
"""Negative prompt is expanded to match batch_size."""
455+
# Simulate the negative prompt expansion logic from forward()
456+
for neg_input, batch_size, expected_len in [
457+
("bad quality", 1, 1),
458+
("bad quality", 3, 3),
459+
(["bad quality"], 3, 3),
460+
(["bad 1", "bad 2", "bad 3"], 3, 3),
461+
]:
462+
negative_prompt = neg_input
463+
if isinstance(negative_prompt, str):
464+
neg_prompt_list = [negative_prompt] * batch_size
465+
else:
466+
neg_prompt_list = list(negative_prompt)
467+
if len(neg_prompt_list) == 1 and batch_size > 1:
468+
neg_prompt_list = neg_prompt_list * batch_size
469+
assert len(neg_prompt_list) == expected_len
470+
471+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
472+
def test_latent_shape_matches_batch(self):
473+
"""Latents created from VideoLatentShape have correct batch dim."""
474+
from tensorrt_llm._torch.visual_gen.models.ltx2.ltx2_core.types import (
475+
VideoLatentShape,
476+
VideoPixelShape,
477+
)
478+
479+
batch_size = 2
480+
pixel_shape = VideoPixelShape(batch=batch_size, frames=9, height=512, width=768, fps=24.0)
481+
video_shape = VideoLatentShape.from_pixel_shape(pixel_shape, latent_channels=128)
482+
latents = torch.randn(video_shape.to_torch_shape(), device="cuda", dtype=torch.float32)
483+
assert latents.shape[0] == batch_size
484+
485+
416486
if __name__ == "__main__":
417487
pytest.main([__file__, "-v"])

0 commit comments

Comments
 (0)