Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,5 @@ transforms:
stage: post_load_fusion
enabled: true
max_batch_size: 8
cuda_graph_config:
max_batch_size: 8
3 changes: 2 additions & 1 deletion examples/auto_deploy/model_registry/configs/deepseek-r1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ compile_backend: torch-cudagraph
max_batch_size: 64
max_seq_len: 8192
enable_chunked_prefill: true
cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 32, 64]
cuda_graph_config:
batch_sizes: [1, 2, 4, 8, 16, 32, 64]
transforms:
fuse_nvfp4_moe:
allow_different_input_scales: true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ compile_backend: torch-cudagraph
model_factory: AutoModelForCausalLM
max_seq_len: 512
max_batch_size: 8
cuda_graph_config:
max_batch_size: 8
world_size: 1

# Gemma 3n uses shared-KV decode semantics in the tail layers. FlashInfer
Expand Down
3 changes: 3 additions & 0 deletions tensorrt_llm/_torch/visual_gen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
VisualGenArgs,
discover_pipeline_components,
)
from .mapping import VisualGenMapping
from .models import AutoPipeline, BasePipeline, WanPipeline
from .pipeline_loader import PipelineLoader

Expand All @@ -42,6 +43,8 @@
"DiffusionRequest",
"DiffusionResponse",
"MediaOutput",
# Mapping
"VisualGenMapping",
# Pipelines
"AutoPipeline",
"BasePipeline",
Expand Down
68 changes: 34 additions & 34 deletions tensorrt_llm/_torch/visual_gen/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pydantic import BaseModel, ConfigDict, model_validator
from pydantic import Field as PydanticField

from tensorrt_llm._torch.visual_gen.mapping import DEFAULT_DIM_ORDER
from tensorrt_llm.functional import AllReduceStrategy
from tensorrt_llm.llmapi.utils import StrictBaseModel, set_api_status
from tensorrt_llm.logger import logger
Expand Down Expand Up @@ -64,11 +65,7 @@ class ParallelConfig(StrictBaseModel):
- dit_ring_size: Ring attention (not implemented)
- dit_cp_size, dit_dp_size, dit_fsdp_size: Other parallelism types

Total world_size = dit_cfg_size × dit_ulysses_size

Parallelism Strategy:
- CFG Parallelism: Distributes positive/negative prompts across GPUs
- Ulysses Parallelism: Distributes sequence within each CFG group
See mapping.py for more details.

Example Configurations:
1. cfg_size=1, ulysses_size=2 -> 2 GPUs (Ulysses only)
Expand Down Expand Up @@ -98,6 +95,13 @@ class ParallelConfig(StrictBaseModel):
dit_cp_size: int = PydanticField(1, ge=1)
dit_cfg_size: int = PydanticField(1, ge=1) # Supported
dit_fsdp_size: int = PydanticField(1, ge=1)
dit_dim_order: str = PydanticField(
DEFAULT_DIM_ORDER,
description=(
"Outermost-to-innermost ordering of parallelism axes for the "
"DeviceMesh. Innermost = most contiguous ranks."
),
)

# Refiner Parallelism (Optional)
refiner_dit_dp_size: int = 1
Expand All @@ -114,36 +118,14 @@ class ParallelConfig(StrictBaseModel):
def n_workers(self) -> int:
return self.dit_cfg_size * self.dit_ulysses_size

def to_mapping(self) -> Mapping:
"""Convert to TRT-LLM Mapping."""
world_size = self.dit_tp_size * self.dit_cp_size
return Mapping(
world_size=world_size,
tp_size=self.dit_tp_size,
pp_size=1,
cp_size=self.dit_cp_size,
)

@property
def total_parallel_size(self) -> int:
"""Total parallelism across all DiT dimensions."""
return (
self.dit_tp_size
* self.dit_ulysses_size
* self.dit_ring_size
* self.dit_cp_size
* self.dit_dp_size
* self.dit_cfg_size
)
return self.dit_cfg_size * self.dit_tp_size * self.dit_ring_size * self.dit_ulysses_size

def validate_world_size(self, world_size: int) -> None:
"""Validate that the parallel config is compatible with the given world size.

Called at launch time when WORLD_SIZE is known (not at config construction).
"""
if self.total_parallel_size > world_size:
raise ValueError(
f"Total DiT parallel size ({self.total_parallel_size}) "
f"total_parallel_size ({self.total_parallel_size}) "
f"exceeds world_size ({world_size})"
)

Expand Down Expand Up @@ -420,6 +402,19 @@ def to_mapping(self) -> Mapping:
"""Derive Mapping from ParallelConfig."""
return self.parallel.to_mapping()

def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return self.model_dump()

@set_api_status("prototype")
@classmethod
def from_dict(cls, config_dict: Dict[str, Any]) -> "VisualGenArgs":
"""Create from dictionary with automatic nested config parsing.

Unknown fields cause a ValidationError (extra="forbid").
"""
return cls(**config_dict)

@set_api_status("prototype")
@classmethod
def from_yaml(cls, yaml_path: Union[str, Path], **overrides: Any) -> "VisualGenArgs":
Expand Down Expand Up @@ -479,7 +474,8 @@ class DiffusionModelConfig(BaseModel):
Contains merged/parsed config from:
- pretrained_config: From checkpoint/config.json
- quant_config: From checkpoint or user quant config
- Sub-configs: From VisualGenArgs (pipeline, attention, parallel, teacache)
- Sub-configs: From VisualGenArgs (pipeline, attention, teacache)
- visual_gen_mapping: Populated by setup_visual_gen_mapping() from ParallelConfig
"""

model_config = ConfigDict(arbitrary_types_allowed=True)
Expand All @@ -491,8 +487,12 @@ class DiffusionModelConfig(BaseModel):
allreduce_strategy: AllReduceStrategy = PydanticField(default=AllReduceStrategy.AUTO)
extra_attrs: Dict = PydanticField(default_factory=dict)

# Distributed process groups
ulysses_process_group: Optional[torch.distributed.ProcessGroup] = None
# Unified parallelism mapping (populated by setup_visual_gen_mapping)
visual_gen_mapping: Optional[Any] = None # VisualGenMapping (lazy import)

# VAE parallelism (promoted from ParallelConfig for pipeline_loader)
enable_parallel_vae: bool = True
parallel_vae_split_dim: Literal["width", "height"] = "width"

dynamic_weight_quant: bool = False

Expand All @@ -505,7 +505,6 @@ class DiffusionModelConfig(BaseModel):
cuda_graph: CudaGraphConfig = PydanticField(default_factory=CudaGraphConfig)
pipeline: PipelineConfig = PydanticField(default_factory=PipelineConfig)
attention: AttentionConfig = PydanticField(default_factory=AttentionConfig)
parallel: ParallelConfig = PydanticField(default_factory=ParallelConfig)
teacache: TeaCacheConfig = PydanticField(default_factory=TeaCacheConfig)

@property
Expand Down Expand Up @@ -853,8 +852,9 @@ def from_pretrained(
cuda_graph=cuda_graph_cfg,
pipeline=pipeline_cfg,
attention=attention_cfg,
parallel=parallel_cfg,
teacache=teacache_cfg,
enable_parallel_vae=parallel_cfg.enable_parallel_vae,
parallel_vae_split_dim=parallel_cfg.parallel_vae_split_dim,
skip_create_weights_in_init=True,
extra_attrs=extra_attrs,
**kwargs,
Expand Down
3 changes: 0 additions & 3 deletions tensorrt_llm/_torch/visual_gen/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,6 @@ def run_diffusion_worker(
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)

# Runtime check: parallel config vs actual world size
diffusion_args.parallel.validate_world_size(world_size)

# Calculate device_id before init_process_group
device_id = rank % torch.cuda.device_count() if torch.cuda.is_available() else 0
if torch.cuda.is_available():
Expand Down
177 changes: 177 additions & 0 deletions tensorrt_llm/_torch/visual_gen/mapping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
"""Unified multi-dimensional communicator mesh for visual generation models.

VisualGenMapping subclasses DeviceMeshTopologyImpl and overrides build_mesh()
to create a single PyTorch DeviceMesh covering all parallelism axes
(CFG, TP, Ring, Ulysses). The resulting mesh is stored in the shared
DeviceMeshTopologyImpl.device_mesh class variable so that any Mapping object
constructed afterward (e.g. via to_llm_mapping()) can reuse the same
process groups.
"""

from __future__ import annotations

from typing import Optional

import torch.distributed as dist
from torch.distributed import ProcessGroup
from torch.distributed.device_mesh import init_device_mesh

from tensorrt_llm._torch.device_mesh import DeviceMeshTopologyImpl, SingleProcessGroup
from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping

_VALID_DIM_NAMES = frozenset({"cfg", "tp", "ring", "ulysses"})
DEFAULT_DIM_ORDER = "cfg-tp-ring-ulysses"


class VisualGenMapping(DeviceMeshTopologyImpl):
"""Multi-dimensional communicator mesh for visual generation models.

Parallelism Strategy:
- CFG Parallelism: Distributes positive/negative prompts across GPUs
- Ulysses Parallelism: Distributes sequence within each CFG group

Ordering rationale (default ``"cfg-tp-ring-ulysses"``):
- Ulysses innermost: all-to-all is latency-sensitive, contiguous ranks
- Ring next: KV streaming between adjacent ranks
- TP next: all-reduce for Linear
- CFG outermost: independent until final all-gather

The *order* string maps directly to ``init_device_mesh``'s
``mesh_shape`` tuple (first = outermost / slowest-varying, last =
innermost / most contiguous).
"""

def __init__(
self,
world_size: int,
rank: int,
cfg_size: int = 1,
tp_size: int = 1,
ring_size: int = 1,
ulysses_size: int = 1,
order: str = DEFAULT_DIM_ORDER,
):
product = cfg_size * tp_size * ring_size * ulysses_size
if product != world_size:
raise ValueError(
f"cfg({cfg_size}) * tp({tp_size}) * ring({ring_size}) * "
f"ulysses({ulysses_size}) = {product} != world_size({world_size})"
)

dims = order.split("-")
if set(dims) != _VALID_DIM_NAMES or len(dims) != len(_VALID_DIM_NAMES):
raise ValueError(
f"order must be a '-'-separated permutation of "
f"{sorted(_VALID_DIM_NAMES)}, got '{order}'"
)

self.world_size = world_size
self._rank = rank
self.cfg_size = cfg_size
self.tp_size = tp_size
self.ring_size = ring_size
self.ulysses_size = ulysses_size
self._order = order
self._dim_names = tuple(dims)
self._dim_sizes = {
"cfg": cfg_size,
"tp": tp_size,
"ring": ring_size,
"ulysses": ulysses_size,
}

if dist.is_initialized() and world_size > 1:
self.build_mesh()

# ------------------------------------------------------------------
# Mesh construction
# ------------------------------------------------------------------
def build_mesh(self):
cls = DeviceMeshTopologyImpl
if cls.device_mesh is not None:
return

shape = tuple(self._dim_sizes[d] for d in self._dim_names)
cls.device_mesh = init_device_mesh(
"cuda",
mesh_shape=shape,
mesh_dim_names=self._dim_names,
)
logger.debug(
f"VisualGenMapping.build_mesh: dims={self._dim_names}, "
f"shape={shape}, mesh={cls.device_mesh}"
)

# ------------------------------------------------------------------
# Rank decomposition
# ------------------------------------------------------------------
def _local_rank(self, dim: str) -> int:
cls = DeviceMeshTopologyImpl
if cls.device_mesh is None:
return 0
return cls.device_mesh[dim].get_local_rank()

@property
def cfg_rank(self) -> int:
return self._local_rank("cfg")

@property
def tp_rank(self) -> int:
return self._local_rank("tp")

@property
def ring_rank(self) -> int:
return self._local_rank("ring")

@property
def ulysses_rank(self) -> int:
return self._local_rank("ulysses")

@property
def is_cfg_conditional(self) -> bool:
return self.cfg_rank == 0

# ------------------------------------------------------------------
# Process groups (None when size == 1 and mesh was not built)
# ------------------------------------------------------------------
def _group(self, dim: str) -> Optional[ProcessGroup]:
cls = DeviceMeshTopologyImpl
if cls.device_mesh is None:
if self.world_size == 1:
return SingleProcessGroup.get_group()
return None
return cls.device_mesh[dim].get_group()

@property
def ulysses_group(self) -> Optional[ProcessGroup]:
return self._group("ulysses")

@property
def ring_group(self) -> Optional[ProcessGroup]:
return self._group("ring")

@property
def tp_group_pg(self) -> Optional[ProcessGroup]:
return self._group("tp")

@property
def cfg_group(self) -> Optional[ProcessGroup]:
return self._group("cfg")

# ------------------------------------------------------------------
# Bridge to LLM Mapping (for Linear layers)
# ------------------------------------------------------------------
def to_llm_mapping(self) -> Mapping:
"""Return a ``Mapping`` whose TP group is backed by this mesh's TP dim.

``build_mesh()`` has already populated
``DeviceMeshTopologyImpl.device_mesh``, so the returned ``Mapping``'s
``build_mesh()`` is a no-op and ``tp_group_pg`` reads from the shared
mega-mesh.
"""
return Mapping(
world_size=self.tp_size,
rank=self.tp_rank,
tp_size=self.tp_size,
)
5 changes: 4 additions & 1 deletion tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ class FluxPipeline(BasePipeline):
"""

def __init__(self, model_config):
if model_config.parallel.dit_cfg_size != 1:
if (
model_config.visual_gen_mapping is not None
and model_config.visual_gen_mapping.cfg_size != 1
):
raise ValueError(
"FluxPipeline does not support CFG parallelism. Please set dit_cfg_size to 1."
)
Expand Down
5 changes: 4 additions & 1 deletion tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,10 @@ class Flux2Pipeline(BasePipeline):
HIDDEN_STATE_LAYERS: Tuple[int, ...] = (10, 20, 30)

def __init__(self, model_config):
if model_config.parallel.dit_cfg_size != 1:
if (
model_config.visual_gen_mapping is not None
and model_config.visual_gen_mapping.cfg_size != 1
):
raise ValueError(
"Flux2Pipeline does not support CFG parallelism. Please set dit_cfg_size to 1."
)
Expand Down
Loading
Loading