diff --git a/examples/auto_deploy/model_registry/configs/dashboard_default.yaml b/examples/auto_deploy/model_registry/configs/dashboard_default.yaml index 6c7fa4a243b0..e01a7a3a1b4a 100644 --- a/examples/auto_deploy/model_registry/configs/dashboard_default.yaml +++ b/examples/auto_deploy/model_registry/configs/dashboard_default.yaml @@ -14,3 +14,5 @@ transforms: stage: post_load_fusion enabled: true max_batch_size: 8 +cuda_graph_config: + max_batch_size: 8 diff --git a/examples/auto_deploy/model_registry/configs/deepseek-r1.yaml b/examples/auto_deploy/model_registry/configs/deepseek-r1.yaml index 38e58ced8f76..47efd051b4ea 100644 --- a/examples/auto_deploy/model_registry/configs/deepseek-r1.yaml +++ b/examples/auto_deploy/model_registry/configs/deepseek-r1.yaml @@ -3,7 +3,8 @@ compile_backend: torch-cudagraph max_batch_size: 64 max_seq_len: 8192 enable_chunked_prefill: true -cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 32, 64] +cuda_graph_config: + batch_sizes: [1, 2, 4, 8, 16, 32, 64] transforms: fuse_nvfp4_moe: allow_different_input_scales: true diff --git a/examples/auto_deploy/model_registry/configs/gemma3n_e2b_it.yaml b/examples/auto_deploy/model_registry/configs/gemma3n_e2b_it.yaml index 2e5e1f8d5cb4..11b5a6262077 100644 --- a/examples/auto_deploy/model_registry/configs/gemma3n_e2b_it.yaml +++ b/examples/auto_deploy/model_registry/configs/gemma3n_e2b_it.yaml @@ -6,6 +6,8 @@ compile_backend: torch-cudagraph model_factory: AutoModelForCausalLM max_seq_len: 512 max_batch_size: 8 +cuda_graph_config: + max_batch_size: 8 world_size: 1 # Gemma 3n uses shared-KV decode semantics in the tail layers. FlashInfer diff --git a/tensorrt_llm/_torch/visual_gen/__init__.py b/tensorrt_llm/_torch/visual_gen/__init__.py index 5926a61681ba..154d0aefd010 100644 --- a/tensorrt_llm/_torch/visual_gen/__init__.py +++ b/tensorrt_llm/_torch/visual_gen/__init__.py @@ -21,6 +21,7 @@ VisualGenArgs, discover_pipeline_components, ) +from .mapping import VisualGenMapping from .models import AutoPipeline, BasePipeline, WanPipeline from .pipeline_loader import PipelineLoader @@ -42,6 +43,8 @@ "DiffusionRequest", "DiffusionResponse", "MediaOutput", + # Mapping + "VisualGenMapping", # Pipelines "AutoPipeline", "BasePipeline", diff --git a/tensorrt_llm/_torch/visual_gen/config.py b/tensorrt_llm/_torch/visual_gen/config.py index 35b1c8d622a4..62ddbed69ce9 100644 --- a/tensorrt_llm/_torch/visual_gen/config.py +++ b/tensorrt_llm/_torch/visual_gen/config.py @@ -9,6 +9,7 @@ from pydantic import BaseModel, ConfigDict, model_validator from pydantic import Field as PydanticField +from tensorrt_llm._torch.visual_gen.mapping import DEFAULT_DIM_ORDER from tensorrt_llm.functional import AllReduceStrategy from tensorrt_llm.llmapi.utils import StrictBaseModel, set_api_status from tensorrt_llm.logger import logger @@ -64,11 +65,7 @@ class ParallelConfig(StrictBaseModel): - dit_ring_size: Ring attention (not implemented) - dit_cp_size, dit_dp_size, dit_fsdp_size: Other parallelism types - Total world_size = dit_cfg_size × dit_ulysses_size - - Parallelism Strategy: - - CFG Parallelism: Distributes positive/negative prompts across GPUs - - Ulysses Parallelism: Distributes sequence within each CFG group + See mapping.py for more details. Example Configurations: 1. cfg_size=1, ulysses_size=2 -> 2 GPUs (Ulysses only) @@ -98,6 +95,13 @@ class ParallelConfig(StrictBaseModel): dit_cp_size: int = PydanticField(1, ge=1) dit_cfg_size: int = PydanticField(1, ge=1) # Supported dit_fsdp_size: int = PydanticField(1, ge=1) + dit_dim_order: str = PydanticField( + DEFAULT_DIM_ORDER, + description=( + "Outermost-to-innermost ordering of parallelism axes for the " + "DeviceMesh. Innermost = most contiguous ranks." + ), + ) # Refiner Parallelism (Optional) refiner_dit_dp_size: int = 1 @@ -114,36 +118,14 @@ class ParallelConfig(StrictBaseModel): def n_workers(self) -> int: return self.dit_cfg_size * self.dit_ulysses_size - def to_mapping(self) -> Mapping: - """Convert to TRT-LLM Mapping.""" - world_size = self.dit_tp_size * self.dit_cp_size - return Mapping( - world_size=world_size, - tp_size=self.dit_tp_size, - pp_size=1, - cp_size=self.dit_cp_size, - ) - @property def total_parallel_size(self) -> int: - """Total parallelism across all DiT dimensions.""" - return ( - self.dit_tp_size - * self.dit_ulysses_size - * self.dit_ring_size - * self.dit_cp_size - * self.dit_dp_size - * self.dit_cfg_size - ) + return self.dit_cfg_size * self.dit_tp_size * self.dit_ring_size * self.dit_ulysses_size def validate_world_size(self, world_size: int) -> None: - """Validate that the parallel config is compatible with the given world size. - - Called at launch time when WORLD_SIZE is known (not at config construction). - """ if self.total_parallel_size > world_size: raise ValueError( - f"Total DiT parallel size ({self.total_parallel_size}) " + f"total_parallel_size ({self.total_parallel_size}) " f"exceeds world_size ({world_size})" ) @@ -420,6 +402,19 @@ def to_mapping(self) -> Mapping: """Derive Mapping from ParallelConfig.""" return self.parallel.to_mapping() + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary.""" + return self.model_dump() + + @set_api_status("prototype") + @classmethod + def from_dict(cls, config_dict: Dict[str, Any]) -> "VisualGenArgs": + """Create from dictionary with automatic nested config parsing. + + Unknown fields cause a ValidationError (extra="forbid"). + """ + return cls(**config_dict) + @set_api_status("prototype") @classmethod def from_yaml(cls, yaml_path: Union[str, Path], **overrides: Any) -> "VisualGenArgs": @@ -479,7 +474,8 @@ class DiffusionModelConfig(BaseModel): Contains merged/parsed config from: - pretrained_config: From checkpoint/config.json - quant_config: From checkpoint or user quant config - - Sub-configs: From VisualGenArgs (pipeline, attention, parallel, teacache) + - Sub-configs: From VisualGenArgs (pipeline, attention, teacache) + - visual_gen_mapping: Populated by setup_visual_gen_mapping() from ParallelConfig """ model_config = ConfigDict(arbitrary_types_allowed=True) @@ -491,8 +487,12 @@ class DiffusionModelConfig(BaseModel): allreduce_strategy: AllReduceStrategy = PydanticField(default=AllReduceStrategy.AUTO) extra_attrs: Dict = PydanticField(default_factory=dict) - # Distributed process groups - ulysses_process_group: Optional[torch.distributed.ProcessGroup] = None + # Unified parallelism mapping (populated by setup_visual_gen_mapping) + visual_gen_mapping: Optional[Any] = None # VisualGenMapping (lazy import) + + # VAE parallelism (promoted from ParallelConfig for pipeline_loader) + enable_parallel_vae: bool = True + parallel_vae_split_dim: Literal["width", "height"] = "width" dynamic_weight_quant: bool = False @@ -505,7 +505,6 @@ class DiffusionModelConfig(BaseModel): cuda_graph: CudaGraphConfig = PydanticField(default_factory=CudaGraphConfig) pipeline: PipelineConfig = PydanticField(default_factory=PipelineConfig) attention: AttentionConfig = PydanticField(default_factory=AttentionConfig) - parallel: ParallelConfig = PydanticField(default_factory=ParallelConfig) teacache: TeaCacheConfig = PydanticField(default_factory=TeaCacheConfig) @property @@ -853,8 +852,9 @@ def from_pretrained( cuda_graph=cuda_graph_cfg, pipeline=pipeline_cfg, attention=attention_cfg, - parallel=parallel_cfg, teacache=teacache_cfg, + enable_parallel_vae=parallel_cfg.enable_parallel_vae, + parallel_vae_split_dim=parallel_cfg.parallel_vae_split_dim, skip_create_weights_in_init=True, extra_attrs=extra_attrs, **kwargs, diff --git a/tensorrt_llm/_torch/visual_gen/executor.py b/tensorrt_llm/_torch/visual_gen/executor.py index e56b0225f67b..3062de13fd6b 100644 --- a/tensorrt_llm/_torch/visual_gen/executor.py +++ b/tensorrt_llm/_torch/visual_gen/executor.py @@ -232,9 +232,6 @@ def run_diffusion_worker( os.environ["RANK"] = str(rank) os.environ["WORLD_SIZE"] = str(world_size) - # Runtime check: parallel config vs actual world size - diffusion_args.parallel.validate_world_size(world_size) - # Calculate device_id before init_process_group device_id = rank % torch.cuda.device_count() if torch.cuda.is_available() else 0 if torch.cuda.is_available(): diff --git a/tensorrt_llm/_torch/visual_gen/mapping.py b/tensorrt_llm/_torch/visual_gen/mapping.py new file mode 100644 index 000000000000..5c93bb2fb8c8 --- /dev/null +++ b/tensorrt_llm/_torch/visual_gen/mapping.py @@ -0,0 +1,177 @@ +"""Unified multi-dimensional communicator mesh for visual generation models. + +VisualGenMapping subclasses DeviceMeshTopologyImpl and overrides build_mesh() +to create a single PyTorch DeviceMesh covering all parallelism axes +(CFG, TP, Ring, Ulysses). The resulting mesh is stored in the shared +DeviceMeshTopologyImpl.device_mesh class variable so that any Mapping object +constructed afterward (e.g. via to_llm_mapping()) can reuse the same +process groups. +""" + +from __future__ import annotations + +from typing import Optional + +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed.device_mesh import init_device_mesh + +from tensorrt_llm._torch.device_mesh import DeviceMeshTopologyImpl, SingleProcessGroup +from tensorrt_llm.logger import logger +from tensorrt_llm.mapping import Mapping + +_VALID_DIM_NAMES = frozenset({"cfg", "tp", "ring", "ulysses"}) +DEFAULT_DIM_ORDER = "cfg-tp-ring-ulysses" + + +class VisualGenMapping(DeviceMeshTopologyImpl): + """Multi-dimensional communicator mesh for visual generation models. + + Parallelism Strategy: + - CFG Parallelism: Distributes positive/negative prompts across GPUs + - Ulysses Parallelism: Distributes sequence within each CFG group + + Ordering rationale (default ``"cfg-tp-ring-ulysses"``): + - Ulysses innermost: all-to-all is latency-sensitive, contiguous ranks + - Ring next: KV streaming between adjacent ranks + - TP next: all-reduce for Linear + - CFG outermost: independent until final all-gather + + The *order* string maps directly to ``init_device_mesh``'s + ``mesh_shape`` tuple (first = outermost / slowest-varying, last = + innermost / most contiguous). + """ + + def __init__( + self, + world_size: int, + rank: int, + cfg_size: int = 1, + tp_size: int = 1, + ring_size: int = 1, + ulysses_size: int = 1, + order: str = DEFAULT_DIM_ORDER, + ): + product = cfg_size * tp_size * ring_size * ulysses_size + if product != world_size: + raise ValueError( + f"cfg({cfg_size}) * tp({tp_size}) * ring({ring_size}) * " + f"ulysses({ulysses_size}) = {product} != world_size({world_size})" + ) + + dims = order.split("-") + if set(dims) != _VALID_DIM_NAMES or len(dims) != len(_VALID_DIM_NAMES): + raise ValueError( + f"order must be a '-'-separated permutation of " + f"{sorted(_VALID_DIM_NAMES)}, got '{order}'" + ) + + self.world_size = world_size + self._rank = rank + self.cfg_size = cfg_size + self.tp_size = tp_size + self.ring_size = ring_size + self.ulysses_size = ulysses_size + self._order = order + self._dim_names = tuple(dims) + self._dim_sizes = { + "cfg": cfg_size, + "tp": tp_size, + "ring": ring_size, + "ulysses": ulysses_size, + } + + if dist.is_initialized() and world_size > 1: + self.build_mesh() + + # ------------------------------------------------------------------ + # Mesh construction + # ------------------------------------------------------------------ + def build_mesh(self): + cls = DeviceMeshTopologyImpl + if cls.device_mesh is not None: + return + + shape = tuple(self._dim_sizes[d] for d in self._dim_names) + cls.device_mesh = init_device_mesh( + "cuda", + mesh_shape=shape, + mesh_dim_names=self._dim_names, + ) + logger.debug( + f"VisualGenMapping.build_mesh: dims={self._dim_names}, " + f"shape={shape}, mesh={cls.device_mesh}" + ) + + # ------------------------------------------------------------------ + # Rank decomposition + # ------------------------------------------------------------------ + def _local_rank(self, dim: str) -> int: + cls = DeviceMeshTopologyImpl + if cls.device_mesh is None: + return 0 + return cls.device_mesh[dim].get_local_rank() + + @property + def cfg_rank(self) -> int: + return self._local_rank("cfg") + + @property + def tp_rank(self) -> int: + return self._local_rank("tp") + + @property + def ring_rank(self) -> int: + return self._local_rank("ring") + + @property + def ulysses_rank(self) -> int: + return self._local_rank("ulysses") + + @property + def is_cfg_conditional(self) -> bool: + return self.cfg_rank == 0 + + # ------------------------------------------------------------------ + # Process groups (None when size == 1 and mesh was not built) + # ------------------------------------------------------------------ + def _group(self, dim: str) -> Optional[ProcessGroup]: + cls = DeviceMeshTopologyImpl + if cls.device_mesh is None: + if self.world_size == 1: + return SingleProcessGroup.get_group() + return None + return cls.device_mesh[dim].get_group() + + @property + def ulysses_group(self) -> Optional[ProcessGroup]: + return self._group("ulysses") + + @property + def ring_group(self) -> Optional[ProcessGroup]: + return self._group("ring") + + @property + def tp_group_pg(self) -> Optional[ProcessGroup]: + return self._group("tp") + + @property + def cfg_group(self) -> Optional[ProcessGroup]: + return self._group("cfg") + + # ------------------------------------------------------------------ + # Bridge to LLM Mapping (for Linear layers) + # ------------------------------------------------------------------ + def to_llm_mapping(self) -> Mapping: + """Return a ``Mapping`` whose TP group is backed by this mesh's TP dim. + + ``build_mesh()`` has already populated + ``DeviceMeshTopologyImpl.device_mesh``, so the returned ``Mapping``'s + ``build_mesh()`` is a no-op and ``tp_group_pg`` reads from the shared + mega-mesh. + """ + return Mapping( + world_size=self.tp_size, + rank=self.tp_rank, + tp_size=self.tp_size, + ) diff --git a/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux.py b/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux.py index 422eb1594c03..457c90fb3e25 100644 --- a/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux.py +++ b/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux.py @@ -45,7 +45,10 @@ class FluxPipeline(BasePipeline): """ def __init__(self, model_config): - if model_config.parallel.dit_cfg_size != 1: + if ( + model_config.visual_gen_mapping is not None + and model_config.visual_gen_mapping.cfg_size != 1 + ): raise ValueError( "FluxPipeline does not support CFG parallelism. Please set dit_cfg_size to 1." ) diff --git a/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py b/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py index 9b5eef5b80b9..522421124ad4 100644 --- a/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py +++ b/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py @@ -114,7 +114,10 @@ class Flux2Pipeline(BasePipeline): HIDDEN_STATE_LAYERS: Tuple[int, ...] = (10, 20, 30) def __init__(self, model_config): - if model_config.parallel.dit_cfg_size != 1: + if ( + model_config.visual_gen_mapping is not None + and model_config.visual_gen_mapping.cfg_size != 1 + ): raise ValueError( "Flux2Pipeline does not support CFG parallelism. Please set dit_cfg_size to 1." ) diff --git a/tensorrt_llm/_torch/visual_gen/models/flux/transformer_flux.py b/tensorrt_llm/_torch/visual_gen/models/flux/transformer_flux.py index a07d44c58c53..400beefc747b 100644 --- a/tensorrt_llm/_torch/visual_gen/models/flux/transformer_flux.py +++ b/tensorrt_llm/_torch/visual_gen/models/flux/transformer_flux.py @@ -32,7 +32,6 @@ from tensorrt_llm._torch.utils import maybe_compile from tensorrt_llm._torch.visual_gen.models.flux.attention import FluxJointAttention from tensorrt_llm._torch.visual_gen.models.flux.pos_embed_flux import FluxPosEmbed -from tensorrt_llm._torch.visual_gen.parallelism import setup_sequence_parallelism from tensorrt_llm._torch.visual_gen.quantization.loader import DynamicLinearWeightLoader from tensorrt_llm.models.modeling_utils import QuantConfig @@ -566,14 +565,18 @@ def __init__(self, model_config: "DiffusionModelConfig"): super().__init__() self.model_config = model_config - # Setup sequence parallelism (Ulysses) + vgm = model_config.visual_gen_mapping num_heads = getattr(model_config.pretrained_config, "num_attention_heads", 24) - self.use_ulysses, self.ulysses_size, self.ulysses_pg, self.ulysses_rank = ( - setup_sequence_parallelism( - model_config=model_config, - num_attention_heads=num_heads, + ulysses_size = vgm.ulysses_size if vgm else 1 + if ulysses_size > 1 and num_heads % ulysses_size != 0: + raise ValueError( + f"num_attention_heads ({num_heads}) must be divisible by " + f"ulysses_size ({ulysses_size})" ) - ) + self.use_ulysses = ulysses_size > 1 + self.ulysses_size = ulysses_size + self.ulysses_pg = vgm.ulysses_group if vgm else None + self.ulysses_rank = vgm.ulysses_rank if vgm else 0 # Extract pretrained config from model_config pretrained_config = model_config.pretrained_config diff --git a/tensorrt_llm/_torch/visual_gen/models/flux/transformer_flux2.py b/tensorrt_llm/_torch/visual_gen/models/flux/transformer_flux2.py index e08293d13b02..0de5a7dad6ba 100644 --- a/tensorrt_llm/_torch/visual_gen/models/flux/transformer_flux2.py +++ b/tensorrt_llm/_torch/visual_gen/models/flux/transformer_flux2.py @@ -36,7 +36,6 @@ AdaLayerNormContinuous, _remap_checkpoint_keys, ) -from tensorrt_llm._torch.visual_gen.parallelism import setup_sequence_parallelism from tensorrt_llm._torch.visual_gen.quantization.loader import DynamicLinearWeightLoader from tensorrt_llm.models.modeling_utils import QuantConfig @@ -432,14 +431,18 @@ def __init__(self, model_config: "DiffusionModelConfig"): super().__init__() self.model_config = model_config - # Setup sequence parallelism (Ulysses) + vgm = model_config.visual_gen_mapping num_heads = getattr(model_config.pretrained_config, "num_attention_heads", 48) - self.use_ulysses, self.ulysses_size, self.ulysses_pg, self.ulysses_rank = ( - setup_sequence_parallelism( - model_config=model_config, - num_attention_heads=num_heads, + ulysses_size = vgm.ulysses_size if vgm else 1 + if ulysses_size > 1 and num_heads % ulysses_size != 0: + raise ValueError( + f"num_attention_heads ({num_heads}) must be divisible by " + f"ulysses_size ({ulysses_size})" ) - ) + self.use_ulysses = ulysses_size > 1 + self.ulysses_size = ulysses_size + self.ulysses_pg = vgm.ulysses_group if vgm else None + self.ulysses_rank = vgm.ulysses_rank if vgm else 0 # Extract pretrained config from model_config (following WAN/FLUX.1 pattern) pretrained_config = model_config.pretrained_config diff --git a/tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py b/tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py index 23b038cdd33d..6c5e63683afd 100644 --- a/tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py +++ b/tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py @@ -1181,10 +1181,17 @@ def forward( # CFG parallel for multi-modal guidance: each GPU handles one # CFG pass (cond or uncond), results are all-gathered, then # STG/modality passes run on every GPU before the guidance formula. - cfg_size = self.model_config.parallel.dit_cfg_size - ulysses_size = self.model_config.parallel.dit_ulysses_size + vgm = self.model_config.visual_gen_mapping + cfg_size = vgm.cfg_size if vgm else 1 + ulysses_size = vgm.ulysses_size if vgm else 1 do_cfg_parallel_mm = use_multi_modal_guidance and cfg_size >= 2 and do_cfg - cfg_group = self.rank // ulysses_size + if do_cfg_parallel_mm and cfg_size != 2: + raise ValueError( + f"Multi-modal CFG parallel only supports cfg_size=2 " + f"(cond/uncond), got cfg_size={cfg_size}" + ) + cfg_rank = vgm.cfg_rank if vgm else 0 + cfg_pg = vgm.cfg_group if vgm else None if do_cfg_parallel_mm and self.rank == 0: logger.info( f"CFG parallel (multi-modal guidance): cfg_size={cfg_size}, " @@ -1462,8 +1469,9 @@ def forward_fn( # --- CFG: conditional + unconditional passes -------------------- if do_cfg_parallel_mm: - # CFG parallel: split cond/uncond across GPUs, all-gather - if cfg_group == 0: + # CFG parallel: each CFG rank runs one pass (cond or uncond), + # then all-gather across the CFG group (size 2). + if cfg_rank == 0: local_v, local_a = _run_transformer( video_latents, audio_latents_in, @@ -1483,17 +1491,17 @@ def forward_fn( ) local_v = local_v.contiguous() - gather_v = [torch.empty_like(local_v) for _ in range(self.world_size)] - dist.all_gather(gather_v, local_v) + gather_v = [torch.empty_like(local_v) for _ in range(cfg_size)] + dist.all_gather(gather_v, local_v, group=cfg_pg) cond_v = gather_v[0] - uncond_v = gather_v[ulysses_size] + uncond_v = gather_v[1] if local_a is not None: local_a = local_a.contiguous() - gather_a = [torch.empty_like(local_a) for _ in range(self.world_size)] - dist.all_gather(gather_a, local_a) + gather_a = [torch.empty_like(local_a) for _ in range(cfg_size)] + dist.all_gather(gather_a, local_a, group=cfg_pg) cond_a = gather_a[0] - uncond_a = gather_a[ulysses_size] + uncond_a = gather_a[1] else: cond_a = None uncond_a = 0.0 diff --git a/tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py b/tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py index e42257d90516..bbe61271ac6c 100644 --- a/tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py +++ b/tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py @@ -32,7 +32,6 @@ from tensorrt_llm._torch.modules.mlp import MLP from tensorrt_llm._torch.visual_gen.attention_backend.utils import create_attention from tensorrt_llm._torch.visual_gen.modules.attention import Attention, QKVMode -from tensorrt_llm._torch.visual_gen.parallelism import setup_sequence_parallelism from tensorrt_llm._torch.visual_gen.quantization.loader import DynamicLinearWeightLoader from tensorrt_llm.logger import logger from tensorrt_llm.models.modeling_utils import QuantConfig @@ -91,6 +90,7 @@ def __init__( from tensorrt_llm._torch.visual_gen.config import DiffusionModelConfig config = config or DiffusionModelConfig() + vgm = config.visual_gen_mapping # Store before super().__init__() — _init_qkv_proj needs _context_dim self._context_dim = context_dim if context_dim is not None else query_dim @@ -120,7 +120,7 @@ def __init__( # plain backend as fallback. The base class already set self.attn # to UlyssesAttention(inner_backend=sharded_backend). self._has_dual_attn = False - ulysses_size = config.parallel.dit_ulysses_size + ulysses_size = vgm.ulysses_size if vgm is not None else 1 if use_ulysses and not self._is_cross_attn and ulysses_size > 1: self._ulysses_attn = self.attn self._plain_attn = create_attention( @@ -294,10 +294,11 @@ def __init__( self._use_ulysses = False self._audio_is_sharded = False - if config is not None and config.parallel.dit_ulysses_size > 1: + vgm = config.visual_gen_mapping if config is not None else None + if vgm is not None and vgm.ulysses_size > 1: self._use_ulysses = True - self._ulysses_size = config.parallel.dit_ulysses_size - self._ulysses_pg = getattr(config, "ulysses_process_group", None) + self._ulysses_size = vgm.ulysses_size + self._ulysses_pg = vgm.ulysses_group if video is not None: self._init_video_modules(video, rope_type, norm_eps, config, idx) @@ -766,17 +767,20 @@ def __init__( self._init_preprocessors(cross_pe_max_pos) - # Ulysses sequence parallelism — must run before block/attention init - # so that model_config.ulysses_process_group is available. + vgm = model_config.visual_gen_mapping primary_heads = ( num_attention_heads if model_type.is_video_enabled() else audio_num_attention_heads ) - (self.use_ulysses, self.ulysses_size, self.ulysses_pg, self.ulysses_rank) = ( - setup_sequence_parallelism( - model_config=model_config, - num_attention_heads=primary_heads, + ulysses_size = vgm.ulysses_size if vgm else 1 + if ulysses_size > 1 and primary_heads % ulysses_size != 0: + raise ValueError( + f"num_attention_heads ({primary_heads}) must be divisible by " + f"ulysses_size ({ulysses_size})" ) - ) + self.use_ulysses = ulysses_size > 1 + self.ulysses_size = ulysses_size + self.ulysses_pg = vgm.ulysses_group if vgm else None + self.ulysses_rank = vgm.ulysses_rank if vgm else 0 # Audio is sharded by Ulysses only when its sequence length is # divisible by ulysses_size (checked at runtime in forward). # Head divisibility is validated here since the attention backend diff --git a/tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py b/tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py index 49e56c4d23d5..791d2bf2a3fe 100644 --- a/tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py +++ b/tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py @@ -14,7 +14,6 @@ from tensorrt_llm._torch.modules.rms_norm import RMSNorm from tensorrt_llm._torch.visual_gen.config import DiffusionModelConfig from tensorrt_llm._torch.visual_gen.modules.attention import Attention, QKVMode -from tensorrt_llm._torch.visual_gen.parallelism import setup_sequence_parallelism from tensorrt_llm._torch.visual_gen.quantization.loader import DynamicLinearWeightLoader from tensorrt_llm.logger import logger from tensorrt_llm.models.modeling_utils import QuantConfig @@ -423,21 +422,21 @@ def __init__( self.model_config = model_config - # Validate no tensor parallelism - if model_config.parallel.dit_tp_size > 1: - raise ValueError( - f"WAN does not support tensor parallelism. " - f"Got dit_tp_size={model_config.parallel.dit_tp_size}" - ) + vgm = model_config.visual_gen_mapping + if vgm is not None and vgm.tp_size > 1: + raise ValueError(f"WAN does not support tensor parallelism. Got tp_size={vgm.tp_size}") - # Setup sequence parallelism (Ulysses) num_heads = getattr(model_config.pretrained_config, "num_attention_heads", 12) - self.use_ulysses, self.ulysses_size, self.ulysses_pg, self.ulysses_rank = ( - setup_sequence_parallelism( - model_config=model_config, - num_attention_heads=num_heads, + ulysses_size = vgm.ulysses_size if vgm else 1 + if ulysses_size > 1 and num_heads % ulysses_size != 0: + raise ValueError( + f"num_attention_heads ({num_heads}) must be divisible by " + f"ulysses_size ({ulysses_size})" ) - ) + self.use_ulysses = ulysses_size > 1 + self.ulysses_size = ulysses_size + self.ulysses_pg = vgm.ulysses_group if vgm else None + self.ulysses_rank = vgm.ulysses_rank if vgm else 0 config = model_config.pretrained_config diff --git a/tensorrt_llm/_torch/visual_gen/modules/attention.py b/tensorrt_llm/_torch/visual_gen/modules/attention.py index d779ddfd0caf..d4fb426c75cd 100644 --- a/tensorrt_llm/_torch/visual_gen/modules/attention.py +++ b/tensorrt_llm/_torch/visual_gen/modules/attention.py @@ -77,7 +77,8 @@ def __init__( self.interleave = interleave # Select compute backend (orthogonal to parallelism) - ulysses_size = config.parallel.dit_ulysses_size + vgm = config.visual_gen_mapping + ulysses_size = vgm.ulysses_size if vgm else 1 base_backend = config.attention.backend # TRTLLM doesn't support cross-attention (different Q/KV seq lengths); fall back to VANILLA @@ -144,14 +145,13 @@ def __init__( dtype=self.dtype, ) - # Wrap with parallelism strategy (orthogonal to backend choice) + # Wrap with parallelism strategies (orthogonal to backend choice) if ulysses_size > 1 and self.qkv_mode != QKVMode.SEPARATE_QKV: from ..attention_backend.parallel import UlyssesAttention - process_group = getattr(config, "ulysses_process_group", None) self.attn = UlyssesAttention( inner_backend=self.attn, - process_group=process_group, + process_group=vgm.ulysses_group, ) def _init_qkv_proj(self) -> None: diff --git a/tensorrt_llm/_torch/visual_gen/parallelism.py b/tensorrt_llm/_torch/visual_gen/parallelism.py deleted file mode 100644 index 1bda600fa015..000000000000 --- a/tensorrt_llm/_torch/visual_gen/parallelism.py +++ /dev/null @@ -1,100 +0,0 @@ -"""Utilities for distributed parallelism setup in diffusion models.""" - -from typing import Optional, Tuple - -import torch.distributed as dist - -from tensorrt_llm._torch.visual_gen.config import DiffusionModelConfig - - -def setup_sequence_parallelism( - model_config: DiffusionModelConfig, - num_attention_heads: int, -) -> Tuple[bool, int, Optional[dist.ProcessGroup], int]: - """ - Setup sequence parallelism (currently Ulysses only) with CFG support. - - Creates nested process groups where each CFG group has its own Ulysses group. - Example with cfg_size=2, ulysses_size=2, world_size=4: - GPU 0-1: CFG group 0, Ulysses group 0 - GPU 2-3: CFG group 1, Ulysses group 1 - - Args: - model_config: Model configuration containing parallel settings - num_attention_heads: Number of attention heads in the model - - Returns: - Tuple of (use_parallelism, parallelism_size, parallelism_pg, parallelism_rank): - - use_parallelism: Whether sequence parallelism is enabled - - parallelism_size: The sequence parallelism degree - - parallelism_pg: The process group for this rank (or None) - - parallelism_rank: This rank's position within its parallelism group - - Raises: - RuntimeError: If torch.distributed is not initialized - ValueError: If configuration is invalid (incompatible sizes, head count not divisible, etc.) - NotImplementedError: If Ring attention is requested (not yet implemented) - - Side Effects: - - Sets model_config.ulysses_process_group to the created process group - - Note: - Both num_attention_heads and sequence length must be divisible by ulysses_size. - Head count is validated here; sequence length is validated at runtime during forward pass. - """ - ulysses_size = model_config.parallel.dit_ulysses_size - ring_size = model_config.parallel.dit_ring_size - cfg_size = model_config.parallel.dit_cfg_size - - # Check for ring attention (not yet implemented) - if ring_size > 1: - raise NotImplementedError("Ring attention parallelism is not yet implemented") - - # Early exit if not using sequence parallelism - if ulysses_size <= 1: - model_config.ulysses_process_group = None - return False, 1, None, 0 - - # Validate distributed initialization - if not dist.is_initialized(): - raise RuntimeError( - "torch.distributed.init_process_group() must be called before " - "setting up sequence parallelism" - ) - - rank = dist.get_rank() - world_size = dist.get_world_size() - - # Validate total parallelism capacity - total_parallel = cfg_size * ulysses_size - if total_parallel > world_size: - raise ValueError( - f"cfg_size ({cfg_size}) * ulysses_size ({ulysses_size}) = " - f"{total_parallel} exceeds world_size ({world_size})" - ) - - # Validate head count divisibility - if num_attention_heads % ulysses_size != 0: - raise ValueError( - f"num_attention_heads ({num_attention_heads}) must be divisible by " - f"ulysses_size ({ulysses_size})" - ) - - # Create nested process groups - # Each CFG group has its own Ulysses group - ulysses_pg = None - ulysses_rank = 0 - - for cfg_id in range(cfg_size): - ulysses_ranks = list(range(cfg_id * ulysses_size, (cfg_id + 1) * ulysses_size)) - pg = dist.new_group(ulysses_ranks, use_local_synchronization=True) - - # Store if this rank belongs to this group - if rank in ulysses_ranks: - ulysses_pg = pg - ulysses_rank = rank - cfg_id * ulysses_size - - # Store in config for Attention modules - model_config.ulysses_process_group = ulysses_pg - - return True, ulysses_size, ulysses_pg, ulysses_rank diff --git a/tensorrt_llm/_torch/visual_gen/pipeline.py b/tensorrt_llm/_torch/visual_gen/pipeline.py index d2c22f8222c6..b564378a6fda 100644 --- a/tensorrt_llm/_torch/visual_gen/pipeline.py +++ b/tensorrt_llm/_torch/visual_gen/pipeline.py @@ -331,7 +331,7 @@ def _setup_teacache(self, model, coefficients: Optional[Dict] = None): self.cache_backend.enable(model) def setup_parallel_vae(self): - if not self.model_config.parallel.enable_parallel_vae: + if not self.model_config.enable_parallel_vae: return if not dist.is_initialized() or dist.get_world_size() <= 1: return @@ -343,7 +343,7 @@ def setup_parallel_vae(self): try: self.vae = ParallelVAEFactory.from_vae( self.vae, - split_dim=self.model_config.parallel.parallel_vae_split_dim, + split_dim=self.model_config.parallel_vae_split_dim, pg=pg, ) except ValueError: @@ -358,7 +358,7 @@ def setup_parallel_vae(self): self._parallel_vae_enabled = True logger.info( f"Parallel VAE enabled: {type(self.vae).__name__}, " - f"split_dim={self.model_config.parallel.parallel_vae_split_dim}, " + f"split_dim={self.model_config.parallel_vae_split_dim}, " f"world_size={dist.get_world_size(pg)}" ) @@ -544,11 +544,11 @@ def _setup_cfg_config( Returns: Dict with CFG configuration including split tensors """ - # Access parallel config directly (always present now) - cfg_size = self.model_config.parallel.dit_cfg_size - ulysses_size = self.model_config.parallel.dit_ulysses_size + vgm = self.model_config.visual_gen_mapping + cfg_size = vgm.cfg_size if vgm else 1 + ulysses_size = vgm.ulysses_size if vgm else 1 - cfg_group = self.rank // ulysses_size + is_conditional = vgm.is_cfg_conditional if vgm else True is_split_embeds = neg_prompt_embeds is not None do_cfg_parallel = cfg_size >= 2 and guidance_scale > 1.0 @@ -564,15 +564,14 @@ def _setup_cfg_config( else: neg_embeds, pos_embeds = prompt_embeds.chunk(2) - local_embeds = pos_embeds if cfg_group == 0 else neg_embeds + local_embeds = pos_embeds if is_conditional else neg_embeds # Split extra tensors if provided if extra_cfg_tensors: for name, (pos_tensor, neg_tensor) in extra_cfg_tensors.items(): if pos_tensor is not None and neg_tensor is not None: - local_extras[name] = pos_tensor if cfg_group == 0 else neg_tensor + local_extras[name] = pos_tensor if is_conditional else neg_tensor elif pos_tensor is not None: - # Only positive provided, use it for both local_extras[name] = pos_tensor else: local_embeds = None @@ -591,7 +590,7 @@ def _setup_cfg_config( "enabled": do_cfg_parallel, "cfg_size": cfg_size, "ulysses_size": ulysses_size, - "cfg_group": cfg_group, + "cfg_rank": vgm.cfg_rank if vgm else 0, "local_embeds": local_embeds, "prompt_embeds": prompt_embeds, "local_extras": local_extras, @@ -610,6 +609,10 @@ def _denoise_step_cfg_parallel( local_extras, ): """Execute single denoising step with CFG parallel.""" + vgm = self.model_config.visual_gen_mapping + cfg_pg = vgm.cfg_group if vgm else None + cfg_size = vgm.cfg_size if vgm else 1 + t_start = time.time() result = forward_fn(latents, extra_stream_latents, timestep, local_embeds, local_extras) @@ -624,22 +627,24 @@ def _denoise_step_cfg_parallel( c_start = time.time() - # All-gather primary noise (must be contiguous for NCCL) + # All-gather primary noise over the CFG group. + # Each entry in gather_list corresponds to one CFG rank + # (index 0 = conditional, index 1 = unconditional). noise_pred_local = noise_pred_local.contiguous() - gather_list = [torch.empty_like(noise_pred_local) for _ in range(self.world_size)] - dist.all_gather(gather_list, noise_pred_local) + gather_list = [torch.empty_like(noise_pred_local) for _ in range(cfg_size)] + dist.all_gather(gather_list, noise_pred_local, group=cfg_pg) noise_cond = gather_list[0] - noise_uncond = gather_list[ulysses_size] + noise_uncond = gather_list[1] noise_pred = noise_uncond + guidance_scale * (noise_cond - noise_uncond) # All-gather extra stream noises extra_noise_preds = {} for name, noise_local in extra_noise_locals.items(): noise_local = noise_local.contiguous() - gather_list_extra = [torch.empty_like(noise_local) for _ in range(self.world_size)] - dist.all_gather(gather_list_extra, noise_local) + gather_list_extra = [torch.empty_like(noise_local) for _ in range(cfg_size)] + dist.all_gather(gather_list_extra, noise_local, group=cfg_pg) noise_cond_extra = gather_list_extra[0] - noise_uncond_extra = gather_list_extra[ulysses_size] + noise_uncond_extra = gather_list_extra[1] extra_noise_preds[name] = noise_uncond_extra + guidance_scale * ( noise_cond_extra - noise_uncond_extra ) diff --git a/tensorrt_llm/_torch/visual_gen/pipeline_loader.py b/tensorrt_llm/_torch/visual_gen/pipeline_loader.py index dfebfe53f15e..fb06efc3fdf5 100644 --- a/tensorrt_llm/_torch/visual_gen/pipeline_loader.py +++ b/tensorrt_llm/_torch/visual_gen/pipeline_loader.py @@ -19,14 +19,15 @@ from typing import TYPE_CHECKING, Optional import torch +import torch.distributed as dist from tensorrt_llm._torch.autotuner import autotune from tensorrt_llm._torch.models.modeling_utils import MetaInitMode from tensorrt_llm.llmapi.utils import download_hf_model from tensorrt_llm.logger import logger -from tensorrt_llm.mapping import Mapping from .config import DiffusionModelConfig, VisualGenArgs +from .mapping import VisualGenMapping from .models import AutoPipeline if TYPE_CHECKING: @@ -54,7 +55,6 @@ def __init__( self, args: Optional[VisualGenArgs] = None, *, - mapping: Optional[Mapping] = None, device: str = "cuda", ): """ @@ -62,16 +62,10 @@ def __init__( Args: args: VisualGenArgs containing all configuration (preferred) - mapping: Tensor parallel mapping (fallback if args is None) device: Device to load model on (fallback if args is None) """ self.args = args - if args is not None: - self.mapping = args.to_mapping() - self.device = torch.device(args.device) - else: - self.mapping = mapping or Mapping() - self.device = torch.device(device) + self.device = torch.device(args.device if args is not None else device) def _resolve_checkpoint_dir(self, checkpoint_dir: str) -> str: """Resolve checkpoint_dir to a local directory path. @@ -108,6 +102,25 @@ def _resolve_checkpoint_dir(self, checkpoint_dir: str) -> str: ) from e return str(local_dir) + def _setup_visual_gen_mapping(self, config: DiffusionModelConfig) -> None: + if self.args is not None: + ws = dist.get_world_size() if dist.is_initialized() else 1 + rk = dist.get_rank() if dist.is_initialized() else 0 + vgm = VisualGenMapping( + ws, + rk, + cfg_size=self.args.parallel.dit_cfg_size, + tp_size=self.args.parallel.dit_tp_size, + ulysses_size=self.args.parallel.dit_ulysses_size, + ring_size=self.args.parallel.dit_ring_size, + order=self.args.parallel.dit_dim_order, + ) + else: + # Single-GPU fallback. no args = no parallelism. + vgm = VisualGenMapping(world_size=1, rank=0) + config.visual_gen_mapping = vgm + config.mapping = vgm.to_llm_mapping() + def load( self, checkpoint_dir: Optional[str] = None, @@ -151,7 +164,6 @@ def load( config = DiffusionModelConfig.from_pretrained( checkpoint_dir, args=self.args, - mapping=self.mapping, ) # Log quantization settings @@ -159,6 +171,11 @@ def load( logger.info(f"Quantization: {config.quant_config.quant_algo.name}") logger.info(f"Dynamic weight quant: {config.dynamic_weight_quant}") + # ===================================================================== + # STEP 1b: Build VisualGenMapping (must precede model creation) + # ===================================================================== + self._setup_visual_gen_mapping(config) + # ===================================================================== # STEP 2: Create Pipeline with MetaInit # Pipeline type is auto-detected from model_index.json @@ -206,7 +223,7 @@ def load( # ===================================================================== t0 = time.time() - if config.parallel.enable_parallel_vae: + if config.enable_parallel_vae: pipeline.setup_parallel_vae() if hasattr(pipeline, "post_load_weights"): diff --git a/tests/unittest/_torch/visual_gen/multi_gpu/test_flux_ulysses.py b/tests/unittest/_torch/visual_gen/multi_gpu/test_flux_ulysses.py index abf41e45d5a1..03af6005f841 100644 --- a/tests/unittest/_torch/visual_gen/multi_gpu/test_flux_ulysses.py +++ b/tests/unittest/_torch/visual_gen/multi_gpu/test_flux_ulysses.py @@ -26,10 +26,10 @@ from tensorrt_llm._torch.visual_gen.config import ( AttentionConfig, DiffusionModelConfig, - ParallelConfig, TeaCacheConfig, TorchCompileConfig, ) + from tensorrt_llm._torch.visual_gen.mapping import VisualGenMapping from tensorrt_llm._utils import get_free_port from tensorrt_llm.models.modeling_utils import QuantConfig @@ -138,17 +138,25 @@ def run_test_in_distributed(world_size: int, test_fn: Callable, use_cuda: bool = def _make_model_config(pretrained_dict, ulysses_size=1, backend="VANILLA"): """Create DiffusionModelConfig for testing.""" pretrained_config = SimpleNamespace(**pretrained_dict) - parallel = ParallelConfig(dit_ulysses_size=ulysses_size) - - return DiffusionModelConfig( + if ulysses_size > 1 and dist.is_initialized(): + ws = dist.get_world_size() + rk = dist.get_rank() + else: + ws = ulysses_size + rk = 0 + vgm = VisualGenMapping(world_size=ws, rank=rk, ulysses_size=ulysses_size) + + config = DiffusionModelConfig( pretrained_config=pretrained_config, quant_config=QuantConfig(), torch_compile=TorchCompileConfig(enable_torch_compile=False), attention=AttentionConfig(backend=backend), - parallel=parallel, + visual_gen_mapping=vgm, teacache=TeaCacheConfig(), skip_create_weights_in_init=False, ) + config.mapping = vgm.to_llm_mapping() + return config def _stabilize_model_weights(model): diff --git a/tests/unittest/_torch/visual_gen/multi_gpu/test_visual_gen_mapping.py b/tests/unittest/_torch/visual_gen/multi_gpu/test_visual_gen_mapping.py new file mode 100644 index 000000000000..576566bb35dd --- /dev/null +++ b/tests/unittest/_torch/visual_gen/multi_gpu/test_visual_gen_mapping.py @@ -0,0 +1,280 @@ +"""Tests for VisualGenMapping — the unified multi-dimensional communicator mesh. + +Single-GPU tests run without dist. Multi-GPU tests use mp.spawn with NCCL. +""" + +import itertools +import os + +os.environ["TLLM_DISABLE_MPI"] = "1" + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +try: + from tensorrt_llm._torch.visual_gen.mapping import _VALID_DIM_NAMES, VisualGenMapping + from tensorrt_llm._utils import get_free_port + + MODULES_AVAILABLE = True +except ImportError: + MODULES_AVAILABLE = False + + +@pytest.fixture(autouse=True, scope="module") +def _cleanup_mpi_env(): + yield + os.environ.pop("TLLM_DISABLE_MPI", None) + + +# ============================================================================= +# Helpers for multi-GPU tests +# ============================================================================= + + +def _init_dist(rank, world_size, port): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(port) + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + torch.cuda.set_device(rank % torch.cuda.device_count()) + dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) + + +def _worker(rank, world_size, test_fn, port): + try: + _init_dist(rank, world_size, port) + test_fn(rank, world_size) + finally: + if dist.is_initialized(): + dist.destroy_process_group() + + +def _run_multi_gpu(world_size, test_fn): + if not MODULES_AVAILABLE: + pytest.skip("Required modules not available") + if not torch.cuda.is_available() or torch.cuda.device_count() < world_size: + pytest.skip(f"Requires {world_size} GPUs, have {torch.cuda.device_count()}") + port = get_free_port() + mp.spawn(_worker, args=(world_size, test_fn, port), nprocs=world_size, join=True) + + +# ============================================================================= +# Single-GPU tests (no dist required) +# ============================================================================= + + +class TestConstruction: + def test_single_gpu_defaults(self): + vgm = VisualGenMapping(world_size=1, rank=0) + assert vgm.world_size == 1 + assert vgm.cfg_size == 1 + assert vgm.tp_size == 1 + assert vgm.ring_size == 1 + assert vgm.ulysses_size == 1 + + def test_stores_sizes(self): + vgm = VisualGenMapping( + world_size=8, + rank=0, + cfg_size=2, + tp_size=2, + ulysses_size=2, + ) + assert vgm.cfg_size == 2 + assert vgm.tp_size == 2 + assert vgm.ring_size == 1 + assert vgm.ulysses_size == 2 + assert vgm.world_size == 8 + + def test_product_mismatch_raises(self): + with pytest.raises(ValueError, match="!= world_size"): + VisualGenMapping(world_size=4, rank=0, cfg_size=2, ulysses_size=3) + + def test_invalid_order_raises(self): + with pytest.raises(ValueError, match="permutation"): + VisualGenMapping(world_size=1, rank=0, order="cfg-tp-ulysses") + + def test_duplicate_dim_raises(self): + with pytest.raises(ValueError, match="permutation"): + VisualGenMapping(world_size=1, rank=0, order="cfg-cfg-tp-ulysses") + + def test_custom_order_stored(self): + vgm = VisualGenMapping(world_size=1, rank=0, order="ulysses-ring-tp-cfg") + assert vgm._dim_names == ("ulysses", "ring", "tp", "cfg") + + def test_all_valid_orders(self): + for perm in itertools.permutations(sorted(_VALID_DIM_NAMES)): + order = "-".join(perm) + vgm = VisualGenMapping(world_size=1, rank=0, order=order) + assert vgm._dim_names == perm + + +class TestSingleGPURanksAndGroups: + def test_ranks_are_zero(self): + vgm = VisualGenMapping(world_size=1, rank=0) + assert vgm.cfg_rank == 0 + assert vgm.tp_rank == 0 + assert vgm.ring_rank == 0 + assert vgm.ulysses_rank == 0 + + def test_is_cfg_conditional(self): + vgm = VisualGenMapping(world_size=1, rank=0) + assert vgm.is_cfg_conditional is True + + def test_groups_return_single_process_group(self): + vgm = VisualGenMapping(world_size=1, rank=0) + assert vgm.ulysses_group is not None + assert vgm.ring_group is not None + assert vgm.tp_group_pg is not None + assert vgm.cfg_group is not None + + +class TestToLlmMapping: + def test_single_gpu(self): + vgm = VisualGenMapping(world_size=1, rank=0) + m = vgm.to_llm_mapping() + assert m.tp_size == 1 + assert m.world_size == 1 + + def test_tp_size_propagated(self): + vgm = VisualGenMapping(world_size=4, rank=0, tp_size=4) + m = vgm.to_llm_mapping() + assert m.tp_size == 4 + + def test_mixed_parallelism(self): + vgm = VisualGenMapping( + world_size=8, + rank=0, + cfg_size=2, + tp_size=2, + ulysses_size=2, + ) + m = vgm.to_llm_mapping() + assert m.tp_size == 2 + assert m.world_size == 2 + + +# ============================================================================= +# Multi-GPU tests — validate actual DeviceMesh groups and ranks +# ============================================================================= + + +def _logic_default_order_cfg2_ulysses2(rank, world_size): + """Default order cfg-tp-ring-ulysses with cfg=2, ulysses=2 on 4 GPUs. + + Expected rank layout (outermost=cfg, innermost=ulysses): + Rank 0: cfg=0, ulysses=0 (conditional, ulysses group 0) + Rank 1: cfg=0, ulysses=1 (conditional, ulysses group 1) + Rank 2: cfg=1, ulysses=0 (unconditional, ulysses group 0) + Rank 3: cfg=1, ulysses=1 (unconditional, ulysses group 1) + """ + from tensorrt_llm._torch.device_mesh import DeviceMeshTopologyImpl + + DeviceMeshTopologyImpl.device_mesh = None + + vgm = VisualGenMapping( + world_size=world_size, + rank=rank, + cfg_size=2, + ulysses_size=2, + ) + + assert vgm.cfg_rank == rank // 2 + assert vgm.ulysses_rank == rank % 2 + assert vgm.tp_rank == 0 + assert vgm.ring_rank == 0 + assert vgm.is_cfg_conditional == (rank < 2) + + assert vgm.cfg_group is not None + assert vgm.ulysses_group is not None + + cfg_pg_size = dist.get_world_size(vgm.cfg_group) + ulysses_pg_size = dist.get_world_size(vgm.ulysses_group) + assert cfg_pg_size == 2 + assert ulysses_pg_size == 2 + + m = vgm.to_llm_mapping() + assert m.tp_size == 1 + + +def _logic_custom_order_ulysses_outermost(rank, world_size): + """Custom order ulysses-ring-tp-cfg with cfg=2, ulysses=2 on 4 GPUs. + + Expected rank layout (outermost=ulysses, innermost=cfg): + Rank 0: ulysses=0, cfg=0 + Rank 1: ulysses=0, cfg=1 + Rank 2: ulysses=1, cfg=0 + Rank 3: ulysses=1, cfg=1 + """ + from tensorrt_llm._torch.device_mesh import DeviceMeshTopologyImpl + + DeviceMeshTopologyImpl.device_mesh = None + + vgm = VisualGenMapping( + world_size=world_size, + rank=rank, + cfg_size=2, + ulysses_size=2, + order="ulysses-ring-tp-cfg", + ) + + assert vgm.ulysses_rank == rank // 2 + assert vgm.cfg_rank == rank % 2 + assert vgm.is_cfg_conditional == (rank % 2 == 0) + + cfg_pg_size = dist.get_world_size(vgm.cfg_group) + ulysses_pg_size = dist.get_world_size(vgm.ulysses_group) + assert cfg_pg_size == 2 + assert ulysses_pg_size == 2 + + +def _logic_allreduce_over_tp_group(rank, world_size): + """Verify TP group works for collective ops (tp=2, ulysses=2 on 4 GPUs). + + Default order cfg-tp-ring-ulysses with cfg=1, tp=2, ring=1, ulysses=2: + Rank 0: tp=0, ulysses=0 + Rank 1: tp=0, ulysses=1 + Rank 2: tp=1, ulysses=0 + Rank 3: tp=1, ulysses=1 + TP groups: {0, 2} and {1, 3}. + """ + from tensorrt_llm._torch.device_mesh import DeviceMeshTopologyImpl + + DeviceMeshTopologyImpl.device_mesh = None + + vgm = VisualGenMapping( + world_size=world_size, + rank=rank, + tp_size=2, + ulysses_size=2, + ) + + device = torch.device(f"cuda:{rank}") + + # Each rank contributes 1.0; after all_reduce(sum) over tp_size=2, expect 2.0 + tensor = torch.ones(1, device=device) + dist.all_reduce(tensor, group=vgm.tp_group_pg) + assert tensor.item() == float(vgm.tp_size), ( + f"Rank {rank}: expected {vgm.tp_size}, got {tensor.item()}" + ) + + # Also verify ulysses group with the same pattern + tensor2 = torch.ones(1, device=device) + dist.all_reduce(tensor2, group=vgm.ulysses_group) + assert tensor2.item() == float(vgm.ulysses_size), ( + f"Rank {rank}: expected {vgm.ulysses_size}, got {tensor2.item()}" + ) + + +@pytest.mark.skipif(not MODULES_AVAILABLE, reason="Modules not available") +class TestMultiGPU: + def test_default_order_cfg2_ulysses2(self): + _run_multi_gpu(4, _logic_default_order_cfg2_ulysses2) + + def test_custom_order_ulysses_outermost(self): + _run_multi_gpu(4, _logic_custom_order_ulysses_outermost) + + def test_allreduce_over_tp_group(self): + _run_multi_gpu(4, _logic_allreduce_over_tp_group) diff --git a/tests/unittest/_torch/visual_gen/test_flux_pipeline.py b/tests/unittest/_torch/visual_gen/test_flux_pipeline.py index 2afdf934b75e..76dd425d9c06 100644 --- a/tests/unittest/_torch/visual_gen/test_flux_pipeline.py +++ b/tests/unittest/_torch/visual_gen/test_flux_pipeline.py @@ -1097,7 +1097,7 @@ def _run_all_optimizations_worker(rank, world_size, checkpoint_path, inputs_cpu, transformer = pipeline.transformer.eval() # Verify all optimizations are enabled - assert pipeline.model_config.parallel.dit_ulysses_size == world_size, ( + assert pipeline.model_config.visual_gen_mapping.ulysses_size == world_size, ( "Ulysses parallel not enabled" ) assert transformer.model_config.quant_config.quant_algo == QuantAlgo.FP8, "FP8 not enabled" diff --git a/tests/unittest/_torch/visual_gen/test_model_loader.py b/tests/unittest/_torch/visual_gen/test_model_loader.py index dc20e6b4c1d8..5185df59dc85 100644 --- a/tests/unittest/_torch/visual_gen/test_model_loader.py +++ b/tests/unittest/_torch/visual_gen/test_model_loader.py @@ -227,30 +227,6 @@ def test_diffusion_args_to_quant_config(): assert args.dynamic_weight_quant is True -def test_diffusion_args_to_mapping(): - """Test that VisualGenArgs correctly generates Mapping from ParallelConfig.""" - from tensorrt_llm._torch.visual_gen import ParallelConfig, VisualGenArgs - - # ParallelConfig validator requires WORLD_SIZE >= total parallel (tp*cp = 4) - old_world = os.environ.get("WORLD_SIZE") - try: - os.environ["WORLD_SIZE"] = "4" - args = VisualGenArgs( - checkpoint_path="/fake/path", - parallel=ParallelConfig(dit_tp_size=2, dit_cp_size=2), - ) - mapping = args.to_mapping() - assert mapping.tp_size == 2 - assert mapping.cp_size == 2 - # world_size = tp_size * pp_size * cp_size (DP is handled separately) - assert mapping.world_size == 4 - finally: - if old_world is not None: - os.environ["WORLD_SIZE"] = old_world - elif "WORLD_SIZE" in os.environ: - del os.environ["WORLD_SIZE"] - - def test_load_without_quant_config_no_fp8(checkpoint_exists): """Test that loading without quant_config does NOT produce FP8 weights.""" if not checkpoint_exists: diff --git a/tests/unittest/_torch/visual_gen/test_wan.py b/tests/unittest/_torch/visual_gen/test_wan.py index 05e5e239e89a..7edffdfa4f66 100644 --- a/tests/unittest/_torch/visual_gen/test_wan.py +++ b/tests/unittest/_torch/visual_gen/test_wan.py @@ -163,8 +163,8 @@ def _run_cfg_worker(rank, world_size, checkpoint_path, inputs_list, return_dict) pipeline = PipelineLoader(args).load(skip_warmup=True) # Verify CFG parallel configuration - assert pipeline.model_config.parallel.dit_cfg_size == world_size, ( - f"Expected cfg_size={world_size}, got {pipeline.model_config.parallel.dit_cfg_size}" + assert pipeline.model_config.visual_gen_mapping.cfg_size == world_size, ( + f"Expected cfg_size={world_size}, got {pipeline.model_config.visual_gen_mapping.cfg_size}" ) # Load inputs on this GPU @@ -271,7 +271,9 @@ def _run_all_optimizations_worker(rank, world_size, checkpoint_path, inputs_list transformer = pipeline.transformer.eval() # Verify all optimizations are enabled - assert pipeline.model_config.parallel.dit_cfg_size == world_size, "CFG parallel not enabled" + assert pipeline.model_config.visual_gen_mapping.cfg_size == world_size, ( + "CFG parallel not enabled" + ) assert transformer.model_config.quant_config.quant_algo == QuantAlgo.FP8, "FP8 not enabled" assert hasattr(pipeline, "cache_backend"), "TeaCache not enabled" assert transformer.blocks[0].attn1.attn_backend == "TRTLLM", ( diff --git a/tests/unittest/_torch/visual_gen/test_wan_i2v.py b/tests/unittest/_torch/visual_gen/test_wan_i2v.py index 43dd2f53d7c5..e8a9f583f816 100644 --- a/tests/unittest/_torch/visual_gen/test_wan_i2v.py +++ b/tests/unittest/_torch/visual_gen/test_wan_i2v.py @@ -284,8 +284,8 @@ def _run_cfg_worker_i2v(rank, world_size, checkpoint_path, inputs_list, return_d pipeline = PipelineLoader(args).load(skip_warmup=True) # Verify CFG parallel configuration - assert pipeline.model_config.parallel.dit_cfg_size == world_size, ( - f"Expected cfg_size={world_size}, got {pipeline.model_config.parallel.dit_cfg_size}" + assert pipeline.model_config.visual_gen_mapping.cfg_size == world_size, ( + f"Expected cfg_size={world_size}, got {pipeline.model_config.visual_gen_mapping.cfg_size}" ) # Load inputs on this GPU @@ -395,7 +395,9 @@ def _run_all_optimizations_worker_i2v(rank, world_size, checkpoint_path, inputs_ transformer = pipeline.transformer.eval() # Verify all optimizations are enabled - assert pipeline.model_config.parallel.dit_cfg_size == world_size, "CFG parallel not enabled" + assert pipeline.model_config.visual_gen_mapping.cfg_size == world_size, ( + "CFG parallel not enabled" + ) assert transformer.model_config.quant_config.quant_algo == QuantAlgo.FP8, "FP8 not enabled" assert hasattr(pipeline, "transformer_cache_backend"), "TeaCache not enabled" assert transformer.blocks[0].attn1.attn_backend == "TRTLLM", (