From 96bb8b76598af852ec8338c9c065fe5a9384067f Mon Sep 17 00:00:00 2001 From: dhansen-nvidia <218031328+dhansen-nvidia@users.noreply.github.com> Date: Tue, 21 Apr 2026 11:50:20 -0400 Subject: [PATCH 1/6] [https://nvbugs/6074014][fix] Min-reduce available host memory to ensure that all ranks agree about whether prefetch is enabled (#13161) Signed-off-by: Dan Hansen <1+dhansen-nvidia@users.noreply.github.com> Co-authored-by: Dan Hansen <1+dhansen-nvidia@users.noreply.github.com> --- .../models/checkpoints/hf/weight_loader.py | 42 +++++++++++++++++-- 1 file changed, 38 insertions(+), 4 deletions(-) diff --git a/tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py b/tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py index f47e77a81661..222593173be4 100644 --- a/tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py +++ b/tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py @@ -1,3 +1,17 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import glob import multiprocessing import os @@ -8,13 +22,14 @@ import safetensors import torch import tqdm +from mpi4py import MPI as _MPI from tensorrt_llm._torch.models.checkpoints.base_weight_loader import ( BaseWeightLoader, ConsumableWeightsDict) from tensorrt_llm._torch.models.modeling_utils import ( register_checkpoint_weight_loader, run_concurrently) -from tensorrt_llm._utils import (local_mpi_barrier, local_mpi_rank, - local_mpi_size) +from tensorrt_llm._utils import (ENABLE_MULTI_DEVICE, local_mpi_barrier, + local_mpi_comm, local_mpi_rank, local_mpi_size) from tensorrt_llm.logger import logger from tensorrt_llm.mapping import Mapping @@ -26,6 +41,24 @@ class HfWeightLoader(BaseWeightLoader): Loads weights from SafeTensors/bin/pth files. """ + @staticmethod + def _get_local_available_host_memory() -> int: + """Determine the minimum available memory observed on the local node + and distribute it to all local ranks + + Because psutil.virtual_memory().available is just a snapshot in time, + it is possible for the local ranks to get different numbers due to + timing differences. This can lead to disagreement among the local ranks + as to whether prefetch should be enabled, which causes a deadlock, + because the ranks that think prefetch is enabled will wait at a local + mpi barrier indefinitely for the ranks that do not. + """ + available_host_memory = psutil.virtual_memory().available + if ENABLE_MULTI_DEVICE: + return local_mpi_comm().allreduce(available_host_memory, + op=_MPI.MIN) + return available_host_memory + def load_weights(self, checkpoint_dir: str, mapping: Mapping) -> dict[str, Any]: weight_files = glob.glob(f"{checkpoint_dir}/*.safetensors") @@ -44,8 +77,9 @@ def load_weights(self, checkpoint_dir: str, # If the layer number is overridden, it indicates that only a subset of layers are loaded. # Prefetching all layers is unnecessary. num_layers = int(os.environ.get("TLLM_OVERRIDE_LAYER_NUM", "0")) - enable_prefetch = prefetch_size < psutil.virtual_memory( - ).available * 0.9 and num_layers == 0 + enable_prefetch = (prefetch_size + < self._get_local_available_host_memory() * 0.9 + and num_layers == 0) if enable_prefetch: logger.info( f"Prefetching {prefetch_size / (1024**3):.2f}GB checkpoint files." From 9d66b82666e11dd0bd502e8075e1c16bbbe9f193 Mon Sep 17 00:00:00 2001 From: o-stoner <245287810+o-stoner@users.noreply.github.com> Date: Tue, 21 Apr 2026 11:28:13 -0700 Subject: [PATCH 2/6] [TRTLLM-11339][fix] Wan tests refactor + small transformer fix (#12128) Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com> Signed-off-by: o-stoner <245287810+o-stoner@users.noreply.github.com> --- requirements.txt | 1 + .../visual_gen/models/wan/transformer_wan.py | 57 +- .../test_lists/test-db/l0_b200.yml | 10 +- .../test_lists/test-db/l0_dgx_b200.yml | 4 - tests/unittest/_torch/visual_gen/test_wan.py | 3514 ----------------- .../visual_gen/test_wan21_i2v_pipeline.py | 402 ++ .../visual_gen/test_wan21_i2v_teacache.py | 260 ++ .../visual_gen/test_wan21_t2v_pipeline.py | 655 +++ .../visual_gen/test_wan21_t2v_teacache.py | 242 ++ .../visual_gen/test_wan22_i2v_pipeline.py | 502 +++ .../visual_gen/test_wan22_t2v_pipeline.py | 472 +++ .../_torch/visual_gen/test_wan_i2v.py | 1544 -------- .../_torch/visual_gen/test_wan_transformer.py | 468 +++ 13 files changed, 3043 insertions(+), 5088 deletions(-) delete mode 100644 tests/unittest/_torch/visual_gen/test_wan.py create mode 100644 tests/unittest/_torch/visual_gen/test_wan21_i2v_pipeline.py create mode 100644 tests/unittest/_torch/visual_gen/test_wan21_i2v_teacache.py create mode 100644 tests/unittest/_torch/visual_gen/test_wan21_t2v_pipeline.py create mode 100644 tests/unittest/_torch/visual_gen/test_wan21_t2v_teacache.py create mode 100644 tests/unittest/_torch/visual_gen/test_wan22_i2v_pipeline.py create mode 100644 tests/unittest/_torch/visual_gen/test_wan22_t2v_pipeline.py delete mode 100644 tests/unittest/_torch/visual_gen/test_wan_i2v.py create mode 100644 tests/unittest/_torch/visual_gen/test_wan_transformer.py diff --git a/requirements.txt b/requirements.txt index eaa5d0082f31..110cf75aa64f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,6 +5,7 @@ build colored cuda-python>=13 diffusers>=0.27.0 +ftfy lark mpi4py numpy>=2.0.0,<2.4 # numba 0.63.1 requires numpy<2.4 diff --git a/tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py b/tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py index 791d2bf2a3fe..c35c47fb0662 100644 --- a/tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py +++ b/tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py @@ -102,7 +102,7 @@ def __init__( dtype = model_config.torch_dtype if model_config else None # LayerNorm weights in fp32 (matches internal float32 normalization; avoids bf16/fp32 mismatch). self.norm1 = LayerNorm( - hidden_size=in_features, eps=1e-6, dtype=torch.float32, has_weights=True, has_bias=True + hidden_size=in_features, eps=1e-5, dtype=torch.float32, has_weights=True, has_bias=True ) # Match HF FeedForward structure: Linear(in, in) → GELU → Linear(in, out) @@ -136,7 +136,7 @@ def __init__( ) self.norm2 = LayerNorm( - hidden_size=out_features, eps=1e-6, dtype=torch.float32, has_weights=True, has_bias=True + hidden_size=out_features, eps=1e-5, dtype=torch.float32, has_weights=True, has_bias=True ) if pos_embed_seq_len is not None: @@ -245,6 +245,7 @@ def __init__( head_dim = getattr(config, "attention_head_dim", 128) ffn_dim = getattr(config, "ffn_dim", 8960) eps = getattr(config, "eps", 1e-6) + cross_attn_norm = getattr(config, "cross_attn_norm", True) dtype = model_config.torch_dtype quant_config = model_config.quant_config @@ -287,9 +288,16 @@ def __init__( layer_idx=_layer_idx, ) - self.norm2 = LayerNorm( - hidden_size=hidden_size, eps=eps, dtype=torch.float32, has_weights=True, has_bias=True - ) + if cross_attn_norm: + self.norm2 = LayerNorm( + hidden_size=hidden_size, + eps=eps, + dtype=torch.float32, + has_weights=True, + has_bias=True, + ) + else: + self.norm2 = nn.Identity() self.norm3 = LayerNorm( hidden_size=hidden_size, eps=eps, dtype=torch.float32, has_weights=False, has_bias=False ) @@ -375,32 +383,35 @@ def forward( encoder_hidden_states_text = encoder_hidden_states[:, image_context_length:] # Text cross-attention - attn2_output = self.attn2(norm_x, encoder_hidden_states=encoder_hidden_states_text) + batch_size, seq_len = norm_x.shape[:2] + q, k, v = self.attn2.get_qkv(norm_x, encoder_hidden_states_text) + q, k = self.attn2.apply_qk_norm(q, k) + attn2_output = self.attn2._attn_impl( + q, + k, + v, + batch_size=batch_size, + seq_len=seq_len, + kv_seq_len=encoder_hidden_states_text.shape[1], + ) - # I2V: Additional image cross-attention if image embeddings are present + # I2V: image cross-attention if encoder_hidden_states_img is not None: - batch_size, seq_len = norm_x.shape[:2] - - query = self.attn2.get_qkv(norm_x, None)[0] # Q only - query, _ = self.attn2.apply_qk_norm(query, query) - key_img = self.add_k_proj(encoder_hidden_states_img) value_img = self.add_v_proj(encoder_hidden_states_img) key_img = self.norm_added_k(key_img) - - query = query.view(batch_size, seq_len, self.num_heads, self.head_dim) - key_img = key_img.view( - batch_size, encoder_hidden_states_img.shape[1], self.num_heads, self.head_dim - ) - value_img = value_img.view( - batch_size, encoder_hidden_states_img.shape[1], self.num_heads, self.head_dim + attn_img_output = self.attn2._attn_impl( + q, + key_img, + value_img, + batch_size=batch_size, + seq_len=seq_len, + kv_seq_len=encoder_hidden_states_img.shape[1], ) - - attn_img_output = self.attn2._attn_impl(query, key_img, value_img) - attn2_output = attn2_output + attn_img_output - x = x + attn2_output + # Apply to_out once to the combined (text + image) attention output + x = x + self.attn2.to_out[0](attn2_output) # 3. Feed-forward normed = self.norm3(x.float()) * (1 + c_scale_msa) + c_shift_msa diff --git a/tests/integration/test_lists/test-db/l0_b200.yml b/tests/integration/test_lists/test-db/l0_b200.yml index 3f7e2f7aec9d..9260140fbd52 100644 --- a/tests/integration/test_lists/test-db/l0_b200.yml +++ b/tests/integration/test_lists/test-db/l0_b200.yml @@ -183,8 +183,6 @@ l0_b200: - unittest/_torch/visual_gen/test_attention_perf.py - unittest/_torch/visual_gen/test_trtllm_serve_endpoints.py - unittest/_torch/visual_gen/test_trtllm_serve_e2e.py - - unittest/_torch/visual_gen/test_wan.py -k "not TestWanTwoStageTransformer" - - unittest/_torch/visual_gen/test_wan_i2v.py - unittest/_torch/visual_gen/test_model_loader.py - unittest/_torch/visual_gen/test_flux_transformer.py - unittest/_torch/visual_gen/test_flux_attention.py @@ -192,6 +190,13 @@ l0_b200: - unittest/_torch/visual_gen/test_ltx2_transformer.py - unittest/_torch/visual_gen/test_ltx2_attention.py - unittest/_torch/visual_gen/test_ltx2_pipeline.py + - unittest/_torch/visual_gen/test_wan21_i2v_pipeline.py + - unittest/_torch/visual_gen/test_wan21_t2v_pipeline.py + - unittest/_torch/visual_gen/test_wan22_i2v_pipeline.py + - unittest/_torch/visual_gen/test_wan22_t2v_pipeline.py + - unittest/_torch/visual_gen/test_wan21_i2v_teacache.py + - unittest/_torch/visual_gen/test_wan21_t2v_teacache.py + - unittest/_torch/visual_gen/test_wan_transformer.py - examples/test_visual_gen.py::test_wan_t2v_example # - examples/test_visual_gen.py # ------------- Host perf module regression tests (6 representative scenarios) --------------- @@ -278,7 +283,6 @@ l0_b200: - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTEDSL-mtp_nextn=0-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestSeedOss_36B::test_auto_dtype - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8B_Instruct_RocketKV::test_auto_dtype - - unittest/_torch/visual_gen/test_wan.py::TestWanTwoStageTransformer # ------------- AutoDeploy Backend Stages --------------- - condition: ranges: diff --git a/tests/integration/test_lists/test-db/l0_dgx_b200.yml b/tests/integration/test_lists/test-db/l0_dgx_b200.yml index dc908d922471..d5d9602a7815 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_b200.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_b200.yml @@ -43,10 +43,6 @@ l0_dgx_b200: - kv_cache/test_kv_cache_v2_scheduler.py::TestKVCacheV2DSv3Lite::test_mtp_eviction # ------------- VisualGen multi-GPU tests --------------- - unittest/_torch/visual_gen/multi_gpu - - unittest/_torch/visual_gen/test_wan.py::TestWanParallelism::test_cfg_2gpu_correctness - - unittest/_torch/visual_gen/test_wan.py::TestWanCombinedOptimizations::test_all_optimizations_combined - - unittest/_torch/visual_gen/test_wan_i2v.py::TestWanI2VParallelism::test_cfg_2gpu_correctness - - unittest/_torch/visual_gen/test_wan_i2v.py::TestWanI2VCombinedOptimizations::test_all_optimizations_combined - unittest/_torch/visual_gen/test_flux_pipeline.py::TestFluxParallelism::test_ulysses_2gpu_correctness - unittest/_torch/visual_gen/test_flux_pipeline.py::TestFluxCombinedOptimizations::test_all_optimizations_combined - condition: diff --git a/tests/unittest/_torch/visual_gen/test_wan.py b/tests/unittest/_torch/visual_gen/test_wan.py deleted file mode 100644 index 729e63c51dda..000000000000 --- a/tests/unittest/_torch/visual_gen/test_wan.py +++ /dev/null @@ -1,3514 +0,0 @@ -"""Comprehensive unit tests for the Wan model and pipeline.""" - -import os - -os.environ["TLLM_DISABLE_MPI"] = "1" - -import unittest -from copy import deepcopy -from pathlib import Path -from types import SimpleNamespace - -import pytest -import torch -import torch.distributed as dist -import torch.multiprocessing as mp -import torch.nn.functional as F -from diffusers import WanTransformer3DModel as HFWanTransformer3DModel -from parameterized import parameterized - -from tensorrt_llm._torch.modules.linear import Linear -from tensorrt_llm._torch.visual_gen.config import ( - AttentionConfig, - DiffusionModelConfig, - ParallelConfig, - PipelineComponent, - TeaCacheConfig, - VisualGenArgs, -) -from tensorrt_llm._torch.visual_gen.models.wan.transformer_wan import WanTransformer3DModel -from tensorrt_llm._torch.visual_gen.pipeline_loader import PipelineLoader -from tensorrt_llm.models.modeling_utils import QuantConfig -from tensorrt_llm.quantization.mode import QuantAlgo - - -@pytest.fixture(autouse=True, scope="module") -def _cleanup_mpi_env(): - """Clean up TLLM_DISABLE_MPI env var after tests complete.""" - yield - os.environ.pop("TLLM_DISABLE_MPI", None) - - -def _llm_models_root() -> str: - """Return LLM_MODELS_ROOT path if it is set in env, assert when it's set but not a valid path.""" - root = Path("/home/scratch.trt_llm_data_ci/llm-models/") - if "LLM_MODELS_ROOT" in os.environ: - root = Path(os.environ["LLM_MODELS_ROOT"]) - if not root.exists(): - root = Path("/scratch.trt_llm_data/llm-models/") - assert root.exists(), ( - "You shall set LLM_MODELS_ROOT env or be able to access scratch.trt_llm_data to run this test" - ) - return str(root) - - -# Checkpoint paths for integration tests -CHECKPOINT_PATH = os.environ.get( - "DIFFUSION_MODEL_PATH", - os.path.join(_llm_models_root(), "Wan2.1-T2V-1.3B-Diffusers"), -) -# Wan 2.2 TI2V-5B: BF16 base, FP8 pre-quantized, NVFP4 pre-quantized -CHECKPOINT_PATH_WAN22_BF16 = os.environ.get( - "DIFFUSION_MODEL_PATH_WAN22_BF16", - os.path.join(_llm_models_root(), "Wan2.2-TI2V-5B-Diffusers"), -) -CHECKPOINT_PATH_WAN22_FP8 = os.environ.get( - "DIFFUSION_MODEL_PATH_WAN22_FP8", - os.path.join(_llm_models_root(), "Wan2.2-TI2V-5B-Diffusers-FP8"), -) -CHECKPOINT_PATH_WAN22_NVFP4 = os.environ.get( - "DIFFUSION_MODEL_PATH_WAN22_NVFP4", - os.path.join(_llm_models_root(), "Wan2.2-TI2V-5B-Diffusers-NVFP4"), -) -# Wan 2.2 T2V (two-stage transformer) -CHECKPOINT_PATH_WAN22_T2V = os.environ.get( - "DIFFUSION_MODEL_PATH_WAN22_T2V", - os.path.join(_llm_models_root(), "Wan2.2-T2V-A14B-Diffusers"), -) -CHECKPOINT_PATH_WAN22_T2V_NVFP4 = os.environ.get( - "DIFFUSION_MODEL_PATH_WAN22_T2V_NVFP4", - os.path.join(_llm_models_root(), "Wan2.2-T2V-A14B-Diffusers-NVFP4"), -) -SKIP_COMPONENTS = [ - PipelineComponent.TEXT_ENCODER, - PipelineComponent.VAE, - PipelineComponent.TOKENIZER, - PipelineComponent.SCHEDULER, -] - - -def is_wan21_checkpoint() -> bool: - """Check if DIFFUSION_MODEL_PATH is Wan 2.1 (contains '2.1' in path).""" - return "2.1" in CHECKPOINT_PATH - - -def is_wan22_checkpoint() -> bool: - """Check if DIFFUSION_MODEL_PATH is Wan 2.2 (contains '2.2' in path).""" - return "2.2" in CHECKPOINT_PATH_WAN22_T2V - - -WAN_1_3B_CONFIG = { - "attention_head_dim": 128, - "eps": 1e-06, - "ffn_dim": 8960, - "freq_dim": 256, - "in_channels": 16, - "num_attention_heads": 12, - "num_layers": 30, - "out_channels": 16, - "patch_size": [1, 2, 2], - "qk_norm": "rms_norm_across_heads", - "rope_max_seq_len": 1024, - "text_dim": 4096, - "torch_dtype": "bfloat16", - "cross_attn_norm": True, -} - - -def reduce_wan_config(mem_for_full_model: int, config_dict: dict): - """Reduce model size if insufficient GPU memory.""" - _, total_mem = torch.cuda.mem_get_info() - if total_mem < mem_for_full_model: - model_fraction = total_mem / mem_for_full_model - num_layers = max(1, int(config_dict["num_layers"] * model_fraction)) - config_dict["num_layers"] = min(num_layers, 4) - - -def setup_distributed(rank, world_size, backend="nccl"): - """Initialize distributed process group for multi-GPU tests.""" - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "12355" - os.environ["RANK"] = str(rank) - os.environ["WORLD_SIZE"] = str(world_size) - - dist.init_process_group(backend=backend, rank=rank, world_size=world_size) - torch.cuda.set_device(rank) - - -def cleanup_distributed(): - """Clean up distributed process group.""" - if dist.is_initialized(): - dist.destroy_process_group() - - -def _run_cfg_worker(rank, world_size, checkpoint_path, inputs_list, return_dict): - """Worker function for CFG Parallelism multi-GPU test. - - Must be module-level for multiprocessing.spawn() pickling. - """ - try: - setup_distributed(rank, world_size) - - from tensorrt_llm._torch.visual_gen.config import ParallelConfig, VisualGenArgs - from tensorrt_llm._torch.visual_gen.pipeline_loader import PipelineLoader - - # Load pipeline with CFG parallel - args = VisualGenArgs( - checkpoint_path=checkpoint_path, - device=f"cuda:{rank}", - dtype="bfloat16", - skip_components=SKIP_COMPONENTS, - parallel=ParallelConfig(dit_cfg_size=world_size), - ) - pipeline = PipelineLoader(args).load(skip_warmup=True) - - # Verify CFG parallel configuration - assert pipeline.model_config.visual_gen_mapping.cfg_size == world_size, ( - f"Expected cfg_size={world_size}, got {pipeline.model_config.visual_gen_mapping.cfg_size}" - ) - - # Load inputs on this GPU - prompt_embeds = inputs_list[0].to(f"cuda:{rank}") - neg_prompt_embeds = inputs_list[1].to(f"cuda:{rank}") - latents = inputs_list[2].to(f"cuda:{rank}") - timestep = inputs_list[3].to(f"cuda:{rank}") - - # Setup CFG config - cfg_config = pipeline._setup_cfg_config( - guidance_scale=5.0, - prompt_embeds=prompt_embeds, - neg_prompt_embeds=neg_prompt_embeds, - ) - - # Verify CFG parallel is enabled - assert cfg_config["enabled"], f"Rank {rank}: CFG parallel not enabled" - assert cfg_config["cfg_size"] == world_size, f"Rank {rank}: Wrong cfg_size" - - expected_cfg_group = rank // cfg_config["ulysses_size"] - assert cfg_config["cfg_rank"] == expected_cfg_group, ( - f"Rank {rank}: Wrong cfg_rank. Expected {expected_cfg_group}, got {cfg_config['cfg_rank']}" - ) - - if rank == 0: - print(f"[CFG Rank {rank}] Loaded with cfg_size={world_size}") - print(f" cfg_rank: {cfg_config['cfg_rank']}") - print(f" local_embeds shape: {cfg_config['local_embeds'].shape}") - print(f" Using {'positive' if cfg_config['cfg_rank'] == 0 else 'negative'} prompts") - - # Verify prompt splitting - rank 0 gets positive, rank 1 gets negative - expected_embeds = prompt_embeds if cfg_config["cfg_rank"] == 0 else neg_prompt_embeds - assert torch.allclose(cfg_config["local_embeds"], expected_embeds), ( - f"Rank {rank}: local_embeds doesn't match expected" - f"{'positive' if cfg_config['cfg_rank'] == 0 else 'negative'} embeds" - ) - - # Run single denoising step with CFG parallel - def forward_fn( - latents, extra_stream_latents, timestep, encoder_hidden_states, extra_tensors - ): - return pipeline.transformer( # noqa: F821 - hidden_states=latents, - timestep=timestep, - encoder_hidden_states=encoder_hidden_states, - ) - - with torch.no_grad(): - noise_pred, _, _, _ = pipeline._denoise_step_cfg_parallel( - latents=latents, - extra_stream_latents={}, - timestep=timestep, - local_embeds=cfg_config["local_embeds"], - forward_fn=forward_fn, - guidance_scale=5.0, - guidance_rescale=0.0, - ulysses_size=cfg_config["ulysses_size"], - local_extras={}, - ) - - # Validate output - assert not torch.isnan(noise_pred).any(), f"Rank {rank}: Output contains NaN" - assert not torch.isinf(noise_pred).any(), f"Rank {rank}: Output contains Inf" - - # Return output from rank 0 - if rank == 0: - return_dict["output"] = noise_pred.cpu() - print(f"[CFG Rank {rank}] ✓ Output shape: {noise_pred.shape}") - print( - f"[CFG Rank {rank}] ✓ Output range: [{noise_pred.min():.4f}, {noise_pred.max():.4f}]" - ) - - del pipeline - torch.cuda.empty_cache() - - finally: - cleanup_distributed() - - -def _run_all_optimizations_worker(rank, world_size, checkpoint_path, inputs_list, return_dict): - """Worker function for all optimizations combined test (FP8 + TeaCache + TRTLLM + CFG). - - Must be module-level for multiprocessing.spawn() pickling. - """ - try: - setup_distributed(rank, world_size) - - # Load pipeline with ALL optimizations - args_full = VisualGenArgs( - checkpoint_path=checkpoint_path, - device=f"cuda:{rank}", - dtype="bfloat16", - skip_components=SKIP_COMPONENTS, - quant_config={"quant_algo": "FP8", "dynamic": True}, - cache=TeaCacheConfig( - teacache_thresh=0.2, - use_ret_steps=True, - ), - attention=AttentionConfig(backend="TRTLLM"), - parallel=ParallelConfig(dit_cfg_size=world_size), - ) - pipeline = PipelineLoader(args_full).load(skip_warmup=True) - transformer = pipeline.transformer.eval() - - # Verify all optimizations are enabled - assert pipeline.model_config.visual_gen_mapping.cfg_size == world_size, ( - "CFG parallel not enabled" - ) - assert transformer.model_config.quant_config.quant_algo == QuantAlgo.FP8, "FP8 not enabled" - assert ( - getattr(pipeline, "cache_accelerator", None) is not None - and pipeline.cache_accelerator.is_enabled() - ), "TeaCache not enabled" - assert transformer.blocks[0].attn1.attn_backend == "TRTLLM", ( - "TRTLLM not enabled for self-attn" - ) - - if rank == 0: - print(f" ✓ All optimizations verified on rank {rank}:") - print(f" - FP8 quantization: {transformer.model_config.quant_config.quant_algo}") - print(" - TeaCache: enabled") - print(f" - TRTLLM attention: {transformer.blocks[0].attn1.attn_backend}") - print(f" - CFG Parallelism: cfg_size={world_size}") - - # Initialize TeaCache for single-step inference - if getattr(pipeline, "cache_accelerator", None) and pipeline.cache_accelerator.is_enabled(): - pipeline.cache_accelerator.refresh(num_inference_steps=1) - - # Load inputs on this GPU - prompt_embeds = inputs_list[0].to(f"cuda:{rank}") - neg_prompt_embeds = inputs_list[1].to(f"cuda:{rank}") - latents = inputs_list[2].to(f"cuda:{rank}") - timestep = inputs_list[3].to(f"cuda:{rank}") - - # Setup CFG config - cfg_config = pipeline._setup_cfg_config( - guidance_scale=5.0, - prompt_embeds=prompt_embeds, - neg_prompt_embeds=neg_prompt_embeds, - ) - - assert cfg_config["enabled"], "CFG parallel not enabled" - - # Run single denoising step with all optimizations - def forward_fn( - latents, extra_stream_latents, timestep, encoder_hidden_states, extra_tensors - ): - return transformer( # noqa: F821 - hidden_states=latents, - timestep=timestep, - encoder_hidden_states=encoder_hidden_states, - ) - - with torch.no_grad(): - noise_pred, _, _, _ = pipeline._denoise_step_cfg_parallel( - latents=latents, - extra_stream_latents={}, - timestep=timestep, - local_embeds=cfg_config["local_embeds"], - forward_fn=forward_fn, - guidance_scale=5.0, - guidance_rescale=0.0, - ulysses_size=cfg_config["ulysses_size"], - local_extras={}, - ) - - # Validate output - assert not torch.isnan(noise_pred).any(), f"Rank {rank}: Output contains NaN" - assert not torch.isinf(noise_pred).any(), f"Rank {rank}: Output contains Inf" - - # Return output from rank 0 - if rank == 0: - return_dict["output"] = noise_pred.cpu() - print(f" ✓ Combined optimization output shape: {noise_pred.shape}") - print( - f" ✓ Combined optimization range: [{noise_pred.min():.4f}, {noise_pred.max():.4f}]" - ) - - del pipeline, transformer - torch.cuda.empty_cache() - - finally: - cleanup_distributed() - - -# ============================================================================= -# Basic Unit Tests -# ============================================================================= - - -class TestWan(unittest.TestCase): - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - def _create_model_config(self, config_dict): - """Helper to create DiffusionModelConfig from test config dict.""" - # Create pretrained_config as SimpleNamespace - pretrained_config = SimpleNamespace(**config_dict) - - # Use default quantization (no quantization for unit tests) - quant_config = QuantConfig() - dynamic_weight_quant = False - dynamic_activation_quant = False - - # Create DiffusionModelConfig - model_config = DiffusionModelConfig( - pretrained_config=pretrained_config, - quant_config=quant_config, - quant_config_dict=None, - dynamic_weight_quant=dynamic_weight_quant, - force_dynamic_quantization=dynamic_activation_quant, - skip_create_weights_in_init=False, # Create weights immediately for testing - ) - return model_config - - def test_wan_model_structure(self): - """Test that model structure matches HuggingFace naming.""" - config = deepcopy(WAN_1_3B_CONFIG) - config["num_layers"] = 1 - hidden_size = config["num_attention_heads"] * config["attention_head_dim"] - config["hidden_size"] = hidden_size - - model_config = self._create_model_config(config) - - model = WanTransformer3DModel(model_config=model_config) - - # Check FFN structure - param_names = [n for n in model.state_dict().keys() if "ffn" in n] - print("\n[DEBUG] FFN parameter names in TRT-LLM model:") - for pn in param_names[:5]: - print(f" - {pn}") - - # Verify expected structure exists (MLP uses up_proj/down_proj) - assert any("ffn.up_proj" in n for n in param_names), "Missing ffn.up_proj structure" - assert any("ffn.down_proj" in n for n in param_names), "Missing ffn.down_proj structure" - - def test_wan_sanity(self): - """Basic sanity test that the model can run forward pass.""" - config = deepcopy(WAN_1_3B_CONFIG) - dtype = getattr(torch, config["torch_dtype"]) - # Use fewer layers for sanity test - config["num_layers"] = 2 - - hidden_size = config["num_attention_heads"] * config["attention_head_dim"] - config["hidden_size"] = hidden_size - - # Create model config - model_config = self._create_model_config(config) - - # Create model with model_config - model = WanTransformer3DModel(model_config=model_config).to(self.DEVICE, dtype=dtype).eval() - - batch_size = 1 - num_frames = 1 - height, width = 64, 64 - seq_len = 128 - generator = torch.Generator(device=self.DEVICE).manual_seed(42) - - hidden_states = torch.randn( - batch_size, - config["in_channels"], - num_frames, - height, - width, - generator=generator, - device=self.DEVICE, - dtype=dtype, - ) - timestep = torch.tensor([50], device=self.DEVICE, dtype=torch.long) - encoder_hidden_states = torch.randn( - batch_size, - seq_len, - config["text_dim"], - generator=generator, - device=self.DEVICE, - dtype=dtype, - ) - - with torch.inference_mode(): - output = model( - hidden_states=hidden_states, - timestep=timestep, - encoder_hidden_states=encoder_hidden_states, - ) - - self.assertEqual(output.shape, hidden_states.shape) - - @parameterized.expand( - [ - ("1_3b", WAN_1_3B_CONFIG), - ] - ) - @torch.no_grad() - def test_wan_allclose_to_hf(self, name, config_template): - """Test TRT-LLM transformer matches HuggingFace output (BF16).""" - torch.random.manual_seed(42) - config = deepcopy(config_template) - dtype = getattr(torch, config["torch_dtype"]) - - mem_for_full_model = (2 + 1) * 1.3 * 2**30 - reduce_wan_config(mem_for_full_model, config) - - if config["num_layers"] <= 0: - self.skipTest("Insufficient memory for a single Wan layer") - - hidden_size = config["num_attention_heads"] * config["attention_head_dim"] - - # Create HuggingFace model (random weights) - hf_model = ( - HFWanTransformer3DModel( - patch_size=config["patch_size"], - num_attention_heads=config["num_attention_heads"], - attention_head_dim=config["attention_head_dim"], - in_channels=config["in_channels"], - out_channels=config["out_channels"], - text_dim=config["text_dim"], - freq_dim=config["freq_dim"], - ffn_dim=config["ffn_dim"], - num_layers=config["num_layers"], - cross_attn_norm=config["cross_attn_norm"], - qk_norm=config["qk_norm"], - eps=config["eps"], - ) - .to(self.DEVICE, dtype=dtype) - .eval() - ) - - # Create TRT-LLM model with model_config - config["hidden_size"] = hidden_size - model_config = self._create_model_config(config) - - trtllm_model = ( - WanTransformer3DModel(model_config=model_config).to(self.DEVICE, dtype=dtype).eval() - ) - - # Copy weights from HF to TRT-LLM - loaded_count = self._load_weights_from_hf(trtllm_model, hf_model.state_dict()) - print(f"[DEBUG] Loaded {loaded_count} weight tensors from HF to TRT-LLM") - - # Create test inputs - batch_size = 1 - num_frames = 1 - height, width = 64, 64 - seq_len = 128 - generator = torch.Generator(device=self.DEVICE).manual_seed(42) - - hidden_states = torch.randn( - batch_size, - config["in_channels"], - num_frames, - height, - width, - generator=generator, - device=self.DEVICE, - dtype=dtype, - ) - timestep = torch.tensor([50], device=self.DEVICE, dtype=torch.long) - encoder_hidden_states = torch.randn( - batch_size, - seq_len, - config["text_dim"], - generator=generator, - device=self.DEVICE, - dtype=dtype, - ) - - # Run both models - with ( - torch.inference_mode(), - torch.backends.cuda.sdp_kernel( - enable_flash=False, enable_math=True, enable_mem_efficient=False - ), - ): - hf_output = hf_model( - hidden_states=hidden_states, - timestep=timestep, - encoder_hidden_states=encoder_hidden_states, - return_dict=False, - )[0] - - trtllm_output = trtllm_model( - hidden_states=hidden_states, - timestep=timestep, - encoder_hidden_states=encoder_hidden_states, - ) - - # Compare outputs - hf_output = hf_output.float() - trtllm_output = trtllm_output.float() - - # Debug: Check for NaN/Inf - hf_has_nan = torch.isnan(hf_output).any().item() - trtllm_has_nan = torch.isnan(trtllm_output).any().item() - hf_has_inf = torch.isinf(hf_output).any().item() - trtllm_has_inf = torch.isinf(trtllm_output).any().item() - - print("\n[DEBUG] Output validation:") - print(f" HF has NaN: {hf_has_nan}, Inf: {hf_has_inf}") - print(f" TRT-LLM has NaN: {trtllm_has_nan}, Inf: {trtllm_has_inf}") - - if not (hf_has_nan or trtllm_has_nan or hf_has_inf or trtllm_has_inf): - # Compute detailed comparison metrics - diff = (trtllm_output - hf_output).abs() - max_diff = diff.max().item() - mean_diff = diff.mean().item() - - cos_sim = torch.nn.functional.cosine_similarity( - trtllm_output.flatten(), hf_output.flatten(), dim=0 - ).item() - - print("\n[DEBUG] Comparison metrics:") - print(f" Max absolute diff: {max_diff:.6f}") - print(f" Mean absolute diff: {mean_diff:.6f}") - print(f" Cosine similarity: {cos_sim:.6f}") - print(f" HF output range: [{hf_output.min():.4f}, {hf_output.max():.4f}]") - print(f" TRT-LLM output range: [{trtllm_output.min():.4f}, {trtllm_output.max():.4f}]") - - torch.testing.assert_close( - trtllm_output, hf_output, atol=0.4, rtol=0.4, msg=f"Output mismatch for {name} config" - ) - - def _load_weights_from_hf(self, trtllm_model, hf_state_dict): - """Load weights from HuggingFace model to TRT-LLM model. - - TRT-LLM structure: - - blocks.0.attn1.qkv_proj (fused QKV for self-attention) - - blocks.0.attn2.to_q/to_k/to_v (separate for cross-attention) - - blocks.0.attn1.to_out.0 and blocks.0.attn2.to_out.0 - - HuggingFace structure: - - blocks.0.attn1.to_q/to_k/to_v (separate Q/K/V) - - blocks.0.attn2.to_q/to_k/to_v (separate Q/K/V) - - blocks.0.attn1.to_out.0 and blocks.0.attn2.to_out.0 - """ - loaded_count = 0 - missing_weights = [] - - def load_linear(module, trtllm_key, hf_key, sd): - """Load weights from HF key into TRT-LLM module.""" - if f"{hf_key}.weight" in sd: - weight_dict = {"weight": sd[f"{hf_key}.weight"]} - if f"{hf_key}.bias" in sd: - weight_dict["bias"] = sd[f"{hf_key}.bias"] - module.load_weights([weight_dict]) - return 1 - else: - missing_weights.append(hf_key) - return 0 - - for name, module in trtllm_model.named_modules(): - if isinstance(module, Linear): - # Self-attention fused QKV: blocks.0.attn1.qkv_proj - # Load from HF separate Q/K/V: blocks.0.attn1.to_q/to_k/to_v - if "attn1.qkv_proj" in name: - base = name.replace(".qkv_proj", "") - q_key, k_key, v_key = f"{base}.to_q", f"{base}.to_k", f"{base}.to_v" - if f"{q_key}.weight" in hf_state_dict: - q_dict = {"weight": hf_state_dict[f"{q_key}.weight"]} - k_dict = {"weight": hf_state_dict[f"{k_key}.weight"]} - v_dict = {"weight": hf_state_dict[f"{v_key}.weight"]} - if f"{q_key}.bias" in hf_state_dict: - q_dict["bias"] = hf_state_dict[f"{q_key}.bias"] - k_dict["bias"] = hf_state_dict[f"{k_key}.bias"] - v_dict["bias"] = hf_state_dict[f"{v_key}.bias"] - module.load_weights([q_dict, k_dict, v_dict]) - loaded_count += 1 - - # Cross-attention separate Q/K/V: blocks.0.attn2.to_q (same path as HF) - elif "attn2.to_q" in name or "attn2.to_k" in name or "attn2.to_v" in name: - # Direct mapping - TRT-LLM and HF use same paths for cross-attention - loaded_count += load_linear(module, name, name, hf_state_dict) - - # Output projections: blocks.0.attn1.to_out.0 (same path as HF) - elif ".to_out" in name: - # Direct mapping - TRT-LLM and HF use same paths for output projections - loaded_count += load_linear(module, name, name, hf_state_dict) - - # FFN layers: TRT-LLM uses up_proj/down_proj, HF uses net.0.proj/net.2 - elif "ffn.up_proj" in name: - hf_key = name.replace(".ffn.up_proj", ".ffn.net.0.proj") - loaded_count += load_linear(module, name, hf_key, hf_state_dict) - elif "ffn.down_proj" in name: - hf_key = name.replace(".ffn.down_proj", ".ffn.net.2") - loaded_count += load_linear(module, name, hf_key, hf_state_dict) - - # Other layers: direct mapping - elif "condition_embedder" in name or "proj_out" in name: - loaded_count += load_linear(module, name, name, hf_state_dict) - - else: - # Direct mapping for any other Linear modules - loaded_count += load_linear(module, name, name, hf_state_dict) - - elif hasattr(module, "weight") and f"{name}.weight" in hf_state_dict: - # Norms & embeddings - with torch.no_grad(): - module.weight.copy_(hf_state_dict[f"{name}.weight"]) - if ( - getattr(module, "bias", None) is not None - and f"{name}.bias" in hf_state_dict - ): - module.bias.copy_(hf_state_dict[f"{name}.bias"]) - loaded_count += 1 - - # Load scale_shift_table parameters - for name, param in trtllm_model.named_parameters(): - if "scale_shift_table" in name and name in hf_state_dict: - with torch.no_grad(): - param.copy_(hf_state_dict[name].view(param.shape)) - loaded_count += 1 - - if missing_weights: - print(f"[DEBUG] Missing {len(missing_weights)} weights:") - for mw in missing_weights[:10]: # Show first 10 - print(f" - {mw}") - - return loaded_count - - def _load_weights_from_state_dict(self, model, state_dict): - """Load weights from state_dict into model (same structure).""" - for name, module in model.named_modules(): - if isinstance(module, Linear): - weight_key = f"{name}.weight" - if weight_key in state_dict: - weight_dict = {"weight": state_dict[weight_key]} - bias_key = f"{name}.bias" - if bias_key in state_dict: - weight_dict["bias"] = state_dict[bias_key] - module.load_weights([weight_dict]) - - elif hasattr(module, "weight") and f"{name}.weight" in state_dict: - with torch.no_grad(): - module.weight.copy_(state_dict[f"{name}.weight"]) - if getattr(module, "bias", None) is not None and f"{name}.bias" in state_dict: - module.bias.copy_(state_dict[f"{name}.bias"]) - - # Load parameters - for name, param in model.named_parameters(): - if name in state_dict: - with torch.no_grad(): - param.copy_(state_dict[name].view(param.shape)) - - -# ============================================================================= -# Pipeline Test - Require Real Checkpoint -# ============================================================================= - - -@pytest.fixture -def checkpoint_exists(): - """Check if checkpoint path is set and exists.""" - return CHECKPOINT_PATH and os.path.exists(CHECKPOINT_PATH) - - -@pytest.fixture(autouse=True) -def cleanup_gpu_memory(): - """Automatically cleanup GPU memory after each test to prevent OOM errors. - - This fixture runs automatically after every test in this file. - It performs garbage collection and clears CUDA cache to free up GPU memory. - """ - yield # Test runs here - # Cleanup after test completes - import gc - - gc.collect() - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - -def _is_fp32_layernorm_param(param_name: str) -> bool: - """True if param is a LayerNorm weight/bias we keep in float32. Only LayerNorm (norm1/norm2/norm3/norm_out).""" - if not param_name.endswith((".weight", ".bias")): - return False - # blocks..norm1, norm2, norm3 (LayerNorm only; attn norm_q/norm_k are RMSNorm) - if ".norm" in param_name and "blocks." in param_name: - parts = param_name.split(".") - for p in parts: - if p in ("norm1", "norm2", "norm3"): - return True - return False - # top-level norm_out (LayerNorm) - if param_name == "norm_out.weight" or param_name == "norm_out.bias": - return True - # condition_embedder.norm1, norm2 (LayerNorm) - if param_name.startswith("condition_embedder.") and ".norm" in param_name: - return True - return False - - -class TestWanPipeline: - """Pipeline tests for Wan pipeline loading with PipelineLoader. - - These tests require a real checkpoint (set DIFFUSION_MODEL_PATH env var). - They test the full loading flow: config → model → weight loading → inference. - """ - - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - def test_load_wan_pipeline_basic(self, checkpoint_exists): - """Test loading Wan pipeline without quantization via PipelineLoader.""" - if not checkpoint_exists: - pytest.skip("Checkpoint not available. Set DIFFUSION_MODEL_PATH.") - if not is_wan21_checkpoint(): - pytest.skip( - "This test requires Wan 2.1 checkpoint (single-stage). Use DIFFUSION_MODEL_PATH with '2.1' in the path." - ) - - args = VisualGenArgs( - checkpoint_path=CHECKPOINT_PATH, - device="cuda", - dtype="bfloat16", - skip_components=SKIP_COMPONENTS, - ) - pipeline = PipelineLoader(args).load(skip_warmup=True) - - # Verify pipeline loaded correctly - assert pipeline.transformer is not None - assert len(pipeline.transformer.blocks) > 0 - - # Verify weights are loaded - # Check that non-scale parameters are bfloat16 - bf16_count = 0 - f32_scale_count = 0 - for name, param in pipeline.transformer.named_parameters(): - assert param.device.type == "cuda", f"Parameter {name} not on CUDA" - if "scale" in name.lower(): - # Scale parameters can stay float32 for FP8 kernels - assert param.dtype in [torch.float32, torch.bfloat16], ( - f"Scale param {name} has unexpected dtype {param.dtype}" - ) - if param.dtype == torch.float32: - f32_scale_count += 1 - elif _is_fp32_layernorm_param(name): - # LayerNorm (norm1/norm2/norm3/norm_out) use float32; RMSNorm (norm_q, norm_k, etc.) stay bf16 - assert param.dtype == torch.float32, ( - f"LayerNorm param {name} expected float32 but got {param.dtype}" - ) - else: - # Non-scale parameters should be bfloat16 - assert param.dtype == torch.bfloat16, ( - f"Parameter {name} expected bfloat16 but got {param.dtype}" - ) - bf16_count += 1 - - assert bf16_count > 0, "Should have at least some bfloat16 parameters" - print( - f"\n[Pipeline] BF16 pipeline loaded: {bf16_count} bf16 params" - f"\n{f32_scale_count} f32 scale params, {len(pipeline.transformer.blocks)} blocks" - ) - - @pytest.mark.parametrize("quant_algo", ["FP8", "FP8_BLOCK_SCALES"]) - def test_load_wan_pipeline_with_quantization(self, checkpoint_exists, quant_algo): - """Test loading Wan with FP8 quantization (per-tensor or blockwise).""" - if not checkpoint_exists: - pytest.skip("Checkpoint not available. Set DIFFUSION_MODEL_PATH.") - if not is_wan21_checkpoint(): - pytest.skip( - "This test requires Wan 2.1 checkpoint. Use DIFFUSION_MODEL_PATH with '2.1' in the path." - ) - - args = VisualGenArgs( - checkpoint_path=CHECKPOINT_PATH, - device="cuda", - dtype="bfloat16", - skip_components=SKIP_COMPONENTS, - quant_config={"quant_algo": quant_algo, "dynamic": True}, - ) - pipeline = PipelineLoader(args).load(skip_warmup=True) - - # Verify FP8 weights in transformer blocks - found_fp8 = False - for name, module in pipeline.transformer.named_modules(): - if isinstance(module, Linear): - if "blocks." in name and hasattr(module, "weight") and module.weight is not None: - assert module.weight.dtype == torch.float8_e4m3fn, ( - f"Linear {name} should have FP8 weight, got {module.weight.dtype}" - ) - assert hasattr(module, "weight_scale"), f"Linear {name} missing weight_scale" - found_fp8 = True - print(f"[{quant_algo}] FP8 layer {name}: weight {module.weight.shape}") - break - - assert found_fp8, f"No FP8 Linear modules found for {quant_algo}" - - def test_load_wan_pipeline_with_nvfp4_quantization(self, checkpoint_exists): - """Test loading Wan with NVFP4 dynamic quantization.""" - if not checkpoint_exists: - pytest.skip("Checkpoint not available. Set DIFFUSION_MODEL_PATH.") - if not is_wan21_checkpoint(): - pytest.skip( - "This test requires Wan 2.1 checkpoint. Use DIFFUSION_MODEL_PATH with '2.1' in the path." - ) - - from tensorrt_llm.quantization.utils import fp4_utils - - args = VisualGenArgs( - checkpoint_path=CHECKPOINT_PATH, - device="cuda", - dtype="bfloat16", - skip_components=SKIP_COMPONENTS, - quant_config={ - "quant_algo": "NVFP4", - "dynamic": True, - }, - ) - pipeline = PipelineLoader(args).load(skip_warmup=True) - - # Verify NVFP4 weights in transformer blocks - found_nvfp4 = False - for name, module in pipeline.transformer.named_modules(): - if isinstance(module, Linear): - if "blocks." in name and hasattr(module, "weight") and module.weight is not None: - # NVFP4 uses packed FP4 format (float4_e2m1x2) - assert module.weight.dtype == fp4_utils.float4_e2m1x2, ( - f"Linear {name} should have NVFP4 weight, got {module.weight.dtype}" - ) - assert hasattr(module, "weight_scale"), f"Linear {name} missing weight_scale" - assert hasattr(module, "weight_scale_2"), ( - f"Linear {name} missing weight_scale_2" - ) - found_nvfp4 = True - print(f"[NVFP4] NVFP4 layer {name}: weight {module.weight.shape}") - break - - assert found_nvfp4, "No NVFP4 Linear modules found" - - @pytest.mark.parametrize("quant_algo", ["FP8", "FP8_BLOCK_SCALES"]) - def test_fp8_vs_bf16_numerical_correctness(self, checkpoint_exists, quant_algo): - """Test FP8 vs BF16 numerical accuracy on real checkpoint weights. - - Pattern (similar to that in test_pipeline_dynamic_quant.py): - 1. Use F.linear() with BF16 weights as ground truth reference - 2. Verify BF16 layer matches F.linear exactly - 3. Compare FP8 layer output against reference - 4. Check max_diff, cosine_similarity, mse_loss - """ - if not checkpoint_exists: - pytest.skip("Checkpoint not available. Set DIFFUSION_MODEL_PATH.") - if not is_wan21_checkpoint(): - pytest.skip( - "This test requires Wan 2.1 checkpoint (loads 2 full models and " - "Needs single transformer). Use DIFFUSION_MODEL_PATH with '2.1' in the path." - ) - - # ===================================================================== - # Load BF16 Pipeline (Reference) - # ===================================================================== - print(f"\n[Compare {quant_algo}] Loading BF16 pipeline...") - - args_bf16 = VisualGenArgs( - checkpoint_path=CHECKPOINT_PATH, - device="cuda", - dtype="bfloat16", - ) - pipeline_bf16 = PipelineLoader(args_bf16).load(skip_warmup=True) - - # ===================================================================== - # Load FP8 Pipeline - # ===================================================================== - print(f"[Compare {quant_algo}] Loading {quant_algo} pipeline...") - - args_fp8 = VisualGenArgs( - checkpoint_path=CHECKPOINT_PATH, - device="cuda", - dtype="bfloat16", - quant_config={"quant_algo": quant_algo, "dynamic": True}, - ) - pipeline_fp8 = PipelineLoader(args_fp8).load(skip_warmup=True) - - # ===================================================================== - # Get Linear Layers from Both Pipelines - # ===================================================================== - attn_bf16 = pipeline_bf16.transformer.blocks[0].attn1 - attn_fp8 = pipeline_fp8.transformer.blocks[0].attn1 - - # Get linear layer - try fused qkv_proj first, fallback to qkv_proj on attention module - if hasattr(attn_bf16, "qkv_proj"): - linear_bf16 = attn_bf16.qkv_proj - linear_fp8 = attn_fp8.qkv_proj - layer_name = "blocks.0.attn1.qkv_proj" - elif hasattr(attn_bf16, "attn") and hasattr(attn_bf16.attn, "qkv_proj"): - linear_bf16 = attn_bf16.attn.qkv_proj - linear_fp8 = attn_fp8.attn.qkv_proj - layer_name = "blocks.0.attn1.attn.qkv_proj" - else: - # Use FFN linear instead (always available) - linear_bf16 = pipeline_bf16.transformer.blocks[0].ffn.net[0]["proj"] - linear_fp8 = pipeline_fp8.transformer.blocks[0].ffn.net[0]["proj"] - layer_name = "blocks.0.ffn.net.0.proj" - - # ===================================================================== - # Get BF16 weights and bias for F.linear reference - # ===================================================================== - weight_bf16 = linear_bf16.weight.data.clone() - bias_bf16 = linear_bf16.bias.data.clone() if linear_bf16.bias is not None else None - - # ===================================================================== - # Create Test Input - # ===================================================================== - torch.manual_seed(42) - hidden_size = linear_bf16.in_features - batch_size = 1 - seq_len = 14040 - - # 2D input for FP8 kernel compatibility - input_tensor = torch.randn( - batch_size * seq_len, hidden_size, dtype=torch.bfloat16, device="cuda" - ) - print(f"[Compare] Input shape: {input_tensor.shape}") - - # ===================================================================== - # Compute Reference Output: F.linear (ground truth) - # ===================================================================== - with torch.no_grad(): - expected = F.linear(input_tensor, weight_bf16, bias_bf16) - - # ===================================================================== - # Compute FP8 Output - # ===================================================================== - with torch.no_grad(): - result_fp8 = linear_fp8(input_tensor) - - # ===================================================================== - # Compute BF16 Layer Output - # ===================================================================== - with torch.no_grad(): - result_bf16 = linear_bf16(input_tensor) - - # Verify BF16 layer matches F.linear reference - assert torch.allclose(result_bf16, expected, rtol=1e-5, atol=1e-6), ( - "BF16 layer should match F.linear reference exactly" - ) - - # Compare FP8 vs Reference - max_diff = torch.max(torch.abs(result_fp8 - expected)).item() - cos_sim = F.cosine_similarity( - result_fp8.flatten().float(), expected.flatten().float(), dim=0 - ) - mse = F.mse_loss(result_fp8.flatten().float(), expected.flatten().float()) - - print( - f"\n[{layer_name}] max_diff={max_diff:.6f}, cos_sim={cos_sim.item():.6f}, mse={mse.item():.6f}" - ) - - assert cos_sim > 0.99, f"Cosine similarity too low: {cos_sim.item()}" - assert mse < 1.0, f"MSE too high: {mse.item()}" - - # Cleanup - del pipeline_bf16, pipeline_fp8 - torch.cuda.empty_cache() - - def test_fp8_vs_bf16_memory_comparison(self, checkpoint_exists): - """Test FP8 uses ~2x less memory than BF16.""" - if not checkpoint_exists: - pytest.skip("Checkpoint not available. Set DIFFUSION_MODEL_PATH.") - if not is_wan21_checkpoint(): - pytest.skip( - "This test requires Wan 2.1 checkpoint. Use DIFFUSION_MODEL_PATH with '2.1' in the path." - ) - - def get_module_memory_gb(module): - return sum(p.numel() * p.element_size() for p in module.parameters()) / 1024**3 - - # Load BF16 - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() - - args_bf16 = VisualGenArgs( - checkpoint_path=CHECKPOINT_PATH, - device="cuda", - dtype="bfloat16", - ) - pipeline_bf16 = PipelineLoader(args_bf16).load(skip_warmup=True) - - bf16_model_mem = get_module_memory_gb(pipeline_bf16.transformer) - bf16_peak_mem = torch.cuda.max_memory_allocated() / 1024**3 - - print(f"\n[BF16] Transformer memory: {bf16_model_mem:.2f} GB") - print(f"[BF16] Peak memory: {bf16_peak_mem:.2f} GB") - - del pipeline_bf16 - torch.cuda.empty_cache() - - # Load FP8 - torch.cuda.reset_peak_memory_stats() - - args_fp8 = VisualGenArgs( - checkpoint_path=CHECKPOINT_PATH, - device="cuda", - dtype="bfloat16", - quant_config={"quant_algo": "FP8", "dynamic": True}, - ) - pipeline_fp8 = PipelineLoader(args_fp8).load(skip_warmup=True) - - fp8_model_mem = get_module_memory_gb(pipeline_fp8.transformer) - fp8_peak_mem = torch.cuda.max_memory_allocated() / 1024**3 - - print(f"\n[FP8] Transformer memory: {fp8_model_mem:.2f} GB") - print(f"[FP8] Peak memory: {fp8_peak_mem:.2f} GB") - - # Verify memory savings - model_mem_ratio = bf16_model_mem / fp8_model_mem - peak_mem_ratio = bf16_peak_mem / fp8_peak_mem - - print(f"\n[Comparison] Model memory ratio (BF16/FP8): {model_mem_ratio:.2f}x") - print(f"[Comparison] Peak memory ratio (BF16/FP8): {peak_mem_ratio:.2f}x") - - # FP8 should use ~2x less memory - assert model_mem_ratio > 1.8, f"FP8 should use ~2x less memory, got {model_mem_ratio:.2f}x" - - del pipeline_fp8 - torch.cuda.empty_cache() - - @pytest.mark.parametrize("quant_algo", ["FP8", "FP8_BLOCK_SCALES"]) - def test_fp8_vs_bf16_full_transformer_e2e(self, checkpoint_exists, quant_algo): - """End-to-end test: Compare full Wan transformer FP8 vs BF16 output. - - Unlike test_fp8_vs_bf16_numerical_correctness which tests a single Linear layer, - this test runs the ENTIRE transformer (all 30 blocks) and compares outputs. - - Expectations: - - Errors accumulate across 30 layers, so use relaxed tolerances - - Cosine similarity should be high (>0.95) but lower than single-layer test (>0.99) - - This validates that FP8 quantization doesn't degrade quality too much end-to-end - """ - if not checkpoint_exists: - pytest.skip("Checkpoint not available. Set DIFFUSION_MODEL_PATH.") - if not is_wan21_checkpoint(): - pytest.skip( - "This test requires Wan 2.1 checkpoint. Use DIFFUSION_MODEL_PATH with '2.1' in the path." - ) - - # ===================================================================== - # Load BF16 Transformer (Reference) - # ===================================================================== - print("\n[E2E] Loading BF16 transformer...") - - args_bf16 = VisualGenArgs( - checkpoint_path=CHECKPOINT_PATH, - device="cuda", - dtype="bfloat16", - ) - pipeline_bf16 = PipelineLoader(args_bf16).load(skip_warmup=True) - transformer_bf16 = pipeline_bf16.transformer - - # ===================================================================== - # Load FP8 Transformer - # ===================================================================== - print(f"[E2E] Loading {quant_algo} transformer...") - - args_fp8 = VisualGenArgs( - checkpoint_path=CHECKPOINT_PATH, - device="cuda", - dtype="bfloat16", - quant_config={"quant_algo": quant_algo, "dynamic": True}, - ) - pipeline_fp8 = PipelineLoader(args_fp8).load(skip_warmup=True) - transformer_fp8 = pipeline_fp8.transformer - - # ===================================================================== - # Create Realistic Inputs - # ===================================================================== - torch.manual_seed(42) - - # Use smaller size for faster testing (still realistic) - batch_size = 1 - num_frames = 1 - height, width = 64, 64 # Smaller than full 720x1280 - in_channels = 16 - text_seq_len = 128 - text_dim = 4096 - - # Create inputs - hidden_states = torch.randn( - batch_size, in_channels, num_frames, height, width, dtype=torch.bfloat16, device="cuda" - ) - timestep = torch.tensor([500], dtype=torch.long, device="cuda") - encoder_hidden_states = torch.randn( - batch_size, text_seq_len, text_dim, dtype=torch.bfloat16, device="cuda" - ) - - print("[E2E] Input shapes:") - print(f" hidden_states: {hidden_states.shape}") - print(f" timestep: {timestep.shape}") - print(f" encoder_hidden_states: {encoder_hidden_states.shape}") - - # ===================================================================== - # Run Full Transformer Forward Pass - # ===================================================================== - print("[E2E] Running BF16 transformer forward...") - with torch.no_grad(): - output_bf16 = transformer_bf16( - hidden_states=hidden_states.clone(), - timestep=timestep, - encoder_hidden_states=encoder_hidden_states.clone(), - ) - - print(f"[E2E] Running {quant_algo} transformer forward...") - with torch.no_grad(): - output_fp8 = transformer_fp8( - hidden_states=hidden_states.clone(), - timestep=timestep, - encoder_hidden_states=encoder_hidden_states.clone(), - ) - - # ===================================================================== - # Verify Outputs - # ===================================================================== - assert output_bf16.shape == output_fp8.shape, ( - f"Output shape mismatch: BF16={output_bf16.shape}, FP8={output_fp8.shape}" - ) - print(f"[E2E] Output shape: {output_bf16.shape}") - - # Check for NaN/Inf - bf16_has_nan = torch.isnan(output_bf16).any().item() - fp8_has_nan = torch.isnan(output_fp8).any().item() - bf16_has_inf = torch.isinf(output_bf16).any().item() - fp8_has_inf = torch.isinf(output_fp8).any().item() - - assert not bf16_has_nan, "BF16 output contains NaN" - assert not bf16_has_inf, "BF16 output contains Inf" - assert not fp8_has_nan, f"{quant_algo} output contains NaN" - assert not fp8_has_inf, f"{quant_algo} output contains Inf" - - # ===================================================================== - # Compare Numerical Accuracy - # ===================================================================== - output_bf16_float = output_bf16.float() - output_fp8_float = output_fp8.float() - - max_diff = torch.max(torch.abs(output_fp8_float - output_bf16_float)).item() - mean_diff = torch.mean(torch.abs(output_fp8_float - output_bf16_float)).item() - - cos_sim = F.cosine_similarity( - output_fp8_float.flatten(), output_bf16_float.flatten(), dim=0 - ).item() - - mse = F.mse_loss(output_fp8_float, output_bf16_float).item() - - # Relative error - rel_error = mean_diff / (output_bf16_float.abs().mean().item() + 1e-8) - - print(f"\n{'=' * 60}") - print(f"END-TO-END TRANSFORMER COMPARISON ({quant_algo} vs BF16)") - print(f"{'=' * 60}") - print(f"Number of layers: {len(transformer_bf16.blocks)}") - print(f"Output shape: {output_bf16.shape}") - print("") - print(f"Max absolute difference: {max_diff:.6f}") - print(f"Mean absolute difference: {mean_diff:.6f}") - print(f"Relative error: {rel_error:.6f}") - print(f"Cosine similarity: {cos_sim:.6f}") - print(f"MSE loss: {mse:.6f}") - print("") - print(f"BF16 output range: [{output_bf16_float.min():.4f}, {output_bf16_float.max():.4f}]") - print( - f"{quant_algo} output range: [{output_fp8_float.min():.4f}, {output_fp8_float.max():.4f}]" - ) - print(f"{'=' * 60}") - - # ===================================================================== - # Assert Numerical Correctness (Relaxed Tolerances) - # ===================================================================== - # Cosine similarity should be high, but lower than single-layer test - # due to error accumulation across 30 layers - assert cos_sim > 0.95, ( - f"Cosine similarity too low for full transformer: {cos_sim:.6f} (expected >0.95)" - ) - - # Relative error should be reasonable - # Note: Error accumulates across 30 layers, so we use a relaxed tolerance - assert rel_error < 0.15, f"Relative error too high: {rel_error:.6f} (expected <0.15)" - - print(f"\n[PASS] {quant_algo} full transformer output matches BF16 within tolerance!") - print(f" ✓ Cosine similarity: {cos_sim:.4f} (>0.95)") - print(f" ✓ Relative error: {rel_error:.4f} (<0.15)") - - # Cleanup - del pipeline_bf16, pipeline_fp8, transformer_bf16, transformer_fp8 - torch.cuda.empty_cache() - - def test_attention_backend_comparison(self, checkpoint_exists): - """Test accuracy of full Wan forward pass with attention backend comparison. - - Wan uses both self-attention (attn1) and cross-attention (attn2). TRTLLM backend - doesn't support cross-attention (seq_len != kv_seq_len), but WanAttention - automatically falls back to VANILLA for cross-attention when TRTLLM is configured. - - This test verifies: - 1. VANILLA backend works correctly - 2. TRTLLM backend with automatic VANILLA fallback for cross-attention produces - numerically similar results to pure VANILLA - """ - if not checkpoint_exists: - pytest.skip("Checkpoint not available. Set DIFFUSION_MODEL_PATH.") - if not is_wan21_checkpoint(): - pytest.skip( - "This test requires Wan 2.1 checkpoint. Use DIFFUSION_MODEL_PATH with '2.1' in the path." - ) - - # ===================================================================== - # Load Baseline Transformer (Default VANILLA) - # ===================================================================== - print("\n[Attention Backend Test] Loading baseline transformer (default VANILLA)...") - - from tensorrt_llm._torch.visual_gen.config import AttentionConfig - - args_baseline = VisualGenArgs( - checkpoint_path=CHECKPOINT_PATH, - device="cuda", - dtype="bfloat16", - ) - # Default attention backend is VANILLA - pipeline_baseline = PipelineLoader(args_baseline).load(skip_warmup=True) - transformer_baseline = pipeline_baseline.transformer - - # ===================================================================== - # Load VANILLA Transformer - # ===================================================================== - print("[Attention Backend Test] Loading VANILLA transformer (explicit)...") - - args_vanilla = VisualGenArgs( - checkpoint_path=CHECKPOINT_PATH, - device="cuda", - dtype="bfloat16", - ) - args_vanilla.attention = AttentionConfig(backend="VANILLA") - pipeline_vanilla = PipelineLoader(args_vanilla).load(skip_warmup=True) - transformer_vanilla = pipeline_vanilla.transformer - - # ===================================================================== - # Create Fixed Test Inputs - # ===================================================================== - torch.manual_seed(42) - - # Smaller size for faster testing - batch_size = 1 - num_frames = 1 - height, width = 64, 64 - in_channels = 16 - text_seq_len = 128 - text_dim = 4096 - - # Create inputs - hidden_states = torch.randn( - batch_size, in_channels, num_frames, height, width, dtype=torch.bfloat16, device="cuda" - ) - timestep = torch.tensor([500], dtype=torch.long, device="cuda") - encoder_hidden_states = torch.randn( - batch_size, text_seq_len, text_dim, dtype=torch.bfloat16, device="cuda" - ) - - print("[Attention Backend Test] Input shapes:") - print(f" hidden_states: {hidden_states.shape}") - print(f" timestep: {timestep.shape}") - print(f" encoder_hidden_states: {encoder_hidden_states.shape}") - - # ===================================================================== - # Run Full Transformer Forward Pass - # ===================================================================== - print("[Attention Backend Test] Running baseline transformer forward...") - with torch.no_grad(): - output_baseline = transformer_baseline( - hidden_states=hidden_states.clone(), - timestep=timestep, - encoder_hidden_states=encoder_hidden_states.clone(), - ) - - print("[Attention Backend Test] Running VANILLA transformer forward...") - with torch.no_grad(): - output_vanilla = transformer_vanilla( - hidden_states=hidden_states.clone(), - timestep=timestep, - encoder_hidden_states=encoder_hidden_states.clone(), - ) - - # ===================================================================== - # Verify Output Shapes - # ===================================================================== - assert output_baseline.shape == output_vanilla.shape, ( - f"Output shape mismatch: baseline={output_baseline.shape}, " - f"VANILLA={output_vanilla.shape}" - ) - print(f"[Attention Backend Test] Output shape: {output_baseline.shape}") - - # ===================================================================== - # Check for NaN/Inf in All Outputs - # ===================================================================== - for name, output in [("baseline", output_baseline), ("VANILLA", output_vanilla)]: - has_nan = torch.isnan(output).any().item() - has_inf = torch.isinf(output).any().item() - assert not has_nan, f"{name} output contains NaN" - assert not has_inf, f"{name} output contains Inf" - print(f"[Attention Backend Test] {name} output: NaN={has_nan}, Inf={has_inf}") - - # ===================================================================== - # Compare VANILLA (Explicit) vs Baseline - # ===================================================================== - output_baseline_float = output_baseline.float() - output_vanilla_float = output_vanilla.float() - - # VANILLA explicit vs baseline (should be identical) - max_diff_vanilla = torch.max(torch.abs(output_vanilla_float - output_baseline_float)).item() - mean_diff_vanilla = torch.mean( - torch.abs(output_vanilla_float - output_baseline_float) - ).item() - cos_sim_vanilla = F.cosine_similarity( - output_vanilla_float.flatten(), output_baseline_float.flatten(), dim=0 - ).item() - mse_vanilla = F.mse_loss(output_vanilla_float, output_baseline_float).item() - - print(f"\n{'=' * 60}") - print("VANILLA (Explicit) vs Baseline Comparison") - print(f"{'=' * 60}") - print(f"Max absolute difference: {max_diff_vanilla:.6f}") - print(f"Mean absolute difference: {mean_diff_vanilla:.6f}") - print(f"Cosine similarity: {cos_sim_vanilla:.6f}") - print(f"MSE loss: {mse_vanilla:.6f}") - print(f"{'=' * 60}") - - # VANILLA explicit should match baseline closely (same backend) - # Note: Not exactly identical - assert cos_sim_vanilla > 0.995, ( - f"VANILLA explicit should match baseline closely: cos_sim={cos_sim_vanilla:.6f}" - ) - - print("\n[PASS] VANILLA backend produces consistent outputs!") - print(f" ✓ VANILLA (explicit) matches baseline: cos_sim={cos_sim_vanilla:.6f} (>0.995)") - - # ===================================================================== - # Load TRTLLM Transformer (with automatic VANILLA fallback for cross-attention) - # ===================================================================== - print("\n[Attention Backend Test] Loading TRTLLM transformer...") - - args_trtllm = VisualGenArgs( - checkpoint_path=CHECKPOINT_PATH, - device="cuda", - dtype="bfloat16", - ) - args_trtllm.attention = AttentionConfig(backend="TRTLLM") - pipeline_trtllm = PipelineLoader(args_trtllm).load(skip_warmup=True) - transformer_trtllm = pipeline_trtllm.transformer - - # Verify automatic backend override for cross-attention - print("[Attention Backend Test] Verifying backend configuration...") - first_block = transformer_trtllm.blocks[0] - attn1_backend = first_block.attn1.attn_backend - attn2_backend = first_block.attn2.attn_backend - print(f" attn1 (self-attention) backend: {attn1_backend}") - print(f" attn2 (cross-attention) backend: {attn2_backend}") - assert attn1_backend == "TRTLLM", f"Expected attn1 to use TRTLLM, got {attn1_backend}" - assert attn2_backend == "VANILLA", f"Expected attn2 to use VANILLA, got {attn2_backend}" - print(" ✓ Automatic backend override working correctly!") - - # ===================================================================== - # Run TRTLLM Transformer Forward Pass - # ===================================================================== - print("[Attention Backend Test] Running TRTLLM transformer forward...") - with torch.no_grad(): - output_trtllm = transformer_trtllm( - hidden_states=hidden_states.clone(), - timestep=timestep, - encoder_hidden_states=encoder_hidden_states.clone(), - ) - - # ===================================================================== - # Check for NaN/Inf in TRTLLM Output - # ===================================================================== - has_nan = torch.isnan(output_trtllm).any().item() - has_inf = torch.isinf(output_trtllm).any().item() - assert not has_nan, "TRTLLM output contains NaN" - assert not has_inf, "TRTLLM output contains Inf" - print(f"[Attention Backend Test] TRTLLM output: NaN={has_nan}, Inf={has_inf}") - - # ===================================================================== - # Compare TRTLLM vs Baseline - # ===================================================================== - output_trtllm_float = output_trtllm.float() - - max_diff_trtllm = torch.max(torch.abs(output_trtllm_float - output_baseline_float)).item() - mean_diff_trtllm = torch.mean(torch.abs(output_trtllm_float - output_baseline_float)).item() - cos_sim_trtllm = F.cosine_similarity( - output_trtllm_float.flatten(), output_baseline_float.flatten(), dim=0 - ).item() - mse_trtllm = F.mse_loss(output_trtllm_float, output_baseline_float).item() - - print(f"\n{'=' * 60}") - print("TRTLLM (with auto VANILLA fallback) vs Baseline Comparison") - print(f"{'=' * 60}") - print(f"Max absolute difference: {max_diff_trtllm:.6f}") - print(f"Mean absolute difference: {mean_diff_trtllm:.6f}") - print(f"Cosine similarity: {cos_sim_trtllm:.6f}") - print(f"MSE loss: {mse_trtllm:.6f}") - print(f"{'=' * 60}") - - # TRTLLM should produce similar results (attn1 uses TRTLLM, attn2 uses VANILLA) - # Allow slightly more tolerance since different attention implementations - assert cos_sim_trtllm > 0.99, ( - f"TRTLLM should produce similar results to baseline: cos_sim={cos_sim_trtllm:.6f}" - ) - - print("\n[PASS] TRTLLM backend with automatic fallback works correctly!") - print(f" ✓ TRTLLM matches baseline: cos_sim={cos_sim_trtllm:.6f} (>0.99)") - - # Cleanup - del pipeline_baseline, pipeline_vanilla, pipeline_trtllm - del transformer_baseline, transformer_vanilla, transformer_trtllm - torch.cuda.empty_cache() - - @pytest.mark.parametrize("quant_algo", ["FP8", "FP8_BLOCK_SCALES"]) - def test_fp8_mixed_quant_numerical_correctness(self, checkpoint_exists, quant_algo): - """Test numerical correctness with mixed quantization (some layers excluded). - - Compares outputs between: - 1. Full BF16 model (reference) - 2. Full FP8 model (all layers quantized) - 3. Mixed FP8 model (some layers excluded from quantization) - - Expected behavior: - - Mixed model should have accuracy between full BF16 and full FP8 - - Excluding sensitive layers (like first/last blocks) may improve accuracy - """ - if not checkpoint_exists: - pytest.skip("Checkpoint not available. Set DIFFUSION_MODEL_PATH.") - if not is_wan21_checkpoint(): - pytest.skip( - "This test requires Wan 2.1 checkpoint. Use DIFFUSION_MODEL_PATH with '2.1' in the path." - ) - - # ===================================================================== - # Define Mixed Quant Config - # ===================================================================== - # Exclude first block and output projection (often sensitive layers) - mixed_ignore_patterns = [ - "proj_out", - "condition_embedder.*", - "blocks.0.*", - "blocks.29.*", # Last block (if exists) - ] - - # ===================================================================== - # Load Models - # ===================================================================== - print("\n[Mixed Quant Accuracy] Loading BF16 model (reference)...") - args_bf16 = VisualGenArgs( - checkpoint_path=CHECKPOINT_PATH, - device="cuda", - dtype="bfloat16", - skip_components=SKIP_COMPONENTS, - ) - pipeline_bf16 = PipelineLoader(args_bf16).load(skip_warmup=True) - - print(f"[Mixed Quant Accuracy] Loading mixed {quant_algo} model...") - args_fp8_mixed = VisualGenArgs( - checkpoint_path=CHECKPOINT_PATH, - device="cuda", - dtype="bfloat16", - skip_components=SKIP_COMPONENTS, - quant_config={ - "quant_algo": quant_algo, - "dynamic": True, - "ignore": mixed_ignore_patterns, - }, - ) - pipeline_fp8_mixed = PipelineLoader(args_fp8_mixed).load(skip_warmup=True) - - # ===================================================================== - # Create Test Inputs - # ===================================================================== - torch.manual_seed(42) - - batch_size = 1 - num_frames = 1 - height, width = 64, 64 - in_channels = 16 - text_seq_len = 128 - text_dim = 4096 - - hidden_states = torch.randn( - batch_size, in_channels, num_frames, height, width, dtype=torch.bfloat16, device="cuda" - ) - timestep = torch.tensor([500], dtype=torch.long, device="cuda") - encoder_hidden_states = torch.randn( - batch_size, text_seq_len, text_dim, dtype=torch.bfloat16, device="cuda" - ) - - # ===================================================================== - # Run Forward Pass - # ===================================================================== - print("[Mixed Quant Accuracy] Running forward passes...") - - with torch.no_grad(): - output_bf16 = pipeline_bf16.transformer( - hidden_states=hidden_states.clone(), - timestep=timestep, - encoder_hidden_states=encoder_hidden_states.clone(), - ) - - output_fp8_mixed = pipeline_fp8_mixed.transformer( - hidden_states=hidden_states.clone(), - timestep=timestep, - encoder_hidden_states=encoder_hidden_states.clone(), - ) - - # ===================================================================== - # Compute Metrics - # ===================================================================== - output_bf16_float = output_bf16.float() - output_fp8_mixed_float = output_fp8_mixed.float() - - # Mixed FP8 vs BF16 - cos_sim_mixed = F.cosine_similarity( - output_fp8_mixed_float.flatten(), output_bf16_float.flatten(), dim=0 - ).item() - mse_mixed = F.mse_loss(output_fp8_mixed_float, output_bf16_float).item() - - print(f"\n{'=' * 60}") - print(f"MIXED QUANTIZATION ACCURACY TEST ({quant_algo})") - print(f"{'=' * 60}") - print(f"Ignored patterns: {mixed_ignore_patterns}") - print("") - print(f"Mixed {quant_algo} vs BF16:") - print(f" Cosine similarity: {cos_sim_mixed:.6f}") - print(f" MSE: {mse_mixed:.6f}") - print(f"{'=' * 60}") - - # ===================================================================== - # Assertions - # ===================================================================== - # Both should maintain reasonable accuracy - assert cos_sim_mixed > 0.99, ( - f"Mixed {quant_algo} cosine similarity too low: {cos_sim_mixed}" - ) - assert mse_mixed < 1.0, f"Mixed {quant_algo} MSE too high: {mse_mixed}" - - print("\n[PASS] Mixed quantization numerical correctness verified!") - print(f" ✓ Mixed {quant_algo}: cos_sim={cos_sim_mixed:.4f}") - - # Cleanup - del pipeline_bf16, pipeline_fp8_mixed - torch.cuda.empty_cache() - - def test_fp8_vs_bf16_accuracy(self, wan22_both_checkpoints_exist): - """Test FP8 static and dynamic quantization accuracy against BF16 reference. - - Compares outputs from: - 1. TRT-LLM BF16 model (reference checkpoint) - 2. TRT-LLM FP8 static quantized model (pre-quantized checkpoint) - 3. TRT-LLM FP8 dynamic quantized model (BF16 checkpoint + on-the-fly quant) - - Uses spatially-correlated inputs that mimic real VAE latent patterns, - which achieves much higher accuracy than random noise inputs. - """ - if not wan22_both_checkpoints_exist: - pytest.skip( - f"Both checkpoints required. FP8: {CHECKPOINT_PATH_WAN22_FP8}, " - f"BF16: {CHECKPOINT_PATH_WAN22_BF16}" - ) - - # Reset dynamo cache to avoid recompile-limit errors from prior - # tests that compiled kernels with different dtypes (e.g. Float32). - torch._dynamo.reset() - - print("\n" + "=" * 70) - print("FP8 STATIC & DYNAMIC QUANT vs BF16 ACCURACY TEST") - print("=" * 70) - - # Load BF16 reference model - print(f"\n[BF16] Loading from {CHECKPOINT_PATH_WAN22_BF16}") - args_bf16 = VisualGenArgs( - checkpoint_path=CHECKPOINT_PATH_WAN22_BF16, - device="cuda", - dtype="bfloat16", - skip_components=SKIP_COMPONENTS, - ) - pipeline_bf16 = PipelineLoader(args_bf16).load(skip_warmup=True) - - # Load FP8 static quantized model (from pre-quantized checkpoint) - print(f"\n[FP8 Static] Loading from {CHECKPOINT_PATH_WAN22_FP8}") - args_fp8_static = VisualGenArgs( - checkpoint_path=CHECKPOINT_PATH_WAN22_FP8, - device="cuda", - dtype="bfloat16", - skip_components=SKIP_COMPONENTS, - ) - pipeline_fp8_static = PipelineLoader(args_fp8_static).load(skip_warmup=True) - - # Load FP8 dynamic quantized model (from BF16 checkpoint with on-the-fly quant) - print(f"\n[FP8 Dynamic] Loading from {CHECKPOINT_PATH_WAN22_BF16} with dynamic quant") - args_fp8_dynamic = VisualGenArgs( - checkpoint_path=CHECKPOINT_PATH_WAN22_BF16, - device="cuda", - dtype="bfloat16", - skip_components=SKIP_COMPONENTS, - quant_config={ - "quant_algo": "FP8", - "dynamic": True, - }, - ) - pipeline_fp8_dynamic = PipelineLoader(args_fp8_dynamic).load(skip_warmup=True) - - # Verify FP8 static model has calibrated scales - static_quant_modules = 0 - for name, module in pipeline_fp8_static.transformer.named_modules(): - if isinstance(module, Linear): - if hasattr(module, "input_scale") and module.input_scale is not None: - static_quant_modules += 1 - print(f"[FP8 Static] Quantized Linear modules with input_scale: {static_quant_modules}") - assert static_quant_modules > 0, "FP8 static model should have calibrated scales" - - # Verify FP8 dynamic model has quantized weights - dynamic_quant_modules = 0 - for name, module in pipeline_fp8_dynamic.transformer.named_modules(): - if isinstance(module, Linear): - if hasattr(module, "weight_scale") and module.weight_scale is not None: - dynamic_quant_modules += 1 - print(f"[FP8 Dynamic] Quantized Linear modules: {dynamic_quant_modules}") - - # Create spatially-correlated test inputs (mimics real VAE latent patterns) - # Wan 2.2 TI2V-5B specs: - # - VAE compression: 16x16x4 (spatial x spatial x temporal) - # - Latent channels: 48 (z_dim=48) - # - 720P resolution: 1280x704 -> latent: 80x44 - # - Text encoder: UMT5, max_length=512, dim=4096 - torch.manual_seed(42) - - batch_size = 2 # For CFG (positive + negative) - in_channels = 48 # Wan 2.2 TI2V-5B uses 48 latent channels - time_dim = 1 # Single frame for unit test - - # 720P latent dimensions: 1280/16=80 width, 704/16=44 height - height = 44 # 720P latent height (704 / 16) - width = 80 # 720P latent width (1280 / 16) - - # Text encoder: UMT5 with 4096 dim, typical sequence length ~226 - text_seq_len = 226 # Default max_sequence_length for Wan - text_dim = 4096 - - # Create structured latent (not purely random - simulate real VAE output) - base_pattern = torch.randn( - 1, in_channels, time_dim, height // 4, width // 4, device="cuda", dtype=torch.bfloat16 - ) - hidden_states = F.interpolate( - base_pattern.view(1, in_channels, height // 4, width // 4), - size=(height, width), - mode="bilinear", - align_corners=False, - ).view(1, in_channels, time_dim, height, width) - hidden_states = hidden_states * 2.0 - hidden_states = hidden_states.expand(batch_size, -1, -1, -1, -1).contiguous() - - timestep = torch.tensor([500.0, 500.0], device="cuda", dtype=torch.bfloat16) - - text_base = ( - torch.randn(1, text_seq_len, text_dim, device="cuda", dtype=torch.bfloat16) * 0.1 - ) - encoder_hidden_states = text_base.expand(batch_size, -1, -1).contiguous() - - print( - f"\n[Input] 720P latent: {hidden_states.shape} " - f"(batch={batch_size}, ch={in_channels}, t={time_dim}, h={height}, w={width})" - ) - print(f"[Input] range: [{hidden_states.min():.2f}, {hidden_states.max():.2f}]") - print(f"[Input] encoder_hidden_states: {encoder_hidden_states.shape}") - - # Run forward passes - print("\n[Forward] Running BF16 model...") - with torch.no_grad(): - output_bf16 = pipeline_bf16.transformer( - hidden_states=hidden_states.clone(), - timestep=timestep, - encoder_hidden_states=encoder_hidden_states.clone(), - ) - - print("[Forward] Running FP8 static quant model...") - with torch.no_grad(): - output_fp8_static = pipeline_fp8_static.transformer( - hidden_states=hidden_states.clone(), - timestep=timestep, - encoder_hidden_states=encoder_hidden_states.clone(), - ) - - print("[Forward] Running FP8 dynamic quant model...") - with torch.no_grad(): - output_fp8_dynamic = pipeline_fp8_dynamic.transformer( - hidden_states=hidden_states.clone(), - timestep=timestep, - encoder_hidden_states=encoder_hidden_states.clone(), - ) - - # Compute metrics - output_bf16_float = output_bf16.float() - output_fp8_static_float = output_fp8_static.float() - output_fp8_dynamic_float = output_fp8_dynamic.float() - - # FP8 Static vs BF16 - cos_sim_static = F.cosine_similarity( - output_fp8_static_float.flatten(), output_bf16_float.flatten(), dim=0 - ).item() - mse_static = F.mse_loss(output_fp8_static_float, output_bf16_float).item() - - # FP8 Dynamic vs BF16 - cos_sim_dynamic = F.cosine_similarity( - output_fp8_dynamic_float.flatten(), output_bf16_float.flatten(), dim=0 - ).item() - mse_dynamic = F.mse_loss(output_fp8_dynamic_float, output_bf16_float).item() - - # Output statistics - bf16_range = (output_bf16_float.min().item(), output_bf16_float.max().item()) - fp8_static_range = ( - output_fp8_static_float.min().item(), - output_fp8_static_float.max().item(), - ) - fp8_dynamic_range = ( - output_fp8_dynamic_float.min().item(), - output_fp8_dynamic_float.max().item(), - ) - - print("\n" + "=" * 70) - print("RESULTS: FP8 QUANT vs BF16") - print("=" * 70) - print(f"{'Method':<20} {'Cosine Sim':>12} {'MSE':>12}") - print("-" * 70) - print(f"{'FP8 Static':<20} {cos_sim_static:>12.6f} {mse_static:>12.6f}") - print(f"{'FP8 Dynamic':<20} {cos_sim_dynamic:>12.6f} {mse_dynamic:>12.6f}") - print("-" * 70) - print(f"BF16 Output Range: [{bf16_range[0]:.4f}, {bf16_range[1]:.4f}]") - print(f"FP8 Static Output Range: [{fp8_static_range[0]:.4f}, {fp8_static_range[1]:.4f}]") - print(f"FP8 Dynamic Output Range:[{fp8_dynamic_range[0]:.4f}, {fp8_dynamic_range[1]:.4f}]") - print("=" * 70) - - # Assertions - # Static should have high accuracy (calibrated scales) - assert cos_sim_static > 0.99, ( - f"FP8 Static cosine similarity too low: {cos_sim_static:.6f}. Expected >0.99." - ) - # Dynamic may have slightly lower accuracy (no calibration) - assert cos_sim_dynamic > 0.95, ( - f"FP8 Dynamic cosine similarity too low: {cos_sim_dynamic:.6f}. Expected >0.95." - ) - assert not torch.isnan(output_fp8_static).any(), "FP8 static output contains NaN" - assert not torch.isnan(output_fp8_dynamic).any(), "FP8 dynamic output contains NaN" - - print("\n[PASS] FP8 quantization accuracy test passed!") - print(f" - FP8 Static: cos_sim={cos_sim_static:.4f} (>0.99), MSE={mse_static:.6f}") - print(f" - FP8 Dynamic: cos_sim={cos_sim_dynamic:.4f} (>0.95), MSE={mse_dynamic:.6f}") - - # Cleanup - del pipeline_bf16, pipeline_fp8_static, pipeline_fp8_dynamic - torch.cuda.empty_cache() - - def test_nvfp4_vs_bf16_accuracy(self, wan22_nvfp4_bf16_checkpoints_exist): - """Test NVFP4 static and dynamic quantization accuracy against BF16 reference. - - Compares outputs from: - 1. TRT-LLM BF16 model (reference checkpoint) - 2. TRT-LLM NVFP4 static quantized model (pre-quantized checkpoint) - 3. TRT-LLM NVFP4 dynamic quantized model (BF16 checkpoint + on-the-fly quant) - - Uses spatially-correlated inputs that mimic real VAE latent patterns. - NVFP4 (4-bit) has lower precision than FP8 (8-bit), so we use relaxed thresholds. - """ - if not wan22_nvfp4_bf16_checkpoints_exist: - pytest.skip( - f"Both checkpoints required. NVFP4: {CHECKPOINT_PATH_WAN22_NVFP4}, " - f"BF16: {CHECKPOINT_PATH_WAN22_BF16}" - ) - - # Reset dynamo cache to avoid recompile-limit errors from prior - # tests that compiled kernels with different dtypes (e.g. Float32). - torch._dynamo.reset() - - print("\n" + "=" * 70) - print("NVFP4 STATIC & DYNAMIC QUANT vs BF16 ACCURACY TEST") - print("=" * 70) - - # Load BF16 reference model - print(f"\n[BF16] Loading from {CHECKPOINT_PATH_WAN22_BF16}") - args_bf16 = VisualGenArgs( - checkpoint_path=CHECKPOINT_PATH_WAN22_BF16, - device="cuda", - dtype="bfloat16", - skip_components=SKIP_COMPONENTS, - ) - pipeline_bf16 = PipelineLoader(args_bf16).load(skip_warmup=True) - - # Load NVFP4 static quantized model (from pre-quantized checkpoint) - print(f"\n[NVFP4 Static] Loading from {CHECKPOINT_PATH_WAN22_NVFP4}") - args_nvfp4_static = VisualGenArgs( - checkpoint_path=CHECKPOINT_PATH_WAN22_NVFP4, - device="cuda", - dtype="bfloat16", - skip_components=SKIP_COMPONENTS, - ) - pipeline_nvfp4_static = PipelineLoader(args_nvfp4_static).load(skip_warmup=True) - - # Load NVFP4 dynamic quantized model (from BF16 checkpoint with on-the-fly quant) - print(f"\n[NVFP4 Dynamic] Loading from {CHECKPOINT_PATH_WAN22_BF16} with dynamic quant") - args_nvfp4_dynamic = VisualGenArgs( - checkpoint_path=CHECKPOINT_PATH_WAN22_BF16, - device="cuda", - dtype="bfloat16", - skip_components=SKIP_COMPONENTS, - quant_config={ - "quant_algo": "NVFP4", - "dynamic": True, - }, - ) - pipeline_nvfp4_dynamic = PipelineLoader(args_nvfp4_dynamic).load(skip_warmup=True) - - # Verify NVFP4 static model has quantized weights - static_quant_modules = 0 - for name, module in pipeline_nvfp4_static.transformer.named_modules(): - if isinstance(module, Linear): - if hasattr(module, "weight_scale") and module.weight_scale is not None: - if module.weight_scale.numel() > 1: - static_quant_modules += 1 - print(f"[NVFP4 Static] Quantized Linear modules: {static_quant_modules}") - assert static_quant_modules > 0, "NVFP4 static model should have quantization scales" - - # Verify NVFP4 dynamic model has quantized weights - dynamic_quant_modules = 0 - for name, module in pipeline_nvfp4_dynamic.transformer.named_modules(): - if isinstance(module, Linear): - if hasattr(module, "weight_scale") and module.weight_scale is not None: - if module.weight_scale.numel() > 1: - dynamic_quant_modules += 1 - print(f"[NVFP4 Dynamic] Quantized Linear modules: {dynamic_quant_modules}") - - # Read model config for input dimensions (auto-detect from checkpoint) - cfg = pipeline_bf16.model_config.pretrained_config - in_channels = getattr(cfg, "in_channels", 16) - text_dim = getattr(cfg, "text_dim", 4096) - - # Create spatially-correlated test inputs (mimics real VAE latent patterns) - torch.manual_seed(42) - - batch_size = 2 # For CFG (positive + negative) - time_dim = 1 # Single frame for unit test - - # 720P latent dimensions: 1280/16=80 width, 704/16=44 height - height = 44 # 720P latent height (704 / 16) - width = 80 # 720P latent width (1280 / 16) - - # Text encoder: UMT5 with text_dim, typical sequence length ~226 - text_seq_len = 226 # Default max_sequence_length for Wan - - # Create structured latent (not purely random - simulate real VAE output) - base_pattern = torch.randn( - 1, in_channels, time_dim, height // 4, width // 4, device="cuda", dtype=torch.bfloat16 - ) - hidden_states = F.interpolate( - base_pattern.view(1, in_channels, height // 4, width // 4), - size=(height, width), - mode="bilinear", - align_corners=False, - ).view(1, in_channels, time_dim, height, width) - hidden_states = hidden_states * 2.0 - hidden_states = hidden_states.expand(batch_size, -1, -1, -1, -1).contiguous() - - timestep = torch.tensor([500.0, 500.0], device="cuda", dtype=torch.bfloat16) - - text_base = ( - torch.randn(1, text_seq_len, text_dim, device="cuda", dtype=torch.bfloat16) * 0.1 - ) - encoder_hidden_states = text_base.expand(batch_size, -1, -1).contiguous() - - print( - f"\n[Input] latent: {hidden_states.shape} " - f"(batch={batch_size}, ch={in_channels}, t={time_dim}, h={height}, w={width})" - ) - print(f"[Input] range: [{hidden_states.min():.2f}, {hidden_states.max():.2f}]") - print(f"[Input] encoder_hidden_states: {encoder_hidden_states.shape}") - - # Run forward passes - print("\n[Forward] Running BF16 model...") - with torch.no_grad(): - output_bf16 = pipeline_bf16.transformer( - hidden_states=hidden_states.clone(), - timestep=timestep, - encoder_hidden_states=encoder_hidden_states.clone(), - ) - - print("[Forward] Running NVFP4 static quant model...") - with torch.no_grad(): - output_nvfp4_static = pipeline_nvfp4_static.transformer( - hidden_states=hidden_states.clone(), - timestep=timestep, - encoder_hidden_states=encoder_hidden_states.clone(), - ) - - print("[Forward] Running NVFP4 dynamic quant model...") - with torch.no_grad(): - output_nvfp4_dynamic = pipeline_nvfp4_dynamic.transformer( - hidden_states=hidden_states.clone(), - timestep=timestep, - encoder_hidden_states=encoder_hidden_states.clone(), - ) - - # Compute metrics - output_bf16_float = output_bf16.float() - output_nvfp4_static_float = output_nvfp4_static.float() - output_nvfp4_dynamic_float = output_nvfp4_dynamic.float() - - # NVFP4 Static vs BF16 - cos_sim_static = F.cosine_similarity( - output_nvfp4_static_float.flatten(), output_bf16_float.flatten(), dim=0 - ).item() - mse_static = F.mse_loss(output_nvfp4_static_float, output_bf16_float).item() - - # NVFP4 Dynamic vs BF16 - cos_sim_dynamic = F.cosine_similarity( - output_nvfp4_dynamic_float.flatten(), output_bf16_float.flatten(), dim=0 - ).item() - mse_dynamic = F.mse_loss(output_nvfp4_dynamic_float, output_bf16_float).item() - - # Output statistics - bf16_range = (output_bf16_float.min().item(), output_bf16_float.max().item()) - nvfp4_static_range = ( - output_nvfp4_static_float.min().item(), - output_nvfp4_static_float.max().item(), - ) - nvfp4_dynamic_range = ( - output_nvfp4_dynamic_float.min().item(), - output_nvfp4_dynamic_float.max().item(), - ) - - print("\n" + "=" * 70) - print("RESULTS: NVFP4 QUANT vs BF16") - print("=" * 70) - print(f"{'Method':<25} {'Cosine Sim':>12} {'MSE':>12}") - print("-" * 70) - print(f"{'NVFP4 Static':<25} {cos_sim_static:>12.6f} {mse_static:>12.6f}") - print(f"{'NVFP4 Dynamic (TRT-LLM)':<25} {cos_sim_dynamic:>12.6f} {mse_dynamic:>12.6f}") - print("-" * 70) - print(f"BF16 Output Range: [{bf16_range[0]:.4f}, {bf16_range[1]:.4f}]") - print( - f"NVFP4 Static Range: [{nvfp4_static_range[0]:.4f}, {nvfp4_static_range[1]:.4f}]" - ) - print( - f"NVFP4 Dynamic (TRT) Range: [{nvfp4_dynamic_range[0]:.4f}, {nvfp4_dynamic_range[1]:.4f}]" - ) - print("=" * 70) - - # Assertions - NVFP4 (4-bit) has lower precision than FP8 (8-bit) - # Static has calibrated input scales, dynamic does not - assert cos_sim_static > 0.95, ( - f"NVFP4 Static cosine similarity too low: {cos_sim_static:.6f}. Expected >0.95." - ) - assert cos_sim_dynamic > 0.95, ( - f"NVFP4 Dynamic cosine similarity too low: {cos_sim_dynamic:.6f}. Expected >0.95." - ) - assert not torch.isnan(output_nvfp4_static).any(), "NVFP4 static output contains NaN" - assert not torch.isnan(output_nvfp4_dynamic).any(), "NVFP4 dynamic output contains NaN" - - print("\n[PASS] NVFP4 quantization accuracy test passed!") - print( - f" - NVFP4 Static: cos_sim={cos_sim_static:.4f} (>0.95), MSE={mse_static:.6f}" - ) - print( - f" - NVFP4 Dynamic (TRT-LLM): cos_sim={cos_sim_dynamic:.4f} (>0.95), MSE={mse_dynamic:.6f}" - ) - - # Cleanup - del pipeline_bf16, pipeline_nvfp4_static, pipeline_nvfp4_dynamic - torch.cuda.empty_cache() - - def test_nvfp4_vs_bf16_accuracy_mixed_quant(self, wan22_t2v_bf16_checkpoint_exists): - """Test NVFP4 mixed quantization accuracy on Wan 2.2 T2V A14B. - - The A14B NVFP4 checkpoint uses mixed quantization — certain layers are - excluded from quantization via an ignore list (first/last blocks, - condition_embedder, patch_embedding, proj_out). - - Compares outputs from: - 1. TRT-LLM BF16 model (reference) - 2. TRT-LLM NVFP4 static model (pre-quantized with ignore patterns) - — skipped if NVFP4 checkpoint not available - 3. TRT-LLM NVFP4 dynamic model (BF16 checkpoint + on-the-fly quant - with the same ignore patterns) - - This validates that dynamic quantization with ignore patterns produces - comparable accuracy to the statically pre-quantized checkpoint. - """ - if not wan22_t2v_bf16_checkpoint_exists: - pytest.skip(f"BF16 checkpoint required: {CHECKPOINT_PATH_WAN22_T2V}") - - have_nvfp4_static = CHECKPOINT_PATH_WAN22_T2V_NVFP4 and os.path.exists( - CHECKPOINT_PATH_WAN22_T2V_NVFP4 - ) - - torch._dynamo.reset() - - # Ignore patterns from the A14B NVFP4 pre-quantized checkpoint - mixed_ignore_patterns = [ - "blocks.0*", - "blocks.1", - "blocks.1.*", - "blocks.38*", - "blocks.39*", - "condition_embedder*", - "patch_embedding", - "proj_out", - ] - - print("\n" + "=" * 70) - print("NVFP4 MIXED QUANT (A14B T2V) vs BF16 ACCURACY TEST") - print("=" * 70) - print(f"Ignore patterns: {mixed_ignore_patterns}") - print( - f"NVFP4 static checkpoint: {'available' if have_nvfp4_static else 'NOT available (skipping static)'}" - ) - - # Load BF16 reference model - print(f"\n[BF16] Loading from {CHECKPOINT_PATH_WAN22_T2V}") - args_bf16 = VisualGenArgs( - checkpoint_path=CHECKPOINT_PATH_WAN22_T2V, - device="cuda", - dtype="bfloat16", - skip_components=SKIP_COMPONENTS, - ) - pipeline_bf16 = PipelineLoader(args_bf16).load(skip_warmup=True) - - # Load NVFP4 static model (if checkpoint available) - pipeline_nvfp4_static = None - static_quant_modules = 0 - static_bf16_modules = 0 - if have_nvfp4_static: - print(f"\n[NVFP4 Static] Loading from {CHECKPOINT_PATH_WAN22_T2V_NVFP4}") - args_nvfp4_static = VisualGenArgs( - checkpoint_path=CHECKPOINT_PATH_WAN22_T2V_NVFP4, - device="cuda", - dtype="bfloat16", - skip_components=SKIP_COMPONENTS, - ) - pipeline_nvfp4_static = PipelineLoader(args_nvfp4_static).load(skip_warmup=True) - - # Verify static model has quantized weights (but NOT in ignored layers) - for name, module in pipeline_nvfp4_static.transformer.named_modules(): - if isinstance(module, Linear): - if hasattr(module, "weight_scale") and module.weight_scale is not None: - if module.weight_scale.numel() > 1: - static_quant_modules += 1 - else: - static_bf16_modules += 1 - else: - static_bf16_modules += 1 - print( - f"[NVFP4 Static] Quantized: {static_quant_modules}, " - f"BF16 (ignored): {static_bf16_modules}" - ) - assert static_quant_modules > 0, "NVFP4 static model should have quantization scales" - assert static_bf16_modules > 0, ( - "NVFP4 static mixed model should have some BF16 (ignored) modules" - ) - else: - print("\n[NVFP4 Static] SKIPPED — checkpoint not available") - - # Load NVFP4 dynamic model (BF16 checkpoint + on-the-fly quant with same ignores) - print( - f"\n[NVFP4 Dynamic] Loading from {CHECKPOINT_PATH_WAN22_T2V} " - f"with dynamic quant + ignore patterns" - ) - args_nvfp4_dynamic = VisualGenArgs( - checkpoint_path=CHECKPOINT_PATH_WAN22_T2V, - device="cuda", - dtype="bfloat16", - skip_components=SKIP_COMPONENTS, - quant_config={ - "quant_algo": "NVFP4", - "dynamic": True, - "ignore": mixed_ignore_patterns, - }, - ) - pipeline_nvfp4_dynamic = PipelineLoader(args_nvfp4_dynamic).load(skip_warmup=True) - - # Verify dynamic model has quantized weights - dynamic_quant_modules = 0 - dynamic_bf16_modules = 0 - for name, module in pipeline_nvfp4_dynamic.transformer.named_modules(): - if isinstance(module, Linear): - if hasattr(module, "weight_scale") and module.weight_scale is not None: - if module.weight_scale.numel() > 1: - dynamic_quant_modules += 1 - else: - dynamic_bf16_modules += 1 - else: - dynamic_bf16_modules += 1 - print( - f"[NVFP4 Dynamic] Quantized: {dynamic_quant_modules}, " - f"BF16 (ignored): {dynamic_bf16_modules}" - ) - - # When both are available, verify they have the same quant/bf16 split - if pipeline_nvfp4_static is not None: - assert static_quant_modules == dynamic_quant_modules, ( - f"Quant module count mismatch: static={static_quant_modules}, " - f"dynamic={dynamic_quant_modules}. " - f"Same ignore patterns should produce the same quantized layer set." - ) - assert static_bf16_modules == dynamic_bf16_modules, ( - f"BF16 module count mismatch: static={static_bf16_modules}, " - f"dynamic={dynamic_bf16_modules}. " - f"Same ignore patterns should produce the same ignored layer set." - ) - print( - f"[Verify] Static and dynamic have identical quant/bf16 split: " - f"{static_quant_modules} quantized, {static_bf16_modules} BF16" - ) - - # Read model config for input dimensions - cfg = pipeline_bf16.model_config.pretrained_config - in_channels = getattr(cfg, "in_channels", 16) - text_dim = getattr(cfg, "text_dim", 4096) - - # Create spatially-correlated test inputs - torch.manual_seed(42) - - batch_size = 2 - time_dim = 1 - height = 44 - width = 80 - text_seq_len = 226 - - base_pattern = torch.randn( - 1, - in_channels, - time_dim, - height // 4, - width // 4, - device="cuda", - dtype=torch.bfloat16, - ) - hidden_states = F.interpolate( - base_pattern.view(1, in_channels, height // 4, width // 4), - size=(height, width), - mode="bilinear", - align_corners=False, - ).view(1, in_channels, time_dim, height, width) - hidden_states = hidden_states * 2.0 - hidden_states = hidden_states.expand(batch_size, -1, -1, -1, -1).contiguous() - - timestep = torch.tensor([500.0, 500.0], device="cuda", dtype=torch.bfloat16) - - text_base = ( - torch.randn(1, text_seq_len, text_dim, device="cuda", dtype=torch.bfloat16) * 0.1 - ) - encoder_hidden_states = text_base.expand(batch_size, -1, -1).contiguous() - - print( - f"\n[Input] latent: {hidden_states.shape} " - f"(batch={batch_size}, ch={in_channels}, t={time_dim}, h={height}, w={width})" - ) - print(f"[Input] range: [{hidden_states.min():.2f}, {hidden_states.max():.2f}]") - print(f"[Input] encoder_hidden_states: {encoder_hidden_states.shape}") - - # Run forward passes - print("\n[Forward] Running BF16 model...") - with torch.no_grad(): - output_bf16 = pipeline_bf16.transformer( - hidden_states=hidden_states.clone(), - timestep=timestep, - encoder_hidden_states=encoder_hidden_states.clone(), - ) - - output_nvfp4_static = None - if pipeline_nvfp4_static is not None: - print("[Forward] Running NVFP4 static (mixed) quant model...") - with torch.no_grad(): - output_nvfp4_static = pipeline_nvfp4_static.transformer( - hidden_states=hidden_states.clone(), - timestep=timestep, - encoder_hidden_states=encoder_hidden_states.clone(), - ) - - print("[Forward] Running NVFP4 dynamic (mixed) quant model...") - with torch.no_grad(): - output_nvfp4_dynamic = pipeline_nvfp4_dynamic.transformer( - hidden_states=hidden_states.clone(), - timestep=timestep, - encoder_hidden_states=encoder_hidden_states.clone(), - ) - - # Compute metrics - output_bf16_float = output_bf16.float() - output_nvfp4_dynamic_float = output_nvfp4_dynamic.float() - - cos_sim_dynamic = F.cosine_similarity( - output_nvfp4_dynamic_float.flatten(), output_bf16_float.flatten(), dim=0 - ).item() - mse_dynamic = F.mse_loss(output_nvfp4_dynamic_float, output_bf16_float).item() - - # Static metrics (if available) - cos_sim_static = None - mse_static = None - cos_sim_static_vs_dynamic = None - mse_static_vs_dynamic = None - if output_nvfp4_static is not None: - output_nvfp4_static_float = output_nvfp4_static.float() - - cos_sim_static = F.cosine_similarity( - output_nvfp4_static_float.flatten(), output_bf16_float.flatten(), dim=0 - ).item() - mse_static = F.mse_loss(output_nvfp4_static_float, output_bf16_float).item() - - cos_sim_static_vs_dynamic = F.cosine_similarity( - output_nvfp4_static_float.flatten(), - output_nvfp4_dynamic_float.flatten(), - dim=0, - ).item() - mse_static_vs_dynamic = F.mse_loss( - output_nvfp4_static_float, output_nvfp4_dynamic_float - ).item() - - # Output statistics - bf16_range = (output_bf16_float.min().item(), output_bf16_float.max().item()) - nvfp4_dynamic_range = ( - output_nvfp4_dynamic_float.min().item(), - output_nvfp4_dynamic_float.max().item(), - ) - - print("\n" + "=" * 70) - print("RESULTS: NVFP4 MIXED QUANT (A14B T2V) vs BF16") - print("=" * 70) - print(f"{'Comparison':<30} {'Cosine Sim':>12} {'MSE':>12}") - print("-" * 70) - if cos_sim_static is not None: - print(f"{'NVFP4 Static vs BF16':<30} {cos_sim_static:>12.6f} {mse_static:>12.6f}") - else: - print(f"{'NVFP4 Static vs BF16':<30} {'N/A (no ckpt)':>12} {'N/A':>12}") - print(f"{'NVFP4 Dynamic vs BF16':<30} {cos_sim_dynamic:>12.6f} {mse_dynamic:>12.6f}") - if cos_sim_static_vs_dynamic is not None: - print( - f"{'Static vs Dynamic':<30} " - f"{cos_sim_static_vs_dynamic:>12.6f} {mse_static_vs_dynamic:>12.6f}" - ) - print("-" * 70) - print(f"BF16 Output Range: [{bf16_range[0]:.4f}, {bf16_range[1]:.4f}]") - if output_nvfp4_static is not None: - nvfp4_static_range = ( - output_nvfp4_static.float().min().item(), - output_nvfp4_static.float().max().item(), - ) - print( - f"NVFP4 Static Range: [{nvfp4_static_range[0]:.4f}, " - f"{nvfp4_static_range[1]:.4f}]" - ) - print( - f"NVFP4 Dynamic Range: [{nvfp4_dynamic_range[0]:.4f}, " - f"{nvfp4_dynamic_range[1]:.4f}]" - ) - if pipeline_nvfp4_static is not None: - print(f"Static quant/bf16 modules: {static_quant_modules}/{static_bf16_modules}") - print(f"Dynamic quant/bf16 modules: {dynamic_quant_modules}/{dynamic_bf16_modules}") - print("=" * 70) - - # Assertions - if cos_sim_static is not None: - assert cos_sim_static > 0.99, ( - f"NVFP4 Static (mixed) cosine similarity too low: {cos_sim_static:.6f}. " - f"Expected >0.99." - ) - assert not torch.isnan(output_nvfp4_static).any(), "NVFP4 static output contains NaN" - - assert cos_sim_dynamic > 0.99, ( - f"NVFP4 Dynamic (mixed) cosine similarity too low: {cos_sim_dynamic:.6f}. " - f"Expected >0.99." - ) - assert not torch.isnan(output_nvfp4_dynamic).any(), "NVFP4 dynamic output contains NaN" - - print("\n[PASS] NVFP4 mixed quantization accuracy test passed!") - if cos_sim_static is not None: - print( - f" - NVFP4 Static (mixed): cos_sim={cos_sim_static:.4f} (>0.99), " - f"MSE={mse_static:.6f}" - ) - print( - f" - NVFP4 Dynamic (mixed): cos_sim={cos_sim_dynamic:.4f} (>0.99), " - f"MSE={mse_dynamic:.6f}" - ) - if cos_sim_static_vs_dynamic is not None: - print( - f" - Static vs Dynamic: cos_sim={cos_sim_static_vs_dynamic:.4f}, " - f"MSE={mse_static_vs_dynamic:.6f}" - ) - - # Cleanup - del pipeline_bf16, pipeline_nvfp4_dynamic - if pipeline_nvfp4_static is not None: - del pipeline_nvfp4_static - torch.cuda.empty_cache() - - -# ============================================================================= -# Wan 2.2 checkpoint fixtures -# ============================================================================= - - -@pytest.fixture -def wan22_both_checkpoints_exist(): - """Check if both Wan 2.2 FP8 and BF16 checkpoints exist.""" - fp8_exists = CHECKPOINT_PATH_WAN22_FP8 and os.path.exists(CHECKPOINT_PATH_WAN22_FP8) - bf16_exists = CHECKPOINT_PATH_WAN22_BF16 and os.path.exists(CHECKPOINT_PATH_WAN22_BF16) - return fp8_exists and bf16_exists - - -@pytest.fixture -def wan22_nvfp4_bf16_checkpoints_exist(): - """Check if both NVFP4 and BF16 checkpoints exist.""" - nvfp4_exists = CHECKPOINT_PATH_WAN22_NVFP4 and os.path.exists(CHECKPOINT_PATH_WAN22_NVFP4) - bf16_exists = CHECKPOINT_PATH_WAN22_BF16 and os.path.exists(CHECKPOINT_PATH_WAN22_BF16) - return nvfp4_exists and bf16_exists - - -@pytest.fixture -def wan22_t2v_bf16_checkpoint_exists(): - """Check if Wan 2.2 T2V A14B BF16 checkpoint exists.""" - return CHECKPOINT_PATH_WAN22_T2V and os.path.exists(CHECKPOINT_PATH_WAN22_T2V) - - -# ============================================================================= -# Optimization Tests -# ============================================================================= - - -class TestWanOptimizations(unittest.TestCase): - """Runtime optimization correctness tests.""" - - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - def setUp(self): - """Set up test fixtures and skip if checkpoint not available.""" - torch.manual_seed(42) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(42) - if not CHECKPOINT_PATH or not os.path.exists(CHECKPOINT_PATH): - self.skipTest( - "Checkpoint not available. Set DIFFUSION_MODEL_PATH environment variable." - ) - - def tearDown(self): - """Clean up GPU memory.""" - import gc - - gc.collect() - if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.synchronize() - - @torch.no_grad() - def test_teacache_multi_step(self): - """Test TeaCache correctness across multiple timesteps (validates caching behavior). - - TeaCache is a runtime optimization that caches transformer outputs when timestep - embeddings change slowly. This test validates: - 1. Correctness against HuggingFace baseline - 2. Actual caching behavior across 20 timesteps - 3. Cache hits occur after warmup phase - """ - if not os.path.exists(CHECKPOINT_PATH): - pytest.skip("Checkpoint not available. Set DIFFUSION_MODEL_PATH.") - if not is_wan21_checkpoint(): - pytest.skip( - "This test requires Wan 2.1 checkpoint. Use DIFFUSION_MODEL_PATH with '2.1' in the path." - ) - - from safetensors.torch import load_file - - print("\n" + "=" * 80) - print("TEACACHE MULTI-STEP TEST (20 steps, validates caching)") - print("=" * 80) - - # Load HuggingFace baseline - print("\n[1/4] Loading HuggingFace baseline...") - args_trtllm = VisualGenArgs( - checkpoint_path=CHECKPOINT_PATH, - device="cuda", - dtype="bfloat16", - skip_components=SKIP_COMPONENTS, - ) - pipeline_trtllm = PipelineLoader(args_trtllm).load(skip_warmup=True) - config = pipeline_trtllm.transformer.model_config.pretrained_config - - hf_model = ( - HFWanTransformer3DModel( - patch_size=[config.patch_size[0], config.patch_size[1], config.patch_size[2]], - num_attention_heads=config.num_attention_heads, - attention_head_dim=config.attention_head_dim, - in_channels=config.in_channels, - out_channels=config.out_channels, - text_dim=config.text_dim, - freq_dim=config.freq_dim, - ffn_dim=config.ffn_dim, - num_layers=config.num_layers, - cross_attn_norm=config.cross_attn_norm, - qk_norm=config.qk_norm, - eps=config.eps, - ) - .to("cuda", dtype=torch.bfloat16) - .eval() - ) - - # Load weights from checkpoint (auto-discover all shard files) - import glob - - transformer_dir = os.path.join(CHECKPOINT_PATH, "transformer") - shard_pattern = os.path.join(transformer_dir, "diffusion_pytorch_model-*.safetensors") - shard_files = sorted(glob.glob(shard_pattern)) - - checkpoint_weights = {} - for shard_file in shard_files: - if os.path.exists(shard_file): - checkpoint_weights.update(load_file(shard_file)) - hf_model.load_state_dict(checkpoint_weights, strict=True) - print(" ✓ HuggingFace model loaded") - - # Load TeaCache-enabled pipeline - print("\n[2/4] Loading TeaCache-enabled TRT-LLM pipeline...") - args_teacache = VisualGenArgs( - checkpoint_path=CHECKPOINT_PATH, - device="cuda", - dtype="bfloat16", - skip_components=SKIP_COMPONENTS, - cache=TeaCacheConfig( - teacache_thresh=0.2, - use_ret_steps=True, - ), - ) - pipeline_teacache = PipelineLoader(args_teacache).load(skip_warmup=True) - transformer_teacache = pipeline_teacache.transformer.eval() - - # Verify TeaCache is enabled - assert ( - getattr(pipeline_teacache, "cache_accelerator", None) is not None - and pipeline_teacache.cache_accelerator.is_enabled() - ), "TeaCache not enabled on pipeline" - assert hasattr(transformer_teacache, "_original_forward"), ( - "TeaCache forward hook not installed" - ) - print(" ✓ TeaCache enabled and verified") - - # Create FIXED test inputs - print("\n[3/4] Creating fixed test inputs...") - torch.manual_seed(42) - batch_size, num_frames, height, width, seq_len = 1, 1, 64, 64, 128 - - hidden_states = torch.randn( - batch_size, - config.in_channels, - num_frames, - height, - width, - dtype=torch.bfloat16, - device="cuda", - ) - encoder_hidden_states = torch.randn( - batch_size, seq_len, config.text_dim, dtype=torch.bfloat16, device="cuda" - ) - - # Run multi-step inference - print("\n[4/4] Running 20-step inference with TeaCache...") - num_steps = 20 - pipeline_teacache.cache_accelerator.refresh(num_inference_steps=num_steps) - - # Simulate diffusion timestep schedule (from high to low) - timesteps = torch.linspace(999, 0, num_steps, dtype=torch.long, device="cuda") - - hf_outputs, teacache_outputs = [], [] - - for step_idx, timestep in enumerate(timesteps): - timestep_tensor = timestep.unsqueeze(0) - - # Run HuggingFace - with torch.no_grad(): - hf_out = hf_model( - hidden_states=hidden_states.clone(), - timestep=timestep_tensor, - encoder_hidden_states=encoder_hidden_states.clone(), - return_dict=False, - )[0] - hf_outputs.append(hf_out) - - # Run TeaCache - with torch.no_grad(): - teacache_out = transformer_teacache( - hidden_states=hidden_states.clone(), - timestep=timestep_tensor, - encoder_hidden_states=encoder_hidden_states.clone(), - ) - teacache_outputs.append(teacache_out) - - if step_idx % 5 == 0: - print(f" Step {step_idx}/{num_steps} - timestep: {timestep.item()}") - - # Compare outputs at selected steps - print("\n[Comparison] TeaCache vs HuggingFace at different steps:") - test_steps = [0, num_steps // 2, num_steps - 1] - - for step_idx in test_steps: - hf_float = hf_outputs[step_idx].float() - teacache_float = teacache_outputs[step_idx].float() - - cos_sim = F.cosine_similarity( - teacache_float.flatten(), hf_float.flatten(), dim=0 - ).item() - - print(f"\n Step {step_idx} (timestep={timesteps[step_idx].item()}):") - print(f" Cosine similarity: {cos_sim:.6f}") - - assert cos_sim > 0.99, ( - f"Step {step_idx}: TeaCache cosine similarity {cos_sim:.6f} below threshold 0.99" - ) - - print("\n[PASS] TeaCache multi-step correctness validated!") - print("=" * 80) - - # Cleanup - del pipeline_trtllm, pipeline_teacache, transformer_teacache, hf_model - torch.cuda.empty_cache() - - -# ============================================================================= -# Parallelism Tests -# ============================================================================= - - -class TestWanParallelism(unittest.TestCase): - """Distributed parallelism correctness tests (CFG Parallelism).""" - - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - def setUp(self): - """Set up test fixtures and skip if checkpoint not available.""" - torch.manual_seed(42) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(42) - if not CHECKPOINT_PATH or not os.path.exists(CHECKPOINT_PATH): - self.skipTest( - "Checkpoint not available. Set DIFFUSION_MODEL_PATH environment variable." - ) - - def tearDown(self): - """Clean up GPU memory.""" - import gc - - gc.collect() - if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.synchronize() - - def test_cfg_2gpu_correctness(self): - """Test CFG Parallelism (cfg_size=2) correctness against standard CFG baseline.""" - num_gpus = torch.cuda.device_count() - if num_gpus < 2: - pytest.skip("CFG parallel test requires at least 2 GPUs") - if not is_wan21_checkpoint(): - pytest.skip( - "This test requires Wan 2.1 checkpoint. Use DIFFUSION_MODEL_PATH with '2.1' in the path." - ) - - print("\n" + "=" * 80) - print("CFG PARALLELISM (cfg_size=2) CORRECTNESS TEST") - print("=" * 80) - - # Load standard CFG baseline on GPU 0 - print("\n[1/3] Loading standard CFG baseline (cfg_size=1) on GPU 0...") - args_baseline = VisualGenArgs( - checkpoint_path=CHECKPOINT_PATH, - device="cuda:0", - dtype="bfloat16", - skip_components=SKIP_COMPONENTS, - parallel=ParallelConfig(dit_cfg_size=1), # Standard CFG (no parallel) - ) - pipeline_baseline = PipelineLoader(args_baseline).load(skip_warmup=True) - config = pipeline_baseline.transformer.model_config.pretrained_config - - # Reset torch compile state to avoid BFloat16 dtype issues - torch._dynamo.reset() - - # Create FIXED test inputs - print("\n[2/3] Creating fixed test inputs...") - torch.manual_seed(42) - batch_size, num_frames, height, width, seq_len = 1, 1, 64, 64, 128 - - latents = torch.randn( - batch_size, - config.in_channels, - num_frames, - height, - width, - dtype=torch.bfloat16, - device="cuda:0", - ) - timestep = torch.tensor([500], dtype=torch.long, device="cuda:0") - prompt_embeds = torch.randn( - batch_size, seq_len, config.text_dim, dtype=torch.bfloat16, device="cuda:0" - ) - neg_prompt_embeds = torch.randn( - batch_size, seq_len, config.text_dim, dtype=torch.bfloat16, device="cuda:0" - ) - - # Setup standard CFG config - cfg_config_baseline = pipeline_baseline._setup_cfg_config( - guidance_scale=5.0, - prompt_embeds=prompt_embeds, - neg_prompt_embeds=neg_prompt_embeds, - ) - - print(" Baseline CFG config:") - print(f" enabled: {cfg_config_baseline['enabled']}") - print(f" cfg_size: {cfg_config_baseline['cfg_size']}") - - # Verify standard CFG is NOT parallel - assert not cfg_config_baseline["enabled"], "Baseline should not use CFG parallel" - assert cfg_config_baseline["cfg_size"] == 1, "Baseline cfg_size should be 1" - - # Run standard CFG denoising step - def forward_fn( - latents, extra_stream_latents, timestep, encoder_hidden_states, extra_tensors - ): - return pipeline_baseline.transformer( # noqa: F821 - hidden_states=latents, - timestep=timestep, - encoder_hidden_states=encoder_hidden_states, - ) - - with torch.no_grad(): - baseline_output, _, _, _ = pipeline_baseline._denoise_step_standard( - latents=latents.clone(), - extra_stream_latents={}, - timestep=timestep, - prompt_embeds=cfg_config_baseline["prompt_embeds"], - forward_fn=forward_fn, - guidance_scale=5.0, - guidance_rescale=0.0, - local_extras={}, - ) - - print(f" ✓ Baseline output shape: {baseline_output.shape}") - print(f" ✓ Baseline range: [{baseline_output.min():.4f}, {baseline_output.max():.4f}]") - - # Cleanup baseline to free memory for CFG workers - del pipeline_baseline - torch.cuda.empty_cache() - - # Run CFG parallel (cfg_size=2) in distributed processes - print("\n[3/3] Running CFG Parallelism (cfg_size=2) across 2 GPUs...") - cfg_size = 2 - - inputs_cpu = [ - prompt_embeds.cpu(), - neg_prompt_embeds.cpu(), - latents.cpu(), - timestep.cpu(), - ] - - manager = mp.Manager() - return_dict = manager.dict() - - # Spawn CFG workers - mp.spawn( - _run_cfg_worker, - args=(cfg_size, CHECKPOINT_PATH, inputs_cpu, return_dict), - nprocs=cfg_size, - join=True, - ) - - # Get CFG parallel output from rank 0 - cfg_parallel_output = return_dict["output"].to("cuda:0") - print(f" ✓ CFG parallel output shape: {cfg_parallel_output.shape}") - - # Compare outputs - print("\n[Comparison] CFG Parallel vs Standard CFG:") - baseline_float = baseline_output.float() - cfg_parallel_float = cfg_parallel_output.float() - - cos_sim = F.cosine_similarity( - cfg_parallel_float.flatten(), baseline_float.flatten(), dim=0 - ).item() - - max_diff = torch.max(torch.abs(cfg_parallel_float - baseline_float)).item() - mean_diff = torch.mean(torch.abs(cfg_parallel_float - baseline_float)).item() - - print(f" Cosine similarity: {cos_sim:.6f}") - print(f" Max absolute difference: {max_diff:.6f}") - print(f" Mean absolute difference: {mean_diff:.6f}") - print( - f" CFG parallel range: [{cfg_parallel_float.min():.4f}, {cfg_parallel_float.max():.4f}]" - ) - print(f" Baseline range: [{baseline_float.min():.4f}, {baseline_float.max():.4f}]") - - assert cos_sim > 0.99, ( - f"CFG parallel cosine similarity {cos_sim:.6f} below threshold 0.99. " - f"CFG Parallelism does not match standard CFG baseline." - ) - - print("\n[PASS] CFG Parallelism (cfg_size=2) validated!") - print(" ✓ CFG parallel produces same output as standard CFG") - print(" ✓ Prompt splitting and all-gather working correctly") - print("=" * 80) - - torch.cuda.empty_cache() - - -# ============================================================================= -# Combined Optimizations Tests -# ============================================================================= - - -class TestWanCombinedOptimizations(unittest.TestCase): - """Test all optimizations combined: FP8 + TeaCache + TRTLLM attention + CFG Parallelism.""" - - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - def setUp(self): - """Set up test fixtures and skip if checkpoint not available.""" - torch.manual_seed(42) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(42) - if not CHECKPOINT_PATH or not os.path.exists(CHECKPOINT_PATH): - self.skipTest( - "Checkpoint not available. Set DIFFUSION_MODEL_PATH environment variable." - ) - - def tearDown(self): - """Clean up GPU memory.""" - import gc - - gc.collect() - if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.synchronize() - - def test_all_optimizations_combined(self): - """Test FP8 + TeaCache + TRTLLM attention + CFG=2 combined correctness. - - This test validates that all optimizations work together correctly: - 1. FP8 per-tensor quantization for reduced memory/compute - 2. TeaCache for caching repeated computations - 3. TRTLLM attention backend for optimized attention kernels - 4. CFG Parallelism (cfg_size=2) for distributed CFG computation - - We compare against a standard CFG baseline with relaxed thresholds since multiple - optimizations compound numerical differences. - """ - num_gpus = torch.cuda.device_count() - if num_gpus < 2: - pytest.skip("Combined optimization test requires at least 2 GPUs for CFG parallel") - if not is_wan21_checkpoint(): - pytest.skip( - "This test requires Wan 2.1 checkpoint. Use DIFFUSION_MODEL_PATH with '2.1' in the path." - ) - - print("\n" + "=" * 80) - print("ALL OPTIMIZATIONS COMBINED TEST") - print("FP8 + TeaCache + TRTLLM Attention + CFG Parallelism (cfg_size=2)") - print("=" * 80) - - # Load baseline on GPU 0 (no optimizations, standard CFG) - print("\n[1/3] Loading baseline on GPU 0 (standard CFG, no optimizations)...") - args_baseline = VisualGenArgs( - checkpoint_path=CHECKPOINT_PATH, - device="cuda:0", - dtype="bfloat16", - skip_components=SKIP_COMPONENTS, - parallel=ParallelConfig(dit_cfg_size=1), # Standard CFG - ) - pipeline_baseline = PipelineLoader(args_baseline).load(skip_warmup=True) - config = pipeline_baseline.transformer.model_config.pretrained_config - - # Reset torch compile state to avoid BFloat16 dtype issues - torch._dynamo.reset() - - # Create FIXED test inputs - print("\n[2/3] Creating fixed test inputs...") - torch.manual_seed(42) - batch_size, num_frames, height, width, seq_len = 1, 1, 64, 64, 128 - - latents = torch.randn( - batch_size, - config.in_channels, - num_frames, - height, - width, - dtype=torch.bfloat16, - device="cuda:0", - ) - timestep = torch.tensor([500], dtype=torch.long, device="cuda:0") - prompt_embeds = torch.randn( - batch_size, seq_len, config.text_dim, dtype=torch.bfloat16, device="cuda:0" - ) - neg_prompt_embeds = torch.randn( - batch_size, seq_len, config.text_dim, dtype=torch.bfloat16, device="cuda:0" - ) - - # Setup standard CFG config - cfg_config_baseline = pipeline_baseline._setup_cfg_config( - guidance_scale=5.0, - prompt_embeds=prompt_embeds, - neg_prompt_embeds=neg_prompt_embeds, - ) - - # Run baseline standard CFG - print(" Running baseline (standard CFG)...") - - def forward_fn_baseline( - latents, extra_stream_latents, timestep, encoder_hidden_states, extra_tensors - ): - return pipeline_baseline.transformer( # noqa: F821 - hidden_states=latents, - timestep=timestep, - encoder_hidden_states=encoder_hidden_states, - ) - - with torch.no_grad(): - baseline_output, _, _, _ = pipeline_baseline._denoise_step_standard( - latents=latents.clone(), - extra_stream_latents={}, - timestep=timestep, - prompt_embeds=cfg_config_baseline["prompt_embeds"], - forward_fn=forward_fn_baseline, - guidance_scale=5.0, - guidance_rescale=0.0, - local_extras={}, - ) - - print(f" ✓ Baseline output shape: {baseline_output.shape}") - print(f" ✓ Baseline range: [{baseline_output.min():.4f}, {baseline_output.max():.4f}]") - - # Cleanup baseline to free memory for workers - del pipeline_baseline - torch.cuda.empty_cache() - - # Run with ALL optimizations combined in distributed processes - print("\n[3/3] Running with ALL optimizations (FP8 + TeaCache + TRTLLM + CFG=2)...") - cfg_size = 2 - - inputs_cpu = [ - prompt_embeds.cpu(), - neg_prompt_embeds.cpu(), - latents.cpu(), - timestep.cpu(), - ] - - manager = mp.Manager() - return_dict = manager.dict() - - # Spawn workers - mp.spawn( - _run_all_optimizations_worker, - args=(cfg_size, CHECKPOINT_PATH, inputs_cpu, return_dict), - nprocs=cfg_size, - join=True, - ) - - # Get combined optimization output - combined_output = return_dict["output"].to("cuda:0") - - # Compare outputs with RELAXED thresholds (multiple optimizations compound errors) - print("\n[Comparison] Combined Optimizations vs Baseline:") - baseline_float = baseline_output.float() - combined_float = combined_output.float() - - cos_sim = F.cosine_similarity( - combined_float.flatten(), baseline_float.flatten(), dim=0 - ).item() - - max_diff = torch.max(torch.abs(combined_float - baseline_float)).item() - mean_diff = torch.mean(torch.abs(combined_float - baseline_float)).item() - - print(f" Cosine similarity: {cos_sim:.6f}") - print(f" Max absolute difference: {max_diff:.6f}") - print(f" Mean absolute difference: {mean_diff:.6f}") - print(f" Combined range: [{combined_float.min():.4f}, {combined_float.max():.4f}]") - print(f" Baseline range: [{baseline_float.min():.4f}, {baseline_float.max():.4f}]") - - # Relaxed threshold: cos_sim > 0.90 (compounded numerical differences from 4 optimizations) - assert cos_sim > 0.90, ( - f"Combined optimization cosine similarity {cos_sim:.6f} below threshold 0.90. " - f"This suggests an issue with optimization interactions." - ) - - print("\n[PASS] All optimizations (FP8 + TeaCache + TRTLLM + CFG) validated!") - print(" ✓ All optimizations work correctly together") - print(" ✓ Numerical accuracy within acceptable tolerance") - print("=" * 80) - - torch.cuda.empty_cache() - - -# ============================================================================= -# Two-Stage Transformer Tests (Wan 2.2) -# ============================================================================= - - -class TestWanTwoStageTransformer(unittest.TestCase): - """Test two-stage transformer support for Wan 2.2 T2V.""" - - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - def setUp(self): - """Set up test fixtures and skip if checkpoint not available.""" - torch.manual_seed(42) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(42) - if not CHECKPOINT_PATH_WAN22_T2V or not os.path.exists(CHECKPOINT_PATH_WAN22_T2V): - self.skipTest( - "Wan 2.2 T2V checkpoint not available. Set DIFFUSION_MODEL_PATH_WAN22_T2V." - ) - - def tearDown(self): - """Clean up GPU memory.""" - import gc - - gc.collect() - if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.synchronize() - - def test_two_stage_pipeline_initialization(self): - """Test that Wan 2.2 pipeline initializes with two transformers.""" - if not is_wan22_checkpoint(): - pytest.skip( - "This test requires Wan 2.2 T2V checkpoint. Set DIFFUSION_MODEL_PATH_WAN22_T2V." - ) - print("\n" + "=" * 80) - print("WAN 2.2 TWO-STAGE PIPELINE INITIALIZATION TEST") - print("=" * 80) - - args = VisualGenArgs( - checkpoint_path=CHECKPOINT_PATH_WAN22_T2V, - device="cuda", - dtype="bfloat16", - skip_components=SKIP_COMPONENTS, - ) - pipeline = PipelineLoader(args).load(skip_warmup=True) - - try: - # Check if this is a two-stage model - has_boundary_ratio = pipeline.boundary_ratio is not None - has_transformer_2 = pipeline.transformer_2 is not None - - print(f"\n[Pipeline] boundary_ratio: {pipeline.boundary_ratio}") - print(f"[Pipeline] transformer: {pipeline.transformer is not None}") - print(f"[Pipeline] transformer_2: {has_transformer_2}") - - if not has_boundary_ratio: - pytest.skip("Checkpoint is not Wan 2.2 (no boundary_ratio)") - - # Verify two-stage configuration - assert pipeline.transformer is not None, "Transformer (high-noise) should exist" - assert has_transformer_2, "Transformer_2 (low-noise) should exist for Wan 2.2" - assert 0.0 < pipeline.boundary_ratio < 1.0, ( - f"boundary_ratio should be in (0, 1), got {pipeline.boundary_ratio}" - ) - - print("\n[PASS] ✓ Wan 2.2 two-stage pipeline initialized correctly") - print(f" ✓ boundary_ratio: {pipeline.boundary_ratio}") - print("=" * 80) - - finally: - del pipeline - import gc - - gc.collect() - torch.cuda.empty_cache() - - def test_two_stage_transformer_selection_logic(self): - """Test that correct transformer is selected based on timestep.""" - if not is_wan22_checkpoint(): - pytest.skip( - "This test requires Wan 2.2 T2V checkpoint. Set DIFFUSION_MODEL_PATH_WAN22_T2V." - ) - print("\n" + "=" * 80) - print("WAN 2.2 TRANSFORMER SELECTION LOGIC TEST") - print("=" * 80) - - args = VisualGenArgs( - checkpoint_path=CHECKPOINT_PATH_WAN22_T2V, - device="cuda", - dtype="bfloat16", - skip_components=SKIP_COMPONENTS, - ) - pipeline = PipelineLoader(args).load(skip_warmup=True) - - try: - # Skip if not two-stage - if pipeline.boundary_ratio is None or pipeline.transformer_2 is None: - pytest.skip("Checkpoint is not Wan 2.2 (two-stage)") - - # Calculate boundary timestep - num_train_timesteps = 1000 # Default for Wan models - boundary_timestep = pipeline.boundary_ratio * num_train_timesteps - - print(f"\n[Selection Logic] boundary_ratio: {pipeline.boundary_ratio}") - print(f"[Selection Logic] boundary_timestep: {boundary_timestep:.1f}") - - # Create mock tensors for testing - batch_size, num_frames, height, width = 1, 1, 64, 64 - seq_len = 128 - # Use standard Wan model dimensions - in_channels = 16 # Standard for Wan models - text_dim = 4096 # Standard for Wan models - - latents = torch.randn( - batch_size, - in_channels, - num_frames, - height, - width, - dtype=torch.bfloat16, - device=self.DEVICE, - ) - encoder_hidden_states = torch.randn( - batch_size, seq_len, text_dim, dtype=torch.bfloat16, device=self.DEVICE - ) - - # Test high-noise timestep (should use transformer) - high_noise_t = torch.tensor([900.0], device=self.DEVICE) - print(f"\n[High-Noise] timestep: {high_noise_t.item():.1f}") - print(f"[High-Noise] {high_noise_t.item():.1f} >= {boundary_timestep:.1f}: True") - print("[High-Noise] Should use: transformer (high-noise)") - - with torch.no_grad(): - high_noise_output = pipeline.transformer( - hidden_states=latents, - timestep=high_noise_t, - encoder_hidden_states=encoder_hidden_states, - ) - print(f"[High-Noise] ✓ Output shape: {high_noise_output.shape}") - - # Test low-noise timestep (should use transformer_2) - low_noise_t = torch.tensor([200.0], device=self.DEVICE) - print(f"\n[Low-Noise] timestep: {low_noise_t.item():.1f}") - print(f"[Low-Noise] {low_noise_t.item():.1f} < {boundary_timestep:.1f}: True") - print("[Low-Noise] Should use: transformer_2 (low-noise)") - - with torch.no_grad(): - low_noise_output = pipeline.transformer_2( - hidden_states=latents, - timestep=low_noise_t, - encoder_hidden_states=encoder_hidden_states, - ) - print(f"[Low-Noise] ✓ Output shape: {low_noise_output.shape}") - - # Verify outputs have same shape but different values - assert high_noise_output.shape == low_noise_output.shape - assert not torch.allclose(high_noise_output, low_noise_output, atol=1e-3), ( - "Different transformers should produce different outputs" - ) - - print("\n[PASS] ✓ Transformer selection logic working correctly") - print(" ✓ High-noise stage uses transformer") - print(" ✓ Low-noise stage uses transformer_2") - print("=" * 80) - - finally: - del pipeline - import gc - - gc.collect() - torch.cuda.empty_cache() - - def test_two_stage_with_custom_boundary_ratio(self): - """Test overriding boundary_ratio at inference time.""" - if not is_wan22_checkpoint(): - pytest.skip( - "This test requires Wan 2.2 T2V checkpoint. Set DIFFUSION_MODEL_PATH_WAN22_T2V." - ) - print("\n" + "=" * 80) - print("WAN 2.2 CUSTOM BOUNDARY_RATIO TEST") - print("=" * 80) - - args = VisualGenArgs( - checkpoint_path=CHECKPOINT_PATH_WAN22_T2V, - device="cuda", - dtype="bfloat16", - skip_components=SKIP_COMPONENTS, - ) - pipeline = PipelineLoader(args).load(skip_warmup=True) - - try: - # Skip if not two-stage - if pipeline.boundary_ratio is None or pipeline.transformer_2 is None: - pytest.skip("Checkpoint is not Wan 2.2 (two-stage)") - - model_boundary_ratio = pipeline.boundary_ratio - custom_boundary_ratio = 0.3 # Override value - - print(f"\n[Custom Boundary] Model default: {model_boundary_ratio}") - print(f"[Custom Boundary] Custom override: {custom_boundary_ratio}") - - # Verify custom value would change boundary timestep - num_train_timesteps = 1000 - model_boundary_t = model_boundary_ratio * num_train_timesteps - custom_boundary_t = custom_boundary_ratio * num_train_timesteps - - print(f"[Custom Boundary] Model boundary_timestep: {model_boundary_t:.1f}") - print(f"[Custom Boundary] Custom boundary_timestep: {custom_boundary_t:.1f}") - print( - f"[Custom Boundary] Difference: {abs(model_boundary_t - custom_boundary_t):.1f} timesteps" - ) - - assert custom_boundary_ratio != model_boundary_ratio - print("\n[PASS] ✓ Custom boundary_ratio can override model default") - print("=" * 80) - - finally: - del pipeline - import gc - - gc.collect() - torch.cuda.empty_cache() - - def test_two_stage_guidance_scale_2(self): - """Test two-stage denoising with different guidance_scale_2 values.""" - if not is_wan22_checkpoint(): - pytest.skip( - "This test requires Wan 2.2 T2V checkpoint. Set DIFFUSION_MODEL_PATH_WAN22_T2V." - ) - print("\n" + "=" * 80) - print("WAN 2.2 GUIDANCE_SCALE_2 SUPPORT TEST") - print("=" * 80) - - args = VisualGenArgs( - checkpoint_path=CHECKPOINT_PATH_WAN22_T2V, - device="cuda", - dtype="bfloat16", - skip_components=SKIP_COMPONENTS, - ) - pipeline = PipelineLoader(args).load(skip_warmup=True) - - try: - # Skip if not two-stage - if pipeline.boundary_ratio is None or pipeline.transformer_2 is None: - pytest.skip("Checkpoint is not Wan 2.2 (two-stage)") - - print("\n[Guidance Scale 2] Two-stage model supports separate guidance scales:") - print("[Guidance Scale 2] High-noise stage: uses guidance_scale (e.g., 4.0)") - print("[Guidance Scale 2] Low-noise stage: uses guidance_scale_2 (e.g., 2.0, 3.0, 4.0)") - print("\n[PASS] ✓ Different guidance scales supported for two stages") - print("=" * 80) - - finally: - del pipeline - import gc - - gc.collect() - torch.cuda.empty_cache() - - def test_two_stage_with_fp8_quantization(self): - """Test two-stage with FP8 quantization on both transformers.""" - if not is_wan22_checkpoint(): - pytest.skip( - "This test requires Wan 2.2 T2V checkpoint. Set DIFFUSION_MODEL_PATH_WAN22_T2V." - ) - print("\n" + "=" * 80) - print("WAN 2.2 TWO-STAGE + FP8 QUANTIZATION TEST") - print("=" * 80) - - args = VisualGenArgs( - checkpoint_path=CHECKPOINT_PATH_WAN22_T2V, - device="cuda", - dtype="bfloat16", - skip_components=SKIP_COMPONENTS, - quant_config={"quant_algo": "FP8", "dynamic": True}, - ) - pipeline = PipelineLoader(args).load(skip_warmup=True) - - try: - # Skip if not two-stage - if pipeline.boundary_ratio is None or pipeline.transformer_2 is None: - pytest.skip("Checkpoint is not Wan 2.2 (two-stage)") - - # Verify FP8 in transformer (high-noise) - found_fp8_t1 = False - for name, param in pipeline.transformer.named_parameters(): - if "blocks.0" in name and "weight" in name and param.dtype == torch.float8_e4m3fn: - found_fp8_t1 = True - print(f"\n[FP8] ✓ Transformer: Found FP8 weight in {name}") - break - assert found_fp8_t1, "No FP8 weights found in transformer" - - # Verify FP8 in transformer_2 (low-noise) - found_fp8_t2 = False - for name, param in pipeline.transformer_2.named_parameters(): - if "blocks.0" in name and "weight" in name and param.dtype == torch.float8_e4m3fn: - found_fp8_t2 = True - print(f"[FP8] ✓ Transformer_2: Found FP8 weight in {name}") - break - assert found_fp8_t2, "No FP8 weights found in transformer_2" - - print("\n[PASS] ✓ FP8 quantization enabled for BOTH transformers") - print("=" * 80) - - finally: - del pipeline - import gc - - gc.collect() - torch.cuda.empty_cache() - - def test_two_stage_with_trtllm_attention(self): - """Test two-stage with TRTLLM attention backend on both transformers.""" - if not is_wan22_checkpoint(): - pytest.skip( - "This test requires Wan 2.2 T2V checkpoint. Set DIFFUSION_MODEL_PATH_WAN22_T2V." - ) - print("\n" + "=" * 80) - print("WAN 2.2 TWO-STAGE + TRTLLM ATTENTION TEST") - print("=" * 80) - - args = VisualGenArgs( - checkpoint_path=CHECKPOINT_PATH_WAN22_T2V, - device="cuda", - dtype="bfloat16", - skip_components=SKIP_COMPONENTS, - attention=AttentionConfig(backend="TRTLLM"), - ) - pipeline = PipelineLoader(args).load(skip_warmup=True) - - try: - # Skip if not two-stage - if pipeline.boundary_ratio is None or pipeline.transformer_2 is None: - pytest.skip("Checkpoint is not Wan 2.2 (two-stage)") - - # Verify TRTLLM attention on transformer (high-noise) - first_block_t1 = pipeline.transformer.blocks[0] - attn1_backend_t1 = first_block_t1.attn1.attn_backend - attn2_backend_t1 = first_block_t1.attn2.attn_backend - - assert attn1_backend_t1 == "TRTLLM", ( - f"Expected TRTLLM for transformer self-attn, got {attn1_backend_t1}" - ) - assert attn2_backend_t1 == "VANILLA", ( - f"Expected VANILLA for transformer cross-attn, got {attn2_backend_t1}" - ) - - print("\n[Attention] Transformer (high-noise):") - print(f" ✓ Self-attention: {attn1_backend_t1}") - print(f" ✓ Cross-attention: {attn2_backend_t1}") - - # Verify TRTLLM attention on transformer_2 (low-noise) - first_block_t2 = pipeline.transformer_2.blocks[0] - attn1_backend_t2 = first_block_t2.attn1.attn_backend - attn2_backend_t2 = first_block_t2.attn2.attn_backend - - assert attn1_backend_t2 == "TRTLLM", ( - f"Expected TRTLLM for transformer_2 self-attn, got {attn1_backend_t2}" - ) - assert attn2_backend_t2 == "VANILLA", ( - f"Expected VANILLA for transformer_2 cross-attn, got {attn2_backend_t2}" - ) - - print("[Attention] Transformer_2 (low-noise):") - print(f" ✓ Self-attention: {attn1_backend_t2}") - print(f" ✓ Cross-attention: {attn2_backend_t2}") - - print("\n[PASS] ✓ TRTLLM attention enabled for BOTH transformers") - print("=" * 80) - - finally: - del pipeline - import gc - - gc.collect() - torch.cuda.empty_cache() - - def test_two_stage_all_optimizations(self): - """Test two-stage with all supported optimizations: FP8 + TRTLLM.""" - if not is_wan22_checkpoint(): - pytest.skip( - "This test requires Wan 2.2 T2V checkpoint. Set DIFFUSION_MODEL_PATH_WAN22_T2V." - ) - print("\n" + "=" * 80) - print("WAN 2.2 TWO-STAGE + ALL OPTIMIZATIONS TEST") - print("FP8 + TRTLLM Attention (TeaCache not supported for Wan 2.2)") - print("=" * 80) - - args = VisualGenArgs( - checkpoint_path=CHECKPOINT_PATH_WAN22_T2V, - device="cuda", - dtype="bfloat16", - skip_components=SKIP_COMPONENTS, - quant_config={"quant_algo": "FP8_BLOCK_SCALES", "dynamic": True}, - attention=AttentionConfig(backend="TRTLLM"), - ) - pipeline = PipelineLoader(args).load(skip_warmup=True) - - try: - # Skip if not two-stage - if pipeline.boundary_ratio is None or pipeline.transformer_2 is None: - pytest.skip("Checkpoint is not Wan 2.2 (two-stage)") - - optimizations = [] - - # Check FP8 - for name, param in pipeline.transformer.named_parameters(): - if "blocks.0" in name and "weight" in name and param.dtype == torch.float8_e4m3fn: - optimizations.append("FP8") - break - - # Check TRTLLM - if pipeline.transformer.blocks[0].attn1.attn_backend == "TRTLLM": - optimizations.append("TRTLLM") - - # Check two-stage - optimizations.append("Two-Stage") - - print(f"\n[All Optimizations] Enabled: {', '.join(optimizations)}") - assert len(optimizations) == 3, ( - f"Expected 3 optimizations, got {len(optimizations)}: {optimizations}" - ) - - # Verify all optimizations on transformer_2 as well - for name, param in pipeline.transformer_2.named_parameters(): - if "blocks.0" in name and "weight" in name and param.dtype == torch.float8_e4m3fn: - print("[All Optimizations] ✓ Transformer_2: FP8 enabled") - break - - if pipeline.transformer_2.blocks[0].attn1.attn_backend == "TRTLLM": - print("[All Optimizations] ✓ Transformer_2: TRTLLM enabled") - - print("\n[PASS] ✓ All optimizations working on BOTH transformers") - print("=" * 80) - - finally: - del pipeline - import gc - - gc.collect() - torch.cuda.empty_cache() - - -# ============================================================================= -# Robustness Tests -# ============================================================================= - - -class TestWanRobustness(unittest.TestCase): - """Error handling and edge case tests.""" - - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - def setUp(self): - """Set up test fixtures and skip if checkpoint not available.""" - if not CHECKPOINT_PATH or not os.path.exists(CHECKPOINT_PATH): - self.skipTest( - "Checkpoint not available. Set DIFFUSION_MODEL_PATH environment variable." - ) - - def tearDown(self): - """Clean up GPU memory.""" - import gc - - gc.collect() - if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.synchronize() - - def test_invalid_quant_config(self): - """Test that invalid quantization config raises appropriate error.""" - with pytest.raises((ValueError, KeyError)): - args = VisualGenArgs( - checkpoint_path=CHECKPOINT_PATH, - device="cuda", - dtype="bfloat16", - skip_components=SKIP_COMPONENTS, - quant_config={"quant_algo": "INVALID_ALGO"}, - ) - pipeline = PipelineLoader(args).load(skip_warmup=True) # noqa: F841 - - print("\n[Error Handling] ✓ Invalid quant_algo raises error as expected") - - -# ============================================================================= -# Batch Generation Tests -# ============================================================================= - - -class TestWanBatchGeneration: - """Batch generation tests for WAN T2V pipeline. - - Tests that passing a list of prompts produces batched output - and matches sequential generation with the same seeds. - """ - - @pytest.fixture(scope="class") - def wan21_full_pipeline(self): - """Load full Wan 2.1 pipeline (all components) for batch tests.""" - if not CHECKPOINT_PATH or not os.path.exists(CHECKPOINT_PATH): - pytest.skip("Checkpoint not available. Set DIFFUSION_MODEL_PATH.") - if not is_wan21_checkpoint(): - pytest.skip("Batch tests require Wan 2.1 checkpoint") - - from tensorrt_llm._torch.visual_gen.config import TorchCompileConfig - - args = VisualGenArgs( - checkpoint_path=CHECKPOINT_PATH, - device="cuda", - dtype="bfloat16", - torch_compile=TorchCompileConfig(enable_torch_compile=False), - ) - pipeline = PipelineLoader(args).load(skip_warmup=True) - yield pipeline - del pipeline - import gc - - gc.collect() - torch.cuda.empty_cache() - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - def test_single_prompt_backward_compat(self, wan21_full_pipeline): - """Single prompt returns (T, H, W, C) for backward compatibility.""" - result = wan21_full_pipeline.forward( - prompt="a cat walking", - height=480, - width=832, - num_frames=9, - num_inference_steps=4, - guidance_scale=5.0, - seed=42, - ) - assert result.video.dim() == 5, f"Expected 5D (B,T,H,W,C), got {result.video.dim()}D" - B, _T, H, W, C = result.video.shape - assert B == 1 and H == 480 and W == 832 and C == 3 - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - def test_batch_prompt_shape(self, wan21_full_pipeline): - """List of prompts returns (B, T, H, W, C).""" - prompts = ["a sunset over mountains", "a cat on a roof"] - result = wan21_full_pipeline.forward( - prompt=prompts, - height=480, - width=832, - num_frames=9, - num_inference_steps=4, - guidance_scale=5.0, - seed=42, - ) - assert result.video.dim() == 5, f"Expected 5D (B,T,H,W,C), got {result.video.dim()}D" - B, _T, H, W, C = result.video.shape - assert B == 2 and H == 480 and W == 832 and C == 3 - - -if __name__ == "__main__": - unittest.main(verbosity=2) diff --git a/tests/unittest/_torch/visual_gen/test_wan21_i2v_pipeline.py b/tests/unittest/_torch/visual_gen/test_wan21_i2v_pipeline.py new file mode 100644 index 000000000000..646c4451b807 --- /dev/null +++ b/tests/unittest/_torch/visual_gen/test_wan21_i2v_pipeline.py @@ -0,0 +1,402 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Correctness tests for Wan 2.1 I2V pipeline against HuggingFace reference. + +Verifies >= 0.99 cosine similarity on decoded video frames (T, H, W, C) after +full denoising + VAE decode. Image conditioning uses CLIP embeddings + VAE-encoded latent. + +Models: + - Wan2.1-I2V-14B-480P-Diffusers (480x832, 33 frames) + - Wan2.1-I2V-14B-720P-Diffusers (720x1280, 33 frames) + +Run: + pytest tests/unittest/_torch/visual_gen/test_wan21_i2v_pipeline.py -v -s -k 480p + pytest tests/unittest/_torch/visual_gen/test_wan21_i2v_pipeline.py -v -s -k 720p + +Override checkpoint paths: + DIFFUSION_MODEL_PATH_WAN21_I2V_480P=/path/to/480p \\ + DIFFUSION_MODEL_PATH_WAN21_I2V_720P=/path/to/720p \\ + pytest tests/unittest/_torch/visual_gen/test_wan21_i2v_pipeline.py -v -s +""" + +import gc +import os +from pathlib import Path + +os.environ["TLLM_DISABLE_MPI"] = "1" + +import numpy as np +import pytest +import torch +import torch.nn.functional as F +from diffusers import DiffusionPipeline +from PIL import Image + +from tensorrt_llm._torch.visual_gen.config import ( + AttentionConfig, + TeaCacheConfig, + TorchCompileConfig, + VisualGenArgs, +) +from tensorrt_llm._torch.visual_gen.pipeline_loader import PipelineLoader + + +@pytest.fixture(autouse=True, scope="module") +def _cleanup_mpi_env(): + yield + os.environ.pop("TLLM_DISABLE_MPI", None) + + +# ============================================================================ +# Path helpers +# ============================================================================ + + +def _llm_models_root() -> str: + """Return LLM_MODELS_ROOT path if set in env, assert when it's set but not a valid path.""" + root = Path("/home/scratch.trt_llm_data_ci/llm-models/") + if "LLM_MODELS_ROOT" in os.environ: + root = Path(os.environ["LLM_MODELS_ROOT"]) + if not root.exists(): + root = Path("/scratch.trt_llm_data/llm-models/") + assert root.exists(), ( + "Set LLM_MODELS_ROOT or ensure /home/scratch.trt_llm_data_ci/llm-models/ is accessible." + ) + return str(root) + + +def _checkpoint(env_var: str, default_name: str) -> str: + return os.environ.get(env_var) or os.path.join(_llm_models_root(), default_name) + + +WAN21_I2V_480P_PATH = _checkpoint( + "DIFFUSION_MODEL_PATH_WAN21_I2V_480P", "Wan2.1-I2V-14B-480P-Diffusers" +) +WAN21_I2V_720P_PATH = _checkpoint( + "DIFFUSION_MODEL_PATH_WAN21_I2V_720P", "Wan2.1-I2V-14B-720P-Diffusers" +) + +# ============================================================================ +# Test constants +# ============================================================================ + +PROMPT = "A cat sitting on a sunny windowsill watching birds outside." +NEGATIVE_PROMPT = "" +NUM_STEPS = 10 +SEED = 42 +COS_SIM_THRESHOLD = 0.99 + + +# ============================================================================ +# Helpers +# ============================================================================ + + +def _make_test_image(height: int, width: int) -> Image.Image: + """Create a deterministic gradient test image.""" + img = np.zeros((height, width, 3), dtype=np.uint8) + img[:, :, 0] = np.linspace(0, 255, height, dtype=np.uint8)[:, np.newaxis] + img[:, :, 1] = np.linspace(0, 255, width, dtype=np.uint8)[np.newaxis, :] + img[:, :, 2] = 128 + return Image.fromarray(img, mode="RGB") + + +def _load_trtllm_pipeline(checkpoint_path: str): + """Load TRTLLM WanImageToVideoPipeline without torch.compile or warmup.""" + if not os.path.exists(checkpoint_path): + pytest.skip(f"Checkpoint not found: {checkpoint_path}") + args = VisualGenArgs( + checkpoint_path=checkpoint_path, + device="cuda", + dtype="bfloat16", + torch_compile=TorchCompileConfig(enable_torch_compile=False), + ) + return PipelineLoader(args).load(skip_warmup=True) + + +def _load_hf_pipeline(checkpoint_path: str): + """Load HuggingFace diffusers pipeline (auto-detects class from model_index.json).""" + hf_pipe = DiffusionPipeline.from_pretrained( + checkpoint_path, + torch_dtype=torch.bfloat16, + ) + hf_pipe = hf_pipe.to("cuda") + hf_pipe.set_progress_bar_config(disable=True) + return hf_pipe + + +def _capture_trtllm_video( + pipeline, + image: Image.Image, + prompt: str, + negative_prompt: str, + height: int, + width: int, + num_frames: int, + num_inference_steps: int, + guidance_scale: float, + seed: int, +) -> torch.Tensor: + """Run full TRTLLM pipeline including VAE decode; return (T, H, W, C) float in [0, 1].""" + with torch.no_grad(): + result = pipeline.forward( + image=image, + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + num_frames=num_frames, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + seed=seed, + ) + video = result.video # (T, H, W, C) uint8 + return video.float() / 255.0 + + +def _capture_hf_video( + hf_pipe, + image: Image.Image, + prompt: str, + negative_prompt: str, + height: int, + width: int, + num_frames: int, + num_inference_steps: int, + guidance_scale: float, + seed: int, +) -> torch.Tensor: + """Run HF pipeline with output_type='np'; return (T, H, W, C) float in [0, 1].""" + generator = torch.Generator(device="cuda").manual_seed(seed) + output = hf_pipe( + image=image, + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + num_frames=num_frames, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + generator=generator, + output_type="np", + ) + frames = output.frames # (1, T, H, W, C) numpy float32 in [0, 1] + if isinstance(frames, np.ndarray): + return torch.from_numpy(frames[0]).float() + return frames[0].float() + + +def _cosine_similarity(a: torch.Tensor, b: torch.Tensor) -> float: + """Cosine similarity between two tensors (flattened to 1D, cast to float32 on CPU).""" + a_flat = a.float().cpu().reshape(-1) + b_flat = b.float().cpu().reshape(-1) + return F.cosine_similarity(a_flat.unsqueeze(0), b_flat.unsqueeze(0)).clamp(-1.0, 1.0).item() + + +def _assert_pipeline_matches_hf( + checkpoint_path: str, + height: int, + width: int, + num_frames: int, + guidance_scale: float, + model_label: str, +) -> None: + """Run TRTLLM and HF pipelines sequentially, compare decoded video output.""" + test_image = _make_test_image(height, width) + + # --- TRTLLM --- + trtllm_pipe = _load_trtllm_pipeline(checkpoint_path) + trtllm_video = _capture_trtllm_video( + trtllm_pipe, + image=test_image, + prompt=PROMPT, + negative_prompt=NEGATIVE_PROMPT, + height=height, + width=width, + num_frames=num_frames, + num_inference_steps=NUM_STEPS, + guidance_scale=guidance_scale, + seed=SEED, + ) + del trtllm_pipe + gc.collect() + torch.cuda.empty_cache() + + # --- HF reference --- + hf_pipe = _load_hf_pipeline(checkpoint_path) + hf_video = _capture_hf_video( + hf_pipe, + image=test_image, + prompt=PROMPT, + negative_prompt=NEGATIVE_PROMPT, + height=height, + width=width, + num_frames=num_frames, + num_inference_steps=NUM_STEPS, + guidance_scale=guidance_scale, + seed=SEED, + ) + del hf_pipe + gc.collect() + torch.cuda.empty_cache() + + # --- Compare --- + assert trtllm_video.numel() == hf_video.numel(), ( + f"{model_label}: element count mismatch — " + f"TRTLLM {trtllm_video.shape} ({trtllm_video.numel()}) vs " + f"HF {hf_video.shape} ({hf_video.numel()})" + ) + + cos_sim = _cosine_similarity(trtllm_video, hf_video) + print(f"\n {model_label} cosine similarity: {cos_sim:.6f}") + assert cos_sim >= COS_SIM_THRESHOLD, ( + f"{model_label}: cosine similarity {cos_sim:.6f} < {COS_SIM_THRESHOLD}. " + f"TRTLLM pipeline output diverges from the HuggingFace reference. " + f"Video shapes — TRTLLM: {trtllm_video.shape}, HF: {hf_video.shape}." + ) + + +# ============================================================================ +# Tests +# ============================================================================ + + +@pytest.mark.integration +@pytest.mark.wan_i2v +class TestWan21_I2V_480P_PipelineCorrectness: + """Wan2.1-I2V-14B-480P correctness vs HuggingFace reference (480x832, 33 frames).""" + + def test_cosine_similarity(self): + _assert_pipeline_matches_hf( + checkpoint_path=WAN21_I2V_480P_PATH, + height=480, + width=832, + num_frames=33, + guidance_scale=5.0, + model_label="Wan2.1-I2V-14B-480P", + ) + + +# ============================================================================= +# Batch Generation Tests (I2V) +# ============================================================================= + + +@pytest.mark.integration +@pytest.mark.i2v +class TestWanI2VBatchGeneration: + """Batch generation tests for WAN I2V pipeline (Wan 2.1 and Wan 2.2). + + Tests that passing a list of prompts produces batched output + and matches sequential generation with the same seeds. + """ + + @pytest.fixture(scope="class") + def i2v_full_pipeline(self): + """Load full I2V pipeline (all components) for batch tests.""" + if not WAN21_I2V_480P_PATH or not os.path.exists(WAN21_I2V_480P_PATH): + pytest.skip("Checkpoint not available. Set DIFFUSION_MODEL_PATH_WAN21_I2V_480P.") + + args = VisualGenArgs( + checkpoint_path=WAN21_I2V_480P_PATH, + device="cuda", + dtype="bfloat16", + torch_compile=TorchCompileConfig(enable_torch_compile=False), + ) + pipeline = PipelineLoader(args).load(skip_warmup=True) + yield pipeline + del pipeline + import gc + + gc.collect() + torch.cuda.empty_cache() + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + def test_single_prompt_backward_compat(self, i2v_full_pipeline): + """Single prompt returns (T, H, W, C) for backward compatibility.""" + test_image = _make_test_image(480, 832) + result = i2v_full_pipeline.forward( + prompt="a cat walking", + image=test_image, + height=480, + width=832, + num_frames=33, + num_inference_steps=4, + guidance_scale=5.0, + seed=42, + ) + assert result.video.dim() == 5, f"Expected 5D (B,T,H,W,C), got {result.video.dim()}D" + B, _T, H, W, C = result.video.shape + assert B == 1 and H == 480 and W == 832 and C == 3 + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + def test_batch_prompt_shape(self, i2v_full_pipeline): + """List of prompts returns (B, T, H, W, C).""" + test_image = _make_test_image(480, 832) + prompts = ["a sunset over mountains", "a cat on a roof"] + result = i2v_full_pipeline.forward( + prompt=prompts, + image=test_image, + height=480, + width=832, + num_frames=33, + num_inference_steps=4, + guidance_scale=5.0, + seed=42, + ) + assert result.video.dim() == 5, f"Expected 5D (B,T,H,W,C), got {result.video.dim()}D" + B, _T, H, W, C = result.video.shape + assert B == 2 and H == 480 and W == 832 and C == 3 + + +# ============================================================================= +# Combined Optimization Tests +# ============================================================================= + + +@pytest.mark.integration +@pytest.mark.wan_i2v +class TestWan21I2VCombinedOptimizations: + """FP8 + TeaCache + TRTLLM attention combined on Wan 2.1 I2V (480P, 480x832).""" + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + def test_fp8_teacache_trtllm(self): + if not os.path.exists(WAN21_I2V_480P_PATH): + pytest.skip(f"Checkpoint not found: {WAN21_I2V_480P_PATH}") + args = VisualGenArgs( + checkpoint_path=WAN21_I2V_480P_PATH, + device="cuda", + dtype="bfloat16", + torch_compile=TorchCompileConfig(enable_torch_compile=False), + quant_config={"quant_algo": "FP8", "dynamic": True}, + attention=AttentionConfig(backend="TRTLLM"), + cache=TeaCacheConfig(teacache_thresh=0.2), + ) + pipeline = PipelineLoader(args).load(skip_warmup=True) + try: + test_image = _make_test_image(480, 832) + with torch.no_grad(): + result = pipeline.forward( + image=test_image, + prompt="a cat sitting on a windowsill", + negative_prompt="", + height=480, + width=832, + num_frames=33, + num_inference_steps=10, + guidance_scale=5.0, + seed=42, + ) + assert result.video.dim() == 5 + B, _T, H, W, C = result.video.shape + assert B == 1 and H == 480 and W == 832 and C == 3 + + assert pipeline.cache_accelerator is not None + assert pipeline.cache_accelerator.is_enabled() + stats = pipeline.cache_accelerator.get_stats() + assert stats["cached_steps"] > 0, f"No TeaCache hits with FP8+TRTLLM. Stats: {stats}" + finally: + del pipeline + gc.collect() + torch.cuda.empty_cache() diff --git a/tests/unittest/_torch/visual_gen/test_wan21_i2v_teacache.py b/tests/unittest/_torch/visual_gen/test_wan21_i2v_teacache.py new file mode 100644 index 000000000000..1e509a164110 --- /dev/null +++ b/tests/unittest/_torch/visual_gen/test_wan21_i2v_teacache.py @@ -0,0 +1,260 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for Wan I2V TeaCache using the full pipeline. + +Wan 2.1 models are tested with TeaCache enabled: + - Wan2.1-I2V-14B-480P-Diffusers 480x832 + - Wan2.1-I2V-14B-720P-Diffusers 720x1280 + +Wan 2.2 is tested to confirm that enabling TeaCache raises a ValueError. + +Loads all components (VAE, text encoder, scheduler) and calls pipeline.forward() +so that TeaCache runs on the actual scheduler timesteps. + +Run all: + pytest tests/unittest/_torch/visual_gen/test_wan21_i2v_teacache.py -v -s + +Run one model: + pytest tests/unittest/_torch/visual_gen/test_wan21_i2v_teacache.py -v -s -k wan21_i2v_480p + pytest tests/unittest/_torch/visual_gen/test_wan21_i2v_teacache.py -v -s -k wan21_i2v_720p + pytest tests/unittest/_torch/visual_gen/test_wan21_i2v_teacache.py -v -s -k wan22_raises + +Override checkpoint paths: + DIFFUSION_MODEL_PATH_WAN21_I2V_480P=/path/to/480p \\ + DIFFUSION_MODEL_PATH_WAN21_I2V_720P=/path/to/720p \\ + DIFFUSION_MODEL_PATH_WAN22_I2V=/path/to/wan22 \\ + pytest tests/unittest/_torch/visual_gen/test_wan21_i2v_teacache.py -v -s +""" + +import os + +os.environ["TLLM_DISABLE_MPI"] = "1" + +import gc +from pathlib import Path + +import numpy as np +import pytest +import torch +from PIL import Image + +from tensorrt_llm._torch.visual_gen.config import TeaCacheConfig, VisualGenArgs +from tensorrt_llm._torch.visual_gen.pipeline_loader import PipelineLoader + + +@pytest.fixture(autouse=True, scope="module") +def _cleanup_mpi_env(): + yield + os.environ.pop("TLLM_DISABLE_MPI", None) + + +@pytest.fixture(autouse=True) +def _cleanup_gpu(): + gc.collect() + torch.cuda.empty_cache() + yield + gc.collect() + torch.cuda.empty_cache() + + +# ============================================================================ +# Path helpers +# ============================================================================ + + +def _llm_models_root() -> Path: + if "LLM_MODELS_ROOT" in os.environ: + root = Path(os.environ["LLM_MODELS_ROOT"]) + else: + root = Path("/home/scratch.trt_llm_data_ci/llm-models/") + if not root.exists(): + root = Path("/scratch.trt_llm_data/llm-models/") + assert root.exists(), ( + "Set LLM_MODELS_ROOT or ensure /home/scratch.trt_llm_data_ci/llm-models/ is accessible." + ) + return root + + +def _checkpoint(env_var: str, default_name: str) -> str: + return os.environ.get(env_var) or str(_llm_models_root() / default_name) + + +WAN21_I2V_480P_PATH = _checkpoint( + "DIFFUSION_MODEL_PATH_WAN21_I2V_480P", "Wan2.1-I2V-14B-480P-Diffusers" +) +WAN21_I2V_720P_PATH = _checkpoint( + "DIFFUSION_MODEL_PATH_WAN21_I2V_720P", "Wan2.1-I2V-14B-720P-Diffusers" +) + +WAN22_I2V_A14B_PATH = _checkpoint("DIFFUSION_MODEL_PATH_WAN22_I2V", "Wan2.2-I2V-A14B-Diffusers") + +INFER_NUM_FRAMES = 33 # (33-1)/4+1 = 9 latent frames; smallest realistic shape +INFER_NUM_STEPS = 50 # Required for meaningful cache hits with calibrated coefficients +INFER_SEED = 42 + + +# ============================================================================ +# Pipeline fixture factory +# ============================================================================ + + +def _make_pipeline(checkpoint_path: str, use_ret_steps: bool = False): + if not checkpoint_path or not os.path.exists(checkpoint_path): + pytest.skip(f"Checkpoint not found: {checkpoint_path}") + args = VisualGenArgs( + checkpoint_path=checkpoint_path, + device="cuda", + dtype="bfloat16", + cache=TeaCacheConfig( + teacache_thresh=0.2, + use_ret_steps=use_ret_steps, + ), + ) + pipeline = PipelineLoader(args).load(skip_warmup=True) + return pipeline + + +@pytest.fixture +def wan21_i2v_480p_pipeline(): + pipeline = _make_pipeline(WAN21_I2V_480P_PATH) + yield pipeline + del pipeline + torch.cuda.empty_cache() + + +@pytest.fixture +def wan21_i2v_480p_ret_steps_pipeline(): + pipeline = _make_pipeline(WAN21_I2V_480P_PATH, use_ret_steps=True) + yield pipeline + del pipeline + torch.cuda.empty_cache() + + +@pytest.fixture +def wan21_i2v_720p_pipeline(): + pipeline = _make_pipeline(WAN21_I2V_720P_PATH) + yield pipeline + del pipeline + torch.cuda.empty_cache() + + +# ============================================================================ +# Shared helpers +# ============================================================================ + + +def _make_test_image(height: int, width: int) -> Image.Image: + img = np.zeros((height, width, 3), dtype=np.uint8) + img[:, :, 0] = np.linspace(0, 255, height, dtype=np.uint8)[:, None] + return Image.fromarray(img, mode="RGB") + + +def _assert_i2v_teacache( + pipeline, + height: int, + width: int, + model: str = "", + expected_hit_rate: float = None, + atol: float = 0.02, +) -> None: + """Run forward and verify TeaCache produces cache hits (single-stage Wan 2.1 I2V).""" + test_image = _make_test_image(height, width) + + with torch.no_grad(): + pipeline.forward( + image=test_image, + prompt="a cat sitting on a windowsill", + negative_prompt="", + height=height, + width=width, + num_frames=INFER_NUM_FRAMES, + num_inference_steps=INFER_NUM_STEPS, + seed=INFER_SEED, + ) + + stats = pipeline.transformer_cache_backend.get_stats() + + print(f"\n ===== TeaCache — Wan 2.1 {model} single-stage {height}x{width} =====") + print( + f" transformer: {stats['cached_steps']}/{stats['total_steps']} cached " + f"({stats['hit_rate']:.1%} hit rate)" + ) + if expected_hit_rate is not None: + # Reference hit rates derived from vFly reference runs + print(f" expected: {expected_hit_rate:.1%} (vFly reference, atol={atol:.0%})") + delta = stats["hit_rate"] - expected_hit_rate + print(f" delta: {delta:+.1%}") + print(" ================================================") + + assert stats["total_steps"] == INFER_NUM_STEPS, ( + f"total_steps {stats['total_steps']} != {INFER_NUM_STEPS}" + ) + assert stats["compute_steps"] + stats["cached_steps"] == stats["total_steps"] + assert stats["cached_steps"] > 0, ( + f"0 cache hits after {stats['total_steps']} steps. TeaCache is not working. Stats: {stats}" + ) + if expected_hit_rate is not None: + assert abs(stats["hit_rate"] - expected_hit_rate) <= atol + 1e-9, ( + f"Hit rate {stats['hit_rate']:.1%} not within {atol:.0%} " + f"of expected {expected_hit_rate:.1%} (vFly reference)" + ) + + +# ============================================================================ +# Tests +# ============================================================================ + + +@pytest.mark.integration +@pytest.mark.wan_i2v +@pytest.mark.teacache +class TestWan21I2V_480P_TeaCache: + """Wan2.1-I2V-14B-480P 480x832 single-stage.""" + + def test_wan21_i2v_480p_teacache(self, wan21_i2v_480p_pipeline): + _assert_i2v_teacache( + wan21_i2v_480p_pipeline, height=480, width=832, model="I2V-14B", expected_hit_rate=0.54 + ) + + def test_wan21_i2v_480p_teacache_ret_steps(self, wan21_i2v_480p_ret_steps_pipeline): + _assert_i2v_teacache( + wan21_i2v_480p_ret_steps_pipeline, + height=480, + width=832, + model="I2V-14B", + expected_hit_rate=0.50, + ) + + +@pytest.mark.integration +@pytest.mark.wan_i2v +@pytest.mark.teacache +class TestWan21I2V_720P_TeaCache: + """Wan2.1-I2V-14B-720P 720x1280 single-stage.""" + + def test_wan21_i2v_720p_teacache(self, wan21_i2v_720p_pipeline): + _assert_i2v_teacache( + wan21_i2v_720p_pipeline, height=720, width=1280, model="I2V-14B", expected_hit_rate=0.54 + ) + + +@pytest.mark.integration +@pytest.mark.wan_i2v +@pytest.mark.teacache +class TestWan22_I2V_TeaCacheRaisesError: + """Wan2.2-I2V-A14B must raise ValueError when TeaCache is enabled.""" + + def test_wan22_raises_if_teacache_enabled(self): + if not os.path.exists(WAN22_I2V_A14B_PATH): + pytest.skip( + f"Checkpoint not found: {WAN22_I2V_A14B_PATH} (set DIFFUSION_MODEL_PATH_WAN22_I2V)" + ) + args = VisualGenArgs( + checkpoint_path=WAN22_I2V_A14B_PATH, + device="cuda", + dtype="bfloat16", + cache=TeaCacheConfig(), + ) + with pytest.raises(ValueError, match="TeaCache is not supported for Wan 2\\.2"): + PipelineLoader(args).load(skip_warmup=True) diff --git a/tests/unittest/_torch/visual_gen/test_wan21_t2v_pipeline.py b/tests/unittest/_torch/visual_gen/test_wan21_t2v_pipeline.py new file mode 100644 index 000000000000..aa0f492ae350 --- /dev/null +++ b/tests/unittest/_torch/visual_gen/test_wan21_t2v_pipeline.py @@ -0,0 +1,655 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Correctness tests for Wan 2.1 T2V pipeline against HuggingFace reference. + +Verifies >= 0.99 cosine similarity on decoded video frames (T, H, W, C) after +full denoising + VAE decode. + +Models: + - Wan2.1-T2V-1.3B-Diffusers (480x832, 33 frames) + - Wan2.1-T2V-14B-Diffusers (720x1280, 33 frames) + +Run: + pytest tests/unittest/_torch/visual_gen/test_wan21_t2v_pipeline.py -v -s -k 1_3b + pytest tests/unittest/_torch/visual_gen/test_wan21_t2v_pipeline.py -v -s -k 14b + +Override checkpoint paths: + DIFFUSION_MODEL_PATH_WAN21_1_3B=/path/to/1.3b \\ + DIFFUSION_MODEL_PATH_WAN21_14B=/path/to/14b \\ + pytest tests/unittest/_torch/visual_gen/test_wan21_t2v_pipeline.py -v -s +""" + +import gc +import os +from pathlib import Path + +os.environ["TLLM_DISABLE_MPI"] = "1" + +import numpy as np +import pytest +import torch +import torch.nn.functional as F +from diffusers import DiffusionPipeline + +from tensorrt_llm._torch.modules.linear import Linear +from tensorrt_llm._torch.visual_gen.config import ( + AttentionConfig, + TeaCacheConfig, + TorchCompileConfig, + VisualGenArgs, +) +from tensorrt_llm._torch.visual_gen.pipeline_loader import PipelineLoader + + +@pytest.fixture(autouse=True, scope="module") +def _cleanup_mpi_env(): + yield + os.environ.pop("TLLM_DISABLE_MPI", None) + + +# ============================================================================ +# Path helpers +# ============================================================================ + + +def _llm_models_root() -> str: + """Return LLM_MODELS_ROOT path if set in env, assert when it's set but not a valid path.""" + root = Path("/home/scratch.trt_llm_data_ci/llm-models/") + if "LLM_MODELS_ROOT" in os.environ: + root = Path(os.environ["LLM_MODELS_ROOT"]) + if not root.exists(): + root = Path("/scratch.trt_llm_data/llm-models/") + assert root.exists(), ( + "Set LLM_MODELS_ROOT or ensure /home/scratch.trt_llm_data_ci/llm-models/ is accessible." + ) + return str(root) + + +def _checkpoint(env_var: str, default_name: str) -> str: + return os.environ.get(env_var) or os.path.join(_llm_models_root(), default_name) + + +WAN21_1_3B_PATH = _checkpoint("DIFFUSION_MODEL_PATH_WAN21_1_3B", "Wan2.1-T2V-1.3B-Diffusers") +WAN21_14B_PATH = _checkpoint("DIFFUSION_MODEL_PATH_WAN21_14B", "Wan2.1-T2V-14B-Diffusers") + +# ============================================================================ +# Test constants +# ============================================================================ + +PROMPT = "A cat sitting on a sunny windowsill watching birds outside." +NEGATIVE_PROMPT = "" +NUM_STEPS = 4 +SEED = 42 +COS_SIM_THRESHOLD = 0.99 + + +# ============================================================================ +# Helpers +# ============================================================================ + + +def _load_trtllm_pipeline(checkpoint_path: str): + """Load TRTLLM WanPipeline without torch.compile or warmup.""" + if not os.path.exists(checkpoint_path): + pytest.skip(f"Checkpoint not found: {checkpoint_path}") + args = VisualGenArgs( + checkpoint_path=checkpoint_path, + device="cuda", + dtype="bfloat16", + torch_compile=TorchCompileConfig(enable_torch_compile=False), + ) + return PipelineLoader(args).load(skip_warmup=True) + + +def _load_hf_pipeline(checkpoint_path: str): + """Load HuggingFace diffusers pipeline (auto-detects class from model_index.json).""" + hf_pipe = DiffusionPipeline.from_pretrained( + checkpoint_path, + torch_dtype=torch.bfloat16, + ) + hf_pipe = hf_pipe.to("cuda") + hf_pipe.set_progress_bar_config(disable=True) + return hf_pipe + + +def _capture_trtllm_video( + pipeline, + prompt: str, + negative_prompt: str, + height: int, + width: int, + num_frames: int, + num_inference_steps: int, + guidance_scale: float, + seed: int, +) -> torch.Tensor: + """Run full TRTLLM pipeline including VAE decode; return (T, H, W, C) float in [0, 1].""" + with torch.no_grad(): + result = pipeline.forward( + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + num_frames=num_frames, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + seed=seed, + ) + video = result.video # (T, H, W, C) uint8 + return video.float() / 255.0 + + +def _capture_hf_video( + hf_pipe, + prompt: str, + negative_prompt: str, + height: int, + width: int, + num_frames: int, + num_inference_steps: int, + guidance_scale: float, + seed: int, +) -> torch.Tensor: + """Run HF pipeline with output_type='np'; return (T, H, W, C) float in [0, 1].""" + generator = torch.Generator(device="cuda").manual_seed(seed) + output = hf_pipe( + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + num_frames=num_frames, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + generator=generator, + output_type="np", + ) + frames = output.frames # (1, T, H, W, C) numpy float32 in [0, 1] + if isinstance(frames, np.ndarray): + return torch.from_numpy(frames[0]).float() + return frames[0].float() + + +def _cosine_similarity(a: torch.Tensor, b: torch.Tensor) -> float: + """Cosine similarity between two tensors (flattened to 1D, cast to float32 on CPU).""" + a_flat = a.float().cpu().reshape(-1) + b_flat = b.float().cpu().reshape(-1) + return F.cosine_similarity(a_flat.unsqueeze(0), b_flat.unsqueeze(0)).clamp(-1.0, 1.0).item() + + +def _assert_pipeline_matches_hf( + checkpoint_path: str, + height: int, + width: int, + num_frames: int, + guidance_scale: float, + model_label: str, +) -> None: + """Run TRTLLM and HF pipelines sequentially, compare decoded video output.""" + # --- TRTLLM --- + trtllm_pipe = _load_trtllm_pipeline(checkpoint_path) + trtllm_video = _capture_trtllm_video( + trtllm_pipe, + prompt=PROMPT, + negative_prompt=NEGATIVE_PROMPT, + height=height, + width=width, + num_frames=num_frames, + num_inference_steps=NUM_STEPS, + guidance_scale=guidance_scale, + seed=SEED, + ) + del trtllm_pipe + gc.collect() + torch.cuda.empty_cache() + + # --- HF reference --- + hf_pipe = _load_hf_pipeline(checkpoint_path) + hf_video = _capture_hf_video( + hf_pipe, + prompt=PROMPT, + negative_prompt=NEGATIVE_PROMPT, + height=height, + width=width, + num_frames=num_frames, + num_inference_steps=NUM_STEPS, + guidance_scale=guidance_scale, + seed=SEED, + ) + del hf_pipe + gc.collect() + torch.cuda.empty_cache() + + # --- Compare --- + assert trtllm_video.numel() == hf_video.numel(), ( + f"{model_label}: element count mismatch — " + f"TRTLLM {trtllm_video.shape} ({trtllm_video.numel()}) vs " + f"HF {hf_video.shape} ({hf_video.numel()})" + ) + + cos_sim = _cosine_similarity(trtllm_video, hf_video) + print(f"\n {model_label} cosine similarity: {cos_sim:.6f}") + assert cos_sim >= COS_SIM_THRESHOLD, ( + f"{model_label}: cosine similarity {cos_sim:.6f} < {COS_SIM_THRESHOLD}. " + f"TRTLLM pipeline output diverges from the HuggingFace reference. " + f"Video shapes — TRTLLM: {trtllm_video.shape}, HF: {hf_video.shape}." + ) + + +# ============================================================================ +# Tests +# ============================================================================ + + +@pytest.mark.integration +@pytest.mark.wan_t2v +class TestWan21_1_3B_PipelineCorrectness: + """Wan2.1-T2V-1.3B correctness vs HuggingFace reference (480x832, 33 frames).""" + + def test_cosine_similarity(self): + _assert_pipeline_matches_hf( + checkpoint_path=WAN21_1_3B_PATH, + height=480, + width=832, + num_frames=9, + guidance_scale=5.0, + model_label="Wan2.1-T2V-1.3B", + ) + + +@pytest.mark.integration +@pytest.mark.wan_t2v +class TestWan21_14B_PipelineCorrectness: + """Wan2.1-T2V-14B correctness vs HuggingFace reference (720x1280, 33 frames).""" + + def test_cosine_similarity(self): + _assert_pipeline_matches_hf( + checkpoint_path=WAN21_14B_PATH, + height=720, + width=1280, + num_frames=9, + guidance_scale=5.0, + model_label="Wan2.1-T2V-14B", + ) + + +# ============================================================================= +# Batch Generation Tests +# ============================================================================= + + +class TestWanBatchGeneration: + """Batch generation tests for WAN T2V pipeline. + + Tests that passing a list of prompts produces batched output + and matches sequential generation with the same seeds. + """ + + @pytest.fixture(scope="class") + def wan21_full_pipeline(self): + """Load full Wan 2.1 pipeline (all components) for batch tests.""" + if not WAN21_1_3B_PATH or not os.path.exists(WAN21_1_3B_PATH): + pytest.skip("Checkpoint not available. Set DIFFUSION_MODEL_PATH_WAN21_1_3B.") + + args = VisualGenArgs( + checkpoint_path=WAN21_1_3B_PATH, + device="cuda", + dtype="bfloat16", + torch_compile=TorchCompileConfig(enable_torch_compile=False), + ) + pipeline = PipelineLoader(args).load(skip_warmup=True) + yield pipeline + del pipeline + import gc + + gc.collect() + torch.cuda.empty_cache() + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + def test_single_prompt_backward_compat(self, wan21_full_pipeline): + """Single prompt returns (B, T, H, W, C) for backward compatibility.""" + result = wan21_full_pipeline.forward( + prompt="a cat walking", + height=480, + width=832, + num_frames=9, + num_inference_steps=4, + guidance_scale=5.0, + seed=42, + ) + assert result.video.dim() == 5, f"Expected 5D (B,T,H,W,C), got {result.video.dim()}D" + B, _T, H, W, C = result.video.shape + assert B == 1 and H == 480 and W == 832 and C == 3 + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + def test_batch_prompt_shape(self, wan21_full_pipeline): + """List of prompts returns (B, T, H, W, C).""" + prompts = ["a sunset over mountains", "a cat on a roof"] + result = wan21_full_pipeline.forward( + prompt=prompts, + height=480, + width=832, + num_frames=9, + num_inference_steps=4, + guidance_scale=5.0, + seed=42, + ) + assert result.video.dim() == 5, f"Expected 5D (B,T,H,W,C), got {result.video.dim()}D" + B, _T, H, W, C = result.video.shape + assert B == 2 and H == 480 and W == 832 and C == 3 + + +# ============================================================================= +# Combined Optimization Tests +# ============================================================================= + + +@pytest.mark.integration +@pytest.mark.wan_t2v +class TestWan21T2VCombinedOptimizations: + """FP8 + TeaCache + TRTLLM attention combined on Wan 2.1 T2V (1.3B, 480x832).""" + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + def test_fp8_teacache_trtllm(self): + if not os.path.exists(WAN21_1_3B_PATH): + pytest.skip(f"Checkpoint not found: {WAN21_1_3B_PATH}") + args = VisualGenArgs( + checkpoint_path=WAN21_1_3B_PATH, + device="cuda", + dtype="bfloat16", + torch_compile=TorchCompileConfig(enable_torch_compile=False), + quant_config={"quant_algo": "FP8", "dynamic": True}, + attention=AttentionConfig(backend="TRTLLM"), + cache=TeaCacheConfig(teacache_thresh=0.2), + ) + pipeline = PipelineLoader(args).load(skip_warmup=True) + try: + with torch.no_grad(): + result = pipeline.forward( + prompt="a cat sitting on a windowsill", + negative_prompt="", + height=480, + width=832, + num_frames=9, + num_inference_steps=10, + guidance_scale=5.0, + seed=42, + ) + assert result.video.dim() == 5 + B, _T, H, W, C = result.video.shape + assert B == 1 and H == 480 and W == 832 and C == 3 + + assert pipeline.cache_accelerator is not None + assert pipeline.cache_accelerator.is_enabled() + stats = pipeline.cache_accelerator.get_stats() + assert stats["cached_steps"] > 0, f"No TeaCache hits with FP8+TRTLLM. Stats: {stats}" + finally: + del pipeline + gc.collect() + torch.cuda.empty_cache() + + +# ============================================================================= +# Quantization / dtype feature tests (transformer only, module-scoped fixtures) +# ============================================================================= + +_SKIP_AUX = ["text_encoder", "vae", "tokenizer", "scheduler"] + + +def _make_wan21_t2v(quant_config=None): + if not os.path.exists(WAN21_1_3B_PATH): + pytest.skip(f"Checkpoint not found: {WAN21_1_3B_PATH}") + kwargs = dict( + checkpoint_path=WAN21_1_3B_PATH, + device="cuda", + dtype="bfloat16", + skip_components=_SKIP_AUX, + torch_compile=TorchCompileConfig(enable_torch_compile=False), + ) + if quant_config is not None: + kwargs["quant_config"] = quant_config + return PipelineLoader(VisualGenArgs(**kwargs)).load(skip_warmup=True) + + +@pytest.fixture(scope="module") +def wan21_t2v_bf16(): + pipeline = _make_wan21_t2v() + yield pipeline + del pipeline + gc.collect() + torch.cuda.empty_cache() + + +@pytest.fixture(scope="module") +def wan21_t2v_fp8(): + pipeline = _make_wan21_t2v({"quant_algo": "FP8", "dynamic": True}) + yield pipeline + del pipeline + gc.collect() + torch.cuda.empty_cache() + + +@pytest.fixture(scope="module") +def wan21_t2v_fp8_block(): + pipeline = _make_wan21_t2v({"quant_algo": "FP8_BLOCK_SCALES", "dynamic": True}) + yield pipeline + del pipeline + gc.collect() + torch.cuda.empty_cache() + + +@pytest.fixture(scope="module") +def wan21_t2v_nvfp4(): + pipeline = _make_wan21_t2v({"quant_algo": "NVFP4", "dynamic": True}) + yield pipeline + del pipeline + gc.collect() + torch.cuda.empty_cache() + + +def _transformer_inputs(device: str = "cuda"): + torch.manual_seed(42) + return ( + torch.randn(1, 16, 1, 64, 64, dtype=torch.bfloat16, device=device), + torch.tensor([500], dtype=torch.long, device=device), + torch.randn(1, 128, 4096, dtype=torch.bfloat16, device=device), + ) + + +def _is_fp32_layernorm_param(name: str) -> bool: + if not name.endswith((".weight", ".bias")): + return False + if ".norm" in name and "blocks." in name: + return any(p in name.split(".") for p in ("norm1", "norm2", "norm3")) + if name in ("norm_out.weight", "norm_out.bias"): + return True + if name.startswith("condition_embedder.") and ".norm" in name: + return True + return False + + +@pytest.mark.integration +@pytest.mark.wan_t2v +class TestWan21T2VPipelineFeatures: + """Quantization loading, dtype layout, numerical accuracy, and memory for Wan 2.1 T2V.""" + + def test_parameter_dtypes(self, wan21_t2v_bf16): + """BF16 pipeline: CUDA tensors, FP32 LayerNorms, BF16 everything else.""" + bf16_count = 0 + for name, param in wan21_t2v_bf16.transformer.named_parameters(): + assert param.device.type == "cuda", f"{name} not on CUDA" + if _is_fp32_layernorm_param(name): + assert param.dtype == torch.float32, f"{name}: expected float32, got {param.dtype}" + elif "scale" not in name.lower(): + assert param.dtype == torch.bfloat16, ( + f"{name}: expected bfloat16, got {param.dtype}" + ) + bf16_count += 1 + assert bf16_count > 0, "No BF16 parameters found" + + def test_fp8_weights_loaded(self, wan21_t2v_fp8): + """FP8 transformer blocks have float8_e4m3fn weights and weight_scale.""" + try: + if not hasattr(torch.ops, "tensorrt_llm"): + pytest.skip("tensorrt_llm torch ops not available") + _ = torch.ops.tensorrt_llm.quantize_e4m3_per_tensor + _ = torch.ops.tensorrt_llm.quantize_e4m3_activation + except (AttributeError, RuntimeError) as e: + pytest.skip(f"FP8 quantization ops not available: {e}") + for name, module in wan21_t2v_fp8.transformer.named_modules(): + if isinstance(module, Linear) and "blocks." in name: + assert module.weight.dtype == torch.float8_e4m3fn, ( + f"{name}: expected float8_e4m3fn, got {module.weight.dtype}" + ) + assert hasattr(module, "weight_scale"), f"{name}: missing weight_scale" + return + pytest.fail("No FP8 Linear found in transformer blocks") + + def test_fp8_block_scales_weights_loaded(self, wan21_t2v_fp8_block): + """FP8_BLOCK_SCALES transformer blocks have float8_e4m3fn weights and weight_scale.""" + try: + if not hasattr(torch.ops, "tensorrt_llm"): + pytest.skip("tensorrt_llm torch ops not available") + _ = torch.ops.tensorrt_llm.quantize_e4m3_per_tensor + _ = torch.ops.tensorrt_llm.quantize_e4m3_activation + except (AttributeError, RuntimeError) as e: + pytest.skip(f"FP8 quantization ops not available: {e}") + for name, module in wan21_t2v_fp8_block.transformer.named_modules(): + if isinstance(module, Linear) and "blocks." in name: + assert module.weight.dtype == torch.float8_e4m3fn, ( + f"{name}: expected float8_e4m3fn, got {module.weight.dtype}" + ) + assert hasattr(module, "weight_scale"), f"{name}: missing weight_scale" + return + pytest.fail("No FP8_BLOCK_SCALES Linear found in transformer blocks") + + def test_nvfp4_weights_loaded(self, wan21_t2v_nvfp4): + """NVFP4 transformer blocks have packed FP4 weights with two-level scale.""" + if torch.cuda.get_device_capability(0) < (10, 0): + pytest.skip("NVFP4 requires SM>=10.0 (Blackwell+)") + try: + _ = torch.ops.trtllm.fp4_quantize + except (AttributeError, RuntimeError) as e: + pytest.skip(f"fp4_quantize op not available: {e}") + from tensorrt_llm.quantization.utils import fp4_utils + + for name, module in wan21_t2v_nvfp4.transformer.named_modules(): + if isinstance(module, Linear) and "blocks." in name: + assert module.weight.dtype == fp4_utils.float4_e2m1x2, ( + f"{name}: expected float4_e2m1x2, got {module.weight.dtype}" + ) + assert hasattr(module, "weight_scale"), f"{name}: missing weight_scale" + assert hasattr(module, "weight_scale_2"), f"{name}: missing weight_scale_2" + return + pytest.fail("No NVFP4 Linear found in transformer blocks") + + def test_fp8_single_layer_accuracy(self, wan21_t2v_bf16, wan21_t2v_fp8): + """FP8 qkv_proj output matches BF16 F.linear reference (cos_sim > 0.99).""" + linear_bf16 = wan21_t2v_bf16.transformer.blocks[0].attn1.qkv_proj + linear_fp8 = wan21_t2v_fp8.transformer.blocks[0].attn1.qkv_proj + + weight = linear_bf16.weight.data.clone() + bias = linear_bf16.bias.data.clone() if linear_bf16.bias is not None else None + x = torch.randn( + 1024, + linear_bf16.in_features, + dtype=torch.bfloat16, + device="cuda", + generator=torch.Generator("cuda").manual_seed(42), + ) + + with torch.no_grad(): + ref = torch.nn.functional.linear(x, weight, bias) + fp8_out = linear_fp8(x) + + cos_sim = torch.nn.functional.cosine_similarity( + fp8_out.flatten().float(), ref.flatten().float(), dim=0 + ).item() + mse = torch.nn.functional.mse_loss(fp8_out.float(), ref.float()).item() + print(f"\n FP8 qkv_proj: cos_sim={cos_sim:.6f}, mse={mse:.6f}") + assert cos_sim > 0.99, f"cos_sim too low: {cos_sim:.6f}" + assert mse < 1.0, f"MSE too high: {mse:.6f}" + + def test_fp8_memory_savings(self, wan21_t2v_bf16, wan21_t2v_fp8): + """FP8 transformer uses ~2x less parameter memory than BF16.""" + + def _mem_gb(pipeline): + return ( + sum(p.numel() * p.element_size() for p in pipeline.transformer.parameters()) + / 1024**3 + ) + + bf16_gb = _mem_gb(wan21_t2v_bf16) + fp8_gb = _mem_gb(wan21_t2v_fp8) + ratio = bf16_gb / fp8_gb + print(f"\n BF16={bf16_gb:.3f} GB, FP8={fp8_gb:.3f} GB, ratio={ratio:.2f}x") + assert ratio > 1.8, f"Expected ~2x savings, got {ratio:.2f}x" + + @pytest.mark.parametrize( + "quant_name,pipe_fixture", + [ + ("FP8", "wan21_t2v_fp8"), + ("FP8_BLOCK_SCALES", "wan21_t2v_fp8_block"), + ], + ) + def test_fp8_e2e_accuracy( + self, wan21_t2v_bf16, wan21_t2v_fp8, wan21_t2v_fp8_block, quant_name, pipe_fixture + ): + """FP8/FP8_BLOCK_SCALES full-transformer output close to BF16 (cos_sim > 0.99).""" + quant_pipeline = wan21_t2v_fp8 if pipe_fixture == "wan21_t2v_fp8" else wan21_t2v_fp8_block + hs, ts, enc = _transformer_inputs() + + with torch.no_grad(): + out_bf16 = wan21_t2v_bf16.transformer( + hidden_states=hs.clone(), timestep=ts, encoder_hidden_states=enc.clone() + ).float() + out_quant = quant_pipeline.transformer( + hidden_states=hs.clone(), timestep=ts, encoder_hidden_states=enc.clone() + ).float() + + assert not torch.isnan(out_bf16).any(), "BF16 output contains NaN" + assert not torch.isinf(out_bf16).any(), "BF16 output contains Inf" + assert not torch.isnan(out_quant).any(), f"{quant_name} output contains NaN" + assert not torch.isinf(out_quant).any(), f"{quant_name} output contains Inf" + + cos_sim = torch.nn.functional.cosine_similarity( + out_quant.flatten(), out_bf16.flatten(), dim=0 + ).item() + mse = torch.nn.functional.mse_loss(out_quant, out_bf16).item() + print( + f"\n {quant_name} E2E ({len(wan21_t2v_bf16.transformer.blocks)} layers): " + f"cos_sim={cos_sim:.6f}, mse={mse:.6f}" + ) + assert cos_sim > 0.99, f"cos_sim too low: {cos_sim:.6f}" + + def test_nvfp4_e2e_accuracy(self, wan21_t2v_bf16, wan21_t2v_nvfp4): + """NVFP4 full-transformer output close to BF16 (cos_sim > 0.95).""" + if torch.cuda.get_device_capability(0) < (10, 0): + pytest.skip("NVFP4 requires SM>=10.0 (Blackwell+)") + try: + _ = torch.ops.trtllm.fp4_quantize + except (AttributeError, RuntimeError) as e: + pytest.skip(f"fp4_quantize op not available: {e}") + + hs, ts, enc = _transformer_inputs() + + with torch.no_grad(): + out_bf16 = wan21_t2v_bf16.transformer( + hidden_states=hs.clone(), timestep=ts, encoder_hidden_states=enc.clone() + ).float() + out_nvfp4 = wan21_t2v_nvfp4.transformer( + hidden_states=hs.clone(), timestep=ts, encoder_hidden_states=enc.clone() + ).float() + + assert not torch.isnan(out_nvfp4).any(), "NVFP4 output contains NaN" + assert not torch.isinf(out_nvfp4).any(), "NVFP4 output contains Inf" + + cos_sim = torch.nn.functional.cosine_similarity( + out_nvfp4.flatten(), out_bf16.flatten(), dim=0 + ).item() + mse = torch.nn.functional.mse_loss(out_nvfp4, out_bf16).item() + print( + f"\n NVFP4 E2E ({len(wan21_t2v_bf16.transformer.blocks)} layers): " + f"cos_sim={cos_sim:.6f}, mse={mse:.6f}" + ) + assert cos_sim > 0.95, f"NVFP4 cos_sim too low: {cos_sim:.6f}" diff --git a/tests/unittest/_torch/visual_gen/test_wan21_t2v_teacache.py b/tests/unittest/_torch/visual_gen/test_wan21_t2v_teacache.py new file mode 100644 index 000000000000..5513c330c4c1 --- /dev/null +++ b/tests/unittest/_torch/visual_gen/test_wan21_t2v_teacache.py @@ -0,0 +1,242 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for Wan T2V TeaCache using the full pipeline. + +Wan 2.1 models are tested with TeaCache enabled: + - Wan2.1-T2V-1.3B-Diffusers 480x832 single-stage + - Wan2.1-T2V-14B-Diffusers 720x1280 single-stage + +Wan 2.2 is tested to confirm that enabling TeaCache raises a ValueError. + +Each test loads all pipeline components (VAE, text encoder, scheduler) and calls +pipeline.forward() so TeaCache runs on the actual scheduler timesteps. + +Run all: + pytest tests/unittest/_torch/visual_gen/test_wan21_t2v_teacache.py -v -s + +Run one model: + pytest tests/unittest/_torch/visual_gen/test_wan21_t2v_teacache.py -v -s -k wan21_1_3b + pytest tests/unittest/_torch/visual_gen/test_wan21_t2v_teacache.py -v -s -k wan21_14b + pytest tests/unittest/_torch/visual_gen/test_wan21_t2v_teacache.py -v -s -k wan22_raises + +Override checkpoint paths: + DIFFUSION_MODEL_PATH_WAN21_1_3B=/path/to/1.3b \\ + DIFFUSION_MODEL_PATH_WAN21_14B=/path/to/14b \\ + DIFFUSION_MODEL_PATH_WAN22_T2V=/path/to/wan22 \\ + pytest tests/unittest/_torch/visual_gen/test_wan21_t2v_teacache.py -v -s +""" + +import os + +os.environ["TLLM_DISABLE_MPI"] = "1" + +import gc +from pathlib import Path + +import pytest +import torch + +from tensorrt_llm._torch.visual_gen.config import TeaCacheConfig, VisualGenArgs +from tensorrt_llm._torch.visual_gen.pipeline_loader import PipelineLoader + + +@pytest.fixture(autouse=True, scope="module") +def _cleanup_mpi_env(): + yield + os.environ.pop("TLLM_DISABLE_MPI", None) + + +@pytest.fixture(autouse=True) +def _cleanup_gpu(): + gc.collect() + torch.cuda.empty_cache() + yield + gc.collect() + torch.cuda.empty_cache() + + +# ============================================================================ +# Path helpers +# ============================================================================ + + +def _llm_models_root() -> Path: + if "LLM_MODELS_ROOT" in os.environ: + root = Path(os.environ["LLM_MODELS_ROOT"]) + else: + root = Path("/home/scratch.trt_llm_data_ci/llm-models/") + if not root.exists(): + root = Path("/scratch.trt_llm_data/llm-models/") + assert root.exists(), ( + "Set LLM_MODELS_ROOT or ensure /home/scratch.trt_llm_data_ci/llm-models/ is accessible." + ) + return root + + +def _checkpoint(env_var: str, default_name: str) -> str: + return os.environ.get(env_var) or str(_llm_models_root() / default_name) + + +WAN21_1_3B_PATH = _checkpoint("DIFFUSION_MODEL_PATH_WAN21_1_3B", "Wan2.1-T2V-1.3B-Diffusers") +WAN21_14B_PATH = _checkpoint("DIFFUSION_MODEL_PATH_WAN21_14B", "Wan2.1-T2V-14B-Diffusers") +WAN22_A14B_PATH = _checkpoint("DIFFUSION_MODEL_PATH_WAN22_T2V", "Wan2.2-T2V-A14B-Diffusers") + +INFER_NUM_FRAMES = 33 # (33-1)/4+1 = 9 latent frames; smallest realistic shape +INFER_NUM_STEPS = 50 # Required for meaningful cache hits with calibrated coefficients +INFER_SEED = 42 + + +# ============================================================================ +# Pipeline fixture factory +# ============================================================================ + + +def _make_pipeline(checkpoint_path: str, use_ret_steps: bool = False): + if not os.path.exists(checkpoint_path): + pytest.skip(f"Checkpoint not found: {checkpoint_path}") + args = VisualGenArgs( + checkpoint_path=checkpoint_path, + device="cuda", + dtype="bfloat16", + cache=TeaCacheConfig( + teacache_thresh=0.2, + use_ret_steps=use_ret_steps, + ), + ) + pipeline = PipelineLoader(args).load(skip_warmup=True) + return pipeline + + +@pytest.fixture +def wan21_1_3b_pipeline(): + pipeline = _make_pipeline(WAN21_1_3B_PATH) + yield pipeline + del pipeline + torch.cuda.empty_cache() + + +@pytest.fixture +def wan21_1_3b_ret_steps_pipeline(): + pipeline = _make_pipeline(WAN21_1_3B_PATH, use_ret_steps=True) + yield pipeline + del pipeline + torch.cuda.empty_cache() + + +@pytest.fixture +def wan21_14b_pipeline(): + pipeline = _make_pipeline(WAN21_14B_PATH) + yield pipeline + del pipeline + torch.cuda.empty_cache() + + +# ============================================================================ +# Shared assertion helper +# ============================================================================ + + +def _assert_single_stage_teacache( + pipeline, + height: int, + width: int, + model: str = "", + expected_hit_rate: float = None, + atol: float = 0.02, +) -> None: + """Run forward and verify TeaCache produces cache hits (single-stage Wan 2.1).""" + with torch.no_grad(): + pipeline.forward( + prompt="a cat sitting on a windowsill", + negative_prompt="", + height=height, + width=width, + num_frames=INFER_NUM_FRAMES, + num_inference_steps=INFER_NUM_STEPS, + seed=INFER_SEED, + ) + + stats = pipeline.transformer_cache_backend.get_stats() + + print(f"\n ===== TeaCache — Wan 2.1 {model} single-stage {height}x{width} =====") + print( + f" transformer: {stats['cached_steps']}/{stats['total_steps']} cached " + f"({stats['hit_rate']:.1%} hit rate)" + ) + if expected_hit_rate is not None: + # Reference hit rates derived from vFly reference runs + print(f" expected: {expected_hit_rate:.1%} (vFly reference, atol={atol:.0%})") + delta = stats["hit_rate"] - expected_hit_rate + print(f" delta: {delta:+.1%}") + print(" ================================================================") + + assert stats["total_steps"] == INFER_NUM_STEPS, ( + f"total_steps {stats['total_steps']} != {INFER_NUM_STEPS}" + ) + assert stats["compute_steps"] + stats["cached_steps"] == stats["total_steps"] + assert stats["cached_steps"] > 0, ( + f"0 cache hits after {stats['total_steps']} steps. TeaCache is not working. Stats: {stats}" + ) + if expected_hit_rate is not None: + assert abs(stats["hit_rate"] - expected_hit_rate) <= atol + 1e-9, ( + f"Hit rate {stats['hit_rate']:.1%} not within {atol:.0%} " + f"of expected {expected_hit_rate:.1%} (vFly reference)" + ) + + +# ============================================================================ +# Tests +# ============================================================================ + + +@pytest.mark.integration +@pytest.mark.wan_t2v +@pytest.mark.teacache +class TestWan21_1_3B_TeaCache: + """Wan2.1-T2V-1.3B 480x832 single-stage.""" + + def test_wan21_1_3b_teacache(self, wan21_1_3b_pipeline): + _assert_single_stage_teacache( + wan21_1_3b_pipeline, height=480, width=832, model="T2V-1.3B", expected_hit_rate=0.68 + ) + + def test_wan21_1_3b_teacache_ret_steps(self, wan21_1_3b_ret_steps_pipeline): + _assert_single_stage_teacache( + wan21_1_3b_ret_steps_pipeline, + height=480, + width=832, + model="T2V-1.3B", + expected_hit_rate=0.72, + ) + + +@pytest.mark.integration +@pytest.mark.wan_t2v +@pytest.mark.teacache +class TestWan21_14B_TeaCache: + """Wan2.1-T2V-14B 720x1280 single-stage.""" + + def test_wan21_14b_teacache(self, wan21_14b_pipeline): + _assert_single_stage_teacache( + wan21_14b_pipeline, height=720, width=1280, model="T2V-14B", expected_hit_rate=0.48 + ) + + +@pytest.mark.integration +@pytest.mark.wan_t2v +@pytest.mark.teacache +class TestWan22_T2V_TeaCacheRaisesError: + """Wan2.2-T2V-A14B must raise ValueError when TeaCache is enabled.""" + + def test_wan22_raises_if_teacache_enabled(self): + if not os.path.exists(WAN22_A14B_PATH): + pytest.skip(f"Checkpoint not found: {WAN22_A14B_PATH}") + args = VisualGenArgs( + checkpoint_path=WAN22_A14B_PATH, + device="cuda", + dtype="bfloat16", + cache=TeaCacheConfig(), + ) + with pytest.raises(ValueError, match=r"TeaCache is not supported for Wan 2\.2"): + PipelineLoader(args).load(skip_warmup=True) diff --git a/tests/unittest/_torch/visual_gen/test_wan22_i2v_pipeline.py b/tests/unittest/_torch/visual_gen/test_wan22_i2v_pipeline.py new file mode 100644 index 000000000000..f05402365145 --- /dev/null +++ b/tests/unittest/_torch/visual_gen/test_wan22_i2v_pipeline.py @@ -0,0 +1,502 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Correctness tests for Wan 2.2 I2V pipeline against HuggingFace reference. + +Tests verify that the TRTLLM WanImageToVideoPipeline (two-stage I2V) produces +decoded video with >= 0.98 cosine similarity to the HuggingFace diffusers +WanImageToVideoPipeline baseline. The threshold is 0.98 (not 0.99) because +I2V image conditioning (VAE-encoded image + mask concatenated to the latent) +accumulates additional bfloat16 error compared to T2V. + +Key differences from Wan 2.1 I2V: + - Two-stage denoising: transformer (high-noise) + transformer_2 (low-noise) + - No CLIP image encoder: image conditioning is purely via VAE-encoded latent + - boundary_ratio controls the timestep split between stages + +Model tested: + - Wan2.2-I2V-A14B (480x832, 33 frames) + +Run: + pytest tests/unittest/_torch/visual_gen/test_wan22_i2v_pipeline.py -v -s + +Override checkpoint path: + DIFFUSION_MODEL_PATH_WAN22_I2V=/path/to/wan22_i2v \\ + pytest tests/unittest/_torch/visual_gen/test_wan22_i2v_pipeline.py -v -s +""" + +import importlib +import os + +os.environ["TLLM_DISABLE_MPI"] = "1" + +import gc +from pathlib import Path + +import numpy as np +import pytest +import torch +import torch.nn.functional as F +from diffusers import DiffusionPipeline +from PIL import Image + +from tensorrt_llm._torch.visual_gen.config import ( + AttentionConfig, + CacheDiTConfig, + TorchCompileConfig, + VisualGenArgs, +) +from tensorrt_llm._torch.visual_gen.pipeline_loader import PipelineLoader + + +@pytest.fixture(autouse=True, scope="module") +def _cleanup_mpi_env(): + yield + os.environ.pop("TLLM_DISABLE_MPI", None) + + +# ============================================================================ +# Path helpers +# ============================================================================ + + +def _llm_models_root() -> str: + """Return LLM_MODELS_ROOT path if set in env, assert when it's set but not a valid path.""" + root = Path("/home/scratch.trt_llm_data_ci/llm-models/") + if "LLM_MODELS_ROOT" in os.environ: + root = Path(os.environ["LLM_MODELS_ROOT"]) + if not root.exists(): + root = Path("/scratch.trt_llm_data/llm-models/") + assert root.exists(), ( + "Set LLM_MODELS_ROOT or ensure /home/scratch.trt_llm_data_ci/llm-models/ is accessible." + ) + return str(root) + + +def _checkpoint(env_var: str, default_name: str) -> str: + return os.environ.get(env_var) or os.path.join(_llm_models_root(), default_name) + + +WAN22_I2V_PATH = _checkpoint("DIFFUSION_MODEL_PATH_WAN22_I2V", "Wan2.2-I2V-A14B-Diffusers") + +# ============================================================================ +# Test constants +# ============================================================================ + +PROMPT = "A cat sitting on a sunny windowsill watching birds outside." +NEGATIVE_PROMPT = "" +NUM_STEPS = 4 +SEED = 42 +COS_SIM_THRESHOLD = 0.98 + + +# ============================================================================ +# Helpers +# ============================================================================ + + +def _make_test_image(height: int, width: int) -> Image.Image: + """Create a deterministic gradient test image.""" + img = np.zeros((height, width, 3), dtype=np.uint8) + img[:, :, 0] = np.linspace(0, 255, height, dtype=np.uint8)[:, np.newaxis] + img[:, :, 1] = np.linspace(0, 255, width, dtype=np.uint8)[np.newaxis, :] + img[:, :, 2] = 128 + return Image.fromarray(img, mode="RGB") + + +def _load_trtllm_pipeline(checkpoint_path: str): + """Load TRTLLM WanImageToVideoPipeline (two-stage) without torch.compile or warmup.""" + if not os.path.exists(checkpoint_path): + pytest.skip(f"Checkpoint not found: {checkpoint_path}") + args = VisualGenArgs( + checkpoint_path=checkpoint_path, + device="cuda", + dtype="bfloat16", + torch_compile=TorchCompileConfig(enable_torch_compile=False), + ) + return PipelineLoader(args).load(skip_warmup=True) + + +def _load_hf_pipeline(checkpoint_path: str): + """Load HuggingFace diffusers pipeline (auto-detects class from model_index.json).""" + hf_pipe = DiffusionPipeline.from_pretrained( + checkpoint_path, + torch_dtype=torch.bfloat16, + ) + hf_pipe = hf_pipe.to("cuda") + hf_pipe.set_progress_bar_config(disable=True) + return hf_pipe + + +def _capture_trtllm_video( + pipeline, + image: Image.Image, + prompt: str, + negative_prompt: str, + height: int, + width: int, + num_frames: int, + num_inference_steps: int, + guidance_scale: float, + seed: int, +) -> torch.Tensor: + """Run full TRTLLM pipeline including VAE decode; return (T, H, W, C) float in [0, 1].""" + with torch.no_grad(): + result = pipeline.forward( + image=image, + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + num_frames=num_frames, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + seed=seed, + ) + video = result.video # (T, H, W, C) uint8 + return video.float() / 255.0 + + +def _capture_hf_video( + hf_pipe, + image: Image.Image, + prompt: str, + negative_prompt: str, + height: int, + width: int, + num_frames: int, + num_inference_steps: int, + guidance_scale: float, + seed: int, +) -> torch.Tensor: + """Run HF pipeline with output_type='np'; return (T, H, W, C) float in [0, 1].""" + generator = torch.Generator(device="cuda").manual_seed(seed) + output = hf_pipe( + image=image, + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + num_frames=num_frames, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + generator=generator, + output_type="np", + ) + frames = output.frames # (1, T, H, W, C) numpy float32 in [0, 1] + if isinstance(frames, np.ndarray): + return torch.from_numpy(frames[0]).float() + return frames[0].float() + + +def _cosine_similarity(a: torch.Tensor, b: torch.Tensor) -> float: + """Cosine similarity between two tensors (flattened to 1D, cast to float32 on CPU).""" + a_flat = a.float().cpu().reshape(-1) + b_flat = b.float().cpu().reshape(-1) + return F.cosine_similarity(a_flat.unsqueeze(0), b_flat.unsqueeze(0)).clamp(-1.0, 1.0).item() + + +def _assert_pipeline_matches_hf( + checkpoint_path: str, + height: int, + width: int, + num_frames: int, + guidance_scale: float, + model_label: str, +) -> None: + """Run TRTLLM and HF pipelines sequentially, compare decoded video output.""" + test_image = _make_test_image(height, width) + + # --- TRTLLM --- + trtllm_pipe = _load_trtllm_pipeline(checkpoint_path) + + # Confirm two-stage denoising is active (Wan 2.2 specific sanity check) + assert trtllm_pipe.transformer_2 is not None, ( + f"{model_label}: expected two-stage pipeline (transformer_2 should not be None). " + "Check that the checkpoint is a Wan 2.2 model with boundary_ratio in model_index.json." + ) + assert trtllm_pipe.boundary_ratio is not None, ( + f"{model_label}: boundary_ratio is None — two-stage denoising will not activate. " + "Check model_index.json for 'boundary_ratio' key." + ) + + trtllm_video = _capture_trtllm_video( + trtllm_pipe, + image=test_image, + prompt=PROMPT, + negative_prompt=NEGATIVE_PROMPT, + height=height, + width=width, + num_frames=num_frames, + num_inference_steps=NUM_STEPS, + guidance_scale=guidance_scale, + seed=SEED, + ) + del trtllm_pipe + gc.collect() + torch.cuda.empty_cache() + + # --- HF reference --- + hf_pipe = _load_hf_pipeline(checkpoint_path) + hf_video = _capture_hf_video( + hf_pipe, + image=test_image, + prompt=PROMPT, + negative_prompt=NEGATIVE_PROMPT, + height=height, + width=width, + num_frames=num_frames, + num_inference_steps=NUM_STEPS, + guidance_scale=guidance_scale, + seed=SEED, + ) + del hf_pipe + gc.collect() + torch.cuda.empty_cache() + + # --- Compare --- + assert trtllm_video.numel() == hf_video.numel(), ( + f"{model_label}: element count mismatch — " + f"TRTLLM {trtllm_video.shape} ({trtllm_video.numel()}) vs " + f"HF {hf_video.shape} ({hf_video.numel()})" + ) + + cos_sim = _cosine_similarity(trtllm_video, hf_video) + print(f"\n {model_label} cosine similarity: {cos_sim:.6f}") + assert cos_sim >= COS_SIM_THRESHOLD, ( + f"{model_label}: cosine similarity {cos_sim:.6f} < {COS_SIM_THRESHOLD}. " + f"TRTLLM pipeline output diverges from the HuggingFace reference. " + f"Video shapes — TRTLLM: {trtllm_video.shape}, HF: {hf_video.shape}." + ) + + +# ============================================================================ +# Tests +# ============================================================================ + + +@pytest.mark.integration +@pytest.mark.wan_i2v +class TestWan22_I2V_A14B_PipelineCorrectness: + """Wan2.2-I2V-A14B correctness vs HuggingFace reference (480x832, 33 frames). + + Wan 2.2 I2V uses two-stage denoising without CLIP: transformer handles + high-noise timesteps (t >= boundary_timestep) and transformer_2 handles + low-noise timesteps. Image conditioning is via VAE-encoded latent only. + The test verifies the combined two-stage output matches the HF reference. + """ + + def test_cosine_similarity(self): + _assert_pipeline_matches_hf( + checkpoint_path=WAN22_I2V_PATH, + height=480, + width=832, + num_frames=9, + guidance_scale=4.0, + model_label="Wan2.2-I2V-A14B", + ) + + +# ============================================================================ +# Two-stage feature fixtures (skip aux components, loaded once per module) +# ============================================================================ + +_SKIP_AUX = ["text_encoder", "vae", "tokenizer", "scheduler", "image_encoder", "image_processor"] + + +def _make_wan22_i2v(quant_config=None, attention=None): + if not os.path.exists(WAN22_I2V_PATH): + pytest.skip(f"Checkpoint not found: {WAN22_I2V_PATH}") + kwargs = dict( + checkpoint_path=WAN22_I2V_PATH, + device="cuda", + dtype="bfloat16", + skip_components=_SKIP_AUX, + torch_compile=TorchCompileConfig(enable_torch_compile=False), + ) + if quant_config is not None: + kwargs["quant_config"] = quant_config + if attention is not None: + kwargs["attention"] = attention + return PipelineLoader(VisualGenArgs(**kwargs)).load(skip_warmup=True) + + +@pytest.fixture(scope="module") +def wan22_i2v_fp8(): + pipeline = _make_wan22_i2v(quant_config={"quant_algo": "FP8", "dynamic": True}) + yield pipeline + del pipeline + gc.collect() + torch.cuda.empty_cache() + + +@pytest.fixture(scope="module") +def wan22_i2v_trtllm(): + pipeline = _make_wan22_i2v(attention=AttentionConfig(backend="TRTLLM")) + yield pipeline + del pipeline + gc.collect() + torch.cuda.empty_cache() + + +# ============================================================================ +# Tests +# ============================================================================ + + +@pytest.mark.integration +@pytest.mark.wan_i2v +class TestWan22TwoStageI2VFeatures: + """Verify that FP8 and TRTLLM attention apply to both transformer stages in Wan 2.2 I2V.""" + + def test_fp8_on_both_stages(self, wan22_i2v_fp8): + """FP8 quantization is applied to both transformer and transformer_2.""" + if wan22_i2v_fp8.transformer_2 is None: + pytest.skip("Not a two-stage checkpoint") + + def _has_fp8(module): + return any( + p.dtype == torch.float8_e4m3fn + for name, p in module.named_parameters() + if "blocks.0" in name and "weight" in name + ) + + assert _has_fp8(wan22_i2v_fp8.transformer), "No FP8 weights in transformer" + assert _has_fp8(wan22_i2v_fp8.transformer_2), "No FP8 weights in transformer_2" + + def test_trtllm_attention_both_stages(self, wan22_i2v_trtllm): + """TRTLLM self-attention and VANILLA cross-attention on both stages.""" + if wan22_i2v_trtllm.transformer_2 is None: + pytest.skip("Not a two-stage checkpoint") + + for stage_name, transformer in [ + ("transformer", wan22_i2v_trtllm.transformer), + ("transformer_2", wan22_i2v_trtllm.transformer_2), + ]: + b0 = transformer.blocks[0] + assert b0.attn1.attn_backend == "TRTLLM", ( + f"{stage_name} self-attn: expected TRTLLM, got {b0.attn1.attn_backend}" + ) + assert b0.attn2.attn_backend == "VANILLA", ( + f"{stage_name} cross-attn should fall back to VANILLA, got {b0.attn2.attn_backend}" + ) + + +# ============================================================================= +# Batch Generation Tests +# ============================================================================= + + +class TestWan22I2VBatchGeneration: + """Batch generation tests for Wan 2.2 I2V pipeline. + + Verifies that batched prompts produce correct output shape through + both stages of two-stage denoising (transformer + transformer_2). + A single image is broadcast across all items in the batch. + """ + + @pytest.fixture(scope="class") + def wan22_i2v_full_pipeline(self): + """Load full Wan 2.2 I2V pipeline (all components) for batch tests.""" + if not WAN22_I2V_PATH or not os.path.exists(WAN22_I2V_PATH): + pytest.skip("Checkpoint not available. Set DIFFUSION_MODEL_PATH_WAN22_I2V.") + + args = VisualGenArgs( + checkpoint_path=WAN22_I2V_PATH, + device="cuda", + dtype="bfloat16", + torch_compile=TorchCompileConfig(enable_torch_compile=False), + ) + pipeline = PipelineLoader(args).load(skip_warmup=True) + yield pipeline + del pipeline + gc.collect() + torch.cuda.empty_cache() + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + def test_single_prompt_backward_compat(self, wan22_i2v_full_pipeline): + """Single prompt returns (B, T, H, W, C) for backward compatibility.""" + test_image = _make_test_image(480, 832) + result = wan22_i2v_full_pipeline.forward( + prompt="a cat walking", + image=test_image, + height=480, + width=832, + num_frames=9, + num_inference_steps=4, + guidance_scale=4.0, + seed=42, + ) + assert result.video.dim() == 5, f"Expected 5D (B,T,H,W,C), got {result.video.dim()}D" + B, _T, H, W, C = result.video.shape + assert B == 1 and H == 480 and W == 832 and C == 3 + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + def test_batch_prompt_shape(self, wan22_i2v_full_pipeline): + """List of prompts with a single broadcast image returns (B, T, H, W, C).""" + test_image = _make_test_image(480, 832) + prompts = ["a sunset over mountains", "a cat on a roof"] + result = wan22_i2v_full_pipeline.forward( + prompt=prompts, + image=test_image, + height=480, + width=832, + num_frames=9, + num_inference_steps=4, + guidance_scale=4.0, + seed=42, + ) + assert result.video.dim() == 5, f"Expected 5D (B,T,H,W,C), got {result.video.dim()}D" + B, _T, H, W, C = result.video.shape + assert B == 2 and H == 480 and W == 832 and C == 3 + + +# ============================================================================= +# Combined Optimization Tests +# ============================================================================= + + +@pytest.mark.integration +@pytest.mark.wan_i2v +@pytest.mark.skipif(importlib.util.find_spec("cache_dit") is None, reason="cache_dit not installed") +class TestWan22I2VCombinedOptimizations: + """FP8 + CacheDiT + TRTLLM attention combined on Wan 2.2 I2V (480x832).""" + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + def test_fp8_cache_dit_trtllm(self): + if not os.path.exists(WAN22_I2V_PATH): + pytest.skip(f"Checkpoint not found: {WAN22_I2V_PATH}") + args = VisualGenArgs( + checkpoint_path=WAN22_I2V_PATH, + device="cuda", + dtype="bfloat16", + torch_compile=TorchCompileConfig(enable_torch_compile=False), + quant_config={"quant_algo": "FP8", "dynamic": True}, + attention=AttentionConfig(backend="TRTLLM"), + cache=CacheDiTConfig(), + ) + pipeline = PipelineLoader(args).load(skip_warmup=True) + try: + test_image = _make_test_image(480, 832) + with torch.no_grad(): + result = pipeline.forward( + image=test_image, + prompt="a cat sitting on a windowsill", + negative_prompt="", + height=480, + width=832, + num_frames=9, + num_inference_steps=10, + guidance_scale=4.0, + seed=42, + ) + assert result.video.dim() == 5 + B, _T, H, W, C = result.video.shape + assert B == 1 and H == 480 and W == 832 and C == 3 + + assert pipeline.cache_accelerator is not None + assert pipeline.cache_accelerator.is_enabled() + finally: + del pipeline + gc.collect() + torch.cuda.empty_cache() diff --git a/tests/unittest/_torch/visual_gen/test_wan22_t2v_pipeline.py b/tests/unittest/_torch/visual_gen/test_wan22_t2v_pipeline.py new file mode 100644 index 000000000000..17c1aedc113a --- /dev/null +++ b/tests/unittest/_torch/visual_gen/test_wan22_t2v_pipeline.py @@ -0,0 +1,472 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Correctness tests for Wan 2.2 T2V pipeline against HuggingFace reference. + +Tests verify that the TRTLLM WanPipeline (two-stage T2V) produces decoded video +with >= 0.99 cosine similarity to the HuggingFace diffusers WanPipeline baseline. +Comparison is done on decoded video (post-VAE). + +Wan 2.2 uses two-stage denoising (transformer + transformer_2 split at boundary_ratio). +Both TRTLLM and HF pipelines read boundary_ratio from the checkpoint model_index.json. + +Model tested: + - Wan2.2-T2V-A14B-Diffusers (480x832, 33 frames) + +Run: + pytest tests/unittest/_torch/visual_gen/test_wan22_t2v_pipeline.py -v -s + +Override checkpoint path: + DIFFUSION_MODEL_PATH_WAN22_T2V=/path/to/wan22 \\ + pytest tests/unittest/_torch/visual_gen/test_wan22_t2v_pipeline.py -v -s +""" + +import importlib +import os + +os.environ["TLLM_DISABLE_MPI"] = "1" + +import gc +from pathlib import Path + +import numpy as np +import pytest +import torch +import torch.nn.functional as F +from diffusers import DiffusionPipeline + +from tensorrt_llm._torch.visual_gen.config import ( + AttentionConfig, + CacheDiTConfig, + TorchCompileConfig, + VisualGenArgs, +) +from tensorrt_llm._torch.visual_gen.pipeline_loader import PipelineLoader + + +@pytest.fixture(autouse=True, scope="module") +def _cleanup_mpi_env(): + yield + os.environ.pop("TLLM_DISABLE_MPI", None) + + +# ============================================================================ +# Path helpers +# ============================================================================ + + +def _llm_models_root() -> str: + """Return LLM_MODELS_ROOT path if set in env, assert when it's set but not a valid path.""" + root = Path("/home/scratch.trt_llm_data_ci/llm-models/") + if "LLM_MODELS_ROOT" in os.environ: + root = Path(os.environ["LLM_MODELS_ROOT"]) + if not root.exists(): + root = Path("/scratch.trt_llm_data/llm-models/") + assert root.exists(), ( + "Set LLM_MODELS_ROOT or ensure /home/scratch.trt_llm_data_ci/llm-models/ is accessible." + ) + return str(root) + + +def _checkpoint(env_var: str, default_name: str) -> str: + return os.environ.get(env_var) or os.path.join(_llm_models_root(), default_name) + + +WAN22_A14B_PATH = _checkpoint("DIFFUSION_MODEL_PATH_WAN22_T2V", "Wan2.2-T2V-A14B-Diffusers") + +# ============================================================================ +# Test constants +# ============================================================================ + +PROMPT = "A cat sitting on a sunny windowsill watching birds outside." +NEGATIVE_PROMPT = "" +NUM_STEPS = 4 +SEED = 42 +COS_SIM_THRESHOLD = 0.99 + + +# ============================================================================ +# Helpers +# ============================================================================ + + +def _load_trtllm_pipeline(checkpoint_path: str): + """Load TRTLLM WanPipeline (two-stage) without torch.compile or warmup.""" + if not os.path.exists(checkpoint_path): + pytest.skip(f"Checkpoint not found: {checkpoint_path}") + args = VisualGenArgs( + checkpoint_path=checkpoint_path, + device="cuda", + dtype="bfloat16", + torch_compile=TorchCompileConfig(enable_torch_compile=False), + ) + return PipelineLoader(args).load(skip_warmup=True) + + +def _load_hf_pipeline(checkpoint_path: str): + """Load HuggingFace diffusers pipeline (auto-detects class from model_index.json).""" + hf_pipe = DiffusionPipeline.from_pretrained( + checkpoint_path, + torch_dtype=torch.bfloat16, + ) + hf_pipe = hf_pipe.to("cuda") + hf_pipe.set_progress_bar_config(disable=True) + return hf_pipe + + +def _capture_trtllm_video( + pipeline, + prompt: str, + negative_prompt: str, + height: int, + width: int, + num_frames: int, + num_inference_steps: int, + guidance_scale: float, + seed: int, +) -> torch.Tensor: + """Run full TRTLLM pipeline including VAE decode; return (T, H, W, C) float in [0, 1].""" + with torch.no_grad(): + result = pipeline.forward( + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + num_frames=num_frames, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + seed=seed, + ) + video = result.video # (T, H, W, C) uint8 + return video.float() / 255.0 + + +def _capture_hf_video( + hf_pipe, + prompt: str, + negative_prompt: str, + height: int, + width: int, + num_frames: int, + num_inference_steps: int, + guidance_scale: float, + seed: int, +) -> torch.Tensor: + """Run HF pipeline with output_type='np'; return (T, H, W, C) float in [0, 1].""" + generator = torch.Generator(device="cuda").manual_seed(seed) + output = hf_pipe( + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + num_frames=num_frames, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + generator=generator, + output_type="np", + ) + frames = output.frames # (1, T, H, W, C) numpy float32 in [0, 1] + if isinstance(frames, np.ndarray): + return torch.from_numpy(frames[0]).float() + return frames[0].float() + + +def _cosine_similarity(a: torch.Tensor, b: torch.Tensor) -> float: + """Cosine similarity between two tensors (flattened to 1D, cast to float32 on CPU).""" + a_flat = a.float().cpu().reshape(-1) + b_flat = b.float().cpu().reshape(-1) + return F.cosine_similarity(a_flat.unsqueeze(0), b_flat.unsqueeze(0)).clamp(-1.0, 1.0).item() + + +def _assert_pipeline_matches_hf( + checkpoint_path: str, + height: int, + width: int, + num_frames: int, + guidance_scale: float, + model_label: str, +) -> None: + """Run TRTLLM and HF pipelines sequentially, compare decoded video output.""" + # --- TRTLLM --- + trtllm_pipe = _load_trtllm_pipeline(checkpoint_path) + + # Confirm two-stage denoising is active (Wan 2.2 specific sanity check) + assert trtllm_pipe.transformer_2 is not None, ( + f"{model_label}: expected two-stage pipeline (transformer_2 should not be None). " + "Check that the checkpoint is a Wan 2.2 model with boundary_ratio in model_index.json." + ) + assert trtllm_pipe.boundary_ratio is not None, ( + f"{model_label}: boundary_ratio is None — two-stage denoising will not activate. " + "Check model_index.json for 'boundary_ratio' key." + ) + + trtllm_video = _capture_trtllm_video( + trtllm_pipe, + prompt=PROMPT, + negative_prompt=NEGATIVE_PROMPT, + height=height, + width=width, + num_frames=num_frames, + num_inference_steps=NUM_STEPS, + guidance_scale=guidance_scale, + seed=SEED, + ) + del trtllm_pipe + gc.collect() + torch.cuda.empty_cache() + + # --- HF reference --- + hf_pipe = _load_hf_pipeline(checkpoint_path) + hf_video = _capture_hf_video( + hf_pipe, + prompt=PROMPT, + negative_prompt=NEGATIVE_PROMPT, + height=height, + width=width, + num_frames=num_frames, + num_inference_steps=NUM_STEPS, + guidance_scale=guidance_scale, + seed=SEED, + ) + del hf_pipe + gc.collect() + torch.cuda.empty_cache() + + # --- Compare --- + assert trtllm_video.numel() == hf_video.numel(), ( + f"{model_label}: element count mismatch — " + f"TRTLLM {trtllm_video.shape} ({trtllm_video.numel()}) vs " + f"HF {hf_video.shape} ({hf_video.numel()})" + ) + + cos_sim = _cosine_similarity(trtllm_video, hf_video) + print(f"\n {model_label} cosine similarity: {cos_sim:.6f}") + assert cos_sim >= COS_SIM_THRESHOLD, ( + f"{model_label}: cosine similarity {cos_sim:.6f} < {COS_SIM_THRESHOLD}. " + f"TRTLLM pipeline output diverges from the HuggingFace reference. " + f"Video shapes — TRTLLM: {trtllm_video.shape}, HF: {hf_video.shape}." + ) + + +# ============================================================================ +# Tests +# ============================================================================ + + +@pytest.mark.integration +@pytest.mark.wan_t2v +class TestWan22_A14B_PipelineCorrectness: + """Wan2.2-T2V-A14B correctness vs HuggingFace reference (480x832, 33 frames). + + Wan 2.2 uses two-stage denoising: transformer handles high-noise timesteps + (t >= boundary_timestep) and transformer_2 handles low-noise timesteps. + The test verifies the combined two-stage output matches the HF reference. + """ + + def test_cosine_similarity(self): + _assert_pipeline_matches_hf( + checkpoint_path=WAN22_A14B_PATH, + height=480, + width=832, + num_frames=9, + guidance_scale=4.0, + model_label="Wan2.2-T2V-A14B", + ) + + +# ============================================================================ +# Two-stage feature fixtures (skip aux components, loaded once per module) +# ============================================================================ + +_SKIP_AUX = ["text_encoder", "vae", "tokenizer", "scheduler"] + + +def _make_wan22_t2v(quant_config=None, attention=None): + if not os.path.exists(WAN22_A14B_PATH): + pytest.skip(f"Checkpoint not found: {WAN22_A14B_PATH}") + kwargs = dict( + checkpoint_path=WAN22_A14B_PATH, + device="cuda", + dtype="bfloat16", + skip_components=_SKIP_AUX, + torch_compile=TorchCompileConfig(enable_torch_compile=False), + ) + if quant_config is not None: + kwargs["quant_config"] = quant_config + if attention is not None: + kwargs["attention"] = attention + return PipelineLoader(VisualGenArgs(**kwargs)).load(skip_warmup=True) + + +@pytest.fixture(scope="module") +def wan22_t2v_fp8(): + pipeline = _make_wan22_t2v(quant_config={"quant_algo": "FP8", "dynamic": True}) + yield pipeline + del pipeline + gc.collect() + torch.cuda.empty_cache() + + +@pytest.fixture(scope="module") +def wan22_t2v_trtllm(): + pipeline = _make_wan22_t2v(attention=AttentionConfig(backend="TRTLLM")) + yield pipeline + del pipeline + gc.collect() + torch.cuda.empty_cache() + + +# ============================================================================ +# Tests +# ============================================================================ + + +@pytest.mark.integration +@pytest.mark.wan_t2v +class TestWan22TwoStageFeatures: + """Verify that FP8 and TRTLLM attention apply to both transformer stages in Wan 2.2.""" + + def test_fp8_on_both_stages(self, wan22_t2v_fp8): + """FP8 quantization is applied to both transformer and transformer_2.""" + if wan22_t2v_fp8.transformer_2 is None: + pytest.skip("Not a two-stage checkpoint") + + def _has_fp8(module): + return any( + p.dtype == torch.float8_e4m3fn + for name, p in module.named_parameters() + if "blocks.0" in name and "weight" in name + ) + + assert _has_fp8(wan22_t2v_fp8.transformer), "No FP8 weights in transformer" + assert _has_fp8(wan22_t2v_fp8.transformer_2), "No FP8 weights in transformer_2" + + def test_trtllm_attention_both_stages(self, wan22_t2v_trtllm): + """TRTLLM self-attention and VANILLA cross-attention on both stages.""" + if wan22_t2v_trtllm.transformer_2 is None: + pytest.skip("Not a two-stage checkpoint") + + for stage_name, transformer in [ + ("transformer", wan22_t2v_trtllm.transformer), + ("transformer_2", wan22_t2v_trtllm.transformer_2), + ]: + b0 = transformer.blocks[0] + assert b0.attn1.attn_backend == "TRTLLM", ( + f"{stage_name} self-attn: expected TRTLLM, got {b0.attn1.attn_backend}" + ) + assert b0.attn2.attn_backend == "VANILLA", ( + f"{stage_name} cross-attn should fall back to VANILLA, got {b0.attn2.attn_backend}" + ) + + +# ============================================================================= +# Batch Generation Tests +# ============================================================================= + + +class TestWan22T2VBatchGeneration: + """Batch generation tests for Wan 2.2 T2V pipeline. + + Verifies that batched prompts produce correct output shape through + both stages of two-stage denoising (transformer + transformer_2). + """ + + @pytest.fixture(scope="class") + def wan22_t2v_full_pipeline(self): + """Load full Wan 2.2 T2V pipeline (all components) for batch tests.""" + if not WAN22_A14B_PATH or not os.path.exists(WAN22_A14B_PATH): + pytest.skip("Checkpoint not available. Set DIFFUSION_MODEL_PATH_WAN22_T2V.") + + args = VisualGenArgs( + checkpoint_path=WAN22_A14B_PATH, + device="cuda", + dtype="bfloat16", + torch_compile=TorchCompileConfig(enable_torch_compile=False), + ) + pipeline = PipelineLoader(args).load(skip_warmup=True) + yield pipeline + del pipeline + gc.collect() + torch.cuda.empty_cache() + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + def test_single_prompt_backward_compat(self, wan22_t2v_full_pipeline): + """Single prompt returns (B, T, H, W, C) for backward compatibility.""" + result = wan22_t2v_full_pipeline.forward( + prompt="a cat walking", + height=480, + width=832, + num_frames=9, + num_inference_steps=4, + guidance_scale=4.0, + seed=42, + ) + assert result.video.dim() == 5, f"Expected 5D (B,T,H,W,C), got {result.video.dim()}D" + B, _T, H, W, C = result.video.shape + assert B == 1 and H == 480 and W == 832 and C == 3 + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + def test_batch_prompt_shape(self, wan22_t2v_full_pipeline): + """List of prompts returns (B, T, H, W, C) through both denoising stages.""" + prompts = ["a sunset over mountains", "a cat on a roof"] + result = wan22_t2v_full_pipeline.forward( + prompt=prompts, + height=480, + width=832, + num_frames=9, + num_inference_steps=4, + guidance_scale=4.0, + seed=42, + ) + assert result.video.dim() == 5, f"Expected 5D (B,T,H,W,C), got {result.video.dim()}D" + B, _T, H, W, C = result.video.shape + assert B == 2 and H == 480 and W == 832 and C == 3 + + +# ============================================================================= +# Combined Optimization Tests +# ============================================================================= + + +@pytest.mark.integration +@pytest.mark.wan_t2v +@pytest.mark.skipif(importlib.util.find_spec("cache_dit") is None, reason="cache_dit not installed") +class TestWan22T2VCombinedOptimizations: + """FP8 + CacheDiT + TRTLLM attention combined on Wan 2.2 T2V (480x832).""" + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + def test_fp8_cache_dit_trtllm(self): + if not os.path.exists(WAN22_A14B_PATH): + pytest.skip(f"Checkpoint not found: {WAN22_A14B_PATH}") + args = VisualGenArgs( + checkpoint_path=WAN22_A14B_PATH, + device="cuda", + dtype="bfloat16", + torch_compile=TorchCompileConfig(enable_torch_compile=False), + quant_config={"quant_algo": "FP8", "dynamic": True}, + attention=AttentionConfig(backend="TRTLLM"), + cache=CacheDiTConfig(), + ) + pipeline = PipelineLoader(args).load(skip_warmup=True) + try: + with torch.no_grad(): + result = pipeline.forward( + prompt="a cat sitting on a windowsill", + negative_prompt="", + height=480, + width=832, + num_frames=9, + num_inference_steps=10, + guidance_scale=4.0, + seed=42, + ) + assert result.video.dim() == 5 + B, _T, H, W, C = result.video.shape + assert B == 1 and H == 480 and W == 832 and C == 3 + + assert pipeline.cache_accelerator is not None + assert pipeline.cache_accelerator.is_enabled() + finally: + del pipeline + gc.collect() + torch.cuda.empty_cache() diff --git a/tests/unittest/_torch/visual_gen/test_wan_i2v.py b/tests/unittest/_torch/visual_gen/test_wan_i2v.py deleted file mode 100644 index da51447a7c6b..000000000000 --- a/tests/unittest/_torch/visual_gen/test_wan_i2v.py +++ /dev/null @@ -1,1544 +0,0 @@ -"""Optimized tests for Wan Image-to-Video (I2V) pipeline with module-scoped fixtures. - -Run with: - pytest tests/visual_gen/test_wan_i2v_2.py -v - - # With real checkpoint: - DIFFUSION_MODEL_PATH=/path/to/Wan-I2V-Diffusers pytest tests/visual_gen/test_wan_i2v_2.py -v - - # Run only smoke tests: - pytest tests/visual_gen/test_wan_i2v_2.py -v -m "unit and smoke" - - # Run only Wan 2.1 tests: - pytest tests/visual_gen/test_wan_i2v_2.py -v -m "wan21" - - # Run only Wan 2.2 tests: - pytest tests/visual_gen/test_wan_i2v_2.py -v -m "wan22" -""" - -import os - -os.environ["TLLM_DISABLE_MPI"] = "1" - -import unittest -from pathlib import Path -from types import SimpleNamespace - -import pytest -import torch -import torch.distributed as dist -import torch.multiprocessing as mp -import torch.nn.functional as F -from PIL import Image - -from tensorrt_llm._torch.visual_gen.config import ( - AttentionConfig, - DiffusionModelConfig, - ParallelConfig, - TeaCacheConfig, - TorchCompileConfig, - VisualGenArgs, -) -from tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan_i2v import WanImageToVideoPipeline -from tensorrt_llm._torch.visual_gen.pipeline_loader import PipelineLoader -from tensorrt_llm.models.modeling_utils import QuantConfig -from tensorrt_llm.quantization.mode import QuantAlgo - - -@pytest.fixture(autouse=True, scope="module") -def _cleanup_mpi_env(): - """Clean up TLLM_DISABLE_MPI env var after tests complete.""" - yield - os.environ.pop("TLLM_DISABLE_MPI", None) - - -def _llm_models_root() -> str: - """Return LLM_MODELS_ROOT path if it is set in env, assert when it's set but not a valid path.""" - root = Path("/home/scratch.trt_llm_data_ci/llm-models/") - if "LLM_MODELS_ROOT" in os.environ: - root = Path(os.environ["LLM_MODELS_ROOT"]) - if not root.exists(): - root = Path("/scratch.trt_llm_data/llm-models/") - assert root.exists(), ( - "You shall set LLM_MODELS_ROOT env or be able to access scratch.trt_llm_data to run this test" - ) - return str(root) - - -# Checkpoint paths -CHECKPOINT_PATH = os.environ.get( - "DIFFUSION_MODEL_PATH", - os.path.join(_llm_models_root(), "Wan2.2-I2V-A14B-Diffusers"), -) - -# Skip components for different test scenarios -SKIP_MINIMAL = ["text_encoder", "vae", "tokenizer", "scheduler", "image_encoder", "image_processor"] -SKIP_WITH_IMAGE = ["text_encoder", "vae", "tokenizer", "scheduler"] - - -# ============================================================================ -# VERSION DETECTION HELPERS -# ============================================================================ - - -def is_wan21_checkpoint() -> bool: - """Check if DIFFUSION_MODEL_PATH is Wan 2.1 (contains '2.1' in path).""" - return "2.1" in CHECKPOINT_PATH - - -def is_wan22_checkpoint() -> bool: - """Check if DIFFUSION_MODEL_PATH is Wan 2.2 (contains '2.2' in path).""" - return "2.2" in CHECKPOINT_PATH - - -# ============================================================================ -# MODULE-SCOPED FIXTURES -# ============================================================================ - - -@pytest.fixture(scope="module") -def wan21_i2v_pipeline_bf16(): - """Load Wan 2.1 I2V BF16 pipeline once per module.""" - if not CHECKPOINT_PATH or not os.path.exists(CHECKPOINT_PATH): - pytest.skip("I2V checkpoint not available") - if not is_wan21_checkpoint(): - pytest.skip("This fixture requires Wan 2.1 checkpoint") - - args = VisualGenArgs( - checkpoint_path=CHECKPOINT_PATH, - device="cuda", - dtype="bfloat16", - skip_components=SKIP_MINIMAL, - ) - pipeline = PipelineLoader(args).load(skip_warmup=True) - yield pipeline - del pipeline - torch.cuda.empty_cache() - - -@pytest.fixture(scope="module") -def wan21_i2v_pipeline_fp8(): - """Load Wan 2.1 I2V FP8 pipeline once per module.""" - if not CHECKPOINT_PATH or not os.path.exists(CHECKPOINT_PATH): - pytest.skip("I2V checkpoint not available") - if not is_wan21_checkpoint(): - pytest.skip("This fixture requires Wan 2.1 checkpoint") - - args = VisualGenArgs( - checkpoint_path=CHECKPOINT_PATH, - device="cuda", - dtype="bfloat16", - skip_components=SKIP_MINIMAL, - quant_config={"quant_algo": "FP8", "dynamic": True}, - ) - pipeline = PipelineLoader(args).load(skip_warmup=True) - yield pipeline - del pipeline - torch.cuda.empty_cache() - - -@pytest.fixture(scope="module") -def wan21_i2v_pipeline_fp8_blockwise(): - """Load Wan 2.1 I2V FP8 blockwise pipeline once per module.""" - if not CHECKPOINT_PATH or not os.path.exists(CHECKPOINT_PATH): - pytest.skip("I2V checkpoint not available") - if not is_wan21_checkpoint(): - pytest.skip("This fixture requires Wan 2.1 checkpoint") - - args = VisualGenArgs( - checkpoint_path=CHECKPOINT_PATH, - device="cuda", - dtype="bfloat16", - skip_components=SKIP_MINIMAL, - quant_config={"quant_algo": "FP8_BLOCK_SCALES", "dynamic": True}, - ) - pipeline = PipelineLoader(args).load(skip_warmup=True) - yield pipeline - del pipeline - torch.cuda.empty_cache() - - -@pytest.fixture(scope="module") -def wan21_i2v_pipeline_with_image_encoder(): - """Load Wan 2.1 I2V pipeline with image encoder once per module.""" - if not CHECKPOINT_PATH or not os.path.exists(CHECKPOINT_PATH): - pytest.skip("I2V checkpoint not available") - if not is_wan21_checkpoint(): - pytest.skip("This fixture requires Wan 2.1 checkpoint") - - args = VisualGenArgs( - checkpoint_path=CHECKPOINT_PATH, - device="cuda", - dtype="bfloat16", - skip_components=SKIP_WITH_IMAGE, - ) - pipeline = PipelineLoader(args).load(skip_warmup=True) - yield pipeline - del pipeline - torch.cuda.empty_cache() - - -@pytest.fixture(scope="module") -def wan22_i2v_pipeline_bf16(): - """Load Wan 2.2 I2V BF16 pipeline once per module.""" - if not CHECKPOINT_PATH or not os.path.exists(CHECKPOINT_PATH): - pytest.skip("I2V checkpoint not available") - if not is_wan22_checkpoint(): - pytest.skip("This fixture requires Wan 2.2 checkpoint") - - args = VisualGenArgs( - checkpoint_path=CHECKPOINT_PATH, - device="cuda", - dtype="bfloat16", - skip_components=SKIP_MINIMAL, - ) - pipeline = PipelineLoader(args).load(skip_warmup=True) - yield pipeline - del pipeline - torch.cuda.empty_cache() - - -@pytest.fixture(scope="module") -def wan22_i2v_pipeline_fp8(): - """Load Wan 2.2 I2V FP8 pipeline once per module.""" - if not CHECKPOINT_PATH or not os.path.exists(CHECKPOINT_PATH): - pytest.skip("I2V checkpoint not available") - if not is_wan22_checkpoint(): - pytest.skip("This fixture requires Wan 2.2 checkpoint") - - args = VisualGenArgs( - checkpoint_path=CHECKPOINT_PATH, - device="cuda", - dtype="bfloat16", - skip_components=SKIP_MINIMAL, - quant_config={"quant_algo": "FP8_BLOCK_SCALES", "dynamic": True}, - ) - pipeline = PipelineLoader(args).load(skip_warmup=True) - yield pipeline - del pipeline - torch.cuda.empty_cache() - - -@pytest.fixture(scope="module") -def test_image(): - """Create a shared test image for I2V tests.""" - import numpy as np - - img_array = np.zeros((480, 832, 3), dtype=np.uint8) - for i in range(480): - img_array[i, :, 0] = int((i / 480) * 255) - img_array[i, :, 1] = 128 - return Image.fromarray(img_array, mode="RGB") - - -@pytest.fixture(autouse=True) -def cleanup_gpu(): - """GPU cleanup fixture.""" - import gc - - gc.collect() - torch.cuda.empty_cache() - yield - gc.collect() - torch.cuda.empty_cache() - - -# ============================================================================ -# DISTRIBUTED HELPERS (for CFG Parallelism tests) -# ============================================================================ - - -def setup_distributed(rank, world_size, backend="nccl"): - """Initialize distributed process group for multi-GPU tests.""" - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "12356" # Different port from T2V tests - os.environ["RANK"] = str(rank) - os.environ["WORLD_SIZE"] = str(world_size) - - dist.init_process_group(backend=backend, rank=rank, world_size=world_size) - torch.cuda.set_device(rank) - - -def cleanup_distributed(): - """Clean up distributed process group.""" - if dist.is_initialized(): - dist.destroy_process_group() - - -def _run_cfg_worker_i2v(rank, world_size, checkpoint_path, inputs_list, return_dict): - """Worker function for I2V CFG Parallelism multi-GPU test.""" - try: - setup_distributed(rank, world_size) - - from tensorrt_llm._torch.visual_gen.config import ParallelConfig, VisualGenArgs - from tensorrt_llm._torch.visual_gen.pipeline_loader import PipelineLoader - - # Load I2V pipeline with CFG parallel - args = VisualGenArgs( - checkpoint_path=checkpoint_path, - device=f"cuda:{rank}", - dtype="bfloat16", - skip_components=SKIP_MINIMAL, - parallel=ParallelConfig(dit_cfg_size=world_size), - ) - pipeline = PipelineLoader(args).load(skip_warmup=True) - - # Verify CFG parallel configuration - assert pipeline.model_config.visual_gen_mapping.cfg_size == world_size, ( - f"Expected cfg_size={world_size}, got {pipeline.model_config.visual_gen_mapping.cfg_size}" - ) - - # Load inputs on this GPU - prompt_embeds = inputs_list[0].to(f"cuda:{rank}") - neg_prompt_embeds = inputs_list[1].to(f"cuda:{rank}") - latents = inputs_list[2].to(f"cuda:{rank}") - timestep = inputs_list[3].to(f"cuda:{rank}") - # I2V-specific: image embeddings (if present) - image_embeds = inputs_list[4].to(f"cuda:{rank}") if inputs_list[4] is not None else None - - # Setup CFG config - cfg_config = pipeline._setup_cfg_config( - guidance_scale=5.0, - prompt_embeds=prompt_embeds, - neg_prompt_embeds=neg_prompt_embeds, - ) - - # Verify CFG parallel is enabled - assert cfg_config["enabled"], f"Rank {rank}: CFG parallel not enabled" - assert cfg_config["cfg_size"] == world_size, f"Rank {rank}: Wrong cfg_size" - - expected_cfg_group = rank // cfg_config["ulysses_size"] - assert cfg_config["cfg_rank"] == expected_cfg_group, ( - f"Rank {rank}: Wrong cfg_rank. Expected {expected_cfg_group}, got {cfg_config['cfg_rank']}" - ) - - if rank == 0: - print(f"[CFG I2V Rank {rank}] Loaded with cfg_size={world_size}") - print(f" cfg_rank: {cfg_config['cfg_rank']}") - print(f" local_embeds shape: {cfg_config['local_embeds'].shape}") - print(f" Using {'positive' if cfg_config['cfg_rank'] == 0 else 'negative'} prompts") - print(f" Image embeds: {'present' if image_embeds is not None else 'None'}") - - # Verify prompt splitting - expected_embeds = prompt_embeds if cfg_config["cfg_rank"] == 0 else neg_prompt_embeds - assert torch.allclose(cfg_config["local_embeds"], expected_embeds), ( - f"Rank {rank}: local_embeds doesn't match expected embeds" - ) - - # Run single denoising step with CFG parallel - def forward_fn( - latents, extra_stream_latents, timestep, encoder_hidden_states, extra_tensors - ): - # I2V-specific: include image embeddings in extra_tensors if present - return pipeline.transformer( # noqa: F821 - hidden_states=latents, - timestep=timestep, - encoder_hidden_states=encoder_hidden_states, - encoder_hidden_states_image=extra_tensors.get("encoder_hidden_states_image"), - ) - - with torch.no_grad(): - local_extras = ( - {"encoder_hidden_states_image": image_embeds} if image_embeds is not None else {} - ) - noise_pred, _, _, _ = pipeline._denoise_step_cfg_parallel( - latents=latents, - extra_stream_latents={}, - timestep=timestep, - local_embeds=cfg_config["local_embeds"], - forward_fn=forward_fn, - guidance_scale=5.0, - guidance_rescale=0.0, - ulysses_size=cfg_config["ulysses_size"], - local_extras=local_extras, - ) - - # Validate output - assert not torch.isnan(noise_pred).any(), f"Rank {rank}: Output contains NaN" - assert not torch.isinf(noise_pred).any(), f"Rank {rank}: Output contains Inf" - - # Return output from rank 0 - if rank == 0: - return_dict["output"] = noise_pred.cpu() - print(f"[CFG I2V Rank {rank}] ✓ Output shape: {noise_pred.shape}") - print( - f"[CFG I2V Rank {rank}] ✓ Output range: [{noise_pred.min():.4f}, {noise_pred.max():.4f}]" - ) - - del pipeline - torch.cuda.empty_cache() - - finally: - cleanup_distributed() - - -def _run_all_optimizations_worker_i2v(rank, world_size, checkpoint_path, inputs_list, return_dict): - try: - setup_distributed(rank, world_size) - - # Load I2V pipeline with ALL optimizations - args_full = VisualGenArgs( - checkpoint_path=checkpoint_path, - device=f"cuda:{rank}", - dtype="bfloat16", - skip_components=SKIP_MINIMAL, - quant_config={"quant_algo": "FP8", "dynamic": True}, - cache=TeaCacheConfig( - teacache_thresh=0.2, - use_ret_steps=True, - ), - attention=AttentionConfig(backend="TRTLLM"), - parallel=ParallelConfig(dit_cfg_size=world_size), - ) - pipeline = PipelineLoader(args_full).load(skip_warmup=True) - transformer = pipeline.transformer.eval() - - # Verify all optimizations are enabled - assert pipeline.model_config.visual_gen_mapping.cfg_size == world_size, ( - "CFG parallel not enabled" - ) - assert transformer.model_config.quant_config.quant_algo == QuantAlgo.FP8, "FP8 not enabled" - assert hasattr(pipeline, "transformer_cache_backend"), "TeaCache not enabled" - assert transformer.blocks[0].attn1.attn_backend == "TRTLLM", ( - "TRTLLM not enabled for self-attn" - ) - - if rank == 0: - print(f" ✓ All optimizations verified on I2V rank {rank}:") - print(f" - FP8 quantization: {transformer.model_config.quant_config.quant_algo}") - print(" - TeaCache: enabled") - print(f" - TRTLLM attention: {transformer.blocks[0].attn1.attn_backend}") - print(f" - CFG Parallelism: cfg_size={world_size}") - - # Initialize TeaCache for single-step inference - if hasattr(pipeline, "transformer_cache_backend"): - pipeline.transformer_cache_backend.refresh(num_inference_steps=1) - - # Load inputs on this GPU - prompt_embeds = inputs_list[0].to(f"cuda:{rank}") - neg_prompt_embeds = inputs_list[1].to(f"cuda:{rank}") - latents = inputs_list[2].to(f"cuda:{rank}") - timestep = inputs_list[3].to(f"cuda:{rank}") - image_embeds = inputs_list[4].to(f"cuda:{rank}") if inputs_list[4] is not None else None - - # Setup CFG config - cfg_config = pipeline._setup_cfg_config( - guidance_scale=5.0, - prompt_embeds=prompt_embeds, - neg_prompt_embeds=neg_prompt_embeds, - ) - - assert cfg_config["enabled"], "CFG parallel not enabled" - - # Run single denoising step with all optimizations - def forward_fn( - latents, extra_stream_latents, timestep, encoder_hidden_states, extra_tensors - ): - return transformer( # noqa: F821 - hidden_states=latents, - timestep=timestep, - encoder_hidden_states=encoder_hidden_states, - encoder_hidden_states_image=extra_tensors.get("encoder_hidden_states_image"), - ) - - with torch.no_grad(): - local_extras = ( - {"encoder_hidden_states_image": image_embeds} if image_embeds is not None else {} - ) - noise_pred, _, _, _ = pipeline._denoise_step_cfg_parallel( - latents=latents, - extra_stream_latents={}, - timestep=timestep, - local_embeds=cfg_config["local_embeds"], - forward_fn=forward_fn, - guidance_scale=5.0, - guidance_rescale=0.0, - ulysses_size=cfg_config["ulysses_size"], - local_extras=local_extras, - ) - - # Validate output - assert not torch.isnan(noise_pred).any(), f"Rank {rank}: Output contains NaN" - assert not torch.isinf(noise_pred).any(), f"Rank {rank}: Output contains Inf" - - # Return output from rank 0 - if rank == 0: - return_dict["output"] = noise_pred.cpu() - print(f" ✓ Combined optimization I2V output shape: {noise_pred.shape}") - print( - f" ✓ Combined optimization I2V range: [{noise_pred.min():.4f}, {noise_pred.max():.4f}]" - ) - - del pipeline, transformer - torch.cuda.empty_cache() - - finally: - cleanup_distributed() - - -# ============================================================================ -# SMOKE TESTS (No Checkpoint Required) -# ============================================================================ - - -@pytest.mark.unit -@pytest.mark.smoke -class TestWanI2VSmoke: - def _create_model_config(self, boundary_ratio=None): - """Helper to create test model config.""" - config_dict = { - "attention_head_dim": 128, - "in_channels": 16, - "out_channels": 16, - "num_attention_heads": 4, - "num_layers": 1, - "patch_size": [1, 2, 2], - "text_dim": 4096, - "freq_dim": 256, - "ffn_dim": 1024, - "torch_dtype": "bfloat16", - "hidden_size": 512, - "qk_norm": "rms_norm_across_heads", - "cross_attn_norm": "layer_norm", - "eps": 1e-06, - "image_dim": 1280, # CLIP dimension (HF naming convention) - "added_kv_proj_dim": 1280, # Added KV projection dimension for I2V - "boundary_ratio": boundary_ratio, - } - pretrained_config = SimpleNamespace(**config_dict) - quant_config = QuantConfig() - - return DiffusionModelConfig( - pretrained_config=pretrained_config, - quant_config=quant_config, - skip_create_weights_in_init=True, - ) - - def test_wan21_instantiation(self): - """Test Wan 2.1 I2V pipeline (single-stage).""" - model_config = self._create_model_config(boundary_ratio=None) - pipeline = WanImageToVideoPipeline(model_config) - - assert pipeline.transformer is not None - assert pipeline.transformer_2 is None # Single-stage - assert pipeline.boundary_ratio is None - - def test_wan22_instantiation(self): - """Test Wan 2.2 I2V pipeline (two-stage).""" - model_config = self._create_model_config(boundary_ratio=0.4) - pipeline = WanImageToVideoPipeline(model_config) - - assert pipeline.transformer is not None - assert pipeline.transformer_2 is not None # Two-stage - assert pipeline.boundary_ratio == 0.4 - - def test_retrieve_latents(self): - """Test retrieve_latents helper.""" - from tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan_i2v import retrieve_latents - - class MockLatentDist: - def mode(self): - return torch.randn(1, 16, 1, 64, 64) - - def sample(self, generator=None): - return torch.randn(1, 16, 1, 64, 64) - - class MockEncoderOutput: - def __init__(self): - self.latent_dist = MockLatentDist() - - encoder_output = MockEncoderOutput() - - # Test argmax mode (I2V default for deterministic encoding) - latents_argmax = retrieve_latents(encoder_output, sample_mode="argmax") - assert latents_argmax.shape == (1, 16, 1, 64, 64) - - # Test sample mode - latents_sample = retrieve_latents(encoder_output, sample_mode="sample") - assert latents_sample.shape == (1, 16, 1, 64, 64) - - -# ============================================================================ -# INTEGRATION TESTS - WAN 2.1 (Require Wan 2.1 Checkpoint) -# ============================================================================ - - -@pytest.mark.integration -@pytest.mark.i2v -@pytest.mark.wan21 -class TestWanI2VIntegration: - """Integration tests with Wan 2.1 checkpoint.""" - - def test_load_pipeline(self, wan21_i2v_pipeline_bf16): - """Test loading I2V pipeline from checkpoint.""" - # Verify I2V pipeline - assert "ImageToVideo" in type(wan21_i2v_pipeline_bf16).__name__ - assert wan21_i2v_pipeline_bf16.transformer is not None - assert len(wan21_i2v_pipeline_bf16.transformer.blocks) > 0 - - # Detect version - is_two_stage = ( - wan21_i2v_pipeline_bf16.boundary_ratio is not None - and wan21_i2v_pipeline_bf16.transformer_2 is not None - ) - - print(f"\n✓ Pipeline: {type(wan21_i2v_pipeline_bf16).__name__}") - print(f"✓ Transformer blocks: {len(wan21_i2v_pipeline_bf16.transformer.blocks)}") - print(f"✓ boundary_ratio: {wan21_i2v_pipeline_bf16.boundary_ratio}") - print(f"✓ Two-stage: {is_two_stage}") - - def test_image_encoding(self, wan21_i2v_pipeline_with_image_encoder, test_image): - """Test CLIP image encoding (if model uses it).""" - # Check if model uses image encoder - if ( - not hasattr(wan21_i2v_pipeline_with_image_encoder, "image_encoder") - or wan21_i2v_pipeline_with_image_encoder.image_encoder is None - ): - pytest.skip("This checkpoint doesn't use image encoder") - - # Encode test image - image_embeds = wan21_i2v_pipeline_with_image_encoder._encode_image(test_image) - - assert image_embeds is not None - assert image_embeds.dim() == 3 # [batch, seq_len, embed_dim] - print(f"\n✓ Image embeddings: {image_embeds.shape}, dtype={image_embeds.dtype}") - - def test_fp8_per_tensor_quantization(self, wan21_i2v_pipeline_fp8): - """Test FP8 per-tensor dynamic quantization.""" - # Check transformer for FP8 weights - found_fp8 = any( - param.dtype == torch.float8_e4m3fn - for name, param in wan21_i2v_pipeline_fp8.transformer.named_parameters() - if "blocks.0" in name and "weight" in name - ) - assert found_fp8, "No FP8 weights found for FP8" - print("\n✓ FP8: FP8 weights found in transformer") - - # Check transformer_2 if two-stage - if wan21_i2v_pipeline_fp8.transformer_2 is not None: - found_fp8_t2 = any( - param.dtype == torch.float8_e4m3fn - for name, param in wan21_i2v_pipeline_fp8.transformer_2.named_parameters() - if "blocks.0" in name and "weight" in name - ) - assert found_fp8_t2, "No FP8 weights in transformer_2" - print("✓ FP8: FP8 weights found in transformer_2") - - def test_fp8_blockwise_quantization(self, wan21_i2v_pipeline_fp8_blockwise): - """Test FP8 blockwise dynamic quantization.""" - # Check transformer for FP8 weights - found_fp8 = any( - param.dtype == torch.float8_e4m3fn - for name, param in wan21_i2v_pipeline_fp8_blockwise.transformer.named_parameters() - if "blocks.0" in name and "weight" in name - ) - assert found_fp8, "No FP8 weights found for FP8_BLOCK_SCALES" - print("\n✓ FP8_BLOCK_SCALES: FP8 weights found in transformer") - - @pytest.mark.parametrize("backend", ["VANILLA", "TRTLLM"]) - def test_attention_backends(self, backend): - """Test different attention backends.""" - if not CHECKPOINT_PATH or not os.path.exists(CHECKPOINT_PATH): - pytest.skip("DIFFUSION_MODEL_PATH not set") - if not is_wan21_checkpoint(): - pytest.skip("This test requires Wan 2.1 checkpoint") - - args = VisualGenArgs( - checkpoint_path=CHECKPOINT_PATH, - device="cuda", - dtype="bfloat16", - skip_components=SKIP_MINIMAL, - attention=AttentionConfig(backend=backend), - ) - pipeline = PipelineLoader(args).load(skip_warmup=True) - - try: - # Check transformer attention backend - first_block = pipeline.transformer.blocks[0] - attn1_backend = first_block.attn1.attn_backend - attn2_backend = first_block.attn2.attn_backend - - # TRTLLM for self-attention, VANILLA for cross-attention - if backend == "TRTLLM": - assert attn1_backend == "TRTLLM", f"Expected TRTLLM, got {attn1_backend}" - assert attn2_backend == "VANILLA", ( - f"Cross-attn should be VANILLA, got {attn2_backend}" - ) - else: - assert attn1_backend == "VANILLA" - assert attn2_backend == "VANILLA" - - print(f"\n✓ Attention backend: {backend}") - print(f" Self-attn: {attn1_backend}, Cross-attn: {attn2_backend}") - - # Check transformer_2 if two-stage - if pipeline.transformer_2 is not None: - first_block_t2 = pipeline.transformer_2.blocks[0] - attn1_backend_t2 = first_block_t2.attn1.attn_backend - attn2_backend_t2 = first_block_t2.attn2.attn_backend - - if backend == "TRTLLM": - assert attn1_backend_t2 == "TRTLLM" - assert attn2_backend_t2 == "VANILLA" - print( - f" Transformer_2 - Self-attn: {attn1_backend_t2}, Cross-attn: {attn2_backend_t2}" - ) - - finally: - del pipeline - torch.cuda.empty_cache() - - def test_teacache(self): - """Test TeaCache on both transformers.""" - if not CHECKPOINT_PATH or not os.path.exists(CHECKPOINT_PATH): - pytest.skip("DIFFUSION_MODEL_PATH not set") - if not is_wan21_checkpoint(): - pytest.skip("This test requires Wan 2.1 checkpoint") - - args = VisualGenArgs( - checkpoint_path=CHECKPOINT_PATH, - device="cuda", - dtype="bfloat16", - skip_components=SKIP_MINIMAL, - cache=TeaCacheConfig( - teacache_thresh=0.2, - use_ret_steps=True, - ), - ) - pipeline = PipelineLoader(args).load(skip_warmup=True) - - try: - # Verify TeaCache on transformer - assert hasattr(pipeline, "transformer_cache_backend") - assert pipeline.transformer_cache_backend is not None - print("\n✓ TeaCache enabled on transformer (high-noise)") - - # Verify get_stats method - stats = pipeline.transformer_cache_backend.get_stats() - assert "total_steps" in stats - assert "cached_steps" in stats - assert "compute_steps" in stats - print("✓ TeaCache stats available") - - # Check transformer_2 if two-stage - if pipeline.transformer_2 is not None: - assert hasattr(pipeline, "transformer_2_cache_backend") - assert pipeline.transformer_2_cache_backend is not None - stats2 = pipeline.transformer_2_cache_backend.get_stats() - assert "total_steps" in stats2 - print("✓ TeaCache enabled on transformer_2 (low-noise)") - - finally: - del pipeline - torch.cuda.empty_cache() - - def test_all_optimizations_combined(self): - """Test all optimizations enabled simultaneously.""" - if not CHECKPOINT_PATH or not os.path.exists(CHECKPOINT_PATH): - pytest.skip("DIFFUSION_MODEL_PATH not set") - if not is_wan21_checkpoint(): - pytest.skip("This test requires Wan 2.1 checkpoint") - - args = VisualGenArgs( - checkpoint_path=CHECKPOINT_PATH, - device="cuda", - dtype="bfloat16", - skip_components=SKIP_MINIMAL, - quant_config={"quant_algo": "FP8_BLOCK_SCALES", "dynamic": True}, - attention=AttentionConfig(backend="VANILLA"), # VANILLA more stable with all opts - cache=TeaCacheConfig( - teacache_thresh=0.2, - use_ret_steps=True, - ), - ) - pipeline = PipelineLoader(args).load(skip_warmup=True) - - try: - optimizations = [] - - # Check FP8 - if any(p.dtype == torch.float8_e4m3fn for p in pipeline.transformer.parameters()): - optimizations.append("FP8") - - # Check TeaCache - if ( - hasattr(pipeline, "transformer_cache_backend") - and pipeline.transformer_cache_backend - ): - optimizations.append("TeaCache") - - # Check two-stage - if pipeline.transformer_2 is not None: - optimizations.append("Two-Stage") - - # Check attention backend - optimizations.append(f"Attention={args.attention.backend}") - - print(f"\n✓ All optimizations: {', '.join(optimizations)}") - assert len(optimizations) >= 3 - - finally: - del pipeline - torch.cuda.empty_cache() - - def test_fp8_vs_bf16_numerical_correctness( - self, wan21_i2v_pipeline_bf16, wan21_i2v_pipeline_fp8 - ): - """Test FP8 vs BF16 numerical accuracy on I2V transformer.""" - # Get linear layers from first transformer - attn_bf16 = wan21_i2v_pipeline_bf16.transformer.blocks[0].attn1 - attn_fp8 = wan21_i2v_pipeline_fp8.transformer.blocks[0].attn1 - - # Get qkv_proj layer - if hasattr(attn_bf16, "qkv_proj"): - linear_bf16 = attn_bf16.qkv_proj - linear_fp8 = attn_fp8.qkv_proj - layer_name = "blocks.0.attn1.qkv_proj" - elif hasattr(attn_bf16, "attn") and hasattr(attn_bf16.attn, "qkv_proj"): - linear_bf16 = attn_bf16.attn.qkv_proj - linear_fp8 = attn_fp8.attn.qkv_proj - layer_name = "blocks.0.attn1.attn.qkv_proj" - else: - # Use FFN linear instead - linear_bf16 = wan21_i2v_pipeline_bf16.transformer.blocks[0].ffn.net[0]["proj"] - linear_fp8 = wan21_i2v_pipeline_fp8.transformer.blocks[0].ffn.net[0]["proj"] - layer_name = "blocks.0.ffn.net.0.proj" - - # Get weights - weight_bf16 = linear_bf16.weight.data.clone() - bias_bf16 = linear_bf16.bias.data.clone() if linear_bf16.bias is not None else None - - # Create test input - torch.manual_seed(42) - hidden_size = linear_bf16.in_features - batch_size = 1 - seq_len = 14040 - - input_tensor = torch.randn( - batch_size * seq_len, hidden_size, dtype=torch.bfloat16, device="cuda" - ) - print(f"\n[Compare] Input shape: {input_tensor.shape}") - - # Compute reference output - with torch.no_grad(): - expected = F.linear(input_tensor, weight_bf16, bias_bf16) - - # Compute FP8 output - with torch.no_grad(): - result_fp8 = linear_fp8(input_tensor) - - # Compute BF16 output - with torch.no_grad(): - result_bf16 = linear_bf16(input_tensor) - - # Verify BF16 matches reference - assert torch.allclose(result_bf16, expected, rtol=1e-5, atol=1e-6), ( - "BF16 layer should match F.linear reference exactly" - ) - - # Compare FP8 vs reference - max_diff = torch.max(torch.abs(result_fp8 - expected)).item() - cos_sim = F.cosine_similarity( - result_fp8.flatten().float(), expected.flatten().float(), dim=0 - ) - mse = F.mse_loss(result_fp8.flatten().float(), expected.flatten().float()) - - print( - f"\n[{layer_name}] max_diff={max_diff:.6f}, cos_sim={cos_sim.item():.6f}, mse={mse.item():.6f}" - ) - - assert cos_sim > 0.99, f"Cosine similarity too low: {cos_sim.item()}" - assert mse < 1.0, f"MSE too high: {mse.item()}" - - # Test transformer_2 if two-stage - if ( - wan21_i2v_pipeline_bf16.transformer_2 is not None - and wan21_i2v_pipeline_fp8.transformer_2 is not None - ): - print("\n[Testing transformer_2]") - attn2_bf16 = wan21_i2v_pipeline_bf16.transformer_2.blocks[0].attn1 - attn2_fp8 = wan21_i2v_pipeline_fp8.transformer_2.blocks[0].attn1 - - if hasattr(attn2_bf16, "qkv_proj"): - linear2_bf16 = attn2_bf16.qkv_proj - linear2_fp8 = attn2_fp8.qkv_proj - else: - linear2_bf16 = wan21_i2v_pipeline_bf16.transformer_2.blocks[0].ffn.net[0]["proj"] - linear2_fp8 = wan21_i2v_pipeline_fp8.transformer_2.blocks[0].ffn.net[0]["proj"] - - weight2_bf16 = linear2_bf16.weight.data.clone() - bias2_bf16 = linear2_bf16.bias.data.clone() if linear2_bf16.bias is not None else None - - with torch.no_grad(): - expected2 = F.linear(input_tensor, weight2_bf16, bias2_bf16) - result2_fp8 = linear2_fp8(input_tensor) - - cos_sim2 = F.cosine_similarity( - result2_fp8.flatten().float(), expected2.flatten().float(), dim=0 - ) - print(f"[transformer_2] cos_sim={cos_sim2.item():.6f}") - assert cos_sim2 > 0.99, f"Transformer_2 cosine similarity too low: {cos_sim2.item()}" - - def test_fp8_vs_bf16_memory_comparison(self, wan21_i2v_pipeline_bf16, wan21_i2v_pipeline_fp8): - """Test FP8 uses ~2x less memory than BF16 for I2V.""" - - def get_module_memory_gb(module): - return sum(p.numel() * p.element_size() for p in module.parameters()) / 1024**3 - - bf16_model_mem = get_module_memory_gb(wan21_i2v_pipeline_bf16.transformer) - if wan21_i2v_pipeline_bf16.transformer_2 is not None: - bf16_model_mem += get_module_memory_gb(wan21_i2v_pipeline_bf16.transformer_2) - - fp8_model_mem = get_module_memory_gb(wan21_i2v_pipeline_fp8.transformer) - if wan21_i2v_pipeline_fp8.transformer_2 is not None: - fp8_model_mem += get_module_memory_gb(wan21_i2v_pipeline_fp8.transformer_2) - - print(f"\n[BF16] Transformer(s) memory: {bf16_model_mem:.2f} GB") - print(f"[FP8] Transformer(s) memory: {fp8_model_mem:.2f} GB") - - # Verify memory savings - model_mem_ratio = bf16_model_mem / fp8_model_mem - - print(f"\n[Comparison] Model memory ratio (BF16/FP8): {model_mem_ratio:.2f}x") - - # FP8 should use ~2x less memory - assert model_mem_ratio > 1.8, f"FP8 should use ~2x less memory, got {model_mem_ratio:.2f}x" - - -# ============================================================================ -# TWO-STAGE SPECIFIC TESTS - WAN 2.2 (Require Wan 2.2 Checkpoint) -# ============================================================================ - - -@pytest.mark.integration -@pytest.mark.i2v -@pytest.mark.wan22 -class TestWanI2VTwoStage: - """Tests specific to Wan 2.2 two-stage denoising.""" - - def test_transformer_selection_logic(self, wan22_i2v_pipeline_bf16): - """Test boundary_timestep logic for transformer selection.""" - # Skip if not two-stage - if ( - wan22_i2v_pipeline_bf16.boundary_ratio is None - or wan22_i2v_pipeline_bf16.transformer_2 is None - ): - pytest.skip("Not a two-stage checkpoint") - - # Calculate boundary - num_train_timesteps = 1000 - boundary_timestep = wan22_i2v_pipeline_bf16.boundary_ratio * num_train_timesteps - - print(f"\n✓ boundary_ratio: {wan22_i2v_pipeline_bf16.boundary_ratio}") - print(f"✓ boundary_timestep: {boundary_timestep:.1f}") - print(f"✓ High-noise (t >= {boundary_timestep:.1f}): uses transformer") - print(f"✓ Low-noise (t < {boundary_timestep:.1f}): uses transformer_2") - - @pytest.mark.parametrize("guidance_scale_2", [2.0, 3.0, 4.0]) - def test_guidance_scale_2_parameter(self, wan22_i2v_pipeline_bf16, guidance_scale_2): - """Test guidance_scale_2 for low-noise stage.""" - # Skip if not two-stage - if ( - wan22_i2v_pipeline_bf16.boundary_ratio is None - or wan22_i2v_pipeline_bf16.transformer_2 is None - ): - pytest.skip("Not a two-stage checkpoint") - - print(f"\n✓ Two-stage model supports guidance_scale_2={guidance_scale_2}") - print("✓ High-noise: uses guidance_scale") - print(f"✓ Low-noise: uses guidance_scale_2={guidance_scale_2}") - - def test_custom_boundary_ratio(self, wan22_i2v_pipeline_bf16): - """Test overriding boundary_ratio at runtime.""" - # Skip if not two-stage - if ( - wan22_i2v_pipeline_bf16.boundary_ratio is None - or wan22_i2v_pipeline_bf16.transformer_2 is None - ): - pytest.skip("Not a two-stage checkpoint") - - default_ratio = wan22_i2v_pipeline_bf16.boundary_ratio - custom_ratio = 0.3 - - print(f"\n✓ Model default boundary_ratio: {default_ratio}") - print(f"✓ Custom override: {custom_ratio}") - print("✓ forward() accepts boundary_ratio parameter for runtime override") - - def test_two_stage_with_all_optimizations(self, wan22_i2v_pipeline_fp8): - """Test Wan 2.2 with FP8 and TRTLLM attention (TeaCache not supported for Wan 2.2).""" - # Skip if not two-stage - if ( - wan22_i2v_pipeline_fp8.boundary_ratio is None - or wan22_i2v_pipeline_fp8.transformer_2 is None - ): - pytest.skip("Not a two-stage checkpoint") - - # Load pipeline with all optimizations - args = VisualGenArgs( - checkpoint_path=CHECKPOINT_PATH, - device="cuda", - dtype="bfloat16", - skip_components=SKIP_MINIMAL, - quant_config={"quant_algo": "FP8_BLOCK_SCALES", "dynamic": True}, - attention=AttentionConfig(backend="TRTLLM"), - ) - pipeline = PipelineLoader(args).load(skip_warmup=True) - - try: - print("\n[Two-Stage + All Optimizations]") - - # Check FP8 on both transformers - fp8_t1 = any(p.dtype == torch.float8_e4m3fn for p in pipeline.transformer.parameters()) - fp8_t2 = any( - p.dtype == torch.float8_e4m3fn for p in pipeline.transformer_2.parameters() - ) - print(f"✓ FP8: transformer={fp8_t1}, transformer_2={fp8_t2}") - assert fp8_t1 and fp8_t2 - - # Check TRTLLM attention - attn1_backend = pipeline.transformer.blocks[0].attn1.attn_backend - attn2_backend = pipeline.transformer_2.blocks[0].attn1.attn_backend - print(f"✓ TRTLLM: transformer={attn1_backend}, transformer_2={attn2_backend}") - assert attn1_backend == "TRTLLM" - assert attn2_backend == "TRTLLM" - - print("✓ All optimizations working on two-stage model!") - - finally: - del pipeline - torch.cuda.empty_cache() - - -# ============================================================================ -# ROBUSTNESS TESTS -# ============================================================================ - - -@pytest.mark.robustness -class TestWanI2VRobustness: - """Robustness and error handling tests.""" - - def test_invalid_quant_config(self): - """Test that invalid quantization config raises appropriate error.""" - with pytest.raises((ValueError, KeyError)): - args = VisualGenArgs( - checkpoint_path=CHECKPOINT_PATH, - device="cuda", - dtype="bfloat16", - skip_components=SKIP_MINIMAL, - quant_config={"quant_algo": "INVALID_ALGO", "dynamic": True}, - ) - pipeline = PipelineLoader(args).load(skip_warmup=True) - del pipeline - - def test_mismatched_image_size(self, test_image): - """Test handling of unexpected image dimensions.""" - if not CHECKPOINT_PATH or not os.path.exists(CHECKPOINT_PATH): - pytest.skip("DIFFUSION_MODEL_PATH not set") - - args = VisualGenArgs( - checkpoint_path=CHECKPOINT_PATH, - device="cuda", - dtype="bfloat16", - skip_components=SKIP_WITH_IMAGE, - ) - pipeline = PipelineLoader(args).load(skip_warmup=True) - - try: - # Check if model uses image encoder - if not hasattr(pipeline, "image_encoder") or pipeline.image_encoder is None: - pytest.skip("This checkpoint doesn't use image encoder") - - # Create image with unexpected size - import numpy as np - - small_img = np.zeros((224, 224, 3), dtype=np.uint8) - small_image = Image.fromarray(small_img, mode="RGB") - - # Should handle gracefully - try: - image_embeds = pipeline._encode_image(small_image) - assert image_embeds is not None - print("\n✓ Handled non-standard image size gracefully") - except Exception as e: - # Some error is expected - print(f"\n✓ Raised appropriate error for mismatched size: {type(e).__name__}") - - finally: - del pipeline - torch.cuda.empty_cache() - - -# ============================================================================= -# Batch Generation Tests (I2V) -# ============================================================================= - - -@pytest.mark.integration -@pytest.mark.i2v -class TestWanI2VBatchGeneration: - """Batch generation tests for WAN I2V pipeline (Wan 2.1 and Wan 2.2). - - Tests that passing a list of prompts produces batched output - and matches sequential generation with the same seeds. - """ - - @pytest.fixture(scope="class") - def i2v_full_pipeline(self): - """Load full I2V pipeline (all components) for batch tests.""" - if not CHECKPOINT_PATH or not os.path.exists(CHECKPOINT_PATH): - pytest.skip("Checkpoint not available. Set DIFFUSION_MODEL_PATH.") - - args = VisualGenArgs( - checkpoint_path=CHECKPOINT_PATH, - device="cuda", - dtype="bfloat16", - torch_compile=TorchCompileConfig(enable_torch_compile=False), - ) - pipeline = PipelineLoader(args).load(skip_warmup=True) - yield pipeline - del pipeline - import gc - - gc.collect() - torch.cuda.empty_cache() - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - def test_single_prompt_backward_compat(self, i2v_full_pipeline, test_image): - """Single prompt returns (T, H, W, C) for backward compatibility.""" - result = i2v_full_pipeline.forward( - prompt="a cat walking", - image=test_image, - height=480, - width=832, - num_frames=9, - num_inference_steps=4, - guidance_scale=5.0, - seed=42, - ) - assert result.video.dim() == 5, f"Expected 5D (B,T,H,W,C), got {result.video.dim()}D" - B, _T, H, W, C = result.video.shape - assert B == 1 and H == 480 and W == 832 and C == 3 - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - def test_batch_prompt_shape(self, i2v_full_pipeline, test_image): - """List of prompts returns (B, T, H, W, C).""" - prompts = ["a sunset over mountains", "a cat on a roof"] - result = i2v_full_pipeline.forward( - prompt=prompts, - image=test_image, - height=480, - width=832, - num_frames=9, - num_inference_steps=4, - guidance_scale=5.0, - seed=42, - ) - assert result.video.dim() == 5, f"Expected 5D (B,T,H,W,C), got {result.video.dim()}D" - B, _T, H, W, C = result.video.shape - assert B == 2 and H == 480 and W == 832 and C == 3 - - -# ============================================================================ -# CFG PARALLELISM TESTS (Requires 2+ GPUs) -# ============================================================================ - - -@pytest.mark.parallelism -class TestWanI2VParallelism(unittest.TestCase): - """Distributed parallelism correctness tests for I2V (CFG Parallelism).""" - - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - def setUp(self): - """Set up test fixtures and skip if checkpoint not available.""" - torch.manual_seed(42) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(42) - if not CHECKPOINT_PATH or not os.path.exists(CHECKPOINT_PATH): - self.skipTest( - "Checkpoint not available. Set DIFFUSION_MODEL_PATH environment variable." - ) - - def tearDown(self): - """Clean up GPU memory.""" - import gc - - gc.collect() - if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.synchronize() - - def test_cfg_2gpu_correctness(self): - """Test I2V CFG Parallelism (cfg_size=2) correctness against standard CFG baseline.""" - num_gpus = torch.cuda.device_count() - if num_gpus < 2: - pytest.skip("CFG parallel test requires at least 2 GPUs") - if not is_wan21_checkpoint(): - pytest.skip( - "This test requires Wan 2.1 checkpoint. Use DIFFUSION_MODEL_PATH with '2.1' in the path." - ) - - print("\n" + "=" * 80) - print("I2V CFG PARALLELISM (cfg_size=2) CORRECTNESS TEST") - print("=" * 80) - - # Load standard CFG baseline on GPU 0 - print("\n[1/3] Loading standard CFG I2V baseline (cfg_size=1) on GPU 0...") - args_baseline = VisualGenArgs( - checkpoint_path=CHECKPOINT_PATH, - device="cuda:0", - dtype="bfloat16", - skip_components=SKIP_MINIMAL, - parallel=ParallelConfig(dit_cfg_size=1), # Standard CFG (no parallel) - ) - pipeline_baseline = PipelineLoader(args_baseline).load(skip_warmup=True) - config = pipeline_baseline.transformer.model_config.pretrained_config - - # Reset torch compile state - torch._dynamo.reset() - - # Create FIXED test inputs - print("\n[2/3] Creating fixed test inputs...") - torch.manual_seed(42) - batch_size, num_frames, height, width, seq_len = 1, 1, 64, 64, 128 - - latents = torch.randn( - batch_size, - config.in_channels, - num_frames, - height, - width, - dtype=torch.bfloat16, - device="cuda:0", - ) - timestep = torch.tensor([500], dtype=torch.long, device="cuda:0") - prompt_embeds = torch.randn( - batch_size, seq_len, config.text_dim, dtype=torch.bfloat16, device="cuda:0" - ) - neg_prompt_embeds = torch.randn( - batch_size, seq_len, config.text_dim, dtype=torch.bfloat16, device="cuda:0" - ) - - # I2V-specific: Create image embeddings (or None if Wan 2.2) - image_embeds = None - image_dim = getattr(config, "image_dim", getattr(config, "image_embed_dim", None)) - if image_dim is not None: - # Wan 2.1 uses CLIP image embeddings - image_seq_len = 256 # CLIP patch count - image_embeds = torch.randn( - batch_size, image_seq_len, image_dim, dtype=torch.bfloat16, device="cuda:0" - ) - print(f" ✓ Created image embeddings: {image_embeds.shape}") - - # Setup standard CFG config - cfg_config_baseline = pipeline_baseline._setup_cfg_config( - guidance_scale=5.0, - prompt_embeds=prompt_embeds, - neg_prompt_embeds=neg_prompt_embeds, - ) - - print(" Baseline CFG config:") - print(f" enabled: {cfg_config_baseline['enabled']}") - print(f" cfg_size: {cfg_config_baseline['cfg_size']}") - - # Verify standard CFG is NOT parallel - assert not cfg_config_baseline["enabled"], "Baseline should not use CFG parallel" - assert cfg_config_baseline["cfg_size"] == 1, "Baseline cfg_size should be 1" - - # Run standard CFG denoising step - def forward_fn( - latents, extra_stream_latents, timestep, encoder_hidden_states, extra_tensors - ): - return pipeline_baseline.transformer( # noqa: F821 - hidden_states=latents, - timestep=timestep, - encoder_hidden_states=encoder_hidden_states, - encoder_hidden_states_image=extra_tensors.get("encoder_hidden_states_image"), - ) - - with torch.no_grad(): - local_extras = ( - {"encoder_hidden_states_image": image_embeds} if image_embeds is not None else {} - ) - baseline_output, _, _, _ = pipeline_baseline._denoise_step_standard( - latents=latents.clone(), - extra_stream_latents={}, - timestep=timestep, - prompt_embeds=cfg_config_baseline["prompt_embeds"], - forward_fn=forward_fn, - guidance_scale=5.0, - guidance_rescale=0.0, - local_extras=local_extras, - ) - - print(f" ✓ Baseline output shape: {baseline_output.shape}") - print(f" ✓ Baseline range: [{baseline_output.min():.4f}, {baseline_output.max():.4f}]") - - # Cleanup baseline to free memory for CFG workers - del pipeline_baseline - torch.cuda.empty_cache() - - # Run CFG parallel (cfg_size=2) in distributed processes - print("\n[3/3] Running I2V CFG Parallelism (cfg_size=2) across 2 GPUs...") - cfg_size = 2 - - inputs_cpu = [ - prompt_embeds.cpu(), - neg_prompt_embeds.cpu(), - latents.cpu(), - timestep.cpu(), - image_embeds.cpu() if image_embeds is not None else None, - ] - - manager = mp.Manager() - return_dict = manager.dict() - - # Spawn CFG workers - mp.spawn( - _run_cfg_worker_i2v, - args=(cfg_size, CHECKPOINT_PATH, inputs_cpu, return_dict), - nprocs=cfg_size, - join=True, - ) - - # Get CFG parallel output from rank 0 - cfg_parallel_output = return_dict["output"].to("cuda:0") - print(f" ✓ CFG parallel output shape: {cfg_parallel_output.shape}") - - # Compare outputs - print("\n[Comparison] I2V CFG Parallel vs Standard CFG:") - baseline_float = baseline_output.float() - cfg_parallel_float = cfg_parallel_output.float() - - cos_sim = F.cosine_similarity( - cfg_parallel_float.flatten(), baseline_float.flatten(), dim=0 - ).item() - - max_diff = torch.max(torch.abs(cfg_parallel_float - baseline_float)).item() - mean_diff = torch.mean(torch.abs(cfg_parallel_float - baseline_float)).item() - - print(f" Cosine similarity: {cos_sim:.6f}") - print(f" Max absolute difference: {max_diff:.6f}") - print(f" Mean absolute difference: {mean_diff:.6f}") - print( - f" CFG parallel range: [{cfg_parallel_float.min():.4f}, {cfg_parallel_float.max():.4f}]" - ) - print(f" Baseline range: [{baseline_float.min():.4f}, {baseline_float.max():.4f}]") - - assert cos_sim > 0.99, ( - f"I2V CFG parallel cosine similarity {cos_sim:.6f} below threshold 0.99. " - f"CFG Parallelism does not match standard CFG baseline." - ) - - print("\n[PASS] I2V CFG Parallelism (cfg_size=2) validated!") - print(" ✓ CFG parallel produces same output as standard CFG") - print(" ✓ Prompt splitting and all-gather working correctly") - print(" ✓ Image embeddings handled correctly") - print("=" * 80) - - torch.cuda.empty_cache() - - -# ============================================================================ -# COMBINED OPTIMIZATIONS TESTS (I2V) -# ============================================================================ - - -@pytest.mark.parallelism -class TestWanI2VCombinedOptimizations(unittest.TestCase): - """Test all optimizations combined for I2V: FP8 + TeaCache + TRTLLM + CFG Parallelism.""" - - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - def setUp(self): - """Set up test fixtures and skip if checkpoint not available.""" - torch.manual_seed(42) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(42) - if not CHECKPOINT_PATH or not os.path.exists(CHECKPOINT_PATH): - self.skipTest( - "Checkpoint not available. Set DIFFUSION_MODEL_PATH environment variable." - ) - - def tearDown(self): - """Clean up GPU memory.""" - import gc - - gc.collect() - if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.synchronize() - - def test_all_optimizations_combined(self): - """Test I2V FP8 + TeaCache + TRTLLM attention + CFG=2 combined correctness. - - This test validates that all optimizations work together correctly for I2V: - 1. FP8 per-tensor quantization for reduced memory/compute - 2. TeaCache for caching repeated computations - 3. TRTLLM attention backend for optimized attention kernels - 4. CFG Parallelism (cfg_size=2) for distributed CFG computation - - We compare against a standard CFG baseline with relaxed thresholds. - """ - num_gpus = torch.cuda.device_count() - if num_gpus < 2: - pytest.skip("Combined optimization test requires at least 2 GPUs for CFG parallel") - if not is_wan21_checkpoint(): - pytest.skip( - "This test requires Wan 2.1 checkpoint. Use DIFFUSION_MODEL_PATH with '2.1' in the path." - ) - - print("\n" + "=" * 80) - print("I2V ALL OPTIMIZATIONS COMBINED TEST") - print("FP8 + TeaCache + TRTLLM Attention + CFG Parallelism (cfg_size=2)") - print("=" * 80) - - # Load baseline on GPU 0 (no optimizations, standard CFG) - print("\n[1/3] Loading I2V baseline on GPU 0 (standard CFG, no optimizations)...") - args_baseline = VisualGenArgs( - checkpoint_path=CHECKPOINT_PATH, - device="cuda:0", - dtype="bfloat16", - skip_components=SKIP_MINIMAL, - parallel=ParallelConfig(dit_cfg_size=1), # Standard CFG - ) - pipeline_baseline = PipelineLoader(args_baseline).load(skip_warmup=True) - config = pipeline_baseline.transformer.model_config.pretrained_config - - # Reset torch compile state - torch._dynamo.reset() - - # Create FIXED test inputs - print("\n[2/3] Creating fixed test inputs...") - torch.manual_seed(42) - batch_size, num_frames, height, width, seq_len = 1, 1, 64, 64, 128 - - latents = torch.randn( - batch_size, - config.in_channels, - num_frames, - height, - width, - dtype=torch.bfloat16, - device="cuda:0", - ) - timestep = torch.tensor([500], dtype=torch.long, device="cuda:0") - prompt_embeds = torch.randn( - batch_size, seq_len, config.text_dim, dtype=torch.bfloat16, device="cuda:0" - ) - neg_prompt_embeds = torch.randn( - batch_size, seq_len, config.text_dim, dtype=torch.bfloat16, device="cuda:0" - ) - - # I2V-specific: Create image embeddings - image_embeds = None - image_dim = getattr(config, "image_dim", getattr(config, "image_embed_dim", None)) - if image_dim is not None: - image_seq_len = 256 - image_embeds = torch.randn( - batch_size, image_seq_len, image_dim, dtype=torch.bfloat16, device="cuda:0" - ) - - # Setup standard CFG config - cfg_config_baseline = pipeline_baseline._setup_cfg_config( - guidance_scale=5.0, - prompt_embeds=prompt_embeds, - neg_prompt_embeds=neg_prompt_embeds, - ) - - # Run baseline standard CFG - print(" Running baseline (standard CFG)...") - - def forward_fn_baseline( - latents, extra_stream_latents, timestep, encoder_hidden_states, extra_tensors - ): - return pipeline_baseline.transformer( # noqa: F821 - hidden_states=latents, - timestep=timestep, - encoder_hidden_states=encoder_hidden_states, - encoder_hidden_states_image=extra_tensors.get("encoder_hidden_states_image"), - ) - - with torch.no_grad(): - local_extras = ( - {"encoder_hidden_states_image": image_embeds} if image_embeds is not None else {} - ) - baseline_output, _, _, _ = pipeline_baseline._denoise_step_standard( - latents=latents.clone(), - extra_stream_latents={}, - timestep=timestep, - prompt_embeds=cfg_config_baseline["prompt_embeds"], - forward_fn=forward_fn_baseline, - guidance_scale=5.0, - guidance_rescale=0.0, - local_extras=local_extras, - ) - - print(f" ✓ Baseline output shape: {baseline_output.shape}") - print(f" ✓ Baseline range: [{baseline_output.min():.4f}, {baseline_output.max():.4f}]") - - # Cleanup baseline - del pipeline_baseline - torch.cuda.empty_cache() - - # Run with ALL optimizations (FP8 + TeaCache + TRTLLM + CFG=2) - print("\n[3/3] Running with ALL optimizations (FP8 + TeaCache + TRTLLM + CFG=2)...") - cfg_size = 2 - - inputs_cpu = [ - prompt_embeds.cpu(), - neg_prompt_embeds.cpu(), - latents.cpu(), - timestep.cpu(), - image_embeds.cpu() if image_embeds is not None else None, - ] - - manager = mp.Manager() - return_dict = manager.dict() - - # Spawn workers with all optimizations - mp.spawn( - _run_all_optimizations_worker_i2v, - args=(cfg_size, CHECKPOINT_PATH, inputs_cpu, return_dict), - nprocs=cfg_size, - join=True, - ) - - # Get combined optimization output - combined_output = return_dict["output"].to("cuda:0") - print(f" ✓ Combined optimization output shape: {combined_output.shape}") - - # Compare outputs (relaxed threshold for combined optimizations) - print("\n[Comparison] I2V Combined Optimizations vs Baseline:") - baseline_float = baseline_output.float() - combined_float = combined_output.float() - - cos_sim = F.cosine_similarity( - combined_float.flatten(), baseline_float.flatten(), dim=0 - ).item() - - max_diff = torch.max(torch.abs(combined_float - baseline_float)).item() - mean_diff = torch.mean(torch.abs(combined_float - baseline_float)).item() - - print(f" Cosine similarity: {cos_sim:.6f}") - print(f" Max absolute difference: {max_diff:.6f}") - print(f" Mean absolute difference: {mean_diff:.6f}") - - # Relaxed threshold (0.95) since multiple optimizations compound numerical differences - assert cos_sim > 0.95, ( - f"I2V combined optimization cosine similarity {cos_sim:.6f} below threshold 0.95" - ) - - print("\n[PASS] All optimizations (FP8 + TeaCache + TRTLLM + CFG) validated!") - print(" ✓ All optimizations work together correctly") - print(" ✓ I2V image embeddings handled correctly with all opts") - print("=" * 80) - - torch.cuda.empty_cache() - - -if __name__ == "__main__": - import unittest - - unittest.main(verbosity=2) diff --git a/tests/unittest/_torch/visual_gen/test_wan_transformer.py b/tests/unittest/_torch/visual_gen/test_wan_transformer.py new file mode 100644 index 000000000000..73fcb1fe2796 --- /dev/null +++ b/tests/unittest/_torch/visual_gen/test_wan_transformer.py @@ -0,0 +1,468 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Correctness tests for WanTransformer3DModel against the HuggingFace reference. + +Compares our implementation against diffusers WanTransformer3DModel using +identical weights and inputs, asserting cosine similarity >= 0.99. + +Two modes are tested: + - T2V (1.3B): 480x832, no image conditioning + - I2V (14B 480P): 480x832, CLIP image conditioning + +Run all: + pytest tests/unittest/_torch/visual_gen/test_wan_transformer.py -v -s + +Run one: + pytest tests/unittest/_torch/visual_gen/test_wan_transformer.py -v -s -k t2v + pytest tests/unittest/_torch/visual_gen/test_wan_transformer.py -v -s -k i2v + +Override checkpoint paths: + DIFFUSION_MODEL_PATH_WAN21_1_3B=/path/to/Wan2.1-T2V-1.3B-Diffusers \\ + DIFFUSION_MODEL_PATH_WAN21_I2V_480P=/path/to/Wan2.1-I2V-14B-480P-Diffusers \\ + pytest tests/unittest/_torch/visual_gen/test_wan_transformer.py -v -s +""" + +import os + +os.environ["TLLM_DISABLE_MPI"] = "1" + +import gc +from pathlib import Path +from types import SimpleNamespace + +import pytest +import torch +import torch.nn.functional as F +from diffusers import WanTransformer3DModel as HFWanTransformer3DModel + +from tensorrt_llm._torch.modules.linear import Linear +from tensorrt_llm._torch.visual_gen.config import DiffusionModelConfig, VisualGenArgs +from tensorrt_llm._torch.visual_gen.models.wan.transformer_wan import WanTransformer3DModel +from tensorrt_llm.models.modeling_utils import QuantConfig + + +@pytest.fixture(autouse=True, scope="module") +def _cleanup_mpi_env(): + yield + os.environ.pop("TLLM_DISABLE_MPI", None) + + +@pytest.fixture(autouse=True) +def _cleanup_gpu(): + gc.collect() + torch.cuda.empty_cache() + yield + gc.collect() + torch.cuda.empty_cache() + + +# ============================================================================ +# Path helpers +# ============================================================================ + + +def _llm_models_root() -> str: + """Return LLM_MODELS_ROOT path if set in env, assert when it's set but not a valid path.""" + root = Path("/home/scratch.trt_llm_data_ci/llm-models/") + if "LLM_MODELS_ROOT" in os.environ: + root = Path(os.environ["LLM_MODELS_ROOT"]) + if not root.exists(): + root = Path("/scratch.trt_llm_data/llm-models/") + assert root.exists(), ( + "Set LLM_MODELS_ROOT or ensure /home/scratch.trt_llm_data_ci/llm-models/ is accessible." + ) + return str(root) + + +def _checkpoint(env_var: str, default_name: str) -> str: + return os.environ.get(env_var) or os.path.join(_llm_models_root(), default_name) + + +WAN21_1_3B_PATH = _checkpoint("DIFFUSION_MODEL_PATH_WAN21_1_3B", "Wan2.1-T2V-1.3B-Diffusers") +WAN21_I2V_480P_PATH = _checkpoint( + "DIFFUSION_MODEL_PATH_WAN21_I2V_480P", "Wan2.1-I2V-14B-480P-Diffusers" +) + + +COS_SIM_THRESHOLD = 0.99 +DEVICE = "cuda" +DTYPE = torch.bfloat16 + + +# ============================================================================ +# Helpers +# ============================================================================ + + +def _cosine_similarity(a: torch.Tensor, b: torch.Tensor) -> float: + a = a.reshape(-1).float() + b = b.reshape(-1).float() + return F.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item() + + +def _load_models(checkpoint_dir: str): + """Load both the HF reference transformer and our transformer from the same checkpoint. + + Both models receive identical weights (loaded from the HF state_dict). + Returns (hf_model, our_model) in eval mode on DEVICE with DTYPE. + """ + hf_model = ( + HFWanTransformer3DModel.from_pretrained( + checkpoint_dir, + subfolder="transformer", + torch_dtype=DTYPE, + ) + .to(DEVICE) + .eval() + ) + + args = VisualGenArgs( + checkpoint_path=checkpoint_dir, + device=DEVICE, + dtype="bfloat16", + ) + model_config = DiffusionModelConfig.from_pretrained(checkpoint_dir, args=args) + our_model = WanTransformer3DModel(model_config=model_config).to(DEVICE).eval() + + # Initialize our model with the exact same weights as the HF model. + our_model.load_weights({k: v for k, v in hf_model.state_dict().items()}) + # Cast non-Linear embedder submodules (time_embedder, text_embedder) to target dtype. + our_model.post_load_weights() + + return hf_model, our_model + + +# ============================================================================ +# No-checkpoint unit tests +# ============================================================================ + +WAN_1_3B_CONFIG = { + "attention_head_dim": 128, + "cross_attn_norm": True, + "eps": 1e-06, + "ffn_dim": 8960, + "freq_dim": 256, + "in_channels": 16, + "num_attention_heads": 12, + "num_layers": 30, + "out_channels": 16, + "patch_size": [1, 2, 2], + "qk_norm": "rms_norm_across_heads", + "rope_max_seq_len": 1024, + "text_dim": 4096, + "torch_dtype": "bfloat16", +} + + +def _make_model_config(config_dict: dict) -> DiffusionModelConfig: + return DiffusionModelConfig( + pretrained_config=SimpleNamespace(**config_dict), + quant_config=QuantConfig(), + quant_config_dict=None, + dynamic_weight_quant=False, + force_dynamic_quantization=False, + skip_create_weights_in_init=False, + ) + + +def _is_fp32_layernorm_param(name: str) -> bool: + if not name.endswith((".weight", ".bias")): + return False + if ".norm" in name and "blocks." in name: + return any(p in name.split(".") for p in ("norm1", "norm2", "norm3")) + if name in ("norm_out.weight", "norm_out.bias"): + return True + if name.startswith("condition_embedder.") and ".norm" in name: + return True + return False + + +def _load_weights_from_hf(trtllm_model: WanTransformer3DModel, hf_sd: dict) -> int: + """Copy HuggingFace weights into TRT-LLM model. Returns number of tensors loaded.""" + loaded = 0 + + def _load_linear(module, hf_key): + nonlocal loaded + if f"{hf_key}.weight" not in hf_sd: + return + wd = {"weight": hf_sd[f"{hf_key}.weight"]} + if f"{hf_key}.bias" in hf_sd: + wd["bias"] = hf_sd[f"{hf_key}.bias"] + module.load_weights([wd]) + loaded += 1 + + for name, module in trtllm_model.named_modules(): + if isinstance(module, Linear): + if "attn1.qkv_proj" in name: + base = name.replace(".qkv_proj", "") + q, k, v = f"{base}.to_q", f"{base}.to_k", f"{base}.to_v" + if f"{q}.weight" in hf_sd: + + def _qkv_entry(key): + d = {"weight": hf_sd[f"{key}.weight"]} + if f"{key}.bias" in hf_sd: + d["bias"] = hf_sd[f"{key}.bias"] + return d + + module.load_weights([_qkv_entry(q), _qkv_entry(k), _qkv_entry(v)]) + loaded += 1 + elif "ffn.up_proj" in name: + _load_linear(module, name.replace(".ffn.up_proj", ".ffn.net.0.proj")) + elif "ffn.down_proj" in name: + _load_linear(module, name.replace(".ffn.down_proj", ".ffn.net.2")) + else: + _load_linear(module, name) + elif hasattr(module, "weight") and f"{name}.weight" in hf_sd: + with torch.no_grad(): + module.weight.copy_(hf_sd[f"{name}.weight"]) + if getattr(module, "bias", None) is not None and f"{name}.bias" in hf_sd: + module.bias.copy_(hf_sd[f"{name}.bias"]) + loaded += 1 + + for name, param in trtllm_model.named_parameters(): + if "scale_shift_table" in name and name in hf_sd: + with torch.no_grad(): + param.copy_(hf_sd[name].view(param.shape)) + loaded += 1 + + return loaded + + +def _transformer_inputs(device: str = "cuda"): + """Small synthetic inputs for single-step transformer tests.""" + torch.manual_seed(42) + return ( + torch.randn(1, 16, 1, 64, 64, dtype=torch.bfloat16, device=device), + torch.tensor([500], dtype=torch.long, device=device), + torch.randn(1, 128, 4096, dtype=torch.bfloat16, device=device), + ) + + +@pytest.mark.integration +@pytest.mark.wan_t2v +class TestWanUnit: + """Fast unit tests — random weights, no checkpoint required.""" + + DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + def test_model_structure(self): + """FFN uses up_proj / down_proj naming (TRT-LLM convention).""" + cfg = { + **WAN_1_3B_CONFIG, + "num_layers": 1, + "hidden_size": WAN_1_3B_CONFIG["num_attention_heads"] + * WAN_1_3B_CONFIG["attention_head_dim"], + } + model = WanTransformer3DModel(model_config=_make_model_config(cfg)) + names = list(model.state_dict()) + assert any("ffn.up_proj" in n for n in names), "Missing ffn.up_proj" + assert any("ffn.down_proj" in n for n in names), "Missing ffn.down_proj" + + def test_sanity_forward(self): + """Model runs a forward pass without error (2 layers, random weights).""" + cfg = { + **WAN_1_3B_CONFIG, + "num_layers": 2, + "hidden_size": WAN_1_3B_CONFIG["num_attention_heads"] + * WAN_1_3B_CONFIG["attention_head_dim"], + } + model = ( + WanTransformer3DModel(model_config=_make_model_config(cfg)) + .to(self.DEVICE, dtype=torch.bfloat16) + .eval() + ) + hs, ts, enc = _transformer_inputs(str(self.DEVICE)) + with torch.inference_mode(): + out = model(hidden_states=hs, timestep=ts, encoder_hidden_states=enc) + assert out.shape == hs.shape + + @torch.no_grad() + def test_allclose_to_hf(self): + """TRT-LLM output matches HuggingFace when weights are shared (2 layers, random init).""" + cfg = { + **WAN_1_3B_CONFIG, + "num_layers": 2, + "hidden_size": WAN_1_3B_CONFIG["num_attention_heads"] + * WAN_1_3B_CONFIG["attention_head_dim"], + } + dtype = torch.bfloat16 + + hf = ( + HFWanTransformer3DModel( + patch_size=cfg["patch_size"], + num_attention_heads=cfg["num_attention_heads"], + attention_head_dim=cfg["attention_head_dim"], + in_channels=cfg["in_channels"], + out_channels=cfg["out_channels"], + text_dim=cfg["text_dim"], + freq_dim=cfg["freq_dim"], + ffn_dim=cfg["ffn_dim"], + num_layers=cfg["num_layers"], + cross_attn_norm=cfg["cross_attn_norm"], + qk_norm=cfg["qk_norm"], + eps=cfg["eps"], + ) + .to(self.DEVICE, dtype=dtype) + .eval() + ) + trtllm = ( + WanTransformer3DModel(model_config=_make_model_config(cfg)) + .to(self.DEVICE, dtype=dtype) + .eval() + ) + loaded = _load_weights_from_hf(trtllm, hf.state_dict()) + print(f"\n Loaded {loaded} weight tensors HF → TRT-LLM") + + hs, ts, enc = _transformer_inputs(str(self.DEVICE)) + with torch.backends.cuda.sdp_kernel( + enable_flash=False, enable_math=True, enable_mem_efficient=False + ): + hf_out = hf( + hidden_states=hs, timestep=ts, encoder_hidden_states=enc, return_dict=False + )[0].float() + trt_out = trtllm(hidden_states=hs, timestep=ts, encoder_hidden_states=enc).float() + + torch.testing.assert_close(trt_out, hf_out, atol=0.4, rtol=0.4) + + +# ============================================================================ +# T2V correctness test — Wan2.1-T2V-1.3B +# ============================================================================ + + +@pytest.mark.integration +@pytest.mark.wan_t2v +class TestWanT2VTransformerCorrectness: + """Output of our T2V transformer must match HF WanTransformer3DModel. + + Model: Wan2.1-T2V-1.3B-Diffusers + hidden_size=1536, 30 layers, 12 heads, patch_size=(1,2,2) + + Input shape: + latent (1, 16, 1, 60, 104) — 480/8=60, 832/8=104, 1 latent frame + text (1, 77, 4096) + timestep (1,) + + Post-patch sequence length: 1 * (60/2) * (104/2) = 1560 tokens. + """ + + @pytest.fixture(scope="class") + def t2v_models(self): + if not os.path.exists(WAN21_1_3B_PATH): + pytest.skip(f"Checkpoint not found: {WAN21_1_3B_PATH}") + hf_model, our_model = _load_models(WAN21_1_3B_PATH) + yield hf_model, our_model + del hf_model, our_model + torch.cuda.empty_cache() + + def test_cosine_similarity(self, t2v_models): + hf_model, our_model = t2v_models + + torch.manual_seed(42) + B, C, T, H, W = 1, 16, 1, 60, 104 + text_seq_len = 77 + + hidden_states = torch.randn(B, C, T, H, W, device=DEVICE, dtype=DTYPE) + timestep = torch.tensor([500.0], device=DEVICE, dtype=torch.float32) + encoder_hidden_states = torch.randn(B, text_seq_len, 4096, device=DEVICE, dtype=DTYPE) + + with torch.no_grad(): + hf_out = hf_model( + hidden_states=hidden_states, + timestep=timestep, + encoder_hidden_states=encoder_hidden_states, + return_dict=False, + )[0] + + our_out = our_model( + hidden_states=hidden_states, + timestep=timestep, + encoder_hidden_states=encoder_hidden_states, + ) + + cos_sim = _cosine_similarity(our_out, hf_out) + max_diff = (our_out.float() - hf_out.float()).abs().max().item() + print(f"\n T2V 480x832 cosine_similarity={cos_sim:.6f} max_diff={max_diff:.6f}") + print(f" our_out.shape={our_out.shape} hf_out.shape={hf_out.shape}") + + assert cos_sim >= COS_SIM_THRESHOLD, ( + f"T2V cosine similarity {cos_sim:.6f} < {COS_SIM_THRESHOLD}. max_diff={max_diff:.6f}" + ) + + +# ============================================================================ +# I2V correctness test — Wan2.1-I2V-14B-480P +# ============================================================================ + + +@pytest.mark.integration +@pytest.mark.wan_i2v +class TestWanI2VTransformerCorrectness: + """Output of our I2V transformer must match HF WanTransformer3DModel. + + Model: Wan2.1-I2V-14B-480P-Diffusers + hidden_size=5120, 40 layers, 40 heads, patch_size=(1,2,2) + in_channels=36 (16 video + 4 mask + 16 condition), add_k_proj present + + Input shape: + latent (1, 36, 1, 60, 104) — 36ch I2V, 480/8=60, 832/8=104 + text (1, 512, 4096) — 512 required by hardcoded I2V split + image_embeds (1, 257, 1280) — CLIP ViT-H/14 (256 patches + CLS) + timestep (1,) + + Post-patch sequence length: 1 * 30 * 52 = 1560 tokens. + Cross-attention: image context = total_len - 512 = 257 tokens, + text context = 512 tokens. + """ + + @pytest.fixture(scope="class") + def i2v_models(self): + if not WAN21_I2V_480P_PATH or not os.path.exists(WAN21_I2V_480P_PATH): + pytest.skip( + "Checkpoint not found. " + "Set DIFFUSION_MODEL_PATH_WAN21_I2V_480P=/path/to/Wan2.1-I2V-14B-480P-Diffusers" + ) + hf_model, our_model = _load_models(WAN21_I2V_480P_PATH) + yield hf_model, our_model + del hf_model, our_model + torch.cuda.empty_cache() + + def test_cosine_similarity(self, i2v_models): + hf_model, our_model = i2v_models + + torch.manual_seed(42) + B, C, T, H, W = 1, 36, 1, 60, 104 + text_seq_len = 512 # hardcoded split in I2V cross-attention + img_seq_len = 257 # CLIP ViT-H/14 tokens + img_embed_dim = 1280 # CLIP ViT-H/14 embed dim + + hidden_states = torch.randn(B, C, T, H, W, device=DEVICE, dtype=DTYPE) + timestep = torch.tensor([500.0], device=DEVICE, dtype=torch.float32) + encoder_hidden_states = torch.randn(B, text_seq_len, 4096, device=DEVICE, dtype=DTYPE) + image_embeds = torch.randn(B, img_seq_len, img_embed_dim, device=DEVICE, dtype=DTYPE) + + with torch.no_grad(): + hf_out = hf_model( + hidden_states=hidden_states, + timestep=timestep, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_image=image_embeds, + return_dict=False, + )[0] + + our_out = our_model( + hidden_states=hidden_states, + timestep=timestep, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_image=image_embeds, + ) + + cos_sim = _cosine_similarity(our_out, hf_out) + max_diff = (our_out.float() - hf_out.float()).abs().max().item() + print(f"\n I2V 480x832 cosine_similarity={cos_sim:.6f} max_diff={max_diff:.6f}") + print(f" our_out.shape={our_out.shape} hf_out.shape={hf_out.shape}") + + assert cos_sim >= COS_SIM_THRESHOLD, ( + f"I2V cosine similarity {cos_sim:.6f} < {COS_SIM_THRESHOLD}. max_diff={max_diff:.6f}" + ) From 4dd1a3dae6c7358dd14442ffb7a8aa71a7579a90 Mon Sep 17 00:00:00 2001 From: Yechan Kim <161688079+yechank-nvidia@users.noreply.github.com> Date: Wed, 22 Apr 2026 04:16:28 +0900 Subject: [PATCH 3/6] [None][fix] Fix kv_layout for FLASHINFER backend (#13190) Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com> --- tensorrt_llm/_torch/attention_backend/flashinfer.py | 7 ++++--- tensorrt_llm/_torch/models/modeling_radio.py | 10 +++++++++- .../modeling/test_modeling_nemotron_nano_v2_vl.py | 6 +++--- 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/tensorrt_llm/_torch/attention_backend/flashinfer.py b/tensorrt_llm/_torch/attention_backend/flashinfer.py index cd9856c35367..956c1aa96807 100644 --- a/tensorrt_llm/_torch/attention_backend/flashinfer.py +++ b/tensorrt_llm/_torch/attention_backend/flashinfer.py @@ -670,10 +670,11 @@ def forward_impl( attention_mask_data=attention_mask_data, ) wrapper = metadata.get_ragged_prefill_wrapper(plan_params) + # cuDNN's ragged prefill kernel assumes contiguous NHD tensors. wrapper.run( - q, - k, - v, + q.contiguous(), + k.contiguous(), + v.contiguous(), out=output.view(-1, self.num_heads, self.head_dim), ) return diff --git a/tensorrt_llm/_torch/models/modeling_radio.py b/tensorrt_llm/_torch/models/modeling_radio.py index 454dcc0801b6..02b88b65e97b 100644 --- a/tensorrt_llm/_torch/models/modeling_radio.py +++ b/tensorrt_llm/_torch/models/modeling_radio.py @@ -745,11 +745,19 @@ def __init__( self.metadata_cls = attention_utils.get_attention_backend( model_config.attn_backend).Metadata - self.attn_metadata = self.metadata_cls( + metadata_kwargs = dict( max_num_requests=8192, # TODO: Make this dynamic max_num_tokens=model_config.max_num_tokens, kv_cache_manager=None, ) + if model_config.attn_backend == "FLASHINFER": + # FlashInfer's original default kv_layout is "NHD". TRT-LLM changed + # the default to "HND" for paged KV cache paths (see PR #6917). + # For ModelingRadio ragged prefill (kv_cache_manager=None), we + # explicitly use "NHD" because ragged k/v tensors computed directly + # from input are always in NHD format ([tokens, heads, dim]). + metadata_kwargs["kv_layout"] = "NHD" + self.attn_metadata = self.metadata_cls(**metadata_kwargs) def prepare_attn_metadata(self, batch_size: int, seq_lengths: List[int], attn_metadata: AttentionMetadata): diff --git a/tests/unittest/_torch/modeling/test_modeling_nemotron_nano_v2_vl.py b/tests/unittest/_torch/modeling/test_modeling_nemotron_nano_v2_vl.py index 9ef7b39cd0d6..20bdd14dddaf 100644 --- a/tests/unittest/_torch/modeling/test_modeling_nemotron_nano_v2_vl.py +++ b/tests/unittest/_torch/modeling/test_modeling_nemotron_nano_v2_vl.py @@ -181,11 +181,11 @@ def test_nemotron_nano_v2_vl_model_sanity_check( "single": torch.tensor( [-8.9814e-01, -1.5258e-01, -7.6061e-04, -6.3735e-01, -3.1303e-02] ), - "multiple": torch.tensor([-0.4717, -0.7776, -0.0251, -1.2290, -1.0705]), + "multiple": torch.tensor([-0.5807, -0.7470, -0.0100, -0.1203, -0.0551]), }, "video": { - "single": torch.tensor([-1.4745, -0.0674, -1.4121, -0.2152, -1.6297]), - "multiple": torch.tensor([-0.9425, -0.2328, -0.0083, -1.6257, -0.6572]), + "single": torch.tensor([-0.6011, -0.0327, -0.8864, -0.3832, -0.5950]), + "multiple": torch.tensor([-0.4956, -0.8749, -0.0095, -1.2541, -0.9490]), }, } prompts = data_dict_fixture[modality][condition]["prompts"] From 7ad4dd15604dac89c94b386fb5dc13bbe89dbda2 Mon Sep 17 00:00:00 2001 From: tburt-nv <195370667+tburt-nv@users.noreply.github.com> Date: Tue, 21 Apr 2026 16:17:18 -0400 Subject: [PATCH 4/6] [None][chore] Update CI allowlist 2026-04-21 (#13289) Signed-off-by: Tyler Burt <195370667+tburt-nv@users.noreply.github.com> --- .github/workflows/blossom-ci.yml | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/.github/workflows/blossom-ci.yml b/.github/workflows/blossom-ci.yml index 70d39657f499..12b2b5405d24 100644 --- a/.github/workflows/blossom-ci.yml +++ b/.github/workflows/blossom-ci.yml @@ -51,6 +51,7 @@ jobs: "amirkl94", "amitz-nv", "amukkara", + "anikaj-eng", "anish-shanbhag", "arekay", "arysef", @@ -144,8 +145,8 @@ jobs: "Jackch-NV", "JadoTu", "jaedeok-nvidia", - "jdemouth-nvidia", "janbernloehr", + "jdemouth-nvidia", "JennyLiu-nv", "jershi425", "jgangani", @@ -172,8 +173,8 @@ jobs: "katec846", "Kefeng-Duan", "KingsleyLiu-NV", - "KrishnanPrash", "kris1025", + "KrishnanPrash", "kunlunl", "kxdc", "kyleliang-nv", @@ -303,6 +304,7 @@ jobs: "tijyojwad", "timlee0212", "timothygao8710", + "tingyangk", "Tom-Zheng", "tomeras91", "tongyuantongyu", @@ -371,8 +373,8 @@ jobs: "zerollzeng", "zhanga5", "zhangcl", - "zhaoyangwang-nvidia", "ZhanruiSunCh", + "zhaoyangwang-nvidia", "zhengd-nv", "zhenhuaw-me", "zheyuf", From 421422ff684838da4e05b2ed46fc760d5adf04b1 Mon Sep 17 00:00:00 2001 From: Erin <14718778+hchings@users.noreply.github.com> Date: Tue, 21 Apr 2026 13:23:41 -0700 Subject: [PATCH 5/6] [TRTLLM-10703][feature] abort, resume for Async RL in verl (#12272) Signed-off-by: Erin Ho <14718778+hchings@users.noreply.github.com> --- tensorrt_llm/_torch/async_llm.py | 17 +++++ tensorrt_llm/executor/ray_executor.py | 5 ++ tensorrt_llm/llmapi/rlhf_utils.py | 4 ++ tests/unittest/llmapi/test_async_llm.py | 94 +++++++++++++++++++++++++ 4 files changed, 120 insertions(+) diff --git a/tensorrt_llm/_torch/async_llm.py b/tensorrt_llm/_torch/async_llm.py index 8bd93bcdac97..d63ccb4ab3dd 100644 --- a/tensorrt_llm/_torch/async_llm.py +++ b/tensorrt_llm/_torch/async_llm.py @@ -36,6 +36,7 @@ def __init__( super().__init__(*args, **kwargs) self._async_initialized = False + self._paused = False async def setup_async(self): """Setup the LLM asynchronously.""" @@ -94,6 +95,22 @@ async def collective_rpc( method, args, kwargs, unique_reply_rank=unique_reply_rank, target_ranks=target_ranks ) + def generate_async(self, *args, **kwargs): + if self._paused: + raise RuntimeError( + "AsyncLLM is paused. Call resume_generation() before submitting new requests." + ) + return super().generate_async(*args, **kwargs) + + async def pause_generation(self) -> None: + """Abort all in-flight requests and block new ones until resume_generation() is called.""" + self._paused = True + self._executor.abort_all_requests() + + async def resume_generation(self) -> None: + """Allow new generation requests after a pause_generation() call.""" + self._paused = False + def __await__(self): return self.setup_async().__await__() diff --git a/tensorrt_llm/executor/ray_executor.py b/tensorrt_llm/executor/ray_executor.py index 97459a3325ad..5053b7dd0bed 100644 --- a/tensorrt_llm/executor/ray_executor.py +++ b/tensorrt_llm/executor/ray_executor.py @@ -309,6 +309,11 @@ def abort_request(self, request_id: int) -> None: async_call=False, request_id=request_id) + def abort_all_requests(self) -> None: + """Abort all active generation requests.""" + for result in list(self._results.values()): + result.abort() + def shutdown(self): if hasattr(self, '_shutdown_event') and self._shutdown_event.is_set(): return diff --git a/tensorrt_llm/llmapi/rlhf_utils.py b/tensorrt_llm/llmapi/rlhf_utils.py index e31ba18521b6..fdcb04f68c5a 100644 --- a/tensorrt_llm/llmapi/rlhf_utils.py +++ b/tensorrt_llm/llmapi/rlhf_utils.py @@ -146,6 +146,10 @@ def update_weights(self, ipc_handles: Optional[dict] = None): logger.error("Encountered an error in update_weights") raise e + def reset_prefix_cache(self) -> None: + """Invalidate the KV cache prefix reuse state after weight updates.""" + self.engine.reset_prefix_cache() + def check_weights_updated(self) -> bool: """Check if the weights are updated to 0.""" weights_updated = True diff --git a/tests/unittest/llmapi/test_async_llm.py b/tests/unittest/llmapi/test_async_llm.py index 5c4788ab2bf6..9468eedaa6d5 100644 --- a/tests/unittest/llmapi/test_async_llm.py +++ b/tests/unittest/llmapi/test_async_llm.py @@ -1,3 +1,4 @@ +import asyncio import os import pytest @@ -135,3 +136,96 @@ async def test_async_llm_placement_api(setup_ray_cluster, monkeypatch): llm.shutdown() if pg is not None: remove_placement_group(pg) + + +@pytest.mark.ray +@pytest.mark.asyncio +async def test_async_llm_reset_prefix_cache(): + llama_model_path = str(llm_models_root() / "llama-models-v2/TinyLlama-1.1B-Chat-v1.0") + kv_cache_config = KvCacheConfig(enable_block_reuse=True) + prompt = "The future of AI is " * 20 + sampling_params = SamplingParams(temperature=0, max_tokens=5, return_perf_metrics=True) + + async with AsyncLLM( + model=llama_model_path, + kv_cache_config=kv_cache_config, + cuda_graph_config=None, + ) as llm: + # Cold cache: first run should have no reused blocks + out1 = await llm.generate_async(prompt, sampling_params) + m1 = out1.outputs[0].request_perf_metrics + assert m1 is not None + assert m1.kv_cache_metrics.num_reused_blocks == 0, ( + f"Expected 0 reused blocks on cold cache, got {m1.kv_cache_metrics.num_reused_blocks}" + ) + + # Warm cache: same prompt should hit prefix cache + out2 = await llm.generate_async(prompt, sampling_params) + m2 = out2.outputs[0].request_perf_metrics + assert m2 is not None + assert m2.kv_cache_metrics.num_reused_blocks > 0, ( + f"Expected >0 reused blocks on warm cache, got {m2.kv_cache_metrics.num_reused_blocks}" + ) + + await llm.collective_rpc("reset_prefix_cache") + + # After reset: cache should be cold again + out3 = await llm.generate_async(prompt, sampling_params) + m3 = out3.outputs[0].request_perf_metrics + assert m3 is not None + assert m3.kv_cache_metrics.num_reused_blocks == 0, ( + f"Expected 0 reused blocks after reset_prefix_cache, " + f"got {m3.kv_cache_metrics.num_reused_blocks}" + ) + + +@pytest.mark.ray +@pytest.mark.asyncio +async def test_async_llm_pause_resume(): + llama_model_path = str(llm_models_root() / "llama-models-v2/TinyLlama-1.1B-Chat-v1.0") + prompt = "The future of AI is" + sampling_params = SamplingParams(temperature=0, max_tokens=10) + + async with AsyncLLM( + model=llama_model_path, + kv_cache_config=KvCacheConfig(enable_block_reuse=False), + cuda_graph_config=None, + ) as llm: + baseline = (await llm.generate_async(prompt, sampling_params)).outputs[0].text + assert baseline + + for _ in range(2): + await llm.pause_generation() + assert llm._paused + with pytest.raises(RuntimeError, match="paused"): + llm.generate_async(prompt, sampling_params) + + await llm.resume_generation() + assert not llm._paused + out = await llm.generate_async(prompt, sampling_params) + assert out.outputs[0].text == baseline + + +@pytest.mark.ray +@pytest.mark.asyncio +async def test_async_llm_pause_aborts_inflight(): + llama_model_path = str(llm_models_root() / "llama-models-v2/TinyLlama-1.1B-Chat-v1.0") + prompt = "The future of AI is" + inflight_params = SamplingParams(temperature=0, max_tokens=512) + normal_params = SamplingParams(temperature=0, max_tokens=10) + + async with AsyncLLM( + model=llama_model_path, + kv_cache_config=KvCacheConfig(enable_block_reuse=False), + cuda_graph_config=None, + ) as llm: + inflight = llm.generate_async(prompt, inflight_params) + + await llm.pause_generation() + + result = await asyncio.wait_for(inflight, timeout=30.0) + assert result.aborted + + await llm.resume_generation() + out = await llm.generate_async(prompt, normal_params) + assert out.outputs[0].text From 6e5a3392b4c9985ce6edc115b330904101c78ccd Mon Sep 17 00:00:00 2001 From: o-stoner <245287810+o-stoner@users.noreply.github.com> Date: Tue, 21 Apr 2026 13:38:33 -0700 Subject: [PATCH 6/6] [TRTLLM-12127][fix] VisualGen metadata updates (#12862) Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com> --- .../visual_gen/attention_backend/trtllm.py | 45 ++++++++++++------- .../visual_gen/attention_backend/utils.py | 14 ++++++ tensorrt_llm/_torch/visual_gen/config.py | 11 +++++ .../models/ltx2/transformer_ltx2.py | 2 + .../_torch/visual_gen/modules/attention.py | 4 ++ .../visual_gen/multi_gpu/test_flux_ulysses.py | 4 ++ .../visual_gen/test_attention_integration.py | 9 +++- .../_torch/visual_gen/test_attention_perf.py | 9 +++- .../_torch/visual_gen/test_flux_attention.py | 8 +++- .../_torch/visual_gen/test_ltx2_attention.py | 8 +++- 10 files changed, 95 insertions(+), 19 deletions(-) diff --git a/tensorrt_llm/_torch/visual_gen/attention_backend/trtllm.py b/tensorrt_llm/_torch/visual_gen/attention_backend/trtllm.py index ec0e4b76142c..b5b72d76c00c 100644 --- a/tensorrt_llm/_torch/visual_gen/attention_backend/trtllm.py +++ b/tensorrt_llm/_torch/visual_gen/attention_backend/trtllm.py @@ -45,6 +45,8 @@ class TrtllmAttentionMetadata: max_batch_size: Initial batch size hint. Will grow automatically if exceeded. max_seq_len: Initial sequence length hint. Will grow automatically if exceeded. device: Target device for tensors. + attention_metadata_state: Mutable model-scoped state shared by all + attention layers in one model instance. """ def __init__( @@ -52,18 +54,21 @@ def __init__( max_batch_size: int = 16, max_seq_len: int = 4096, device: Optional[torch.device] = None, + attention_metadata_state: Optional[dict] = None, ): # These are initial hints, not hard limits - capacity grows as needed self.max_batch_size = max_batch_size self.max_seq_len = max_seq_len self.device = device or torch.device("cuda") + if attention_metadata_state is None: + raise ValueError( + "TRTLLM attention requires `attention_metadata_state` to be provided " + "by visual-gen config for model-scoped metadata sharing." + ) + self._metadata_state = attention_metadata_state # Lazily created BaseTrtllmAttentionMetadata - self._metadata: Optional[BaseTrtllmAttentionMetadata] = None - - # Track allocated capacity - self._allocated_batch_size = 0 - self._allocated_max_seq_len = 0 + self._metadata: Optional[BaseTrtllmAttentionMetadata] = self._metadata_state["metadata"] # Track prepared state self._cached_seq_lens: Optional[torch.Tensor] = None @@ -71,14 +76,20 @@ def __init__( def _needs_new_metadata(self, batch_size: int, max_seq_len: int) -> bool: """Check if we need to create new metadata (capacity change).""" + metadata = self._metadata_state["metadata"] + allocated_batch_size, allocated_max_seq_len = self._metadata_state["capacity"] return ( - self._metadata is None - or batch_size > self._allocated_batch_size - or max_seq_len > self._allocated_max_seq_len + metadata is None + or batch_size > allocated_batch_size + or max_seq_len > allocated_max_seq_len ) def _needs_prepare(self, batch_size: int, seq_lens: torch.Tensor) -> bool: - """Check if we need to call prepare() (seq_lens changed).""" + """Check if we need to call prepare() (seq_lens changed). + + Assumes uniform sequence length per batch; if per-sample lengths vary, + we may need to check seq_lens tensor instead. + """ if not self._prepared: return True if self._cached_seq_lens is None: @@ -89,9 +100,9 @@ def _needs_prepare(self, batch_size: int, seq_lens: torch.Tensor) -> bool: def _create_metadata(self, batch_size: int, max_seq_len: int) -> None: """Create new metadata with given capacity.""" - # Allocate with some headroom to avoid frequent reallocation - alloc_batch = max(batch_size, self._allocated_batch_size) - alloc_seq_len = max(max_seq_len, self._allocated_max_seq_len) + prev_batch, prev_seq = self._metadata_state["capacity"] + alloc_batch = max(batch_size, prev_batch) + alloc_seq_len = max(max_seq_len, prev_seq) self._metadata = BaseTrtllmAttentionMetadata( max_num_requests=alloc_batch, @@ -102,8 +113,8 @@ def _create_metadata(self, batch_size: int, max_seq_len: int) -> None: runtime_features=AttentionRuntimeFeatures(), ) - self._allocated_batch_size = alloc_batch - self._allocated_max_seq_len = alloc_seq_len + self._metadata_state["metadata"] = self._metadata + self._metadata_state["capacity"] = (alloc_batch, alloc_seq_len) self._prepared = False # Reset prepare state on new metadata def prepare( @@ -116,7 +127,7 @@ def prepare( Lazy behavior: - Creates metadata only when capacity needs increase - - Calls prepare() only when seq_lens actually change + - Calls prepare() only when (batch_size, max_seq_len) actually change """ if isinstance(seq_lens, int): seq_lens_tensor = torch.full((batch_size,), seq_lens, dtype=torch.int32) @@ -127,6 +138,8 @@ def prepare( if self._needs_new_metadata(batch_size, max_seq_len): self._create_metadata(batch_size, max_seq_len) + else: + self._metadata = self._metadata_state["metadata"] if self._needs_prepare(batch_size, seq_lens_tensor): self._metadata.seq_lens = seq_lens_tensor @@ -165,6 +178,7 @@ def __init__( dtype: Optional[torch.dtype] = None, max_batch_size: int = 16, max_seq_len: int = 4096, + attention_metadata_state: Optional[dict] = None, ): num_kv_heads = num_kv_heads or num_heads @@ -183,6 +197,7 @@ def __init__( self.metadata = TrtllmAttentionMetadata( max_batch_size=max_batch_size, max_seq_len=max_seq_len, + attention_metadata_state=attention_metadata_state, ) # Needed to work with torch compile cause of attention metadata diff --git a/tensorrt_llm/_torch/visual_gen/attention_backend/utils.py b/tensorrt_llm/_torch/visual_gen/attention_backend/utils.py index afdfbfe156c2..fe8b8dfca2e1 100644 --- a/tensorrt_llm/_torch/visual_gen/attention_backend/utils.py +++ b/tensorrt_llm/_torch/visual_gen/attention_backend/utils.py @@ -26,6 +26,7 @@ from tensorrt_llm.models.modeling_utils import QuantConfig +from ..config import AttentionConfig from .interface import AttentionBackend @@ -77,6 +78,8 @@ def create_attention( dtype: Optional[torch.dtype] = None, max_batch_size: int = 16, max_seq_len: int = 4096, + attention_config: Optional[AttentionConfig] = None, + attention_metadata_state: Optional[dict] = None, **kwargs, ) -> AttentionBackend: """ @@ -97,6 +100,9 @@ def create_attention( will automatically reallocate if larger batches are encountered. max_seq_len: Initial sequence length for metadata pre-allocation. The backend will automatically reallocate if longer sequences are encountered. + attention_config: Optional AttentionConfig + attention_metadata_state: Optional model-scoped metadata state from + visual-gen config. Required for TRTLLM backend. **kwargs: Additional backend-specific arguments Returns: @@ -104,6 +110,14 @@ def create_attention( """ attn_cls = get_visual_gen_attention_backend(backend) + if backend.upper() == "TRTLLM": + if attention_metadata_state is None: + raise ValueError( + "TRTLLM backend requires `attention_metadata_state` from " + "DiffusionModelConfig; creation path must not allocate metadata implicitly." + ) + kwargs["attention_metadata_state"] = attention_metadata_state + return attn_cls( layer_idx=layer_idx, num_heads=num_heads, diff --git a/tensorrt_llm/_torch/visual_gen/config.py b/tensorrt_llm/_torch/visual_gen/config.py index 1a259b8b0d2d..db3c71c3462c 100644 --- a/tensorrt_llm/_torch/visual_gen/config.py +++ b/tensorrt_llm/_torch/visual_gen/config.py @@ -536,6 +536,11 @@ def discover_pipeline_components(checkpoint_path: Path) -> Dict[str, Path]: return components +def create_attention_metadata_state() -> Dict[str, Any]: + """Create model-scoped attention metadata state for TRTLLM visual-gen backend.""" + return {"metadata": None, "capacity": (0, 0)} + + # ============================================================================= # DiffusionModelConfig - Internal configuration (merged/parsed) # ============================================================================= @@ -579,6 +584,7 @@ class DiffusionModelConfig(BaseModel): cuda_graph: CudaGraphConfig = PydanticField(default_factory=CudaGraphConfig) pipeline: PipelineConfig = PydanticField(default_factory=PipelineConfig) attention: AttentionConfig = PydanticField(default_factory=AttentionConfig) + attention_metadata_state: Optional[Dict[str, Any]] = None parallel: ParallelConfig = PydanticField(default_factory=ParallelConfig) cache: Optional[CacheConfig] = None @@ -935,6 +941,10 @@ def from_pretrained( NVFP4LinearMethod.use_tunable_quantize = True + attention_metadata_state = ( + create_attention_metadata_state() if attention_cfg.backend == "TRTLLM" else None + ) + return cls( pretrained_config=pretrained_config, quant_config=quant_config, @@ -947,6 +957,7 @@ def from_pretrained( cuda_graph=cuda_graph_cfg, pipeline=pipeline_cfg, attention=attention_cfg, + attention_metadata_state=attention_metadata_state, parallel=parallel_cfg, cache=cache_cfg, skip_create_weights_in_init=True, diff --git a/tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py b/tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py index 084271436f6a..b02ea5855412 100644 --- a/tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py +++ b/tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py @@ -131,6 +131,8 @@ def __init__( num_kv_heads=self.num_key_value_heads, quant_config=self.quant_config, dtype=self.dtype, + attention_config=config.attention, + attention_metadata_state=config.attention_metadata_state, ) self._has_dual_attn = True diff --git a/tensorrt_llm/_torch/visual_gen/modules/attention.py b/tensorrt_llm/_torch/visual_gen/modules/attention.py index f922f917affa..0b9aae75f0cc 100644 --- a/tensorrt_llm/_torch/visual_gen/modules/attention.py +++ b/tensorrt_llm/_torch/visual_gen/modules/attention.py @@ -95,6 +95,8 @@ def __init__( self._init_qkv_proj() + attention_metadata_state = getattr(config, "attention_metadata_state", None) + if self.qk_norm: # "full": norm over all heads combined (e.g. WAN, dim=q_dim) # "per_head": norm over each head independently (e.g. FLUX, dim=head_dim) @@ -141,6 +143,8 @@ def __init__( num_kv_heads=backend_num_kv_heads, quant_config=self.quant_config, dtype=self.dtype, + attention_config=config.attention, + attention_metadata_state=attention_metadata_state, ) # Wrap with parallelism strategies (orthogonal to backend choice) diff --git a/tests/unittest/_torch/visual_gen/multi_gpu/test_flux_ulysses.py b/tests/unittest/_torch/visual_gen/multi_gpu/test_flux_ulysses.py index 9d1aa69ee0f1..23d3ed87a21d 100644 --- a/tests/unittest/_torch/visual_gen/multi_gpu/test_flux_ulysses.py +++ b/tests/unittest/_torch/visual_gen/multi_gpu/test_flux_ulysses.py @@ -27,6 +27,7 @@ AttentionConfig, DiffusionModelConfig, TorchCompileConfig, + create_attention_metadata_state, ) from tensorrt_llm._torch.visual_gen.mapping import VisualGenMapping from tensorrt_llm._utils import get_free_port @@ -152,6 +153,9 @@ def _make_model_config(pretrained_dict, ulysses_size=1, backend="VANILLA"): attention=AttentionConfig(backend=backend), visual_gen_mapping=vgm, cache=None, + attention_metadata_state=( + create_attention_metadata_state() if backend.upper() == "TRTLLM" else None + ), skip_create_weights_in_init=False, ) config.mapping = vgm.to_llm_mapping() diff --git a/tests/unittest/_torch/visual_gen/test_attention_integration.py b/tests/unittest/_torch/visual_gen/test_attention_integration.py index a2810427f93b..2a9b5134343f 100644 --- a/tests/unittest/_torch/visual_gen/test_attention_integration.py +++ b/tests/unittest/_torch/visual_gen/test_attention_integration.py @@ -19,7 +19,11 @@ # Flash Attention 4 availability # ============================================================================ from tensorrt_llm._torch.visual_gen.attention_backend.flash_attn4 import _flash_attn_fwd as _fa4_fwd -from tensorrt_llm._torch.visual_gen.config import AttentionConfig, DiffusionModelConfig +from tensorrt_llm._torch.visual_gen.config import ( + AttentionConfig, + DiffusionModelConfig, + create_attention_metadata_state, +) # Import new integrated versions from tensorrt_llm._torch.visual_gen.modules.attention import Attention, QKVMode, apply_rotary_emb @@ -128,6 +132,9 @@ def create_model_config( attention=AttentionConfig(backend=attn_backend), skip_create_weights_in_init=False, ) + config.attention_metadata_state = ( + create_attention_metadata_state() if attn_backend == "TRTLLM" else None + ) return config diff --git a/tests/unittest/_torch/visual_gen/test_attention_perf.py b/tests/unittest/_torch/visual_gen/test_attention_perf.py index 570ffa4a02ea..49463627cdd7 100644 --- a/tests/unittest/_torch/visual_gen/test_attention_perf.py +++ b/tests/unittest/_torch/visual_gen/test_attention_perf.py @@ -43,7 +43,11 @@ from tensorrt_llm._torch.visual_gen.attention_backend.flash_attn4 import ( _flash_attn_fwd_import_error as _fa4_import_error, ) -from tensorrt_llm._torch.visual_gen.config import AttentionConfig, DiffusionModelConfig +from tensorrt_llm._torch.visual_gen.config import ( + AttentionConfig, + DiffusionModelConfig, + create_attention_metadata_state, +) from tensorrt_llm._torch.visual_gen.modules.attention import Attention, QKVMode _flash_attn4_available = _fa4_fwd is not None @@ -155,6 +159,9 @@ def create_model_config( attention=AttentionConfig(backend=attn_backend), skip_create_weights_in_init=False, ) + config.attention_metadata_state = ( + create_attention_metadata_state() if attn_backend == "TRTLLM" else None + ) return config diff --git a/tests/unittest/_torch/visual_gen/test_flux_attention.py b/tests/unittest/_torch/visual_gen/test_flux_attention.py index 93621497e044..2d51d68297c1 100644 --- a/tests/unittest/_torch/visual_gen/test_flux_attention.py +++ b/tests/unittest/_torch/visual_gen/test_flux_attention.py @@ -20,7 +20,11 @@ import torch import torch.nn.functional as F -from tensorrt_llm._torch.visual_gen.config import AttentionConfig, DiffusionModelConfig +from tensorrt_llm._torch.visual_gen.config import ( + AttentionConfig, + DiffusionModelConfig, + create_attention_metadata_state, +) from tensorrt_llm.mapping import Mapping from tensorrt_llm.models.modeling_utils import QuantConfig @@ -103,6 +107,7 @@ def test_trtllm_backend_sanity(self): torch.manual_seed(42) config = self._create_config("TRTLLM") + config.attention_metadata_state = create_attention_metadata_state() attn = ( FluxJointAttention( @@ -175,6 +180,7 @@ def test_backend_equivalence(self): p.normal_(0, 0.02) config = self._create_config("TRTLLM") + config.attention_metadata_state = create_attention_metadata_state() trtllm_attn = ( FluxJointAttention( hidden_size=dim, diff --git a/tests/unittest/_torch/visual_gen/test_ltx2_attention.py b/tests/unittest/_torch/visual_gen/test_ltx2_attention.py index d8f964826ed5..bb91287b02b2 100644 --- a/tests/unittest/_torch/visual_gen/test_ltx2_attention.py +++ b/tests/unittest/_torch/visual_gen/test_ltx2_attention.py @@ -16,7 +16,11 @@ import torch import torch.nn.functional as F -from tensorrt_llm._torch.visual_gen.config import AttentionConfig, DiffusionModelConfig +from tensorrt_llm._torch.visual_gen.config import ( + AttentionConfig, + DiffusionModelConfig, + create_attention_metadata_state, +) from tensorrt_llm.mapping import Mapping from tensorrt_llm.models.modeling_utils import QuantConfig @@ -102,6 +106,7 @@ def test_trtllm_self_attention_sanity(self): torch.manual_seed(42) config = _create_config("TRTLLM") + config.attention_metadata_state = create_attention_metadata_state() attn = ( LTX2Attention( @@ -287,6 +292,7 @@ def test_backend_equivalence(self): # Create TRTLLM attention and copy weights config_trtllm = _create_config("TRTLLM") + config_trtllm.attention_metadata_state = create_attention_metadata_state() trtllm_attn = ( LTX2Attention( query_dim=query_dim,