Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/visual_gen/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class DiffusionRequest:
"""Request for diffusion inference with explicit model-specific parameters."""

request_id: int
prompt: str
prompt: List[str]
negative_prompt: Optional[str] = None
height: int = 720
width: int = 1280
Expand Down
74 changes: 51 additions & 23 deletions tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""FLUX.1 Pipeline implementation following WAN pattern."""

import time
from typing import Optional, Tuple
from typing import List, Optional, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -238,29 +238,37 @@ def infer(self, req):
@torch.inference_mode()
def forward(
self,
prompt: str,
prompt: Union[str, List[str]],
height: int = 1024,
width: int = 1024,
num_inference_steps: int = 50,
guidance_scale: float = 3.5,
seed: int = 42,
max_sequence_length: int = 512,
):
"""Generate image from text prompt.
"""Generate image(s) from text prompt(s).

Args:
prompt: Text prompt for image generation
prompt: Text prompt or list of prompts for image generation.
When a list is provided, generates one image per prompt in a
single batched forward pass.
height: Output image height (default: 1024)
width: Output image width (default: 1024)
num_inference_steps: Number of denoising steps (50 for dev, 4 for schnell)
guidance_scale: Embedded guidance scale (3.5 for dev)
seed: Random seed for reproducibility
seed: Random seed for reproducibility.
max_sequence_length: Maximum text sequence length

Returns:
MediaOutput with image tensor
MediaOutput with image tensor (B, H, W, C).
"""
pipeline_start = time.time()

# Determine batch size
if isinstance(prompt, str):
prompt = [prompt]
batch_size = len(prompt)

generator = torch.Generator(device=self.device).manual_seed(seed)

# Encode prompt
Expand All @@ -271,8 +279,7 @@ def forward(
)
logger.info(f"Prompt encoding completed in {time.time() - encode_start:.2f}s")

# Prepare latents
latents, latent_ids = self._prepare_latents(height, width, generator)
latents, latent_ids = self._prepare_latents(batch_size, height, width, generator)
logger.info(f"Latents shape: {latents.shape}")

# Prepare timesteps with dynamic shifting (FLUX uses mu parameter)
Expand Down Expand Up @@ -336,19 +343,19 @@ def forward_fn(

def _encode_prompt(
self,
prompt: str,
prompt: List[str],
max_sequence_length: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Encode prompt using CLIP and T5.
"""Encode prompt(s) using CLIP and T5.

Args:
prompt: Text prompt
max_sequence_length: Maximum T5 sequence length
prompt: List of prompts.
max_sequence_length: Maximum T5 sequence length.

Returns:
Tuple of (T5 embeddings, CLIP pooled embeddings, text position IDs)
Tuple of (T5 embeddings [B, seq, dim], CLIP pooled [B, dim],
text position IDs [seq, 3])
"""
prompt = [prompt] if isinstance(prompt, str) else prompt

# CLIP encoding (pooled embeddings)
clip_inputs = self.tokenizer(
Expand Down Expand Up @@ -478,33 +485,54 @@ def _unpack_latents(self, latents: torch.Tensor, height: int, width: int) -> tor

def _prepare_latents(
self,
batch_size: int,
height: int,
width: int,
generator: torch.Generator,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Prepare random latents in FLUX packed format and position IDs."""
"""Prepare random latents in FLUX packed format and position IDs.

Args:
batch_size: Number of images to generate.
height: Output image height.
width: Output image width.
generator: Random generator.

Returns:
Tuple of (latents [B, seq, 64], latent_ids [seq, 3])
"""
latent_height = height // self.vae_scale_factor
latent_width = width // self.vae_scale_factor

# Use VAE channels (16), not transformer channels (64)
# The packing will convert 16 -> 64
vae_channels = self.vae.config.latent_channels # 16
shape = (batch_size, vae_channels, latent_height, latent_width)

# Create random latents in VAE spatial format [B, 16, H, W]
shape = (1, vae_channels, latent_height, latent_width)
latents = randn_tensor(shape, generator=generator, device=self.device, dtype=self.dtype)
latents_4d = randn_tensor(shape, generator=generator, device=self.device, dtype=self.dtype)

# Prepare position IDs for packed format
# Prepare position IDs for packed format (shared across batch)
latent_ids = self._prepare_latent_ids(height, width)
latent_ids = latent_ids.to(self.device)

# Pack latents to FLUX sequence format [B, seq_len, 64]
latents = self._pack_latents(latents, 1, vae_channels, latent_height, latent_width)
latents = self._pack_latents(
latents_4d, batch_size, vae_channels, latent_height, latent_width
)

return latents, latent_ids

def _decode_latents(self, latents: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""Decode latents to image tensor."""
"""Decode latents to image tensor.

Args:
latents: Packed latents [B, seq, 64].
height: Output image height.
width: Output image width.

Returns:
Image tensor (B, H, W, C).
"""
# Unpack latents: (batch, seq_len, channels) -> (batch, channels, h, w)
latents = self._unpack_latents(latents, height, width)

Expand All @@ -515,9 +543,9 @@ def _decode_latents(self, latents: torch.Tensor, height: int, width: int) -> tor
latents = latents.to(self.vae.dtype)
image = self.vae.decode(latents, return_dict=False)[0]

# Post-process to tensor (H, W, C) uint8
# Post-process to tensor uint8
image = (image / 2 + 0.5).clamp(0, 1)
image = image.permute(0, 2, 3, 1) # (B, C, H, W) -> (B, H, W, C)
image = (image * 255).round().to(torch.uint8)

return image[0] # Remove batch dimension
return image # (B, H, W, C)
72 changes: 52 additions & 20 deletions tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import json
import os
import time
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -331,29 +331,37 @@ def infer(self, req):
@torch.inference_mode()
def forward(
self,
prompt: str,
prompt: Union[str, List[str]],
height: int = 1024,
width: int = 1024,
num_inference_steps: int = 50,
guidance_scale: float = 3.5,
seed: int = 42,
max_sequence_length: int = 512,
):
"""Generate image from text prompt.
"""Generate image(s) from text prompt(s).

Args:
prompt: Text prompt for image generation
prompt: Text prompt or list of prompts for image generation.
When a list is provided, generates one image per prompt in a
single batched forward pass.
height: Output image height (default: 1024)
width: Output image width (default: 1024)
num_inference_steps: Number of denoising steps
guidance_scale: Embedded guidance scale
seed: Random seed for reproducibility
seed: Random seed for reproducibility.
max_sequence_length: Maximum text sequence length

Returns:
Dict with "image" key containing PIL.Image
MediaOutput with image tensor (B, H, W, C).
"""
pipeline_start = time.time()

# Determine batch size
if isinstance(prompt, str):
prompt = [prompt]
batch_size = len(prompt)

generator = torch.Generator(device=self.device).manual_seed(seed)

# Encode prompt using Mistral3 multi-layer extraction
Expand All @@ -362,8 +370,7 @@ def forward(
prompt_embeds, text_ids = self._encode_prompt(prompt, max_sequence_length)
logger.info(f"Prompt encoding completed in {time.time() - encode_start:.2f}s")

# Prepare latents
latents, latent_ids = self._prepare_latents(height, width, generator)
latents, latent_ids = self._prepare_latents(batch_size, height, width, generator)
logger.info(f"Latents shape: {latents.shape}")

# Prepare timesteps with dynamic shifting
Expand Down Expand Up @@ -427,19 +434,22 @@ def forward_fn(

def _encode_prompt(
self,
prompt: str,
prompt: List[str],
max_sequence_length: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Encode prompt using multi-layer hidden state extraction.
"""Encode prompt(s) using multi-layer hidden state extraction.

Supports both text encoder types:
- Mistral3: system message + PixtralProcessor chat template
- Qwen3: simple user message + Qwen2TokenizerFast chat template

Args:
prompt: List of prompts.
max_sequence_length: Maximum text sequence length.

Returns:
Tuple of (prompt_embeds, text_ids)
Tuple of (prompt_embeds [B, seq, dim], text_ids [seq, 4])
"""
prompt = [prompt] if isinstance(prompt, str) else prompt

# Tokenize (format depends on text encoder type)
text_encoder_class = getattr(
Expand Down Expand Up @@ -565,27 +575,37 @@ def _prepare_latent_ids(self, height: int, width: int) -> torch.Tensor:

def _prepare_latents(
self,
batch_size: int,
height: int,
width: int,
generator: torch.Generator,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Prepare random latents in FLUX.2 packed format and position IDs."""
"""Prepare random latents in FLUX.2 packed format and position IDs.

Args:
batch_size: Number of images to generate.
height: Output image height.
width: Output image width.
generator: Random generator.

Returns:
Tuple of (latents [B, seq, C], latent_ids [seq, 4])
"""
# FLUX.2: in_channels=128, VAE scale=8, 2x2 packing
latent_height = 2 * (height // (self.vae_scale_factor * 2))
latent_width = 2 * (width // (self.vae_scale_factor * 2))

in_channels = self.transformer.config.in_channels # 128
latent_shape = (batch_size, in_channels, latent_height // 2, latent_width // 2)

# Create 4D latents then pack (matches HF for seed reproducibility)
latent_shape = (1, in_channels, latent_height // 2, latent_width // 2)
latents_4d = randn_tensor(
latent_shape, generator=generator, device=self.device, dtype=self.dtype
)

# Pack latents: [B, C, H, W] -> [B, H*W, C]
latents = self._pack_latents(latents_4d)

# Prepare position IDs
# Prepare position IDs (shared across batch)
latent_ids = self._prepare_latent_ids(height, width)

return latents, latent_ids
Expand Down Expand Up @@ -633,8 +653,20 @@ def _unpatchify_latents(self, latents: torch.Tensor) -> torch.Tensor:

return latents

def _decode_latents(self, latents: torch.Tensor, latent_ids: torch.Tensor) -> torch.Tensor:
"""Decode latents to image tensor."""
def _decode_latents(
self,
latents: torch.Tensor,
latent_ids: torch.Tensor,
) -> torch.Tensor:
"""Decode latents to image tensor.

Args:
latents: Packed latents [B, seq, C].
latent_ids: Position IDs [seq, 4].

Returns:
Image tensor (B, H, W, C).
"""
# Unpack latents using position IDs
latents = self._unpack_latents_with_ids(latents, latent_ids)

Expand All @@ -656,9 +688,9 @@ def _decode_latents(self, latents: torch.Tensor, latent_ids: torch.Tensor) -> to
latents = latents.to(self.vae.dtype)
image = self.vae.decode(latents, return_dict=False)[0]

# Post-process to tensor (H, W, C) uint8
# Post-process to tensor uint8
image = (image / 2 + 0.5).clamp(0, 1)
image = image.permute(0, 2, 3, 1) # (B, C, H, W) -> (B, H, W, C)
image = (image * 255).round().to(torch.uint8)

return image[0] # Remove batch dimension
return image # (B, H, W, C)
Original file line number Diff line number Diff line change
Expand Up @@ -1402,7 +1402,7 @@ def decode_video_fn(vid_latents):
)
)
video = torch.cat(chunks, dim=2)
video = postprocess_video_tensor(video, remove_batch_dim=True)
video = postprocess_video_tensor(video)
return video

def decode_audio_fn(aud_latents):
Expand Down
Loading
Loading