From 89acff31cb363b3b48ce7f551ec39754109effdb Mon Sep 17 00:00:00 2001 From: tcherckez-nvidia <127761168+tcherckez-nvidia@users.noreply.github.com> Date: Mon, 9 Mar 2026 00:32:15 +0200 Subject: [PATCH 1/6] Model update 260308 (#12011) Signed-off-by: Tal Cherckez <127761168+tcherckez-nvidia@users.noreply.github.com> --- .../auto_deploy/model_registry/models.yaml | 120 +++++++----------- 1 file changed, 48 insertions(+), 72 deletions(-) diff --git a/examples/auto_deploy/model_registry/models.yaml b/examples/auto_deploy/model_registry/models.yaml index 28a57afaaae..6f6d630d16a 100644 --- a/examples/auto_deploy/model_registry/models.yaml +++ b/examples/auto_deploy/model_registry/models.yaml @@ -9,27 +9,22 @@ models: yaml_extra: ['dashboard_default.yaml', 'world_size_1.yaml'] - name: Qwen/Qwen3-0.6B yaml_extra: ['dashboard_default.yaml', 'world_size_1.yaml'] -# DISABLED: TorchDynamo compilation error - fake tensor dispatch failure -# - name: apple/OpenELM-270M-Instruct -# yaml_extra: ['dashboard_default.yaml', 'world_size_1.yaml', 'openelm.yaml'] -# DISABLED: TorchDynamo compilation error - fake tensor dispatch failure -# - name: apple/OpenELM-1_1B-Instruct -# yaml_extra: ['dashboard_default.yaml', 'world_size_1.yaml', 'openelm.yaml'] -# DISABLED: TorchDynamo compilation error - fake tensor dispatch failure -# - name: apple/OpenELM-3B-Instruct -# yaml_extra: ['dashboard_default.yaml', 'world_size_1.yaml', 'openelm.yaml'] -# DISABLED: model not supporting installed transformers version - https://github.com/NVIDIA/TensorRT-LLM/issues/10980 -# - name: microsoft/Phi-4-mini-instruct -# yaml_extra: ['dashboard_default.yaml', 'world_size_1.yaml'] +- name: apple/OpenELM-270M-Instruct + yaml_extra: ['dashboard_default.yaml', 'world_size_1.yaml', 'openelm.yaml'] +- name: apple/OpenELM-1_1B-Instruct + yaml_extra: ['dashboard_default.yaml', 'world_size_1.yaml', 'openelm.yaml'] +- name: apple/OpenELM-3B-Instruct + yaml_extra: ['dashboard_default.yaml', 'world_size_1.yaml', 'openelm.yaml'] +- name: microsoft/Phi-4-mini-instruct + yaml_extra: ['dashboard_default.yaml', 'world_size_1.yaml'] - name: microsoft/Phi-4-mini-reasoning yaml_extra: ['dashboard_default.yaml', 'world_size_1.yaml'] - name: google/gemma-3-1b-it yaml_extra: ['dashboard_default.yaml', 'world_size_1.yaml', 'gemma3_1b.yaml'] - name: meta-llama/Llama-3.1-8B-Instruct yaml_extra: ['dashboard_default.yaml', 'world_size_2.yaml'] -# DISABLED: NOT SUPPORTED - https://github.com/NVIDIA/TensorRT-LLM/issues/10363 -# - name: casperhansen/llama-3-8b-instruct-awq -# yaml_extra: ['dashboard_default.yaml', 'world_size_2.yaml'] +- name: casperhansen/llama-3-8b-instruct-awq + yaml_extra: ['dashboard_default.yaml', 'world_size_2.yaml'] - name: meta-llama/Llama-3.2-1B-Instruct yaml_extra: ['dashboard_default.yaml', 'world_size_2.yaml'] - name: meta-llama/Llama-3.2-3B-Instruct @@ -40,9 +35,8 @@ models: yaml_extra: ['dashboard_default.yaml', 'world_size_2.yaml'] - name: Qwen/Qwen2.5-7B-Instruct yaml_extra: ['dashboard_default.yaml', 'world_size_2.yaml'] -# DISABLED: NOT SUPPORTED - https://github.com/NVIDIA/TensorRT-LLM/issues/10363 -# - name: Qwen/Qwen2.5-7B-Instruct-AWQ -# yaml_extra: ['dashboard_default.yaml', 'world_size_2.yaml'] +- name: Qwen/Qwen2.5-7B-Instruct-AWQ + yaml_extra: ['dashboard_default.yaml', 'world_size_2.yaml'] - name: Qwen/Qwen3-4B yaml_extra: ['dashboard_default.yaml', 'world_size_2.yaml'] - name: Qwen/Qwen3-8B @@ -97,9 +91,8 @@ models: yaml_extra: ['dashboard_default.yaml', 'world_size_2.yaml'] - name: meta-llama/Llama-2-7b-chat-hf yaml_extra: ['dashboard_default.yaml', 'world_size_2.yaml'] -# DISABLED: FakeTensorMode error in unified_attn export -# - name: nvidia/Llama-3.1-8B-Instruct-FP8 -# yaml_extra: ['dashboard_default.yaml', 'world_size_2.yaml'] +- name: nvidia/Llama-3.1-8B-Instruct-FP8 + yaml_extra: ['dashboard_default.yaml', 'world_size_2.yaml'] - name: nvidia/Llama-3.1-Minitron-4B-Depth-Base yaml_extra: ['dashboard_default.yaml', 'world_size_2.yaml'] - name: nvidia/Llama-3.1-Minitron-4B-Width-Base @@ -116,22 +109,18 @@ models: yaml_extra: ['dashboard_default.yaml', 'world_size_2.yaml'] - name: nvidia/NVIDIA-Nemotron-Nano-9B-v2-FP8 yaml_extra: ['dashboard_default.yaml', 'world_size_2.yaml'] -# DISABLED: NVFP4 quantization not supported for pre BLW - CW has only Hopper -# - name: nvidia/NVIDIA-Nemotron-Nano-9B-v2-NVFP4 -# yaml_extra: ['dashboard_default.yaml', 'world_size_2.yaml'] -# DISABLED: Not supported -# - name: nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-FP8 -# yaml_extra: ['dashboard_default.yaml', 'world_size_2.yaml', 'multimodal.yaml'] +- name: nvidia/NVIDIA-Nemotron-Nano-9B-v2-NVFP4 + yaml_extra: ['dashboard_default.yaml', 'world_size_2.yaml'] +- name: nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-FP8 + yaml_extra: ['dashboard_default.yaml', 'world_size_2.yaml', 'multimodal.yaml'] - name: google/gemma-3-27b-it yaml_extra: ['dashboard_default.yaml', 'world_size_2.yaml', 'multimodal.yaml'] - name: deepseek-ai/DeepSeek-V2.5 yaml_extra: ['dashboard_default.yaml', 'world_size_2.yaml'] -# DISABLED: Not supported -# - name: ai21labs/AI21-Jamba-1.5-Mini -# yaml_extra: ['dashboard_default.yaml', 'world_size_2.yaml'] -# DISABLED: NOT SUPPORTED - https://github.com/NVIDIA/TensorRT-LLM/issues/10977 -# - name: meta-llama/Llama-3.2-11B-Vision-Instruct -# yaml_extra: ['dashboard_default.yaml', 'world_size_2.yaml', 'multimodal.yaml'] +- name: ai21labs/AI21-Jamba-1.5-Mini + yaml_extra: ['dashboard_default.yaml', 'world_size_2.yaml'] +- name: meta-llama/Llama-3.2-11B-Vision-Instruct + yaml_extra: ['dashboard_default.yaml', 'world_size_2.yaml', 'multimodal.yaml'] - name: meta-llama/Llama-3.3-70B-Instruct yaml_extra: ['dashboard_default.yaml', 'world_size_4.yaml', 'llama3_3_70b.yaml'] - name: meta-llama/CodeLlama-34b-Instruct-hf @@ -170,61 +159,48 @@ models: yaml_extra: ['dashboard_default.yaml', 'world_size_8.yaml'] - name: deepseek-ai/DeepSeek-R1-Distill-Qwen-32B yaml_extra: ['dashboard_default.yaml', 'world_size_8.yaml'] -# DISABLED: stuck in graph capturing -# - name: mistralai/Mixtral-8x22B-Instruct-v0.1 -# yaml_extra: ['dashboard_default.yaml', 'world_size_8.yaml'] -# DISABLED: FakeTensorMode error in unified_attn export -# - name: nvidia/Llama-3.1-70B-Instruct-FP8 -# yaml_extra: ['dashboard_default.yaml', 'world_size_8.yaml'] -# DISABLED: FakeTensorMode error in unified_attn export -# - name: nvidia/Llama-3.1-405B-Instruct-FP8 -# yaml_extra: ['dashboard_default.yaml', 'world_size_8.yaml'] +- name: mistralai/Mixtral-8x22B-Instruct-v0.1 + yaml_extra: ['dashboard_default.yaml', 'world_size_8.yaml'] +- name: nvidia/Llama-3.1-70B-Instruct-FP8 + yaml_extra: ['dashboard_default.yaml', 'world_size_8.yaml'] +- name: nvidia/Llama-3.1-405B-Instruct-FP8 + yaml_extra: ['dashboard_default.yaml', 'world_size_8.yaml'] - name: nvidia/Llama-3.1-Nemotron-70B-Instruct-HF yaml_extra: ['dashboard_default.yaml', 'world_size_8.yaml'] -# DISABLED: Model loading failure - dynamic module registry issue -# - name: nvidia/Llama-3_1-Nemotron-51B-Instruct -# yaml_extra: ['dashboard_default.yaml', 'world_size_8.yaml', 'simple_shard_only.yaml'] -# DISABLED: model not supporting installed transformers version - https://github.com/NVIDIA/TensorRT-LLM/issues/10980 -# - name: nvidia/Llama-3_1-Nemotron-Ultra-253B-v1 -# yaml_extra: ['dashboard_default.yaml', 'world_size_8.yaml', 'simple_shard_only.yaml'] -# DISABLED: model not supporting installed transformers version - https://github.com/NVIDIA/TensorRT-LLM/issues/10980 -# - name: nvidia/Llama-3_1-Nemotron-Ultra-253B-v1-FP8 -# yaml_extra: ['dashboard_default.yaml', 'world_size_8.yaml', 'simple_shard_only.yaml'] -# DISABLED: model not supporting installed transformers version - https://github.com/NVIDIA/TensorRT-LLM/issues/10980 -# - name: nvidia/Llama-3_3-Nemotron-Super-49B-v1 -# yaml_extra: ['dashboard_default.yaml', 'world_size_8.yaml', 'simple_shard_only.yaml'] +- name: nvidia/Llama-3_1-Nemotron-51B-Instruct + yaml_extra: ['dashboard_default.yaml', 'world_size_8.yaml', 'simple_shard_only.yaml'] +- name: nvidia/Llama-3_1-Nemotron-Ultra-253B-v1 + yaml_extra: ['dashboard_default.yaml', 'world_size_8.yaml', 'simple_shard_only.yaml'] +- name: nvidia/Llama-3_1-Nemotron-Ultra-253B-v1-FP8 + yaml_extra: ['dashboard_default.yaml', 'world_size_8.yaml', 'simple_shard_only.yaml'] +- name: nvidia/Llama-3_3-Nemotron-Super-49B-v1 + yaml_extra: ['dashboard_default.yaml', 'world_size_8.yaml', 'simple_shard_only.yaml'] - name: Qwen/Qwen3-30B-A3B yaml_extra: ['dashboard_default.yaml', 'world_size_8.yaml', 'simple_shard_only.yaml'] - name: Qwen/Qwen3-235B-A22B yaml_extra: ['dashboard_default.yaml', 'world_size_8.yaml', 'simple_shard_only.yaml'] -# DISABLED: Auto-deploy compilation error - shape mismatch - https://github.com/NVIDIA/TensorRT-LLM/issues/10978 -# - name: deepseek-ai/DeepSeek-R1 -# yaml_extra: ['dashboard_default.yaml', 'world_size_8.yaml', 'num_hidden_layers_5.yaml'] -# DISABLED: Auto-deploy compilation error - shape mismatch - https://github.com/NVIDIA/TensorRT-LLM/issues/10978 -# - name: deepseek-ai/DeepSeek-V3 -# yaml_extra: ['dashboard_default.yaml', 'world_size_8.yaml', 'num_hidden_layers_5.yaml'] -# DISABLED: Auto-deploy compilation error - shape mismatch - https://github.com/NVIDIA/TensorRT-LLM/issues/10978 -# - name: deepseek-ai/DeepSeek-Coder-V2-Instruct -# yaml_extra: ['dashboard_default.yaml', 'world_size_8.yaml'] +- name: deepseek-ai/DeepSeek-R1 + yaml_extra: ['dashboard_default.yaml', 'world_size_8.yaml', 'num_hidden_layers_5.yaml'] +- name: deepseek-ai/DeepSeek-V3 + yaml_extra: ['dashboard_default.yaml', 'world_size_8.yaml', 'num_hidden_layers_5.yaml'] +- name: deepseek-ai/DeepSeek-Coder-V2-Instruct + yaml_extra: ['dashboard_default.yaml', 'world_size_8.yaml'] - name: Qwen/Qwen3-VL-8B-Instruct yaml_extra: ['dashboard_default.yaml', 'world_size_8.yaml', 'multimodal.yaml', 'qwen3_vl.yaml'] -# DISABLED: NOT SUPPORTED - https://github.com/NVIDIA/TensorRT-LLM/issues/10363 -# - name: Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int4 -# yaml_extra: ['dashboard_default.yaml', 'world_size_8.yaml', 'multimodal.yaml'] +- name: Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int4 + yaml_extra: ['dashboard_default.yaml', 'world_size_8.yaml', 'multimodal.yaml'] - name: codellama/CodeLlama-70b-Instruct-hf yaml_extra: ['dashboard_default.yaml', 'world_size_8.yaml'] -# DISABLED: NOT SUPPORTED - https://github.com/NVIDIA/TensorRT-LLM/issues/10977 -# - name: meta-llama/Llama-3.2-90B-Vision-Instruct -# yaml_extra: ['dashboard_default.yaml', 'world_size_8.yaml', 'multimodal.yaml'] +- name: meta-llama/Llama-3.2-90B-Vision-Instruct + yaml_extra: ['dashboard_default.yaml', 'world_size_8.yaml', 'multimodal.yaml'] - name: openai/gpt-oss-120b yaml_extra: ['dashboard_default.yaml', 'world_size_8.yaml', 'num_hidden_layers_5.yaml'] - name: meta-llama/Llama-4-Scout-17B-16E-Instruct yaml_extra: ['dashboard_default.yaml', 'world_size_8.yaml', 'multimodal.yaml', 'llama4_scout.yaml'] - name: meta-llama/Llama-4-Maverick-17B-128E-Instruct yaml_extra: ['dashboard_default.yaml', 'world_size_8.yaml', 'multimodal.yaml', 'llama4_maverick_lite.yaml'] -# DISABLED: Doesn't fit H100 -# - name: nvidia/NVIDIA-Nemotron-3-Super-120B-BF16-BF16KV-010726 -# yaml_extra: ['dashboard_default.yaml', 'world_size_4.yaml','super_v3.yaml'] +- name: nvidia/NVIDIA-Nemotron-3-Super-120B-BF16-BF16KV-010726 + yaml_extra: ['dashboard_default.yaml', 'world_size_4.yaml','super_v3.yaml'] - name: zai-org/GLM-4.7-Flash yaml_extra: ['glm-4.7-flash.yaml'] - name: Nanbeige/Nanbeige4.1-3B From 6ec0aad7caf358cbfe13a5e2697ac89d09f05354 Mon Sep 17 00:00:00 2001 From: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> Date: Sun, 8 Mar 2026 20:31:36 -0400 Subject: [PATCH 2/6] [None][infra] Update AutoDeploy CODEOWNERS coverage (#12013) Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> --- .github/CODEOWNERS | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 24555024fd1..884e6fe90ee 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -59,8 +59,10 @@ /tensorrt_llm/_torch/pyexecutor @NVIDIA/trt-llm-torch-runtime-devs ## TensorRT-LLM Pytorch backend - AutoDeploy flow /tensorrt_llm/_torch/auto_deploy @NVIDIA/trt-llm-torch-autodeploy-devs -/examples/auto_deploy @NVIDIA/trt-llm-torch-autodeploy-devs @NVIDIA/trt-llm-doc-owners -/tests/unittest/_torch/auto_deploy @NVIDIA/trt-llm-torch-autodeploy-devs +/examples/auto_deploy @NVIDIA/trt-llm-torch-autodeploy-devs +/docs/source/features/auto_deploy @NVIDIA/trt-llm-torch-autodeploy-devs @NVIDIA/trt-llm-doc-owners +/tests/unittest/auto_deploy @NVIDIA/trt-llm-torch-autodeploy-devs +/tests/integration/defs/accuracy/test_llm_api_autodeploy.py @NVIDIA/trt-llm-torch-autodeploy-devs @NVIDIA/trt-llm-qa-function ## TensorRT-LLM Pytorch - Speculative Decoding /tensorrt_llm/_torch/speculative @NVIDIA/trt-llm-torch-spec-decoding From 4c15db0bfafdbcf566708208dbdeaba2c0b97f61 Mon Sep 17 00:00:00 2001 From: "Po-Han Huang (NVIDIA)" <53919306+nvpohanh@users.noreply.github.com> Date: Mon, 9 Mar 2026 09:08:58 +0800 Subject: [PATCH 3/6] [https://nvbugs/5732958][bug] Fix TestLlama4MinLatency::test_llama_allclose_to_hf failure (#10191) Signed-off-by: Po-Han Huang --- tensorrt_llm/_torch/models/modeling_llama.py | 32 +++++++++++++++++++ .../_torch/modeling/test_modeling_llama.py | 1 + .../test_modeling_llama_min_latency.py | 6 ++-- 3 files changed, 35 insertions(+), 4 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_llama.py b/tensorrt_llm/_torch/models/modeling_llama.py index 54193a32c08..743e0b8ef50 100644 --- a/tensorrt_llm/_torch/models/modeling_llama.py +++ b/tensorrt_llm/_torch/models/modeling_llama.py @@ -449,6 +449,10 @@ def __init__( self.input_layernorm = RMSNorm(hidden_size=config.hidden_size, eps=config.rms_norm_eps, dtype=config.torch_dtype) + # When post_load_weights() chains layernorms across layers, + # this flag is set to True to skip the input layernorm in + # forward() since it is handled by the previous layer. + self.skip_input_layernorm = False self.post_attention_layernorm = RMSNorm(hidden_size=config.hidden_size, eps=config.rms_norm_eps, @@ -493,6 +497,8 @@ def forward( if residual is None: residual = hidden_states + + if not self.skip_input_layernorm: hidden_states = self.input_layernorm(hidden_states) # Self Attention @@ -668,6 +674,10 @@ def __init__( quantize_type="nvfp4" if not self.disable_nvfp4_layernorm_fusion and self.is_nvfp4 and not (differ_pp_stage_with_previous_layer) else None) + # When post_load_weights() chains layernorms across layers, + # this flag is set to True to skip the input layernorm in + # forward() since it is handled by the previous layer. + self.skip_input_layernorm = False self.post_attention_layernorm = RMSNorm( hidden_size=config.hidden_size, @@ -765,6 +775,8 @@ def forward( ) -> Union[torch.Tensor, Fp4QuantizedTensor]: if residual is None: residual = hidden_states + + if not self.skip_input_layernorm: hidden_states = self.input_layernorm(hidden_states) hidden_states = self.self_attn( @@ -936,6 +948,10 @@ def __init__(self, model_config: ModelConfig[LlamaConfig]): self.norm = RMSNorm(hidden_size=config.hidden_size, eps=config.rms_norm_eps, dtype=config.torch_dtype) + # When post_load_weights() chains the final norm into the + # last decoder layer, this flag is set to True to skip + # applying it again in forward(). + self.skip_norm = False def forward( self, @@ -969,6 +985,10 @@ def forward( lora_params=lora_params, ) + # If self.norm is not handled by the last layer, apply it here. + if not self.skip_norm: + hidden_states = self.norm(hidden_states) + return hidden_states @@ -1033,6 +1053,10 @@ def __init__(self, model_config: ModelConfig[LlamaConfig]): self.norm = RMSNorm(hidden_size=config.hidden_size, eps=config.rms_norm_eps, dtype=config.torch_dtype) + # When post_load_weights() chains the final norm into the + # last decoder layer, this flag is set to True to skip + # applying it again in forward(). + self.skip_norm = False def forward( self, @@ -1065,6 +1089,10 @@ def forward( lora_params=lora_params, ) + # If self.norm is not handled by the last layer, apply it here. + if not self.skip_norm: + hidden_states = self.norm(hidden_states) + return hidden_states @@ -1082,9 +1110,11 @@ def post_load_weights(self): self.model.layers[:self.config.num_hidden_layers]): if idx == self.config.num_hidden_layers - 1: layer.next_layer_layernorm = self.model.norm + self.model.skip_norm = True else: layer.next_layer_layernorm = self.model.layers[ idx + 1].input_layernorm + self.model.layers[idx + 1].skip_input_layernorm = True layer.next_attn = self.model.layers[idx + 1].self_attn @@ -1456,9 +1486,11 @@ def post_load_weights(self): self.model.layers[:self.config.num_hidden_layers]): if idx == self.config.num_hidden_layers - 1: layer.next_layer_layernorm = self.model.norm + self.model.skip_norm = True else: layer.next_layer_layernorm = self.model.layers[ idx + 1].input_layernorm + self.model.layers[idx + 1].skip_input_layernorm = True layer.next_attn = self.model.layers[idx + 1].self_attn diff --git a/tests/unittest/_torch/modeling/test_modeling_llama.py b/tests/unittest/_torch/modeling/test_modeling_llama.py index ca503642c67..334e60a61e9 100644 --- a/tests/unittest/_torch/modeling/test_modeling_llama.py +++ b/tests/unittest/_torch/modeling/test_modeling_llama.py @@ -407,6 +407,7 @@ def test_llama_verification_with_kv_cache_relocation(self) -> None: llama = LlamaForCausalLM(model_config).to(dtype).to(device) llama.load_weights(hf_llama.state_dict()) + num_blocks = 2 tokens_per_block = 32 head_dim = llama.config.hidden_size // llama.config.num_attention_heads diff --git a/tests/unittest/_torch/modeling/test_modeling_llama_min_latency.py b/tests/unittest/_torch/modeling/test_modeling_llama_min_latency.py index 599b1be0211..0ce83559923 100644 --- a/tests/unittest/_torch/modeling/test_modeling_llama_min_latency.py +++ b/tests/unittest/_torch/modeling/test_modeling_llama_min_latency.py @@ -271,10 +271,7 @@ def test_llama_allclose_to_hf(self, scenario: AllCloseScenario) -> None: "The transformers between 4.55.0 and 4.56.1 have accuracy " "issues for Llama4. See: " "https://github.com/huggingface/transformers/pull/40609") - elif transformers.__version__ >= "4.57.1": - self.skipTest( - "Bumping transformers version to 4.57.1 has accuracy issues for Llama4. See: " - "http://nvbugs/5732958") + torch.random.manual_seed(0) config_dict = deepcopy(LLAMA_4_MAVERICK_TWO_LAYER_CONFIG) # 17B * sizeof(float16) plus some extra for activations @@ -301,6 +298,7 @@ def test_llama_allclose_to_hf(self, scenario: AllCloseScenario) -> None: weight_mapper.init_model_and_config(llama, model_config) llama.load_weights(hf_llama.state_dict(), weight_mapper=weight_mapper) + llama.post_load_weights() num_blocks = 1 tokens_per_block = 128 From db533cff870443aca02efb5d9703a96bfa455afb Mon Sep 17 00:00:00 2001 From: Leslie Fang Date: Mon, 9 Mar 2026 10:28:33 +0800 Subject: [PATCH 4/6] [None][chore] Unwaive some skip for trtllm moe backend (#11975) Signed-off-by: leslie-fang25 --- tests/unittest/_torch/modules/moe/moe_test_utils.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/tests/unittest/_torch/modules/moe/moe_test_utils.py b/tests/unittest/_torch/modules/moe/moe_test_utils.py index 69c6418559f..6cdcc0d1990 100644 --- a/tests/unittest/_torch/modules/moe/moe_test_utils.py +++ b/tests/unittest/_torch/modules/moe/moe_test_utils.py @@ -284,13 +284,6 @@ def should_skip_trtllm( # These are known issues that need investigation. Skipping to avoid test failures # and CUDA errors that can cascade to subsequent tests. - # Issue: W4A8_NVFP4_FP8 with top_k=1 causes CUDA illegal memory access - if quant_algo == QuantAlgo.W4A8_NVFP4_FP8 and top_k == 1: - return ( - "[Potential Bug] TRTLLMGenFusedMoE W4A8_NVFP4_FP8 with top_k=1 " - "causes CUDA illegal memory access." - ) - # Issue: NVFP4 with large expert count + large hidden_size + seq_len=1 # has a single FP4BlockScaleMoERunner tactic with accuracy failure. # Observed: e256_k8_h7168_i2048, seq=1, bfloat16 — tactic[204] with tile @@ -324,11 +317,6 @@ def should_skip_trtllm( # Issue: W4A8_MXFP4_MXFP8 has accuracy issues on certain model configs if quant_algo == QuantAlgo.W4A8_MXFP4_MXFP8: - if intermediate_size >= 14336: - return ( - f"[Potential Bug] TRTLLMGenFusedMoE W4A8_MXFP4_MXFP8 with large " - f"intermediate_size has accuracy issues (intermediate_size={intermediate_size} >= 14336)." - ) if num_experts >= 60 and intermediate_size >= 1408: return ( f"[Potential Bug] TRTLLMGenFusedMoE W4A8_MXFP4_MXFP8 with many experts " From 02c8a948208eca28fdec57bdd27be61563891ab0 Mon Sep 17 00:00:00 2001 From: Zhenhua Wang <4936589+zhenhuaw-me@users.noreply.github.com> Date: Mon, 9 Mar 2026 10:31:12 +0800 Subject: [PATCH 5/6] [TRTLLM-11134][feat] export VisualGen API and update doc (#11911) Signed-off-by: Zhenhua Wang --- .../commands/trtllm-serve/trtllm-serve.rst | 13 +- docs/source/developer-guide/overview.md | 4 + docs/source/features/visual-generation.md | 221 -------------- docs/source/index.rst | 2 +- docs/source/models/supported-models.md | 4 + docs/source/models/visual-generation.md | 173 +++++++++++ docs/source/overview.md | 3 +- docs/source/quick-start-guide.md | 13 + examples/visual_gen/README.md | 144 +++------ examples/visual_gen/hf_examples.sh | 192 ------------ examples/visual_gen/hf_flux.py | 142 --------- examples/visual_gen/hf_wan.py | 141 --------- examples/visual_gen/output_handler.py | 237 -------------- examples/visual_gen/quickstart_example.py | 27 ++ examples/visual_gen/serve/configs/flux1.yml | 2 - examples/visual_gen/serve/configs/wan.yml | 2 - examples/visual_gen/visual_gen_examples.sh | 288 ------------------ examples/visual_gen/visual_gen_flux.py | 73 ++--- examples/visual_gen/visual_gen_wan_i2v.py | 69 ++--- examples/visual_gen/visual_gen_wan_t2v.py | 71 ++--- tensorrt_llm/__init__.py | 6 +- tensorrt_llm/_torch/visual_gen/__init__.py | 4 +- tensorrt_llm/_torch/visual_gen/config.py | 112 ++++--- tensorrt_llm/_torch/visual_gen/executor.py | 41 +-- .../_torch/visual_gen/pipeline_loader.py | 32 +- .../_torch/visual_gen/pipeline_registry.py | 2 +- tensorrt_llm/bench/benchmark/visual_gen.py | 22 +- tensorrt_llm/commands/serve.py | 44 ++- tensorrt_llm/llmapi/visual_gen.py | 63 ++-- tensorrt_llm/serve/media_storage.py | 11 +- .../defs/examples/test_visual_gen.py | 34 ++- .../visual_gen/test_visual_gen_benchmark.py | 1 - .../integration/test_lists/test-db/l0_a10.yml | 1 + .../test_lists/test-db/l0_b200.yml | 2 + .../test_lists/test-db/l0_dgx_b200.yml | 2 + .../test_lists/test-db/l0_gb203.yml | 1 + .../test_lists/test-db/l0_gh200.yml | 1 + .../test_lists/test-db/l0_h100.yml | 1 + .../test_lists/test-db/l0_l40s.yml | 1 + .../test_lists/test-db/l0_sanity_check.yml | 1 + .../_torch/visual_gen/test_flux_pipeline.py | 44 +-- .../_torch/visual_gen/test_model_loader.py | 66 ++-- .../visual_gen/test_trtllm_serve_e2e.py | 1 - .../_torch/visual_gen/test_visual_gen_args.py | 183 +++++++++++ tests/unittest/_torch/visual_gen/test_wan.py | 78 ++--- .../_torch/visual_gen/test_wan_i2v.py | 38 +-- 46 files changed, 880 insertions(+), 1733 deletions(-) delete mode 100644 docs/source/features/visual-generation.md create mode 100644 docs/source/models/visual-generation.md delete mode 100755 examples/visual_gen/hf_examples.sh delete mode 100755 examples/visual_gen/hf_flux.py delete mode 100755 examples/visual_gen/hf_wan.py delete mode 100644 examples/visual_gen/output_handler.py create mode 100644 examples/visual_gen/quickstart_example.py delete mode 100755 examples/visual_gen/visual_gen_examples.sh create mode 100644 tests/unittest/_torch/visual_gen/test_visual_gen_args.py diff --git a/docs/source/commands/trtllm-serve/trtllm-serve.rst b/docs/source/commands/trtllm-serve/trtllm-serve.rst index 4cc1a4d12d8..cdfa3cac9fc 100644 --- a/docs/source/commands/trtllm-serve/trtllm-serve.rst +++ b/docs/source/commands/trtllm-serve/trtllm-serve.rst @@ -215,19 +215,24 @@ model. Visual Generation Serving ~~~~~~~~~~~~~~~~~~~~~~~~~ -``trtllm-serve`` supports diffusion-based visual generation models (Wan2.1, Wan2.2) for image and video generation. When a diffusion model directory is provided (detected by the presence of ``model_index.json``), the server automatically launches in visual generation mode with dedicated endpoints. +``trtllm-serve`` supports diffusion-based visual generation models (FLUX.1, FLUX.2, Wan2.1, Wan2.2) for image and video generation. When a diffusion model directory is provided (detected by the presence of ``model_index.json``), the server automatically launches in visual generation mode with dedicated endpoints. .. note:: - This is the initial release of TensorRT-LLM VisualGen. APIs, supported models, and optimization options are actively evolving and may change in future releases. + VisualGen is in **prototype** stage. APIs, supported models, and optimization options are actively evolving and may change in future releases. .. code-block:: bash - trtllm-serve Wan-AI/Wan2.1-T2V-1.3B-Diffusers \ + # Video generation (Wan) + trtllm-serve Wan-AI/Wan2.2-T2V-A14B-Diffusers \ + --extra_visual_gen_options config.yml + + # Image generation (FLUX) + trtllm-serve black-forest-labs/FLUX.2-dev \ --extra_visual_gen_options config.yml The ``--extra_visual_gen_options`` flag accepts a YAML file that configures quantization, parallelism, and TeaCache. Available visual generation endpoints include ``/v1/images/generations``, ``/v1/videos``, ``/v1/videos/generations``, and video management APIs. -For full details, see the :doc:`../../features/visual-generation` feature documentation. Example client scripts are available in the `examples/visual_gen/serve/ `_ directory. +For full details, see the :doc:`../../models/visual-generation.md` feature documentation. Example client scripts are available in the `examples/visual_gen/serve/ `_ directory. Multi-node Serving with Slurm ----------------------------- diff --git a/docs/source/developer-guide/overview.md b/docs/source/developer-guide/overview.md index f1e9b7b3b6c..84a6ab52aea 100644 --- a/docs/source/developer-guide/overview.md +++ b/docs/source/developer-guide/overview.md @@ -73,3 +73,7 @@ if self.previous_batch is not None: ``` This approach effectively reduces GPU idle time and improves overall hardware occupancy. While it introduces one extra decoding step into the pipeline, the resulting throughput gain is a significant trade-off. For this reason, the Overlap Scheduler is enabled by default in TensorRT LLM. + +## Visual Generation + +For diffusion-based visual generation (image/video), TensorRT-LLM provides a separate `VisualGen` API and `DiffusionExecutor` with its own pipeline architecture. See the [Visual Generation](../models/visual-generation.md) feature documentation. diff --git a/docs/source/features/visual-generation.md b/docs/source/features/visual-generation.md deleted file mode 100644 index 266e36e7806..00000000000 --- a/docs/source/features/visual-generation.md +++ /dev/null @@ -1,221 +0,0 @@ -# Visual Generation (Diffusion Models) [Beta] - -- [Background and Motivation](#background-and-motivation) -- [Quick Start](#quick-start) - - [Python API](#python-api) - - [Usage with `trtllm-serve`](#usage-with-trtllm-serve) -- [Quantization](#quantization) -- [Developer Guide](#developer-guide) - - [Architecture Overview](#architecture-overview) - - [Implementing a New Diffusion Model](#implementing-a-new-diffusion-model) -- [Summary and Future Work](#summary-and-future-work) - - [Current Status](#current-status) - - [Future Work](#future-work) - -## Background and Motivation - -Visual generation models based on diffusion transformers (DiT) have become the standard for high-quality image and video synthesis. These models iteratively denoise latent representations through a learned transformer backbone, then decode the final latents with a VAE to produce pixels. As model sizes and output resolutions grow, efficient inference becomes critical — demanding multi-GPU parallelism, weight quantization, and runtime caching to achieve practical throughput and latency. - -TensorRT-LLM **VisualGen** module provides a unified inference stack for diffusion models. Key capabilities include (subject to change as the feature matures): - -- A shared pipeline abstraction for diffusion model families, covering the denoising loop, guidance strategies, and component loading. -- Pluggable attention backends. -- Quantization support (dynamic and static) using the [ModelOpt](https://github.com/NVIDIA/TensorRT-Model-Optimizer) configuration format. -- Multi-GPU parallelism strategies. -- **TeaCache** — a runtime caching optimization for the transformer backbone. -- `trtllm-serve` integration with OpenAI-compatible API endpoints. - -> **Note:** This is the initial release of TensorRT-LLM VisualGen. APIs, supported models, and optimization options are actively evolving and may change in future releases. - -## Quick Start - -### Prerequisites - -```bash -pip install -r requirements-dev.txt -pip install git+https://github.com/huggingface/diffusers.git -pip install av -``` - -### Python API - -The example scripts under `examples/visual_gen/` demonstrate direct Python usage. For Wan2.1 text-to-video generation: - -```bash -cd examples/visual_gen - -python visual_gen_wan_t2v.py \ - --model_path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \ - --prompt "A cute cat playing piano" \ - --height 480 --width 832 --num_frames 33 \ - --output_path output.mp4 -``` - -Run `python visual_gen_wan_t2v.py --help` for the full list of arguments. Key options control resolution, denoising steps, quantization mode, attention backend, parallelism, and TeaCache settings. - -### Usage with `trtllm-serve` - -The `trtllm-serve` command automatically detects diffusion models (by the presence of `model_index.json`) and launches an OpenAI-compatible visual generation server. - -**1. Create a YAML configuration file:** - -```yaml -# wan_config.yml -linear: - type: default -teacache: - enable_teacache: true - teacache_thresh: 0.2 -parallel: - dit_cfg_size: 1 - dit_ulysses_size: 1 -``` - -**2. Launch the server:** - -```bash -trtllm-serve Wan-AI/Wan2.1-T2V-1.3B-Diffusers \ - --extra_visual_gen_options wan_config.yml -``` - -**3. Send requests** using curl or any OpenAI-compatible client: - -Synchronous video generation: - -```bash -curl -X POST "http://localhost:8000/v1/videos/generations" \ - -H "Content-Type: application/json" \ - -d '{ - "prompt": "A cool cat on a motorcycle in the night", - "seconds": 4.0, - "fps": 24, - "size": "480x832" - }' -o output.mp4 -``` - -Asynchronous video generation: - -```bash -# Submit the job -curl -X POST "http://localhost:8000/v1/videos" \ - -H "Content-Type: application/json" \ - -d '{ - "prompt": "A cool cat on a motorcycle in the night", - "seconds": 4.0, - "fps": 24, - "size": "480x832" - }' -# Returns: {"id": "", "status": "processing", ...} - -# Poll for status -curl -X GET "http://localhost:8000/v1/videos/" - -# Download when complete -curl -X GET "http://localhost:8000/v1/videos//content" -o output.mp4 -``` - -The server exposes OpenAI-compatible endpoints for image generation (`/v1/images/generations`), video generation (`/v1/videos`, `/v1/videos/generations`), video management, and standard health/model info endpoints. - -The `--extra_visual_gen_options` YAML file configures quantization (`linear`), TeaCache (`teacache`), and parallelism (`parallel`). See [`examples/visual_gen/serve/configs/`](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/visual_gen/serve/configs) for reference configurations. - -## Quantization - -TensorRT-LLM VisualGen supports both **dynamic quantization** (on-the-fly at weight-loading time from BF16 checkpoints) and **static quantization** (loading pre-quantized checkpoints with embedded scales). Both modes use the same [ModelOpt](https://github.com/NVIDIA/TensorRT-Model-Optimizer) `quantization_config` format. - -**Quick start — dynamic quantization via `--linear_type`:** - -```bash -python visual_gen_wan_t2v.py \ - --model_path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \ - --prompt "A cute cat playing piano" \ - --linear_type trtllm-fp8-per-tensor \ - --output_path output_fp8.mp4 -``` - -The `--linear_type` flag enables **dynamic quantization**, which quantizes linear layer weights on-the-fly during loading from an unquantized (BF16/FP16) checkpoint. No pre-quantized checkpoint is needed — the weights are converted to the target precision at load time. - -Supported `--linear_type` values: `default` (BF16/FP16, no quantization), `trtllm-fp8-per-tensor`, `trtllm-fp8-blockwise`, `trtllm-nvfp4`. - -**ModelOpt `quantization_config` format:** - -Both dynamic and static quantization use the [ModelOpt](https://github.com/NVIDIA/TensorRT-Model-Optimizer) `quantization_config` format — the same format found in a model's `config.json` under the `quantization_config` field. This config can be passed as a dict to `DiffusionArgs.quant_config` when constructing the pipeline programmatically: - -```python -from tensorrt_llm._torch.visual_gen.config import DiffusionArgs - -args = DiffusionArgs( - checkpoint_path="/path/to/model", - quant_config={"quant_algo": "FP8", "dynamic": True}, # dynamic FP8 -) -``` - -The `--linear_type` CLI flag is a convenience shorthand that maps to these configs internally (e.g., `trtllm-fp8-per-tensor` → `{"quant_algo": "FP8", "dynamic": True}`). - -Key fields: `"dynamic"` controls load-time quantization (`true`) vs pre-quantized checkpoint (`false`); `"ignore"` excludes specific modules from quantization. - -## Developer Guide - -This section describes the TensorRT-LLM VisualGen module architecture and guides developers on how to add support for new diffusion model families. - -### Architecture Overview - -The VisualGen module lives under `tensorrt_llm._torch.visual_gen`. At a high level, the flow is: - -1. **Config** — User-facing `DiffusionArgs` (CLI / YAML) is merged with checkpoint metadata into `DiffusionModelConfig`. -2. **Pipeline creation & loading** — `AutoPipeline` detects the model type from `model_index.json`, instantiates the matching `BasePipeline` subclass, and loads weights (with optional dynamic quantization) and standard components (VAE, text encoder, tokenizer, scheduler). -3. **Execution** — `DiffusionExecutor` coordinates multi-GPU inference via worker processes. - -> **Note:** Internal module structure is subject to change. Refer to inline docstrings in `tensorrt_llm/_torch/visual_gen/` for the latest details. - -### Implementing a New Diffusion Model - -Adding a new model (e.g., a hypothetical "MyDiT") requires four steps. The framework handles weight loading, parallelism, quantization, and serving automatically once the pipeline is registered. - -#### 1. Create the Transformer Module - -Create the DiT backbone in `tensorrt_llm/_torch/visual_gen/models/mydit/transformer_mydit.py`. It should be an `nn.Module` that: - -- Uses existing modules (e.g., `Attention` with configurable attention backend, `Linear` for builtin linear ops) wherever possible. -- Implements `load_weights(weights: Dict[str, torch.Tensor])` to map checkpoint weight names to module parameters. - -#### 2. Create the Pipeline Class - -Create a pipeline class extending `BasePipeline` in `tensorrt_llm/_torch/visual_gen/models/mydit/`. Override methods for transformer initialization, component loading, and inference. `BasePipeline` provides the denoising loop, CFG handling, and TeaCache integration — your pipeline only needs to implement model-specific logic. See `WanPipeline` for a reference implementation. - -#### 3. Register the Pipeline - -Use the `@register_pipeline("MyDiTPipeline")` decorator on your pipeline class to register it in the global `PIPELINE_REGISTRY`. Make sure to export it from `models/__init__.py`. - -#### 4. Update AutoPipeline Detection - -In `pipeline_registry.py`, add detection logic for your model's `_class_name` in `model_index.json`. - -After these steps, the framework automatically handles: - -- Weight loading with optional dynamic quantization via `PipelineLoader` -- Multi-GPU execution via `DiffusionExecutor` -- TeaCache integration (if you call `self._setup_teacache()` in `post_load_weights()`) -- Serving via `trtllm-serve` with the full endpoint set - -## Summary and Future Work - -### Current Status - -**Supported models:** Wan2.1 and Wan2.2 families (text-to-video, image-to-video; 1.3B and 14B variants). - -**Supported features:** - -| Feature | Status | -|---------|--------| -| **Multi-GPU Parallelism** | CFG parallel, Ulysses sequence parallel (more strategies planned) | -| **TeaCache** | Caches transformer outputs when timestep embeddings change slowly | -| **Quantization** | Dynamic (on-the-fly from BF16) and static (pre-quantized checkpoints), both via ModelOpt `quantization_config` format | -| **Attention Backends** | Vanilla (torch SDPA) and TRT-LLM optimized fused kernels | -| **`trtllm-serve`** | OpenAI-compatible endpoints for image/video generation (sync + async) | - -### Future Work - -- **Additional model support**: Extend to more diffusion model families. -- **More attention backends**: Support for additional attention backends. -- **Advanced parallelism**: Additional parallelism strategies for larger models and higher resolutions. -- **Serving enhancements**: Improved throughput and user experience for production serving workloads. diff --git a/docs/source/index.rst b/docs/source/index.rst index 141c1645624..80fa43d8c8b 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -34,6 +34,7 @@ Welcome to TensorRT LLM's Documentation! :name: Models models/supported-models.md + models/visual-generation.md models/adding-new-model.md @@ -67,7 +68,6 @@ Welcome to TensorRT LLM's Documentation! features/long-sequence.md features/lora.md features/multi-modality.md - features/visual-generation.md features/overlap-scheduler.md features/paged-attention-ifb-scheduler.md features/parallel-strategy.md diff --git a/docs/source/models/supported-models.md b/docs/source/models/supported-models.md index d34a7ce5b74..74e38b380aa 100644 --- a/docs/source/models/supported-models.md +++ b/docs/source/models/supported-models.md @@ -81,3 +81,7 @@ Note: - I: Image - V: Video - A: Audio + +# Visual Generation Models + +For diffusion-based image and video generation models, see the [Visual Generation](./visual-generation.md) documentation. diff --git a/docs/source/models/visual-generation.md b/docs/source/models/visual-generation.md new file mode 100644 index 00000000000..7e0f2a3a1fc --- /dev/null +++ b/docs/source/models/visual-generation.md @@ -0,0 +1,173 @@ +# Visual Generation (Prototype) + +```{note} +This feature is in **prototype** stage. APIs, supported models, and optimization options are +actively evolving and may change in future releases. +``` + +## Background + +Visual generation models based on diffusion transformers (DiT) have become the standard for high-quality image and video synthesis. These models iteratively denoise latent representations through a learned transformer backbone, then decode the final latents with a VAE to produce pixels. + +TensorRT-LLM **VisualGen** provides a unified inference stack for diffusion models, with a pipeline architecture separate from the LLM inference path. Key capabilities include: + +- A shared pipeline abstraction covering the denoising loop, guidance strategies, and component loading. +- Pluggable attention backends (PyTorch SDPA and TRT-LLM optimized kernels). +- Quantization support (dynamic and static) using the [ModelOpt](https://github.com/NVIDIA/TensorRT-Model-Optimizer) configuration format. +- Multi-GPU parallelism (CFG parallel, Ulysses sequence parallel). +- **TeaCache** — a runtime caching optimization that skips transformer steps when timestep embeddings change slowly. +- `trtllm-serve` integration with OpenAI-compatible API endpoints for image and video generation. + +## Supported Models + +| HuggingFace Model ID | Tasks | +|---|---| +| `black-forest-labs/FLUX.1-dev` | Text-to-Image | +| `black-forest-labs/FLUX.2-dev` | Text-to-Image | +| `Wan-AI/Wan2.1-T2V-1.3B-Diffusers` | Text-to-Video | +| `Wan-AI/Wan2.1-T2V-14B-Diffusers` | Text-to-Video | +| `Wan-AI/Wan2.1-I2V-14B-480P-Diffusers` | Image-to-Video | +| `Wan-AI/Wan2.1-I2V-14B-720P-Diffusers` | Image-to-Video | +| `Wan-AI/Wan2.2-T2V-A14B-Diffusers` | Text-to-Video | +| `Wan-AI/Wan2.2-I2V-A14B-Diffusers` | Image-to-Video | + +Models are auto-detected from the `model_index.json` file in the checkpoint directory. The `AutoPipeline` registry selects the appropriate pipeline class automatically. + +### Feature Matrix + +| Model | FP8 blockwise | NVFP4 | TeaCache | CFG Parallelism | Ulysses Parallelism | Parallel VAE | CUDA Graph | torch.compile | trtllm-serve | +|---|---|---|---|---|---|---|---|---|---| +| **FLUX.1** | Yes | Yes | Yes | No [^1] | Yes | No | Yes | Yes | Yes | +| **FLUX.2** | Yes | Yes | Yes | No [^1] | Yes | No | Yes | Yes | Yes | +| **Wan 2.1** | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | +| **Wan 2.2** | Yes | Yes | No | Yes | Yes | Yes | Yes | Yes | Yes | + +[^1]: FLUX models use embedded guidance and do not have a separate negative prompt path, so CFG parallelism is not applicable. + +## Quick Start + +Here is a simple example to generate a video with Wan 2.1: + +```{literalinclude} ../../../examples/visual_gen/quickstart_example.py + :language: python + :linenos: +``` + +To learn more about VisualGen, see [`examples/visual_gen/`](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/visual_gen) for more examples including text-to-image, image-to-video, and batch generation. + +### Usage with `trtllm-serve` + +The `trtllm-serve` command automatically detects diffusion models (by the presence of `model_index.json`) and launches an OpenAI-compatible visual generation server with image and video generation endpoints. + +See [`examples/visual_gen/serve/`](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/visual_gen/serve) for server launch instructions, example configurations, and API usage. + +### Serving Endpoints + +When served via `trtllm-serve`, the following OpenAI-compatible endpoints are available: + +| Endpoint | Method | Purpose | +|---|---|---| +| `/v1/images/generations` | POST | Synchronous image generation | +| `/v1/images/edits` | POST | Image editing | +| `/v1/videos` | POST | Asynchronous video generation | +| `/v1/videos/generations` | POST | Synchronous video generation | +| `/v1/videos/{id}` | GET | Video status / metadata | +| `/v1/videos/{id}/content` | GET | Download generated video | +| `/v1/videos/{id}` | DELETE | Delete generated video | +| `/v1/videos` | GET | List all videos | + +## Optimizations + +### Quantization + +VisualGen supports both **dynamic quantization** (on-the-fly at weight-loading time from BF16 checkpoints) and **static quantization** (loading pre-quantized checkpoints with embedded scales). Both modes use the [ModelOpt](https://github.com/NVIDIA/TensorRT-Model-Optimizer) `quantization_config` format. + +Dynamic quantization via `--linear_type`: + +```bash +python visual_gen_wan_t2v.py \ + --model_path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \ + --prompt "A cute cat playing piano" \ + --linear_type trtllm-fp8-per-tensor \ + --output_path output_fp8.mp4 +``` + +Supported `--linear_type` values: `default` (BF16/FP16), `trtllm-fp8-per-tensor`, `trtllm-fp8-blockwise`, `trtllm-nvfp4`. + +Programmatic usage via `VisualGenArgs.quant_config`: + +```python +from tensorrt_llm import VisualGenArgs + +args = VisualGenArgs( + checkpoint_path="/path/to/model", + quant_config={"quant_algo": "FP8", "dynamic": True}, +) +``` + +### TeaCache + +TeaCache caches transformer outputs when timestep embeddings change slowly between denoising steps, skipping redundant computation. Enable with `teacache.enable_teacache: true` (YAML config). The `teacache_thresh` parameter controls the similarity threshold. + +### Multi-GPU Parallelism + +Two parallelism modes can be combined: + +- **CFG Parallelism** (`--cfg_size 2`): Splits positive/negative guidance prompts across GPUs. +- **Ulysses Parallelism** (`--ulysses_size N`): Splits the sequence dimension across GPUs for longer sequences. + +Total GPU count = `cfg_size * ulysses_size`. + +## Developer Guide + +### Architecture Overview + +The VisualGen module lives under `tensorrt_llm._torch.visual_gen`. At a high level, the inference flow is: + +1. **Config** — User-facing `VisualGenArgs` (CLI / YAML) is merged with checkpoint metadata into `DiffusionModelConfig`. +2. **Pipeline creation & loading** — `AutoPipeline` detects the model type from `model_index.json`, instantiates the matching `BasePipeline` subclass, and loads weights (with optional dynamic quantization) and standard components (VAE, text encoder, tokenizer, scheduler). +3. **Execution** — `DiffusionExecutor` coordinates multi-GPU inference via worker processes communicating over ZeroMQ IPC. + +Key components: + +| Component | Location | Role | +|---|---|---| +| `VisualGen` | `tensorrt_llm/llmapi/visual_gen.py` | High-level API: manages workers, `generate()` / `generate_async()` | +| `DiffusionExecutor` | `visual_gen/executor.py` | Worker process: loads pipeline, processes requests via ZeroMQ | +| `BasePipeline` | `visual_gen/pipeline.py` | Base class: denoising loop, CFG handling, TeaCache, CUDA graph | +| `AutoPipeline` | `visual_gen/pipeline_registry.py` | Factory: auto-detects model type, selects pipeline class | +| `PipelineLoader` | `visual_gen/pipeline_loader.py` | Resolves checkpoint, loads config/weights, creates pipeline | +| `TeaCacheBackend` | `visual_gen/teacache.py` | Runtime caching for transformer outputs | +| `WeightLoader` | `visual_gen/checkpoints/` | Loads transformer weights from safetensors/bin | + +VisualGen is a parallel inference subsystem within TensorRT-LLM. It shares low-level primitives (`Mapping`, `QuantConfig`, `Linear`, `RMSNorm`, `ZeroMqQueue`, `TrtllmAttention`) but has its own executor, scheduler (diffusers-based), request types, and pipeline architecture separate from the LLM autoregressive decode path. + +### Implementing a New Diffusion Model + +Adding a new model (e.g., a hypothetical "MyDiT") requires four steps. The framework handles weight loading, parallelism, quantization, and serving automatically once the pipeline is registered. + +#### 1. Create the Transformer Module + +Create the DiT backbone in `tensorrt_llm/_torch/visual_gen/models/mydit/transformer_mydit.py`. It should be an `nn.Module` that: + +- Uses existing modules (e.g., `Attention` with configurable attention backend, `Linear` for builtin linear ops) wherever possible. +- Implements `load_weights(weights: Dict[str, torch.Tensor])` to map checkpoint weight names to module parameters. + +#### 2. Create the Pipeline Class + +Create a pipeline class extending `BasePipeline` in `tensorrt_llm/_torch/visual_gen/models/mydit/`. Override methods for transformer initialization, component loading, and inference. `BasePipeline` provides the denoising loop, CFG handling, and TeaCache integration — your pipeline only needs to implement model-specific logic. See `WanPipeline` for a reference implementation. + +#### 3. Register the Pipeline + +Use the `@register_pipeline("MyDiTPipeline")` decorator on your pipeline class to register it in the global `PIPELINE_REGISTRY`. Make sure to export it from `models/__init__.py`. + +#### 4. Update AutoPipeline Detection + +In `pipeline_registry.py`, add detection logic for your model's `_class_name` in `model_index.json`. + +After these steps, the framework automatically handles: + +- Weight loading with optional dynamic quantization via `PipelineLoader` +- Multi-GPU execution via `DiffusionExecutor` +- TeaCache integration (if you call `self._setup_teacache()` in `post_load_weights()`) +- Serving via `trtllm-serve` with the full endpoint set diff --git a/docs/source/overview.md b/docs/source/overview.md index c058b65d2e9..c993f2fcb6a 100644 --- a/docs/source/overview.md +++ b/docs/source/overview.md @@ -23,10 +23,11 @@ TensorRT LLM delivers breakthrough performance on the latest NVIDIA GPUs: ### 🎯 **Comprehensive Model Support** -TensorRT LLM supports the latest and most popular LLM [architectures](https://nvidia.github.io/TensorRT-LLM/models/supported-models.html). +TensorRT LLM supports the latest and most popular LLM and DiT architectures. See [complete list](./models/supported-models.md). - **Language Models**: GPT-OSS, Deepseek-R1/V3, Llama 3/4, Qwen2/3, Gemma 3, Phi 4... - **Multi-modal Models**: LLaVA-NeXT, Qwen2-VL, VILA, Llama 3.2 Vision... +- **[Visual Generation](./models/visual-generation.md) Models**: FLUX, Wan2.1/2.2 for image and video generation. TensorRT LLM strives to support the most popular models on **Day 0**. diff --git a/docs/source/quick-start-guide.md b/docs/source/quick-start-guide.md index 03458cb08fd..b7ea0b49987 100644 --- a/docs/source/quick-start-guide.md +++ b/docs/source/quick-start-guide.md @@ -93,6 +93,7 @@ Pre-configured settings for deploying popular models with `trtllm-serve` can be ``` ## Run Offline Inference with LLM API + The LLM API is a Python API designed to facilitate setup and inference with TensorRT LLM directly within Python. It enables model optimization by simply specifying a HuggingFace repository name or a model checkpoint. The LLM API streamlines the process by managing model loading, optimization, and inference, all through a single `LLM` instance. Here is a simple example to show how to use the LLM API with TinyLlama. @@ -105,6 +106,18 @@ Here is a simple example to show how to use the LLM API with TinyLlama. You can also directly load pre-quantized models [quantized checkpoints on Hugging Face](https://huggingface.co/collections/nvidia/model-optimizer-66aa84f7966b3150262481a4) in the LLM constructor. To learn more about the LLM API, check out the [](llm-api/index) and [](examples/llm_api_examples). + +## Run Offline Inference with VisualGen API + +The VisualGen API provides a similar interface for diffusion-based image and video generation. Here is a simple example to generate a video with Wan 2.1. + +```{literalinclude} ../../examples/visual_gen/quickstart_example.py + :language: python + :linenos: +``` + +To learn more about VisualGen, check out the [Visual Generation](models/visual-generation.md) documentation and [`examples/visual_gen/`](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/visual_gen). + ## Next Steps In this Quick Start Guide, you have: diff --git a/examples/visual_gen/README.md b/examples/visual_gen/README.md index 7b356d0f079..be1b28845a8 100644 --- a/examples/visual_gen/README.md +++ b/examples/visual_gen/README.md @@ -1,6 +1,8 @@ # Visual Generation Examples -Quick reference for running visual generation models (FLUX, WAN). +Quick reference for running visual generation models. +Please refer to [the VisualGen doc](https://nvidia.github.io/TensorRT-LLM/models/visual-generation.html) +about the details of the feature. ## Prerequisites @@ -8,120 +10,53 @@ Quick reference for running visual generation models (FLUX, WAN). # Install dependencies (from repository root) pip install -r requirements-dev.txt pip install git+https://github.com/huggingface/diffusers.git -pip install av ``` -## Quick Start - -```bash -# Set MODEL_ROOT to your model directory (required for examples) -export MODEL_ROOT=/llm-models -# Optional: PROJECT_ROOT defaults to repo root when run from examples/visual_gen - -# Run all examples (auto-detects GPUs) -cd examples/visual_gen -./visual_gen_examples.sh -``` - - -## Environment Variables - -| Variable | Default | Description | -|----------|---------|-------------| -| `PROJECT_ROOT` | Auto-detected | Path to repository root (set when running from `examples/visual_gen`) | -| `MODEL_ROOT` | `/llm-models` | Path to model directory | -| `TLLM_LOG_LEVEL` | `INFO` | Logging level | - ---- ## FLUX (Text-to-Image) -Supports both FLUX.1-dev and FLUX.2-dev. The pipeline type is auto-detected from the model checkpoint (`model_index.json`). - ### Basic Usage -**FLUX.1-dev:** +**FLUX.1:** + ```bash python visual_gen_flux.py \ - --model_path ${MODEL_ROOT}/FLUX.1-dev \ + --model_path black-forest-labs/FLUX.1-dev \ --prompt "A cat sitting on a windowsill" \ + --height 1024 --width 1024 \ --guidance_scale 3.5 \ --output_path output.png ``` -**FLUX.2-dev:** -```bash -python visual_gen_flux.py \ - --model_path ${MODEL_ROOT}/FLUX.2-dev \ - --prompt "A cat sitting on a windowsill" \ - --guidance_scale 4.0 \ - --output_path output.png -``` +**With FP8 quantization:** -**With FP8 Quantization:** ```bash python visual_gen_flux.py \ - --model_path ${MODEL_ROOT}/FLUX.2-dev \ + --model_path black-forest-labs/FLUX.2-dev \ --prompt "A cat sitting on a windowsill" \ --linear_type trtllm-fp8-per-tensor \ - --output_path output.png + --output_path output_fp8.png ``` -**With TeaCache:** -```bash -python visual_gen_flux.py \ - --model_path ${MODEL_ROOT}/FLUX.1-dev \ - --prompt "A cat sitting on a windowsill" \ - --enable_teacache \ - --output_path output.png -``` - -### Batch Mode - -Generate multiple images from a prompts file (one prompt per line): +**Batch mode (multiple prompts from file):** ```bash python visual_gen_flux.py \ - --model_path ${MODEL_ROOT}/FLUX.1-dev \ + --model_path black-forest-labs/FLUX.1-dev \ --prompts_file prompts.txt \ - --output_dir results/bf16/ \ - --seed 42 + --output_dir results/ --seed 42 ``` -```bash -# With FP8 quantization -python visual_gen_flux.py \ - --model_path ${MODEL_ROOT}/FLUX.2-dev \ - --prompts_file prompts.txt \ - --output_dir results/fp8/ \ - --linear_type trtllm-fp8-per-tensor -``` - -Images are saved as `00.png`, `01.png`, etc. with a `timing.json` summary. - -### Multi-GPU Parallelism - -FLUX supports CFG and Ulysses parallelism, same as WAN. - -**CFG + Ulysses (4 GPUs):** -```bash -python visual_gen_flux.py \ - --model_path ${MODEL_ROOT}/FLUX.1-dev \ - --prompts_file prompts.txt \ - --output_dir results/ \ - --cfg_size 2 --ulysses_size 2 -``` - ---- ## WAN (Text-to-Video) ### Basic Usage **Single GPU:** + ```bash python visual_gen_wan_t2v.py \ - --model_path ${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers \ + --model_path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \ --prompt "A cute cat playing piano" \ --height 480 --width 832 --num_frames 33 \ --output_path output.mp4 @@ -130,7 +65,7 @@ python visual_gen_wan_t2v.py \ **With TeaCache:** ```bash python visual_gen_wan_t2v.py \ - --model_path ${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers \ + --model_path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \ --prompt "A cute cat playing piano" \ --height 480 --width 832 --num_frames 33 \ --enable_teacache \ @@ -147,7 +82,7 @@ WAN supports two parallelism modes that can be combined: **Ulysses Only (2 GPUs):** ```bash python visual_gen_wan_t2v.py \ - --model_path ${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers \ + --model_path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \ --prompt "A cute cat playing piano" \ --height 480 --width 832 --num_frames 33 \ --attention_backend TRTLLM \ @@ -159,7 +94,7 @@ GPU Layout: GPU 0-1 share sequence (6 heads each) **CFG Only (2 GPUs):** ```bash python visual_gen_wan_t2v.py \ - --model_path ${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers \ + --model_path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \ --prompt "A cute cat playing piano" \ --height 480 --width 832 --num_frames 33 \ --attention_backend TRTLLM \ @@ -171,7 +106,7 @@ GPU Layout: GPU 0 (positive) | GPU 1 (negative) **CFG + Ulysses (4 GPUs):** ```bash python visual_gen_wan_t2v.py \ - --model_path ${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers \ + --model_path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \ --prompt "A cute cat playing piano" \ --height 480 --width 832 --num_frames 33 \ --attention_backend TRTLLM \ @@ -183,16 +118,26 @@ GPU Layout: GPU 0-1 (positive, Ulysses) | GPU 2-3 (negative, Ulysses) **Large-Scale (8 GPUs):** ```bash python visual_gen_wan_t2v.py \ - --model_path ${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers \ + --model_path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \ --prompt "A cute cat playing piano" \ --height 480 --width 832 --num_frames 33 \ --attention_backend TRTLLM \ --cfg_size 2 --ulysses_size 4 \ --output_path output.mp4 ``` -GPU Layout: GPU 0-3 (positive) | GPU 4-7 (negative) ---- + +## WAN (Image-to-Video) + +```bash +python visual_gen_wan_i2v.py \ + --model_path Wan-AI/Wan2.1-I2V-14B-480P-Diffusers \ + --image_path input_image.jpg \ + --prompt "She turns around and smiles" \ + --height 480 --width 832 --num_frames 81 \ + --output_path output_i2v.mp4 +``` + ## Common Arguments @@ -200,19 +145,16 @@ GPU Layout: GPU 0-3 (positive) | GPU 4-7 (negative) |----------|------|-----|---------|-------------| | `--height` | ✓ | ✓ | 1024 / 720 | Output height | | `--width` | ✓ | ✓ | 1024 / 1280 | Output width | -| `--num_frames` | | ✓ | 81 | Number of frames | +| `--num_frames` | — | ✓ | 81 | Number of frames | | `--steps` | ✓ | ✓ | 50 | Denoising steps | -| `--guidance_scale` | ✓ | ✓ | 3.5 / 5.0 | CFG guidance strength | +| `--guidance_scale` | ✓ | ✓ | 3.5 / 5.0 | Guidance strength | | `--seed` | ✓ | ✓ | 42 | Random seed | | `--enable_teacache` | ✓ | ✓ | False | Cache optimization | | `--teacache_thresh` | ✓ | ✓ | 0.2 | TeaCache similarity threshold | | `--attention_backend` | ✓ | ✓ | VANILLA | VANILLA or TRTLLM | -| `--cfg_size` | ✓ | ✓ | 1 | CFG parallelism | +| `--cfg_size` | — | ✓ | 1 | CFG parallelism | | `--ulysses_size` | ✓ | ✓ | 1 | Sequence parallelism | | `--linear_type` | ✓ | ✓ | default | Quantization type | -| `--prompts_file` | ✓ | | — | Batch mode prompts file | -| `--output_dir` | ✓ | | — | Batch mode output directory | -| `--disable_torch_compile` | ✓ | ✓ | False | Disable torch.compile | ## Troubleshooting @@ -239,18 +181,8 @@ GPU Layout: GPU 0-3 (positive) | GPU 4-7 (negative) ## Output Formats - **FLUX**: `.png` (image) -- **WAN**: `.mp4` (video), `.gif` (animated), `.png` (single frame) - -## Baseline Validation - -Compare with official HuggingFace Diffusers implementation: +- **WAN**: `.mp4` if FFmpeg is installed, otherwise `.avi` (video) -```bash -# Run HuggingFace baselines -./hf_examples.sh - -# Or run individual models -python hf_wan.py --model_path ${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers -``` +## Serving -Compare outputs with same seed for correctness verification. +See [`serve/README.md`](serve/README.md) for `trtllm-serve` examples including image generation (FLUX), video generation (WAN T2V/I2V), and API endpoint reference. diff --git a/examples/visual_gen/hf_examples.sh b/examples/visual_gen/hf_examples.sh deleted file mode 100755 index f2bb84dfd4f..00000000000 --- a/examples/visual_gen/hf_examples.sh +++ /dev/null @@ -1,192 +0,0 @@ -#!/bin/bash -# HuggingFace Baseline Tests - Official Diffusers Implementation -# -# Usage: -# export PROJECT_ROOT=/path/to/tekit -# export MODEL_ROOT=/path/to/models -# ./hf_examples.sh -# -# Or inline: -# PROJECT_ROOT=/workspace/gitlab/tekit-b200 MODEL_ROOT=/llm-models ./hf_examples.sh - -set -e # Exit on error - -# Environment variables with defaults -PROJECT_ROOT=${PROJECT_ROOT:-"/workspace/gitlab/tekit-b200"} -MODEL_ROOT=${MODEL_ROOT:-"/llm-models"} - -# Log configuration -export TLLM_LOG_LEVEL=${TLLM_LOG_LEVEL:-"INFO"} - -echo "============================================" -echo "HuggingFace Diffusers Baseline Tests" -echo "============================================" -echo "PROJECT_ROOT: $PROJECT_ROOT" -echo "MODEL_ROOT: $MODEL_ROOT" -echo "LOG_LEVEL: $TLLM_LOG_LEVEL" -echo "" -echo "Purpose: Establish baseline results using" -echo " official diffusers implementations" -echo "============================================" -echo "" - -# Check Python dependencies -echo "Checking dependencies..." -MISSING_DEPS="" - -if ! python -c "import diffusers" 2>/dev/null; then - echo "❌ ERROR: diffusers not found" - MISSING_DEPS="$MISSING_DEPS diffusers" -fi - -if ! python -c "import torch" 2>/dev/null; then - echo "❌ ERROR: torch not found" - MISSING_DEPS="$MISSING_DEPS torch" -fi - -if [ -n "$MISSING_DEPS" ]; then - echo "" - echo "❌ Missing required dependencies:$MISSING_DEPS" - echo "Install with: pip install$MISSING_DEPS" - exit 1 -fi - -echo "✅ All required dependencies found" -echo "" - -# Detect GPU -if command -v nvidia-smi &> /dev/null; then - GPU_COUNT=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) - echo "Detected $GPU_COUNT GPU(s)" - GPU_NAME=$(nvidia-smi --query-gpu=name --format=csv,noheader | head -1) - echo "GPU: $GPU_NAME" -else - echo "⚠️ WARNING: nvidia-smi not found" - echo " Continuing with CPU (very slow!)" - GPU_COUNT=0 -fi -echo "" - -# Create output directory (in current directory) -OUTPUT_DIR="./baseline_outputs" -mkdir -p "$OUTPUT_DIR" -echo "Output directory: $OUTPUT_DIR ($(pwd)/baseline_outputs)" -echo "" - -############################################# -# WAN (Wan2.1) Baseline Test -############################################# - -echo "============================================" -echo "1/3: WAN Baseline Test" -echo "============================================" -echo "" - -WAN_MODEL="${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers/" -WAN_OUTPUT="${OUTPUT_DIR}/wan_baseline.gif" - -if [ -d "$WAN_MODEL" ]; then - echo "Testing WAN with official diffusers..." - python ${PROJECT_ROOT}/examples/visual_gen/hf_wan.py \ - --model_path "$WAN_MODEL" \ - --output_path "$WAN_OUTPUT" \ - --prompt "A cute cat playing piano" \ - --height 480 \ - --width 832 \ - --num_frames 33 \ - --steps 50 \ - --guidance_scale 7.0 \ - --seed 42 - echo "" - echo "✅ WAN baseline test completed" - echo " Output: $WAN_OUTPUT" -else - echo "⚠️ SKIPPED: WAN model not found at $WAN_MODEL" -fi - -echo "" - -############################################# -# FLUX.1 Baseline Test -############################################# - -echo "============================================" -echo "2/3: FLUX.1 Baseline Test" -echo "============================================" -echo "" - -FLUX1_MODEL="${MODEL_ROOT}/FLUX.1-dev/" -FLUX1_OUTPUT="${OUTPUT_DIR}/flux1_baseline.png" - -if [ -d "$FLUX1_MODEL" ]; then - echo "Testing FLUX.1 with official diffusers..." - python ${PROJECT_ROOT}/examples/visual_gen/hf_flux.py \ - --model_path "$FLUX1_MODEL" \ - --output_path "$FLUX1_OUTPUT" \ - --prompt "A cat holding a sign that says hello world" \ - --height 1024 \ - --width 1024 \ - --steps 50 \ - --guidance_scale 3.5 \ - --seed 42 - echo "" - echo "✅ FLUX.1 baseline test completed" - echo " Output: $FLUX1_OUTPUT" -else - echo "⚠️ SKIPPED: FLUX.1 model not found at $FLUX1_MODEL" -fi - -echo "" - -############################################# -# FLUX.2 Baseline Test -############################################# - -echo "============================================" -echo "3/3: FLUX.2 Baseline Test" -echo "============================================" -echo "" - -FLUX2_MODEL="${MODEL_ROOT}/FLUX.2-dev/" -FLUX2_OUTPUT="${OUTPUT_DIR}/flux2_baseline.png" - -if [ -d "$FLUX2_MODEL" ]; then - echo "Testing FLUX.2 with official diffusers..." - python ${PROJECT_ROOT}/examples/visual_gen/hf_flux2.py \ - --model_path "$FLUX2_MODEL" \ - --output_path "$FLUX2_OUTPUT" \ - --prompt "A cat holding a sign that says hello world" \ - --height 1024 \ - --width 1024 \ - --steps 50 \ - --guidance_scale 3.5 \ - --seed 42 - echo "" - echo "✅ FLUX.2 baseline test completed" - echo " Output: $FLUX2_OUTPUT" -else - echo "⚠️ SKIPPED: FLUX.2 model not found at $FLUX2_MODEL" -fi - -echo "" - -############################################# -# Summary -############################################# - -echo "============================================" -echo "Baseline Tests Complete!" -echo "============================================" -echo "" -echo "Output files saved to: $OUTPUT_DIR" -echo "" -ls -lh "$OUTPUT_DIR" 2>/dev/null || echo "No outputs generated" -echo "" -echo "Next Steps:" -echo " 1. Verify outputs are correct (images/videos generated)" -echo " 2. Compare with custom implementation outputs" -echo " 3. Use these as reference/baseline for debugging" -echo "" -echo "Comparison command:" -echo " diff -r $OUTPUT_DIR " -echo "============================================" diff --git a/examples/visual_gen/hf_flux.py b/examples/visual_gen/hf_flux.py deleted file mode 100755 index aba1848837d..00000000000 --- a/examples/visual_gen/hf_flux.py +++ /dev/null @@ -1,142 +0,0 @@ -#!/usr/bin/env python3 -# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -"""Baseline test for FLUX.1 using official diffusers library.""" - -import sys - -import torch -from output_handler import OutputHandler - -from tensorrt_llm._torch.visual_gen import MediaOutput - - -def test_flux_baseline( - model_path: str, - output_path: str, - prompt: str = "A cat holding a sign that says hello world", - height: int = 1024, - width: int = 1024, - num_inference_steps: int = 50, - guidance_scale: float = 3.5, - seed: int = 42, -): - """Test FLUX.1 image generation with official diffusers.""" - from diffusers import FluxPipeline - - print("=" * 80) - print("FLUX.1 Baseline Test (Official Diffusers)") - print("=" * 80) - print() - - # Load pipeline - print(f"Loading FLUX.1 pipeline from {model_path}...") - pipe = FluxPipeline.from_pretrained(model_path, torch_dtype=torch.bfloat16) - pipe.to("cuda") - print("✅ Pipeline loaded") - print() - - # Check model states - print("Model Training States:") - print(f" text_encoder.training: {pipe.text_encoder.training}") - if hasattr(pipe, "text_encoder_2") and pipe.text_encoder_2 is not None: - print(f" text_encoder_2.training: {pipe.text_encoder_2.training}") - print(f" transformer.training: {pipe.transformer.training}") - print(f" vae.training: {pipe.vae.training}") - print() - - # Generate image - print(f"Generating image: '{prompt}'") - print(f"Parameters: {height}x{width}, {num_inference_steps} steps, guidance={guidance_scale}") - print() - - # Set random seed - generator = torch.Generator(device="cuda").manual_seed(seed) - - result = pipe( - prompt=prompt, - height=height, - width=width, - num_inference_steps=num_inference_steps, - guidance_scale=guidance_scale, - generator=generator, - ) - - # Extract PIL image and convert to (H, W, C) uint8 tensor - import numpy as np - - pil_image = result.images[0] - image = torch.from_numpy(np.array(pil_image)) - - print("=" * 80) - print("Generation Complete!") - print("=" * 80) - print(f"Image shape: {image.shape}") - print(f"Image dtype: {image.dtype}") - print() - - # Save output - print(f"Saving output to {output_path}...") - OutputHandler.save(output=MediaOutput(image=image), output_path=output_path) - print(f"✅ Saved to {output_path}") - print() - - print("=" * 80) - print("FLUX.1 BASELINE TEST PASSED ✅") - print("=" * 80) - return image - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser( - description="HuggingFace Baseline - FLUX.1 Text-to-Image Generation" - ) - - # Model & Input - parser.add_argument( - "--model_path", - type=str, - default="/llm-models/FLUX.1-dev/", - help="Path to FLUX.1 model", - ) - parser.add_argument( - "--prompt", - type=str, - default="A cat holding a sign that says hello world", - help="Text prompt for generation", - ) - parser.add_argument( - "--output_path", type=str, default="flux1_baseline.png", help="Output file path" - ) - - # Generation parameters - parser.add_argument("--height", type=int, default=1024, help="Image height") - parser.add_argument("--width", type=int, default=1024, help="Image width") - parser.add_argument("--steps", type=int, default=50, help="Number of denoising steps") - parser.add_argument( - "--guidance_scale", type=float, default=3.5, help="Guidance scale (embedded guidance)" - ) - parser.add_argument("--seed", type=int, default=42, help="Random seed") - - args = parser.parse_args() - - try: - test_flux_baseline( - args.model_path, - args.output_path, - prompt=args.prompt, - height=args.height, - width=args.width, - num_inference_steps=args.steps, - guidance_scale=args.guidance_scale, - seed=args.seed, - ) - except Exception as e: - print(f"\n❌ ERROR: {e}") - import traceback - - traceback.print_exc() - sys.exit(1) diff --git a/examples/visual_gen/hf_wan.py b/examples/visual_gen/hf_wan.py deleted file mode 100755 index 39197940529..00000000000 --- a/examples/visual_gen/hf_wan.py +++ /dev/null @@ -1,141 +0,0 @@ -#!/usr/bin/env python3 -"""Baseline test for WAN using official diffusers library.""" - -import sys - -import torch -from output_handler import OutputHandler, postprocess_hf_video_tensor - -from tensorrt_llm._torch.visual_gen import MediaOutput - - -def test_wan_baseline( - model_path: str, - output_path: str, - prompt: str = "A cute cat playing piano", - height: int = 480, - width: int = 832, - num_frames: int = 33, - num_inference_steps: int = 50, - guidance_scale: float = 7.0, - seed: int = 42, -): - """Test WAN video generation with official diffusers.""" - from diffusers import WanPipeline - - print("=" * 80) - print("WAN Baseline Test (Official Diffusers)") - print("=" * 80) - print() - - # Load pipeline - print(f"Loading WAN pipeline from {model_path}...") - pipe = WanPipeline.from_pretrained(model_path, torch_dtype=torch.bfloat16) - pipe.to("cuda") - print("✅ Pipeline loaded") - print() - - # Check model states - print("Model Training States:") - print(f" text_encoder.training: {pipe.text_encoder.training}") - print(f" transformer.training: {pipe.transformer.training}") - print(f" vae.training: {pipe.vae.training}") - print() - - # Generate video - print(f"Generating video: '{prompt}'") - print( - f"Parameters: {height}x{width}, {num_frames} frames, {num_inference_steps} steps, guidance={guidance_scale}" - ) - print() - - # Set random seed - generator = torch.Generator(device="cuda").manual_seed(seed) - - result = pipe( - prompt=prompt, - height=height, - width=width, - num_frames=num_frames, - num_inference_steps=num_inference_steps, - guidance_scale=guidance_scale, - generator=generator, - output_type="pt", - return_dict=False, - ) - - video = result[0] - - # Post-process video tensor: (B, T, C, H, W) -> (T, H, W, C) uint8 - video = postprocess_hf_video_tensor(video, remove_batch_dim=True) - - print("=" * 80) - print("Generation Complete!") - print("=" * 80) - print(f"Video shape: {video.shape}") - print(f"Video dtype: {video.dtype}") - print() - - # Save output - print(f"Saving output to {output_path}...") - OutputHandler.save(output=MediaOutput(video=video), output_path=output_path, frame_rate=24.0) - print(f"✅ Saved to {output_path}") - print() - - print("=" * 80) - print("WAN BASELINE TEST PASSED ✅") - print("=" * 80) - return video - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser( - description="HuggingFace Baseline - WAN Text-to-Video Generation" - ) - - # Model & Input - parser.add_argument( - "--model_path", - type=str, - default="/llm-models/Wan2.1-T2V-1.3B-Diffusers/", - help="Path to WAN model", - ) - parser.add_argument( - "--prompt", type=str, default="A cute cat playing piano", help="Text prompt for generation" - ) - parser.add_argument( - "--output_path", type=str, default="wan_baseline.gif", help="Output file path" - ) - - # Generation parameters - parser.add_argument("--height", type=int, default=480, help="Video height") - parser.add_argument("--width", type=int, default=832, help="Video width") - parser.add_argument("--num_frames", type=int, default=33, help="Number of frames to generate") - parser.add_argument("--steps", type=int, default=50, help="Number of denoising steps") - parser.add_argument( - "--guidance_scale", type=float, default=7.0, help="Classifier-free guidance scale" - ) - parser.add_argument("--seed", type=int, default=42, help="Random seed") - - args = parser.parse_args() - - try: - test_wan_baseline( - args.model_path, - args.output_path, - prompt=args.prompt, - height=args.height, - width=args.width, - num_frames=args.num_frames, - num_inference_steps=args.steps, - guidance_scale=args.guidance_scale, - seed=args.seed, - ) - except Exception as e: - print(f"\n❌ ERROR: {e}") - import traceback - - traceback.print_exc() - sys.exit(1) diff --git a/examples/visual_gen/output_handler.py b/examples/visual_gen/output_handler.py deleted file mode 100644 index a360d681f9f..00000000000 --- a/examples/visual_gen/output_handler.py +++ /dev/null @@ -1,237 +0,0 @@ -"""Unified output handler for diffusion model outputs.""" - -import os -from typing import Optional - -import torch -from PIL import Image - -from tensorrt_llm import logger -from tensorrt_llm.llmapi.visual_gen import MediaOutput - - -def postprocess_hf_video_tensor(video: torch.Tensor, remove_batch_dim: bool = True) -> torch.Tensor: - """Post-process video tensor from HuggingFace pipeline output to final format. - - HuggingFace pipelines with output_type="pt" return videos in (B, T, C, H, W) format, - which is different from VAE decoder output format. - - Args: - video: Video tensor in (B, T, C, H, W) format from HuggingFace pipeline - remove_batch_dim: Whether to remove batch dimension. Default True for typical - single-batch video generation. - - Returns: - Post-processed video tensor: - - If remove_batch_dim=True: (T, H, W, C) uint8 tensor - - If remove_batch_dim=False: (B, T, H, W, C) uint8 tensor - - Note: - Assumes video values are in [-1, 1] range (standard pipeline output). - """ - # Remove batch dimension first if requested - if remove_batch_dim: - video = video[0] # (B, T, C, H, W) -> (T, C, H, W) - video = video.permute(0, 2, 3, 1) # (T, C, H, W) -> (T, H, W, C) - else: - video = video.permute(0, 1, 3, 4, 2) # (B, T, C, H, W) -> (B, T, H, W, C) - - # Normalize to [0, 1] range - video = (video / 2 + 0.5).clamp(0, 1) - - # Convert to uint8 - video = (video * 255).round().to(torch.uint8) - - return video - - -def postprocess_hf_image_tensor(image: torch.Tensor) -> torch.Tensor: - """Post-process image tensor from HuggingFace pipeline output to final format. - - HuggingFace pipelines with output_type="pt" return images in (B, C, H, W) format. - - Args: - image: Image tensor in (B, C, H, W) or (C, H, W) format from HuggingFace pipeline - - Returns: - Post-processed image tensor in (H, W, C) uint8 format - - Note: - Assumes image values are in [-1, 1] range (standard pipeline output). - """ - # Remove batch dimension if present - if image.ndim == 4: - image = image[0] # (B, C, H, W) -> (C, H, W) - - # Convert to (H, W, C) format - image = image.permute(1, 2, 0) # (C, H, W) -> (H, W, C) - - # Normalize to [0, 1] range - image = (image / 2 + 0.5).clamp(0, 1) - - # Convert to uint8 - image = (image * 255).round().to(torch.uint8) - - return image - - -class OutputHandler: - """Handle saving of generated outputs in various formats. - - Supports MediaOutput from all models: - - Video models (WAN): MediaOutput(video=torch.Tensor) - - Image models: MediaOutput(image=torch.Tensor) - - Video+Audio models: MediaOutput(video=torch.Tensor, audio=torch.Tensor) - - Supported output formats: - - .png: Save single image or middle frame - - .gif: Save video as animated GIF (no audio) - - .mp4: Save video with audio (requires diffusers export_utils) - """ - - @staticmethod - def save(output: MediaOutput, output_path: str, frame_rate: float = 24.0): - """Save output based on content type and file extension. - - Args: - output: MediaOutput containing model outputs (image/video/audio) - output_path: Path to save the output file - frame_rate: Frames per second for video output (default: 24.0) - """ - if not isinstance(output, MediaOutput): - raise ValueError(f"Expected output to be MediaOutput, got {type(output)}") - - file_ext = os.path.splitext(output_path)[1].lower() - - # Determine content type - if output.image is not None: - OutputHandler._save_image(output.image, output_path, file_ext) - elif output.video is not None: - OutputHandler._save_video(output.video, output.audio, output_path, file_ext, frame_rate) - else: - raise ValueError("Unknown output format. MediaOutput has no image or video data.") - - @staticmethod - def _save_image(image: torch.Tensor, output_path: str, file_ext: str): - """Save single image output. - - Args: - image: Image as torch tensor (H, W, C) uint8 - output_path: Path to save the image - file_ext: File extension (.png, .jpg, etc.) - """ - if file_ext not in [".png", ".jpg", ".jpeg"]: - logger.warning(f"Image output requested with {file_ext}, defaulting to .png") - output_path = output_path.replace(file_ext, ".png") - - # Convert torch.Tensor to PIL Image and save - image_np = image.cpu().numpy() - Image.fromarray(image_np).save(output_path) - logger.info(f"Saved image to {output_path}") - - @staticmethod - def _save_video( - video: torch.Tensor, - audio: Optional[torch.Tensor], - output_path: str, - file_ext: str, - frame_rate: float, - ): - """Save video output with optional audio. - - Args: - video: Video frames as torch tensor (T, H, W, C) with dtype uint8 - audio: Optional audio as torch tensor - output_path: Path to save the video - file_ext: File extension (.mp4, .gif, .png) - frame_rate: Frames per second - """ - if file_ext == ".mp4": - OutputHandler._save_mp4(video, audio, output_path, frame_rate) - elif file_ext == ".gif": - OutputHandler._save_gif(video, output_path, frame_rate) - elif file_ext == ".png": - OutputHandler._save_middle_frame(video, output_path) - else: - logger.warning(f"Unsupported video output format: {file_ext}, defaulting to .png") - output_path = output_path.replace(file_ext, ".png") - OutputHandler._save_middle_frame(video, output_path) - - @staticmethod - def _save_mp4( - video: torch.Tensor, audio: Optional[torch.Tensor], output_path: str, frame_rate: float - ): - """Save video with optional audio as MP4. - - Args: - video: Video frames as torch tensor (T, H, W, C) uint8 - audio: Optional audio as torch tensor (float32) - output_path: Output path for MP4 - frame_rate: Frames per second - """ - try: - from diffusers.pipelines.ltx2.export_utils import encode_video - - # Prepare audio if present - audio_prepared = audio.float() if audio is not None else None - - # encode_video expects (T, H, W, C) uint8 video and float32 audio - encode_video( - video, - fps=frame_rate, - audio=audio_prepared, - audio_sample_rate=24000 if audio_prepared is not None else None, - output_path=output_path, - ) - logger.info(f"Saved video{' with audio' if audio is not None else ''} to {output_path}") - - except ImportError: - logger.warning( - "diffusers export_utils (encode_video) not available. " - "Falling back to saving middle frame as PNG." - ) - png_path = output_path.replace(".mp4", ".png") - OutputHandler._save_middle_frame(video, png_path) - - @staticmethod - def _save_gif(video: torch.Tensor, output_path: str, frame_rate: float): - """Save video as animated GIF. - - Args: - video: Video frames as torch tensor (T, H, W, C) uint8 - output_path: Output path for GIF - frame_rate: Frames per second - """ - # Convert torch.Tensor to numpy for PIL - video_np = video.cpu().numpy() - - # Convert to list of PIL Images - frames = [Image.fromarray(video_np[i]) for i in range(video_np.shape[0])] - - # Save as animated GIF - duration_ms = int(1000 / frame_rate) - frames[0].save( - output_path, - save_all=True, - append_images=frames[1:], - optimize=False, - duration=duration_ms, - loop=0, - ) - logger.info(f"Saved video as GIF to {output_path} ({len(frames)} frames)") - - @staticmethod - def _save_middle_frame(video: torch.Tensor, output_path: str): - """Save middle frame of video as PNG. - - Args: - video: Video frames as torch tensor (T, H, W, C) uint8 - output_path: Output path for PNG - """ - # Convert torch.Tensor to numpy for PIL - video_np = video.cpu().numpy() - - # Extract middle frame - frame_idx = video_np.shape[0] // 2 - Image.fromarray(video_np[frame_idx]).save(output_path) - logger.info(f"Saved frame {frame_idx} to {output_path}") diff --git a/examples/visual_gen/quickstart_example.py b/examples/visual_gen/quickstart_example.py new file mode 100644 index 00000000000..5b60059a2ab --- /dev/null +++ b/examples/visual_gen/quickstart_example.py @@ -0,0 +1,27 @@ +#! /usr/bin/env python +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from tensorrt_llm import VisualGen, VisualGenParams +from tensorrt_llm.serve.media_storage import MediaStorage + + +def main(): + visual_gen = VisualGen(model_path="Wan-AI/Wan2.1-T2V-1.3B-Diffusers") + params = VisualGenParams( + height=480, + width=832, + num_frames=81, + guidance_scale=5.0, + num_inference_steps=50, + seed=42, + ) + output = visual_gen.generate( + inputs="A cat sitting on a windowsill", + params=params, + ) + MediaStorage.save_video(output.video, "output.avi", frame_rate=params.frame_rate) + + +if __name__ == "__main__": + main() diff --git a/examples/visual_gen/serve/configs/flux1.yml b/examples/visual_gen/serve/configs/flux1.yml index 57aa695e46c..945c27a0340 100644 --- a/examples/visual_gen/serve/configs/flux1.yml +++ b/examples/visual_gen/serve/configs/flux1.yml @@ -1,5 +1,3 @@ -linear: - type: default teacache: enable_teacache: true teacache_thresh: 0.2 diff --git a/examples/visual_gen/serve/configs/wan.yml b/examples/visual_gen/serve/configs/wan.yml index 71286fb6e93..0aacfd56a75 100644 --- a/examples/visual_gen/serve/configs/wan.yml +++ b/examples/visual_gen/serve/configs/wan.yml @@ -1,5 +1,3 @@ -linear: - type: default teacache: enable_teacache: true teacache_thresh: 0.2 diff --git a/examples/visual_gen/visual_gen_examples.sh b/examples/visual_gen/visual_gen_examples.sh deleted file mode 100755 index a55342ad8f2..00000000000 --- a/examples/visual_gen/visual_gen_examples.sh +++ /dev/null @@ -1,288 +0,0 @@ -#!/bin/bash -# Visual Generation Examples - Test different models and configurations -# -# This script runs a comprehensive suite of visual generation examples including: -# - WAN T2V: Baseline, TeaCache, CFG parallelism, Ulysses parallelism, and combinations -# - WAN I2V: Baseline, TeaCache, CFG parallelism, Ulysses parallelism, and combinations -# -# The script automatically detects GPU count and runs appropriate examples: -# - 1 GPU: Single-GPU examples only -# - 2 GPUs: + CFG parallelism, Ulysses parallelism -# - 4 GPUs: + CFG + Ulysses combined -# - 8 GPUs: + Large-scale high-resolution examples -# -# Usage: -# export MODEL_ROOT=/path/to/models # required -# # Optional: PROJECT_ROOT auto-detected when run from examples/visual_gen -# cd examples/visual_gen && ./visual_gen_examples.sh -# -# Or inline: -# MODEL_ROOT=/llm-models ./visual_gen_examples.sh - -set -e # Exit on error - -# Environment variables with defaults -# PROJECT_ROOT: auto-detect repo root when run from examples/visual_gen -SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" -PROJECT_ROOT=${PROJECT_ROOT:-"$(cd "${SCRIPT_DIR}/../.." && pwd)"} -MODEL_ROOT=${MODEL_ROOT:-"/llm-models"} - -# Log configuration -export TLLM_LOG_LEVEL=${TLLM_LOG_LEVEL:-"INFO"} - -echo "============================================" -echo "Visual Generation Examples" -echo "============================================" -echo "PROJECT_ROOT: $PROJECT_ROOT" -echo "MODEL_ROOT: $MODEL_ROOT" -echo "LOG_LEVEL: $TLLM_LOG_LEVEL" -echo "============================================" -echo "" - - -# Detect GPU count -if command -v nvidia-smi &> /dev/null; then - GPU_COUNT=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) - echo "Detected $GPU_COUNT GPU(s)" - if [ "$GPU_COUNT" -lt 2 ]; then - echo "Note: Multi-GPU examples will be skipped" - SKIP_MULTI_GPU=1 - elif [ "$GPU_COUNT" -ge 8 ]; then - echo "Note: Will run all examples including 8-GPU configurations" - elif [ "$GPU_COUNT" -ge 4 ]; then - echo "Note: Will run examples up to 4-GPU configurations" - else - echo "Note: Will run 2-GPU examples only" - fi -else - echo "WARNING: nvidia-smi not found. Assuming single GPU." - GPU_COUNT=1 - SKIP_MULTI_GPU=1 -fi -echo "" - -############################################# -# WAN (Wan2.1) Text-to-Video Examples -############################################# -# Demonstrates: -# - Single GPU: Baseline and TeaCache -# - 2 GPUs: CFG only, Ulysses only -# - 4 GPUs: CFG + Ulysses combined -# - 8 GPUs: Large-scale parallelism -############################################# - -echo "=== WAN Example 1: Baseline (no optimization) ===" -python ${PROJECT_ROOT}/examples/visual_gen/visual_gen_wan_t2v.py \ - --height 480 \ - --width 832 \ - --num_frames 33 \ - --model_path ${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers/ \ - --prompt "A cute cat playing piano" \ - --output_path wan_cat_piano.png - -echo "" -echo "=== WAN Example 2: With TeaCache ===" -python ${PROJECT_ROOT}/examples/visual_gen/visual_gen_wan_t2v.py \ - --height 480 \ - --width 832 \ - --num_frames 33 \ - --model_path ${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers \ - --prompt "A cute cat playing piano" \ - --output_path wan_cat_piano_teacache.png \ - --enable_teacache - -if [ -z "$SKIP_MULTI_GPU" ]; then - echo "" - echo "=== WAN Example 3: CFG Only (2 GPUs) ===" - python ${PROJECT_ROOT}/examples/visual_gen/visual_gen_wan_t2v.py \ - --height 480 \ - --width 832 \ - --num_frames 33 \ - --model_path ${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers/ \ - --prompt "A cute cat playing piano" \ - --output_path wan_cfg_2gpu.mp4 \ - --attention_backend TRTLLM \ - --cfg_size 2 \ - --ulysses_size 1 -else - echo "" - echo "=== WAN Example 3: Skipped (requires 2 GPUs) ===" -fi - -if [ -z "$SKIP_MULTI_GPU" ]; then - echo "" - echo "=== WAN Example 4: Ulysses Only (2 GPUs) ===" - python ${PROJECT_ROOT}/examples/visual_gen/visual_gen_wan_t2v.py \ - --height 480 \ - --width 832 \ - --num_frames 33 \ - --model_path ${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers/ \ - --prompt "A cute cat playing piano" \ - --output_path wan_ulysses_2gpu.mp4 \ - --attention_backend TRTLLM \ - --cfg_size 1 \ - --ulysses_size 2 -else - echo "" - echo "=== WAN Example 4: Skipped (requires 2 GPUs) ===" -fi - -if [ "$GPU_COUNT" -ge 4 ]; then - echo "" - echo "=== WAN Example 5: CFG + Ulysses (4 GPUs) ===" - python ${PROJECT_ROOT}/examples/visual_gen/visual_gen_wan_t2v.py \ - --height 480 \ - --width 832 \ - --num_frames 33 \ - --model_path ${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers/ \ - --prompt "A cute cat playing piano" \ - --output_path wan_cfg_ulysses_4gpu.mp4 \ - --attention_backend TRTLLM \ - --cfg_size 2 \ - --ulysses_size 2 -else - echo "" - echo "=== WAN Example 5: Skipped (requires 4 GPUs) ===" -fi - -if [ "$GPU_COUNT" -ge 8 ]; then - echo "" - echo "=== WAN Example 6: Large-Scale (8 GPUs) ===" - python ${PROJECT_ROOT}/examples/visual_gen/visual_gen_wan_t2v.py \ - --height 480 \ - --width 832 \ - --num_frames 33 \ - --model_path ${MODEL_ROOT}/Wan2.1-T2V-1.3B-Diffusers/ \ - --prompt "A cute cat playing piano" \ - --output_path wan_cfg_ulysses_8gpu.mp4 \ - --attention_backend TRTLLM \ - --cfg_size 2 \ - --ulysses_size 4 -else - echo "" - echo "=== WAN Example 6: Skipped (requires 8 GPUs) ===" -fi - -############################################# -# WAN 2.2 (Two-Stage) Text-to-Video Examples -############################################# - -echo "" -echo "=== WAN 2.2 T2V Example: Two-stage with optimizations (FP8 + TRT-LLM + TeaCache) ===" -python ${PROJECT_ROOT}/examples/visual_gen/visual_gen_wan_t2v.py \ - --height 720 \ - --width 1280 \ - --num_frames 81 \ - --model_path ${MODEL_ROOT}/Wan2.2-T2V-A14B-Diffusers \ - --prompt "A cute cat playing piano" \ - --output_path wan22_t2v_cat_piano_optimized.gif \ - --linear_type trtllm-fp8-blockwise \ - --attention_backend TRTLLM \ - --enable_teacache \ - --teacache_thresh 0.2 \ - --guidance_scale 3.0 \ - --guidance_scale_2 2.5 \ - --boundary_ratio 0.85 - -############################################# -# WAN 2.1 Image-to-Video Examples -############################################# - -echo "" -echo "=== WAN 2.1 I2V Example: Single-stage with optimizations (FP8 + TRT-LLM + TeaCache) ===" -python ${PROJECT_ROOT}/examples/visual_gen/visual_gen_wan_i2v.py \ - --height 480 \ - --width 832 \ - --num_frames 33 \ - --model_path ${MODEL_ROOT}/Wan2.1-I2V-14B-480P-Diffusers \ - --image_path ${PROJECT_ROOT}/examples/visual_gen/cat_piano.png \ - --prompt "It snows as the cat plays piano, lots of snow \ - appearing all over the screen, snowflakes, blizzard, - gradually more snow" \ - --negative_prompt "blurry, low quality" \ - --output_path wan21_i2v_cat_piano_optimized.gif \ - --linear_type trtllm-fp8-per-tensor \ - --attention_backend TRTLLM \ - --enable_teacache \ - --teacache_thresh 0.2 \ - --guidance_scale 6.0 - -############################################# -# WAN 2.2 (Two-Stage) Image-to-Video Examples -############################################# - -echo "" -echo "=== WAN 2.2 I2V Example: Two-stage with optimizations (FP8 + TRT-LLM + TeaCache) ===" -python ${PROJECT_ROOT}/examples/visual_gen/visual_gen_wan_i2v.py \ - --height 480 \ - --width 832 \ - --num_frames 81 \ - --model_path ${MODEL_ROOT}/Wan2.2-I2V-A14B-Diffusers \ - --image_path ${PROJECT_ROOT}/examples/visual_gen/cat_piano.png \ - --prompt "It snows as the cat plays piano, lots of snow \ - appearing all over the screen, snowflakes, blizzard, - gradually more snow" \ - --negative_prompt "blurry, low quality" \ - --output_path wan22_i2v_cat_piano_optimized.gif \ - --linear_type trtllm-fp8-blockwise \ - --attention_backend TRTLLM \ - --enable_teacache \ - --teacache_thresh 0.2 \ - --guidance_scale 6.0 \ - --guidance_scale_2 5.0 \ - --boundary_ratio 0.85 - -############################################# -# FLUX.1 Text-to-Image Examples -############################################# - -echo "" -echo "=== FLUX.1 Example 1: Baseline ===" -python ${PROJECT_ROOT}/examples/visual_gen/visual_gen_flux.py \ - --height 1024 \ - --width 1024 \ - --prompt "A cat holding a sign that says hello world" \ - --output_path flux1_cat_sign.png \ - --model_path ${MODEL_ROOT}/FLUX.1-dev/ \ - --guidance_scale 3.5 - -echo "" -echo "=== FLUX.1 Example 2: With FP8 Quantization ===" -python ${PROJECT_ROOT}/examples/visual_gen/visual_gen_flux.py \ - --height 1024 \ - --width 1024 \ - --prompt "A cat holding a sign that says hello world" \ - --output_path flux1_cat_sign_fp8.png \ - --model_path ${MODEL_ROOT}/FLUX.1-dev/ \ - --guidance_scale 3.5 \ - --linear_type trtllm-fp8-per-tensor - -############################################# -# FLUX.2 Text-to-Image Examples -############################################# - -echo "" -echo "=== FLUX.2 Example 1: Baseline ===" -python ${PROJECT_ROOT}/examples/visual_gen/visual_gen_flux.py \ - --height 1024 \ - --width 1024 \ - --prompt "A cat holding a sign that says hello world" \ - --output_path flux2_cat_sign.png \ - --model_path ${MODEL_ROOT}/FLUX.2-dev/ \ - --guidance_scale 4.0 - -echo "" -echo "=== FLUX.2 Example 2: With TeaCache ===" -python ${PROJECT_ROOT}/examples/visual_gen/visual_gen_flux.py \ - --height 1024 \ - --width 1024 \ - --prompt "A cat holding a sign that says hello world" \ - --output_path flux2_cat_sign_teacache.png \ - --model_path ${MODEL_ROOT}/FLUX.2-dev/ \ - --guidance_scale 4.0 \ - --enable_teacache - -echo "" -echo "============================================" -echo "All examples completed successfully!" -echo "============================================" diff --git a/examples/visual_gen/visual_gen_flux.py b/examples/visual_gen/visual_gen_flux.py index 9284d9ce15f..cf108dbd2fa 100755 --- a/examples/visual_gen/visual_gen_flux.py +++ b/examples/visual_gen/visual_gen_flux.py @@ -38,10 +38,8 @@ import os import time -from output_handler import OutputHandler - -from tensorrt_llm import logger -from tensorrt_llm.llmapi.visual_gen import VisualGen, VisualGenParams +from tensorrt_llm import VisualGen, VisualGenArgs, VisualGenParams, logger +from tensorrt_llm.serve.media_storage import MediaStorage logger.set_level("info") @@ -200,61 +198,52 @@ def load_prompts(prompts_file, num_prompts=None): return prompts -def build_diffusion_config(args): - """Build diffusion_config dict from parsed args.""" - quant_config = None - if args.linear_type == "trtllm-fp8-per-tensor": - quant_config = {"quant_algo": "FP8", "dynamic": True} - elif args.linear_type == "trtllm-fp8-blockwise": - quant_config = {"quant_algo": "FP8_BLOCK_SCALES", "dynamic": True} - elif args.linear_type == "trtllm-nvfp4": - quant_config = {"quant_algo": "NVFP4", "dynamic": True} - - diffusion_config = { - "revision": args.revision, - "attention": { - "backend": args.attention_backend, - }, - "teacache": { +def _linear_type_to_quant_config(linear_type: str): + """Map --linear_type CLI shortcut to quant_config dict for VisualGenArgs.""" + mapping = { + "trtllm-fp8-per-tensor": {"quant_algo": "FP8", "dynamic": True}, + "trtllm-fp8-blockwise": {"quant_algo": "FP8_BLOCK_SCALES", "dynamic": True}, + "trtllm-nvfp4": {"quant_algo": "NVFP4", "dynamic": True}, + } + return mapping.get(linear_type) + + +def build_diffusion_args(args) -> VisualGenArgs: + """Build VisualGenArgs from parsed CLI args.""" + kwargs = dict( + revision=args.revision, + attention={"backend": args.attention_backend}, + teacache={ "enable_teacache": args.enable_teacache, "teacache_thresh": args.teacache_thresh, "use_ret_steps": args.use_ret_steps, }, - "parallel": { + parallel={ "dit_ulysses_size": args.ulysses_size, }, - "torch_compile": { + torch_compile={ "enable_torch_compile": not args.disable_torch_compile, "enable_fullgraph": args.enable_fullgraph, "enable_autotune": not args.disable_autotune, }, - "cuda_graph": { - "enable_cuda_graph": args.enable_cudagraph, - }, - "pipeline": { - "enable_layerwise_nvtx_marker": args.enable_layerwise_nvtx_marker, - }, - } - + cuda_graph={"enable_cuda_graph": args.enable_cudagraph}, + pipeline={"enable_layerwise_nvtx_marker": args.enable_layerwise_nvtx_marker}, + ) + quant_config = _linear_type_to_quant_config(args.linear_type) if quant_config is not None: - diffusion_config["quant_config"] = quant_config - - return diffusion_config + kwargs["quant_config"] = quant_config + return VisualGenArgs(**kwargs) def main(): args = parse_args() - n_workers = args.ulysses_size - diffusion_config = build_diffusion_config(args) + diffusion_args = build_diffusion_args(args) - logger.info( - f"Initializing VisualGen: world_size={n_workers} (ulysses_size={args.ulysses_size})" - ) + logger.info(f"Initializing VisualGen: ulysses_size={diffusion_args.parallel.dit_ulysses_size}") visual_gen = VisualGen( model_path=args.model_path, - n_workers=n_workers, - diffusion_config=diffusion_config, + diffusion_args=diffusion_args, ) try: @@ -285,7 +274,7 @@ def main(): elapsed = time.time() - start_time output_path = os.path.join(args.output_dir, f"{i:02d}.png") - OutputHandler.save(output, output_path) + MediaStorage.save_image(output.image, output_path) logger.info(f" Saved {output_path} ({elapsed:.1f}s)") timing_records.append( @@ -343,7 +332,7 @@ def main(): logger.info(f"Generation completed in {time.time() - start_time:.2f}s") - OutputHandler.save(output, args.output_path) + MediaStorage.save_image(output.image, args.output_path) finally: visual_gen.shutdown() diff --git a/examples/visual_gen/visual_gen_wan_i2v.py b/examples/visual_gen/visual_gen_wan_i2v.py index 050215b5100..f561bb2adf4 100644 --- a/examples/visual_gen/visual_gen_wan_i2v.py +++ b/examples/visual_gen/visual_gen_wan_i2v.py @@ -1,13 +1,14 @@ #!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + """WAN Image-to-Video generation using TensorRT-LLM Visual Generation.""" import argparse import time -from output_handler import OutputHandler - -from tensorrt_llm import logger -from tensorrt_llm.llmapi.visual_gen import VisualGen, VisualGenParams +from tensorrt_llm import VisualGen, VisualGenArgs, VisualGenParams, logger +from tensorrt_llm.serve.media_storage import MediaStorage logger.set_level("info") @@ -161,59 +162,53 @@ def parse_args(): return parser.parse_args() +def _linear_type_to_quant_config(linear_type: str): + """Map --linear_type CLI shortcut to quant_config dict for VisualGenArgs.""" + mapping = { + "trtllm-fp8-per-tensor": {"quant_algo": "FP8", "dynamic": True}, + "trtllm-fp8-blockwise": {"quant_algo": "FP8_BLOCK_SCALES", "dynamic": True}, + "trtllm-nvfp4": {"quant_algo": "NVFP4", "dynamic": True}, + } + return mapping.get(linear_type) + + def main(): args = parse_args() - n_workers = args.cfg_size * args.ulysses_size - - # Convert linear_type to quant_config - quant_config = None - if args.linear_type == "trtllm-fp8-per-tensor": - quant_config = {"quant_algo": "FP8", "dynamic": True} - elif args.linear_type == "trtllm-fp8-blockwise": - quant_config = {"quant_algo": "FP8_BLOCK_SCALES", "dynamic": True} - elif args.linear_type == "trtllm-nvfp4": - quant_config = {"quant_algo": "NVFP4", "dynamic": True} - - diffusion_config = { - "model_type": "wan2", - "attention": { - "backend": args.attention_backend, - }, - "teacache": { + kwargs = dict( + attention={"backend": args.attention_backend}, + teacache={ "enable_teacache": args.enable_teacache, "teacache_thresh": args.teacache_thresh, "use_ret_steps": args.use_ret_steps, }, - "parallel": { + parallel={ "dit_cfg_size": args.cfg_size, "dit_ulysses_size": args.ulysses_size, "enable_parallel_vae": not args.disable_parallel_vae, }, - "torch_compile": { + torch_compile={ "enable_torch_compile": not args.disable_torch_compile, "enable_fullgraph": args.enable_fullgraph, "enable_autotune": not args.disable_autotune, }, - "cuda_graph": { - "enable_cuda_graph": args.enable_cudagraph, - }, - "pipeline": { - "enable_layerwise_nvtx_marker": args.enable_layerwise_nvtx_marker, - }, - } - + cuda_graph={"enable_cuda_graph": args.enable_cudagraph}, + pipeline={"enable_layerwise_nvtx_marker": args.enable_layerwise_nvtx_marker}, + ) + quant_config = _linear_type_to_quant_config(args.linear_type) if quant_config is not None: - diffusion_config["quant_config"] = quant_config + kwargs["quant_config"] = quant_config + + diffusion_args = VisualGenArgs(**kwargs) logger.info( - f"Initializing VisualGen: world_size={n_workers} " - f"(cfg_size={args.cfg_size}, ulysses_size={args.ulysses_size})" + f"Initializing VisualGen: " + f"cfg_size={diffusion_args.parallel.dit_cfg_size}, " + f"ulysses_size={diffusion_args.parallel.dit_ulysses_size}" ) visual_gen = VisualGen( model_path=args.model_path, - n_workers=n_workers, - diffusion_config=diffusion_config, + diffusion_args=diffusion_args, ) try: @@ -249,7 +244,7 @@ def main(): logger.info(f"Generation completed in {time.time() - start_time:.2f}s") - OutputHandler.save(output, args.output_path, frame_rate=16.0) + MediaStorage.save_video(output.video, args.output_path, audio=output.audio, frame_rate=16.0) finally: visual_gen.shutdown() diff --git a/examples/visual_gen/visual_gen_wan_t2v.py b/examples/visual_gen/visual_gen_wan_t2v.py index 29c1da66da9..b77ca00f5ea 100755 --- a/examples/visual_gen/visual_gen_wan_t2v.py +++ b/examples/visual_gen/visual_gen_wan_t2v.py @@ -1,13 +1,14 @@ #!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + """WAN Text-to-Video generation using TensorRT-LLM Visual Generation.""" import argparse import time -from output_handler import OutputHandler - -from tensorrt_llm import logger -from tensorrt_llm.llmapi.visual_gen import VisualGen, VisualGenParams +from tensorrt_llm import VisualGen, VisualGenArgs, VisualGenParams, logger +from tensorrt_llm.serve.media_storage import MediaStorage logger.set_level("info") @@ -161,11 +162,19 @@ def parse_args(): return parser.parse_args() +def _linear_type_to_quant_config(linear_type: str): + """Map --linear_type CLI shortcut to quant_config dict for VisualGenArgs.""" + mapping = { + "trtllm-fp8-per-tensor": {"quant_algo": "FP8", "dynamic": True}, + "trtllm-fp8-blockwise": {"quant_algo": "FP8_BLOCK_SCALES", "dynamic": True}, + "trtllm-nvfp4": {"quant_algo": "NVFP4", "dynamic": True}, + } + return mapping.get(linear_type) + + def main(): args = parse_args() - n_workers = args.cfg_size * args.ulysses_size - if args.ulysses_size > 1: num_heads = 12 logger.info( @@ -174,55 +183,41 @@ def main(): f"{num_heads // args.ulysses_size} heads per GPU" ) - # Convert linear_type to quant_config - quant_config = None - if args.linear_type == "trtllm-fp8-per-tensor": - quant_config = {"quant_algo": "FP8", "dynamic": True} - elif args.linear_type == "trtllm-fp8-blockwise": - quant_config = {"quant_algo": "FP8_BLOCK_SCALES", "dynamic": True} - elif args.linear_type == "trtllm-nvfp4": - quant_config = {"quant_algo": "NVFP4", "dynamic": True} - - diffusion_config = { - "model_type": "wan2", - "revision": args.revision, - "attention": { - "backend": args.attention_backend, - }, - "teacache": { + kwargs = dict( + revision=args.revision, + attention={"backend": args.attention_backend}, + teacache={ "enable_teacache": args.enable_teacache, "teacache_thresh": args.teacache_thresh, "use_ret_steps": args.use_ret_steps, }, - "parallel": { + parallel={ "dit_cfg_size": args.cfg_size, "dit_ulysses_size": args.ulysses_size, "enable_parallel_vae": not args.disable_parallel_vae, }, - "torch_compile": { + torch_compile={ "enable_torch_compile": not args.disable_torch_compile, "enable_fullgraph": args.enable_fullgraph, "enable_autotune": not args.disable_autotune, }, - "cuda_graph": { - "enable_cuda_graph": args.enable_cudagraph, - }, - "pipeline": { - "enable_layerwise_nvtx_marker": args.enable_layerwise_nvtx_marker, - }, - } - + cuda_graph={"enable_cuda_graph": args.enable_cudagraph}, + pipeline={"enable_layerwise_nvtx_marker": args.enable_layerwise_nvtx_marker}, + ) + quant_config = _linear_type_to_quant_config(args.linear_type) if quant_config is not None: - diffusion_config["quant_config"] = quant_config + kwargs["quant_config"] = quant_config + + diffusion_args = VisualGenArgs(**kwargs) logger.info( - f"Initializing VisualGen: world_size={n_workers} " - f"(cfg_size={args.cfg_size}, ulysses_size={args.ulysses_size})" + f"Initializing VisualGen: " + f"cfg_size={diffusion_args.parallel.dit_cfg_size}, " + f"ulysses_size={diffusion_args.parallel.dit_ulysses_size}" ) visual_gen = VisualGen( model_path=args.model_path, - n_workers=n_workers, - diffusion_config=diffusion_config, + diffusion_args=diffusion_args, ) try: @@ -253,7 +248,7 @@ def main(): logger.info(f"Generation completed in {time.time() - start_time:.2f}s") - OutputHandler.save(output, args.output_path, frame_rate=16.0) + MediaStorage.save_video(output.video, args.output_path, audio=output.audio, frame_rate=16.0) finally: visual_gen.shutdown() diff --git a/tensorrt_llm/__init__.py b/tensorrt_llm/__init__.py index 7f4a25dd838..f53b9c2f8f2 100644 --- a/tensorrt_llm/__init__.py +++ b/tensorrt_llm/__init__.py @@ -114,6 +114,7 @@ def _setup_vendored_triton_kernels(): from ._common import _init, default_net, default_trtnet, precision from ._mnnvl_utils import MnnvlMemory, MnnvlMoe, MoEAlltoallInfo +from ._torch.visual_gen.config import VisualGenArgs from ._utils import (default_gpus_per_node, local_mpi_rank, local_mpi_size, mpi_barrier, mpi_comm, mpi_rank, mpi_world_size, set_mpi_comm, str_dtype_to_torch, str_dtype_to_trt, @@ -121,7 +122,7 @@ def _setup_vendored_triton_kernels(): from .builder import BuildConfig, Builder, BuilderConfig, build from .disaggregated_params import DisaggregatedParams from .functional import Tensor, constant -from .llmapi import LLM, AsyncLLM, MultimodalEncoder +from .llmapi import LLM, AsyncLLM, MultimodalEncoder, VisualGen, VisualGenParams from .llmapi.llm_args import LlmArgs, TorchLlmArgs, TrtLlmArgs from .logger import logger from .mapping import Mapping @@ -179,9 +180,12 @@ def _setup_vendored_triton_kernels(): 'TorchLlmArgs', 'TrtLlmArgs', 'SamplingParams', + 'VisualGenArgs', 'DisaggregatedParams', 'KvCacheConfig', 'math_utils', + 'VisualGen', + 'VisualGenParams', '__version__', ] diff --git a/tensorrt_llm/_torch/visual_gen/__init__.py b/tensorrt_llm/_torch/visual_gen/__init__.py index b639783cfe9..5926a61681b 100644 --- a/tensorrt_llm/_torch/visual_gen/__init__.py +++ b/tensorrt_llm/_torch/visual_gen/__init__.py @@ -12,13 +12,13 @@ from .config import ( AttentionConfig, CudaGraphConfig, - DiffusionArgs, DiffusionModelConfig, ParallelConfig, PipelineComponent, PipelineConfig, TeaCacheConfig, TorchCompileConfig, + VisualGenArgs, discover_pipeline_components, ) from .models import AutoPipeline, BasePipeline, WanPipeline @@ -28,7 +28,7 @@ # Config classes "TorchCompileConfig", "CudaGraphConfig", - "DiffusionArgs", + "VisualGenArgs", "DiffusionModelConfig", "ParallelConfig", "PipelineComponent", diff --git a/tensorrt_llm/_torch/visual_gen/config.py b/tensorrt_llm/_torch/visual_gen/config.py index 3111957cb70..71621c6dfba 100644 --- a/tensorrt_llm/_torch/visual_gen/config.py +++ b/tensorrt_llm/_torch/visual_gen/config.py @@ -1,15 +1,16 @@ import json -import os from enum import Enum from pathlib import Path from types import SimpleNamespace -from typing import Any, Dict, List, Literal, Optional, Tuple +from typing import Any, Dict, List, Literal, Optional, Tuple, Union import torch +import yaml from pydantic import BaseModel, ConfigDict, model_validator from pydantic import Field as PydanticField from tensorrt_llm.functional import AllReduceStrategy +from tensorrt_llm.llmapi.utils import StrictBaseModel, set_api_status from tensorrt_llm.mapping import Mapping from tensorrt_llm.models.modeling_utils import QuantConfig from tensorrt_llm.quantization.mode import QuantAlgo @@ -38,11 +39,11 @@ class PipelineComponent(str, Enum): # ============================================================================= -# Sub-configuration classes for DiffusionArgs +# Sub-configuration classes for VisualGenArgs # ============================================================================= -class AttentionConfig(BaseModel): +class AttentionConfig(StrictBaseModel): """Configuration for Attention layers.""" backend: Literal["VANILLA", "TRTLLM"] = PydanticField( @@ -50,7 +51,7 @@ class AttentionConfig(BaseModel): ) -class ParallelConfig(BaseModel): +class ParallelConfig(StrictBaseModel): """Configuration for distributed parallelism. Currently Supported: @@ -122,27 +123,31 @@ def to_mapping(self) -> Mapping: cp_size=self.dit_cp_size, ) - @model_validator(mode="after") - def validate_parallel_sizes(self) -> "ParallelConfig": - """Validate configuration against current environment.""" - if torch.cuda.is_available(): - world_size = int(os.environ.get("WORLD_SIZE", 1)) - total_parallel = ( - self.dit_tp_size - * self.dit_ulysses_size - * self.dit_ring_size - * self.dit_cp_size - * self.dit_dp_size - * self.dit_cfg_size + @property + def total_parallel_size(self) -> int: + """Total parallelism across all DiT dimensions.""" + return ( + self.dit_tp_size + * self.dit_ulysses_size + * self.dit_ring_size + * self.dit_cp_size + * self.dit_dp_size + * self.dit_cfg_size + ) + + def validate_world_size(self, world_size: int) -> None: + """Validate that the parallel config is compatible with the given world size. + + Called at launch time when WORLD_SIZE is known (not at config construction). + """ + if self.total_parallel_size > world_size: + raise ValueError( + f"Total DiT parallel size ({self.total_parallel_size}) " + f"exceeds world_size ({world_size})" ) - if total_parallel > world_size: - raise ValueError( - f"Total DiT parallel size ({total_parallel}) exceeds WORLD_SIZE ({world_size})" - ) - return self -class TeaCacheConfig(BaseModel): +class TeaCacheConfig(StrictBaseModel): """Configuration for TeaCache runtime optimization. TeaCache speeds up diffusion by caching transformer outputs when timestep @@ -198,7 +203,7 @@ def validate_teacache(self) -> "TeaCacheConfig": return self -class TorchCompileConfig(BaseModel): +class TorchCompileConfig(StrictBaseModel): """Configuration for torch.compile and autotuning.""" enable_torch_compile: bool = True @@ -206,13 +211,13 @@ class TorchCompileConfig(BaseModel): enable_autotune: bool = True -class CudaGraphConfig(BaseModel): +class CudaGraphConfig(StrictBaseModel): """Configuration for CUDA graph capture/replay.""" enable_cuda_graph: bool = False -class PipelineConfig(BaseModel): +class PipelineConfig(StrictBaseModel): """Model-specific pipeline configuration.""" fuse_qkv: bool = True @@ -225,18 +230,18 @@ class PipelineConfig(BaseModel): # ============================================================================= -# DiffusionArgs - User-facing configuration (CLI / YAML) +# VisualGenArgs - User-facing configuration (CLI / YAML) # ============================================================================= -class DiffusionArgs(BaseModel): +class VisualGenArgs(StrictBaseModel): """User-facing configuration for diffusion model loading and inference. This is the main config class used in CLI args and YAML config files. PipelineLoader converts this to DiffusionModelConfig internally. Example: - args = DiffusionArgs( + args = VisualGenArgs( checkpoint_path="/path/to/model", quant_config={"quant_algo": "FP8_BLOCK_SCALES", "dynamic": True}, parallel=ParallelConfig(dit_tp_size=2), @@ -277,6 +282,9 @@ class DiffusionArgs(BaseModel): ), ) + # Skip warmup inference after loading (useful for testing or fast startup) + skip_warmup: bool = False + # Sub-configs (dict input for quant_config is coerced to QuantConfig in model_validator) quant_config: QuantConfig = PydanticField(default_factory=QuantConfig) torch_compile: TorchCompileConfig = PydanticField(default_factory=TorchCompileConfig) @@ -293,7 +301,7 @@ class DiffusionArgs(BaseModel): @model_validator(mode="before") @classmethod def _parse_quant_config_dict(cls, data: Any) -> Any: - """Parse user-facing DiffusionArgs.quant_config (dict or None) into QuantConfig and dynamic flags. + """Parse user-facing VisualGenArgs.quant_config (dict or None) into QuantConfig and dynamic flags. User input is ModelOpt-format dict (e.g. {"quant_algo": "FP8", "dynamic": True}). We coerce it to QuantConfig + dynamic_weight_quant + force_dynamic_quantization so that @@ -324,21 +332,31 @@ def to_dict(self) -> Dict[str, Any]: """Convert to dictionary.""" return self.model_dump() + @set_api_status("prototype") @classmethod - def from_dict(cls, config_dict: Dict[str, Any]) -> "DiffusionArgs": + def from_dict(cls, config_dict: Dict[str, Any]) -> "VisualGenArgs": """Create from dictionary with automatic nested config parsing. - Pydantic automatically handles nested configs, but we keep this method - for backward compatibility and to filter unknown fields. + Unknown fields cause a ValidationError (extra="forbid"). """ - # Get valid field names for DiffusionArgs - valid_fields = set(cls.model_fields.keys()) + return cls(**config_dict) - # Filter to only include valid fields (ignore unknown fields) - filtered_dict = {k: v for k, v in config_dict.items() if k in valid_fields} + @set_api_status("prototype") + @classmethod + def from_yaml(cls, yaml_path: Union[str, Path], **overrides: Any) -> "VisualGenArgs": + """Load configuration from a YAML file. - # Pydantic automatically converts nested dicts to their respective config classes - return cls(**filtered_dict) + Args: + yaml_path: Path to the YAML configuration file. + **overrides: Keyword arguments that override values from the YAML file. + + Returns: + A validated VisualGenArgs instance. + """ + with open(yaml_path, "r") as f: + config_dict = yaml.safe_load(f) or {} + config_dict.update(overrides) + return cls(**config_dict) # ============================================================================= @@ -378,11 +396,11 @@ def discover_pipeline_components(checkpoint_path: Path) -> Dict[str, Path]: class DiffusionModelConfig(BaseModel): """Internal ModelConfig for diffusion models. - This is created by PipelineLoader from DiffusionArgs + checkpoint. + This is created by PipelineLoader from VisualGenArgs + checkpoint. Contains merged/parsed config from: - pretrained_config: From checkpoint/config.json - quant_config: From checkpoint or user quant config - - Sub-configs: From DiffusionArgs (pipeline, attention, parallel, teacache) + - Sub-configs: From VisualGenArgs (pipeline, attention, parallel, teacache) """ model_config = ConfigDict(arbitrary_types_allowed=True) @@ -399,7 +417,7 @@ class DiffusionModelConfig(BaseModel): dynamic_weight_quant: bool = False - # Sub-configs from DiffusionArgs (merged during from_pretrained) + # Sub-configs from VisualGenArgs (merged during from_pretrained) quant_config: QuantConfig = PydanticField(default_factory=QuantConfig) # Per-layer quant (from load_diffusion_quant_config layer_quant_config; None until mixed-precision parsing exists) quant_config_dict: Optional[Dict[str, QuantConfig]] = None @@ -497,13 +515,13 @@ def load_diffusion_quant_config( def from_pretrained( cls, checkpoint_dir: str, - args: Optional["DiffusionArgs"] = None, + args: Optional["VisualGenArgs"] = None, **kwargs, ) -> "DiffusionModelConfig": """ Load config from pretrained checkpoint. - Called by PipelineLoader with DiffusionArgs: + Called by PipelineLoader with VisualGenArgs: config = DiffusionModelConfig.from_pretrained( checkpoint_dir=args.checkpoint_path, args=args, @@ -511,7 +529,7 @@ def from_pretrained( Args: checkpoint_dir: Path to checkpoint - args: DiffusionArgs containing user config + args: VisualGenArgs containing user config - (torch_compile, cuda_graph, pipeline, attention, parallel, teacache) **kwargs: Additional config options (e.g., mapping) """ @@ -566,7 +584,7 @@ def from_pretrained( if args and args.quant_config.quant_algo is not None: quant_config = args.quant_config quant_config_dict = ( - None # DiffusionArgs has no per-layer dict; only from checkpoint parse + None # VisualGenArgs has no per-layer dict; only from checkpoint parse ) dynamic_weight_quant = args.dynamic_weight_quant dynamic_activation_quant = args.force_dynamic_quantization @@ -587,7 +605,7 @@ def from_pretrained( quant_config_dict=quant_config_dict, dynamic_weight_quant=dynamic_weight_quant, force_dynamic_quantization=dynamic_activation_quant, - # Sub-configs from DiffusionArgs + # Sub-configs from VisualGenArgs torch_compile=torch_compile_cfg, cuda_graph=cuda_graph_cfg, pipeline=pipeline_cfg, diff --git a/tensorrt_llm/_torch/visual_gen/executor.py b/tensorrt_llm/_torch/visual_gen/executor.py index 98e00b4746a..1414443a75d 100644 --- a/tensorrt_llm/_torch/visual_gen/executor.py +++ b/tensorrt_llm/_torch/visual_gen/executor.py @@ -9,7 +9,7 @@ import torch.distributed as dist import zmq -from tensorrt_llm._torch.visual_gen.config import DiffusionArgs +from tensorrt_llm._torch.visual_gen.config import VisualGenArgs from tensorrt_llm._torch.visual_gen.output import MediaOutput from tensorrt_llm._torch.visual_gen.pipeline_loader import PipelineLoader from tensorrt_llm.executor.ipc import ZeroMqQueue @@ -68,17 +68,15 @@ class DiffusionExecutor: def __init__( self, - model_path: str, request_queue_addr: str, response_queue_addr: str, device_id: int, - diffusion_config: Optional[dict] = None, + diffusion_args: "VisualGenArgs", ): - self.model_path = model_path self.request_queue_addr = request_queue_addr self.response_queue_addr = response_queue_addr self.device_id = device_id - self.diffusion_config = diffusion_config + self.diffusion_args = diffusion_args self.pipeline = None # initialized in _load_pipeline self.requests_ipc = None @@ -126,27 +124,15 @@ def _sender_loop(self): def _load_pipeline(self): """ Load pipeline using proper flow: - DiffusionArgs → PipelineLoader → DiffusionModelConfig → AutoPipeline → BasePipeline + VisualGenArgs → PipelineLoader → DiffusionModelConfig → AutoPipeline → BasePipeline """ logger.info(f"Worker {self.device_id}: Loading pipeline") try: - # Convert diffusion_config dict to DiffusionArgs - config_dict = self.diffusion_config.copy() - config_dict["checkpoint_path"] = self.model_path - config_dict["device"] = f"cuda:{self.device_id}" - - # Create DiffusionArgs from dict (handles nested configs) - args = DiffusionArgs.from_dict(config_dict) - - # Use PipelineLoader for proper pipeline creation flow: - # PipelineLoader.load() internally: - # 1. Creates DiffusionModelConfig.from_pretrained() - # 2. Creates pipeline via AutoPipeline.from_config() - # 3. Loads weights with quantization support - # 4. Calls post_load_weights() + args = self.diffusion_args.model_copy(update={"device": f"cuda:{self.device_id}"}) + loader = PipelineLoader(args) - self.pipeline = loader.load() + self.pipeline = loader.load(skip_warmup=args.skip_warmup) except Exception as e: logger.error(f"Worker {self.device_id}: Failed to load pipeline: {e}") @@ -215,13 +201,16 @@ def run_diffusion_worker( world_size: int, master_addr: str, master_port: int, - model_path: str, request_queue_addr: str, response_queue_addr: str, - diffusion_config: Optional[dict] = None, + diffusion_args: "VisualGenArgs", + log_level: str = "info", ): """Entry point for worker process.""" try: + # Set log level before any other work so loading logs are visible + logger.set_level(log_level) + # Setup distributed env — use PyTorch distributed, not MPI os.environ["TLLM_DISABLE_MPI"] = "1" os.environ["MASTER_ADDR"] = master_addr @@ -229,6 +218,9 @@ def run_diffusion_worker( os.environ["RANK"] = str(rank) os.environ["WORLD_SIZE"] = str(world_size) + # Runtime check: parallel config vs actual world size + diffusion_args.parallel.validate_world_size(world_size) + # Calculate device_id before init_process_group device_id = rank % torch.cuda.device_count() if torch.cuda.is_available() else 0 if torch.cuda.is_available(): @@ -243,11 +235,10 @@ def run_diffusion_worker( ) executor = DiffusionExecutor( - model_path=model_path, request_queue_addr=request_queue_addr, response_queue_addr=response_queue_addr, device_id=device_id, - diffusion_config=diffusion_config, + diffusion_args=diffusion_args, ) executor.serve_forever() if executor.pipeline is not None: diff --git a/tensorrt_llm/_torch/visual_gen/pipeline_loader.py b/tensorrt_llm/_torch/visual_gen/pipeline_loader.py index e5bbc26f5b4..6ac5df9d22f 100644 --- a/tensorrt_llm/_torch/visual_gen/pipeline_loader.py +++ b/tensorrt_llm/_torch/visual_gen/pipeline_loader.py @@ -15,6 +15,7 @@ """ import os +import time from typing import TYPE_CHECKING, Optional import torch @@ -26,7 +27,7 @@ from tensorrt_llm.mapping import Mapping from .checkpoints import WeightLoader -from .config import DiffusionArgs, DiffusionModelConfig, PipelineComponent +from .config import DiffusionModelConfig, PipelineComponent, VisualGenArgs from .models import AutoPipeline if TYPE_CHECKING: @@ -42,7 +43,7 @@ class PipelineLoader: on-the-fly during loading. Example: - args = DiffusionArgs( + args = VisualGenArgs( checkpoint_path="/path/to/model", linear=LinearConfig(type="trtllm-fp8-blockwise"), parallel=ParallelConfig(dit_tp_size=2), @@ -52,7 +53,7 @@ class PipelineLoader: def __init__( self, - args: Optional[DiffusionArgs] = None, + args: Optional[VisualGenArgs] = None, *, mapping: Optional[Mapping] = None, device: str = "cuda", @@ -61,7 +62,7 @@ def __init__( Initialize model loader. Args: - args: DiffusionArgs containing all configuration (preferred) + args: VisualGenArgs containing all configuration (preferred) mapping: Tensor parallel mapping (fallback if args is None) device: Device to load model on (fallback if args is None) """ @@ -134,15 +135,17 @@ def load( # Resolve checkpoint_dir checkpoint_dir = checkpoint_dir or (self.args.checkpoint_path if self.args else None) if not checkpoint_dir: - raise ValueError("checkpoint_dir must be provided or set in DiffusionArgs") + raise ValueError("checkpoint_dir must be provided or set in VisualGenArgs") checkpoint_dir = self._resolve_checkpoint_dir(str(checkpoint_dir)) # Get loading options from args skip_components = self.args.skip_components if self.args else [] + load_start = time.time() + # ===================================================================== # STEP 1: Load Config (includes quant config parsing) - # Merge pretrained checkpoint config with user-provided DiffusionArgs + # Merge pretrained checkpoint config with user-provided VisualGenArgs # ===================================================================== logger.info(f"Loading config from {checkpoint_dir}") config = DiffusionModelConfig.from_pretrained( @@ -202,13 +205,16 @@ def load( # These are NOT quantized - loaded as-is from checkpoint # ===================================================================== pipeline.load_standard_components(checkpoint_dir, self.device, skip_components) - - if config.parallel.enable_parallel_vae: - pipeline.setup_parallel_vae() + logger.info("Model loaded successfully in {time.time() - load_start:.2f}s") # ===================================================================== # STEP 5: Post-load Hooks (TeaCache setup, etc.) # ===================================================================== + + t0 = time.time() + if config.parallel.enable_parallel_vae: + pipeline.setup_parallel_vae() + if hasattr(pipeline, "post_load_weights"): pipeline.post_load_weights() @@ -227,6 +233,9 @@ def load( pipeline.warmup() else: pipeline.warmup() + logger.info(f"Warmup completed in {time.time() - t0:.2f}s") + else: + logger.info("Warmup skipped (skip_warmup=True)") if config.pipeline.enable_layerwise_nvtx_marker: from tensorrt_llm._torch.pyexecutor.layerwise_nvtx_marker import LayerwiseNvtxMarker @@ -237,7 +246,10 @@ def load( logger.info(f"Registering layerwise NVTX markers for {transformer_component}") marker.register_hooks(getattr(pipeline, transformer_component), module_prefix) - logger.info(f"Pipeline loaded: {pipeline.__class__.__name__}") + logger.info( + f"Pipeline loaded: {pipeline.__class__.__name__} " + f"(total load time: {time.time() - load_start:.2f}s)" + ) return pipeline def _materialize_meta_tensors(self, module: torch.nn.Module) -> None: diff --git a/tensorrt_llm/_torch/visual_gen/pipeline_registry.py b/tensorrt_llm/_torch/visual_gen/pipeline_registry.py index 25ee0701091..fd8abb9ee3e 100644 --- a/tensorrt_llm/_torch/visual_gen/pipeline_registry.py +++ b/tensorrt_llm/_torch/visual_gen/pipeline_registry.py @@ -1,6 +1,6 @@ """Pipeline registry for unified config flow. -Follows: DiffusionArgs → PipelineLoader → DiffusionModelConfig → AutoPipeline → BasePipeline +Follows: VisualGenArgs → PipelineLoader → DiffusionModelConfig → AutoPipeline → BasePipeline All pipelines (Wan, Flux, Flux2, LTX2) register via @register_pipeline decorator. """ diff --git a/tensorrt_llm/bench/benchmark/visual_gen.py b/tensorrt_llm/bench/benchmark/visual_gen.py index 212b1cb31e0..45dcea2e1de 100644 --- a/tensorrt_llm/bench/benchmark/visual_gen.py +++ b/tensorrt_llm/bench/benchmark/visual_gen.py @@ -196,7 +196,8 @@ def visual_gen_command( """Benchmark VisualGen (image/video generation) models offline.""" import yaml - from tensorrt_llm.commands.utils import get_visual_gen_model_type, get_visual_gen_num_gpus + from tensorrt_llm._torch.visual_gen.config import VisualGenArgs + from tensorrt_llm.commands.utils import get_visual_gen_num_gpus from tensorrt_llm.llmapi.visual_gen import VisualGen, VisualGenParams if prompt is None and prompt_file is None: @@ -207,18 +208,16 @@ def visual_gen_command( model = bench_env.model model_path = str(bench_env.checkpoint_path or model) - # Build diffusion config (same pattern as trtllm-serve _serve_visual_gen) - visual_gen_config: dict = { - "model": model_path, - "model_type": get_visual_gen_model_type(model_path), - } + # Build VisualGenArgs (same pattern as trtllm-serve _serve_visual_gen) + extra_args: dict = {} if extra_visual_gen_options is not None: with open(extra_visual_gen_options, "r") as f: - visual_gen_extra_args = yaml.safe_load(f) or {} - visual_gen_config.update(visual_gen_extra_args) + extra_args = yaml.safe_load(f) or {} + + diffusion_args = VisualGenArgs(**extra_args) if extra_args else None - n_workers = get_visual_gen_num_gpus(visual_gen_config) - parallel_config = visual_gen_config.get("parallel", {}) + n_workers = get_visual_gen_num_gpus(extra_args) + parallel_config = extra_args.get("parallel", {}) if parallel_config: logger.info(f"World size: {n_workers}") logger.info(f"CFG size: {parallel_config.get('dit_cfg_size', 1)}") @@ -265,8 +264,7 @@ def visual_gen_command( logger.info(f"Initializing VisualGen ({model_path})") visual_gen = VisualGen( model_path=model_path, - n_workers=n_workers, - diffusion_config=visual_gen_config, + diffusion_args=diffusion_args, ) try: diff --git a/tensorrt_llm/commands/serve.py b/tensorrt_llm/commands/serve.py index eb20f513cd5..19747fef2a7 100644 --- a/tensorrt_llm/commands/serve.py +++ b/tensorrt_llm/commands/serve.py @@ -18,10 +18,9 @@ from tensorrt_llm import LLM as PyTorchLLM from tensorrt_llm import MultimodalEncoder from tensorrt_llm._tensorrt_engine import LLM +from tensorrt_llm._torch.visual_gen.config import VisualGenArgs from tensorrt_llm._utils import mpi_rank -from tensorrt_llm.commands.utils import (get_is_diffusion_model, - get_visual_gen_model_type, - get_visual_gen_num_gpus) +from tensorrt_llm.commands.utils import get_is_diffusion_model from tensorrt_llm.executor.utils import LlmLauncherEnvs from tensorrt_llm.inputs.multimodal import MultimodalServerConfig from tensorrt_llm.llmapi import (BuildConfig, CapacitySchedulerPolicy, @@ -452,7 +451,8 @@ def launch_mm_encoder_server( def launch_visual_gen_server( host: str, port: int, - visual_gen_config: dict, + model: str, + diffusion_args: Optional[VisualGenArgs] = None, metadata_server_cfg: Optional[MetadataServerConfig] = None, ): """Launch a VISUAL_GEN model server for image/video generation. @@ -460,23 +460,22 @@ def launch_visual_gen_server( Args: host: Server hostname. port: Server port. - visual_gen_config: Arguments for VISUAL_GEN model initialization. + model: Model path or HuggingFace Hub model ID. + diffusion_args: Optional validated VisualGenArgs for model configuration. metadata_server_cfg: Optional metadata server configuration. """ - model = visual_gen_config["model"] logger.info(f"Initializing VisualGen ({model})") - n_workers = get_visual_gen_num_gpus(visual_gen_config) - parallel_config = visual_gen_config.get("parallel", {}) - if parallel_config: - logger.info(f"World size: {n_workers}") - logger.info(f"CFG size: {parallel_config.get('dit_cfg_size', 1)}") - logger.info( - f"Ulysses size: {parallel_config.get('dit_ulysses_size', 1)}") - visual_gen_model = VisualGen(model_path=model, - n_workers=n_workers, - diffusion_config=visual_gen_config) + diffusion_args=diffusion_args) + + n_workers = visual_gen_model.diffusion_args.parallel.n_workers + logger.info(f"World size: {n_workers}") + logger.info( + f"CFG size: {visual_gen_model.diffusion_args.parallel.dit_cfg_size}") + logger.info( + f"Ulysses size: {visual_gen_model.diffusion_args.parallel.dit_ulysses_size}" + ) server = OpenAIServer(generator=visual_gen_model, model=model, @@ -873,22 +872,17 @@ def _serve_llm(): served_model_name=served_model_name) def _serve_visual_gen(): - visual_gen_config = { - "model": model, - "model_type": get_visual_gen_model_type(model), - } - - visual_gen_extra_args = {} + extra_args = {} if extra_visual_gen_options is not None: with open(extra_visual_gen_options, 'r') as f: - visual_gen_extra_args = yaml.safe_load(f) + extra_args = yaml.safe_load(f) or {} - visual_gen_config.update(visual_gen_extra_args) + diffusion_args = VisualGenArgs(**extra_args) if extra_args else None metadata_server_cfg = parse_metadata_server_config_file( metadata_server_config_file) - launch_visual_gen_server(host, port, visual_gen_config, + launch_visual_gen_server(host, port, model, diffusion_args, metadata_server_cfg) if get_is_diffusion_model(model): diff --git a/tensorrt_llm/llmapi/visual_gen.py b/tensorrt_llm/llmapi/visual_gen.py index 8e742911cee..2f17e6db17d 100644 --- a/tensorrt_llm/llmapi/visual_gen.py +++ b/tensorrt_llm/llmapi/visual_gen.py @@ -1,23 +1,27 @@ import asyncio +import atexit import queue import socket import threading import time import traceback +import weakref from dataclasses import dataclass from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Literal, Optional, Union import torch.multiprocessing as mp import zmq from tensorrt_llm._torch.visual_gen import DiffusionRequest, DiffusionResponse +from tensorrt_llm._torch.visual_gen.config import VisualGenArgs from tensorrt_llm._torch.visual_gen.executor import run_diffusion_worker from tensorrt_llm._torch.visual_gen.output import MediaOutput __all__ = ["VisualGen", "VisualGenParams", "MediaOutput"] from tensorrt_llm.executor.ipc import ZeroMqQueue from tensorrt_llm.inputs.data import VisualGenInputs +from tensorrt_llm.llmapi.utils import set_api_status from tensorrt_llm.logger import logger # Timeouts (seconds) @@ -50,13 +54,10 @@ class DiffusionRemoteClient: def __init__( self, - model_path: Union[str, Path], - n_workers: int = 1, - diffusion_config: Optional[dict] = None, + diffusion_args: VisualGenArgs, ): - self.model_path = str(model_path) - self.n_workers = n_workers - self.diffusion_config = diffusion_config + self.diffusion_args = diffusion_args + self.n_workers = diffusion_args.parallel.n_workers # Setup distributed env self.master_addr = "127.0.0.1" @@ -91,7 +92,8 @@ def __init__( # Wait for the background thread to initialize the event loop self.event_loop_ready.wait() - # Launch workers + # Launch workers (VisualGenArgs is pickled via mp.Process spawn context) + n_workers = self.n_workers logger.info(f"DiffusionClient: Launching {n_workers} workers") ctx = mp.get_context("spawn") self.worker_processes = [] @@ -103,10 +105,10 @@ def __init__( "world_size": n_workers, "master_addr": self.master_addr, "master_port": self.master_port, - "model_path": self.model_path, "request_queue_addr": self.req_addr_connect, "response_queue_addr": self.resp_addr_connect, - "diffusion_config": self.diffusion_config, + "diffusion_args": self.diffusion_args, + "log_level": logger.level, }, ) p.start() @@ -404,6 +406,7 @@ def cancel(self): @dataclass +@set_api_status("prototype") class VisualGenParams: """Parameters for visual generation. @@ -442,7 +445,7 @@ class VisualGenParams: # Image-specific parameters num_images_per_prompt: int = 1 - ## Image edit parameters + # Image edit parameters image: Optional[List[str]] = None mask: Optional[str] = None @@ -459,23 +462,25 @@ class VisualGenParams: class VisualGen: """High-level API for visual generation.""" + @set_api_status("prototype") def __init__( self, model_path: Union[str, Path], - n_workers: int = 1, - diffusion_config: Optional[dict] = None, + diffusion_args: Optional[VisualGenArgs] = None, ): self.model_path = str(model_path) - self.n_workers = n_workers - self.diffusion_config = diffusion_config + self.diffusion_args = (diffusion_args or VisualGenArgs()).model_copy( + update={"checkpoint_path": self.model_path} + ) self.executor = DiffusionRemoteClient( - model_path=self.model_path, - n_workers=self.n_workers, - diffusion_config=self.diffusion_config, + diffusion_args=self.diffusion_args, ) self.req_counter = 0 + atexit.register(VisualGen._atexit_shutdown, weakref.ref(self)) + + @set_api_status("prototype") def generate( self, inputs: VisualGenInputs, @@ -503,6 +508,7 @@ def generate( raise RuntimeError(f"Generation failed: {response.error_msg}") return response.output + @set_api_status("prototype") def generate_async( self, inputs: VisualGenInputs, @@ -553,7 +559,28 @@ def generate_async( self.executor.enqueue_requests([request]) return DiffusionGenerationResult(req_id, self.executor) + @staticmethod + def _atexit_shutdown(self_ref): + instance = self_ref() + if instance is not None: + instance.shutdown() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback) -> Literal[False]: + del exc_value, traceback + self.shutdown() + return False + + def __del__(self): + self.shutdown() + + @set_api_status("prototype") def shutdown(self): """Shutdown executor and cleanup.""" + if not hasattr(self, "executor") or self.executor is None: + return logger.info("VisualGen: Shutting down") self.executor.shutdown() + self.executor = None diff --git a/tensorrt_llm/serve/media_storage.py b/tensorrt_llm/serve/media_storage.py index 143b5206b1c..ae934dfa7f5 100644 --- a/tensorrt_llm/serve/media_storage.py +++ b/tensorrt_llm/serve/media_storage.py @@ -502,15 +502,8 @@ def get_video_encoder() -> Optional["VideoEncoder"]: """ global _VIDEO_ENCODER if _VIDEO_ENCODER is None: - if _check_ffmpeg_available(): - logger.info("Using ffmpeg CLI for video encoding") - _VIDEO_ENCODER = FfmpegCliEncoder() - else: - logger.warning( - "FFmpeg is unavailable so no MP4 generation support." - "Using pure Python MJPEG/AVI encoder (no audio support)" - ) - _VIDEO_ENCODER = PurePythonEncoder() + _VIDEO_ENCODER = FfmpegCliEncoder() if _check_ffmpeg_available() else PurePythonEncoder() + logger.info(f"Using {_VIDEO_ENCODER.__class__.__name__} for video encoding") return _VIDEO_ENCODER diff --git a/tests/integration/defs/examples/test_visual_gen.py b/tests/integration/defs/examples/test_visual_gen.py index 65bdb2bedff..dd9f7eb6332 100644 --- a/tests/integration/defs/examples/test_visual_gen.py +++ b/tests/integration/defs/examples/test_visual_gen.py @@ -375,17 +375,23 @@ def test_vbench_dimension_score_wan22_a14b_nvfp4( ) -def test_visual_gen_benchmark_serving(llm_venv): - """Run benchmark_visual_gen.py against a live trtllm-serve visual-gen server.""" - test_root = conftest.unittest_path() / "_torch" / "visual_gen" - llm_venv.run_cmd( - [ - "-m", - "pytest", - "-v", - str( - test_root / "_test_trtllm_serve_visual_gen_benchmark.py" - "::test_visual_gen_benchmark_video[openai-videos]" - ), - ] - ) +def test_visual_gen_quickstart(_visual_gen_deps, llm_root, llm_venv): + """Run examples/visual_gen/quickstart_example.py end-to-end.""" + scratch_space = conftest.llm_models_root() + model_src = os.path.join(scratch_space, WAN_T2V_MODEL_SUBPATH) + if not os.path.isdir(model_src): + pytest.skip( + f"Model not found: {model_src} " + f"(set LLM_MODELS_ROOT or place {WAN_T2V_MODEL_SUBPATH} under scratch)" + ) + + model_dst = os.path.join(llm_venv.get_working_directory(), "Wan-AI", WAN_T2V_MODEL_SUBPATH) + if not os.path.islink(model_dst): + os.makedirs(os.path.dirname(model_dst), exist_ok=True) + os.symlink(model_src, model_dst, target_is_directory=True) + + script_path = os.path.join(llm_root, "examples", "visual_gen", "quickstart_example.py") + venv_check_call(llm_venv, [script_path]) + + output_path = os.path.join(llm_venv.get_working_directory(), "output.avi") + assert os.path.isfile(output_path), f"Quickstart did not produce output.avi at {output_path}" diff --git a/tests/integration/defs/visual_gen/test_visual_gen_benchmark.py b/tests/integration/defs/visual_gen/test_visual_gen_benchmark.py index 19cae81bd24..d10d3fdc47d 100644 --- a/tests/integration/defs/visual_gen/test_visual_gen_benchmark.py +++ b/tests/integration/defs/visual_gen/test_visual_gen_benchmark.py @@ -67,7 +67,6 @@ def _wan_t2v_path() -> Path: def _make_visual_gen_options(**extra) -> dict: """Build a minimal VisualGen YAML config dict.""" config = { - "linear": {"type": "default"}, "parallel": {"dit_cfg_size": 1, "dit_ulysses_size": 1}, } config.update(extra) diff --git a/tests/integration/test_lists/test-db/l0_a10.yml b/tests/integration/test_lists/test-db/l0_a10.yml index 50586d34167..6a7db3cba11 100644 --- a/tests/integration/test_lists/test-db/l0_a10.yml +++ b/tests/integration/test_lists/test-db/l0_a10.yml @@ -174,6 +174,7 @@ l0_a10: - unittest/trt/quantization - unittest/trt/functional # 37 mins - llmapi/test_llm_examples.py::test_llmapi_quickstart_atexit + - examples/test_visual_gen.py::test_visual_gen_quickstart - unittest/api_stability - unittest/bindings - unittest/test_model_runner_cpp.py diff --git a/tests/integration/test_lists/test-db/l0_b200.yml b/tests/integration/test_lists/test-db/l0_b200.yml index 9542bf835ec..fbdb2a56797 100644 --- a/tests/integration/test_lists/test-db/l0_b200.yml +++ b/tests/integration/test_lists/test-db/l0_b200.yml @@ -122,6 +122,8 @@ l0_b200: - unittest/_torch/modeling/test_modeling_exaone4.py::TestEXAONE4::test_llm_load_1_FP8 - unittest/kv_cache_manager_v2_tests/ # ------------- Visual Gen tests --------------- + - unittest/_torch/visual_gen/test_visual_gen_args.py + - unittest/_torch/visual_gen/test_teacache.py - unittest/_torch/visual_gen/test_fused_qkv.py - unittest/_torch/visual_gen/test_quant_ops.py - unittest/_torch/visual_gen/test_attention_integration.py diff --git a/tests/integration/test_lists/test-db/l0_dgx_b200.yml b/tests/integration/test_lists/test-db/l0_dgx_b200.yml index e92141a04cf..f2881dd66fc 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_b200.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_b200.yml @@ -263,9 +263,11 @@ l0_dgx_b200: - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=True-attn_backend=TRTLLM-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTEDSL-mtp_nextn=2-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2[torch_compile=False-enable_gemm_allreduce_fusion=False] + - examples/test_visual_gen.py::test_visual_gen_quickstart - examples/test_visual_gen.py::test_vbench_dimension_score_wan - examples/test_visual_gen.py::test_vbench_dimension_score_wan22_a14b_fp8 - examples/test_visual_gen.py::test_vbench_dimension_score_wan22_a14b_nvfp4 + - visual_gen/test_visual_gen_benchmark.py # ------------- AutoDeploy Backend Stages --------------- - condition: ranges: diff --git a/tests/integration/test_lists/test-db/l0_gb203.yml b/tests/integration/test_lists/test-db/l0_gb203.yml index bd2f60eca6b..692da412689 100644 --- a/tests/integration/test_lists/test-db/l0_gb203.yml +++ b/tests/integration/test_lists/test-db/l0_gb203.yml @@ -30,6 +30,7 @@ l0_gb203: # - examples/test_qwen.py::test_llm_qwen1_5_7b_single_gpu_lora[qwen1.5_7b_chat-Qwen1.5-7B-Chat-750Mb-lora] # https://nvbugs/5234573 # - examples/test_qwen.py::test_llm_qwen_single_gpu_summary[qwen2.5_1.5b_instruct-enable_paged_kv_cache-enable_remove_input_padding-enable_weight_only-enable_fmha_fp32_acc] # https://nvbugs/5234573 - llmapi/test_llm_examples.py::test_llmapi_quickstart + - examples/test_visual_gen.py::test_visual_gen_quickstart - llmapi/test_llm_examples.py::test_llmapi_example_inference - llmapi/test_llm_examples.py::test_llmapi_example_inference_async - llmapi/test_llm_examples.py::test_llmapi_example_inference_async_streaming diff --git a/tests/integration/test_lists/test-db/l0_gh200.yml b/tests/integration/test_lists/test-db/l0_gh200.yml index 52a46a07154..028b9d04128 100644 --- a/tests/integration/test_lists/test-db/l0_gh200.yml +++ b/tests/integration/test_lists/test-db/l0_gh200.yml @@ -23,6 +23,7 @@ l0_gh200: - unittest/bindings - unittest/llmapi/test_llm_quant.py - llmapi/test_llm_examples.py::test_llmapi_quickstart_atexit + - examples/test_visual_gen.py::test_visual_gen_quickstart - unittest/test_model_runner_cpp.py - accuracy/test_cli_flow.py::TestGptNext::test_auto_dtype - accuracy/test_cli_flow.py::TestSantacoder::test_auto_dtype diff --git a/tests/integration/test_lists/test-db/l0_h100.yml b/tests/integration/test_lists/test-db/l0_h100.yml index bc3665dd633..ac423a410c7 100644 --- a/tests/integration/test_lists/test-db/l0_h100.yml +++ b/tests/integration/test_lists/test-db/l0_h100.yml @@ -240,6 +240,7 @@ l0_h100: - unittest/llmapi/test_llm_quant.py # 5.5 mins on H100 - test_e2e.py::test_mistral_large_hidden_vocab_size - llmapi/test_llm_examples.py::test_llmapi_quickstart_atexit + - examples/test_visual_gen.py::test_visual_gen_quickstart - unittest/trt/attention/test_gpt_attention_IFB.py - accuracy/test_cli_flow.py::TestLlama3_1_8BInstruct::test_fp8_prequantized - accuracy/test_cli_flow.py::TestLlama2_7B::test_fp8 diff --git a/tests/integration/test_lists/test-db/l0_l40s.yml b/tests/integration/test_lists/test-db/l0_l40s.yml index 76c6d6e360c..a80e53c000b 100644 --- a/tests/integration/test_lists/test-db/l0_l40s.yml +++ b/tests/integration/test_lists/test-db/l0_l40s.yml @@ -65,6 +65,7 @@ l0_l40s: - examples/test_nemotron_nas.py::test_nemotron_nas_summary_1gpu[DeciLM-7B] - examples/test_gpt.py::test_llm_gpt_starcoder_lora_1gpu[peft-lora-starcoder2-15b-unity-copilot-starcoder2-lora_fp16-base_fp16] - llmapi/test_llm_examples.py::test_llmapi_quickstart + - examples/test_visual_gen.py::test_visual_gen_quickstart - llmapi/test_llm_examples.py::test_llmapi_example_inference - llmapi/test_llm_examples.py::test_llmapi_example_inference_async - llmapi/test_llm_examples.py::test_llmapi_example_inference_async_streaming diff --git a/tests/integration/test_lists/test-db/l0_sanity_check.yml b/tests/integration/test_lists/test-db/l0_sanity_check.yml index 21aafd1e97f..6c75eeb7be5 100644 --- a/tests/integration/test_lists/test-db/l0_sanity_check.yml +++ b/tests/integration/test_lists/test-db/l0_sanity_check.yml @@ -19,6 +19,7 @@ l0_sanity_check: linux_distribution_name: ubuntu* tests: - llmapi/test_llm_examples.py::test_llmapi_quickstart + - examples/test_visual_gen.py::test_visual_gen_quickstart - llmapi/test_llm_examples.py::test_llmapi_example_inference - llmapi/test_llm_examples.py::test_llmapi_example_inference_async - llmapi/test_llm_examples.py::test_llmapi_example_inference_async_streaming diff --git a/tests/unittest/_torch/visual_gen/test_flux_pipeline.py b/tests/unittest/_torch/visual_gen/test_flux_pipeline.py index 22cc6776d5d..6b4f36ea1d7 100644 --- a/tests/unittest/_torch/visual_gen/test_flux_pipeline.py +++ b/tests/unittest/_torch/visual_gen/test_flux_pipeline.py @@ -25,7 +25,7 @@ import torch.nn.functional as F from tensorrt_llm._torch.modules.linear import Linear -from tensorrt_llm._torch.visual_gen.config import AttentionConfig, DiffusionArgs, PipelineConfig +from tensorrt_llm._torch.visual_gen.config import AttentionConfig, PipelineConfig, VisualGenArgs from tensorrt_llm._torch.visual_gen.pipeline_loader import PipelineLoader @@ -155,7 +155,7 @@ class TestFluxPipelineLoading: @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_load_flux1_pipeline_basic(self, flux1_checkpoint_exists): """Test loading FLUX.1 pipeline.""" - args = DiffusionArgs( + args = VisualGenArgs( checkpoint_path=FLUX1_CHECKPOINT_PATH, device="cuda", dtype="bfloat16", @@ -176,7 +176,7 @@ def test_load_flux1_pipeline_basic(self, flux1_checkpoint_exists): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_load_flux2_pipeline_basic(self, flux2_checkpoint_exists): """Test loading FLUX.2 pipeline.""" - args = DiffusionArgs( + args = VisualGenArgs( checkpoint_path=FLUX2_CHECKPOINT_PATH, device="cuda", dtype="bfloat16", @@ -197,7 +197,7 @@ def test_load_flux2_pipeline_basic(self, flux2_checkpoint_exists): @pytest.mark.parametrize("backend", ["VANILLA", "TRTLLM"]) def test_load_flux1_with_attention_backend(self, flux1_checkpoint_exists, backend: str): """Test loading FLUX.1 with different attention backends.""" - args = DiffusionArgs( + args = VisualGenArgs( checkpoint_path=FLUX1_CHECKPOINT_PATH, device="cuda", dtype="bfloat16", @@ -226,7 +226,7 @@ class TestFluxQuantization: @pytest.mark.parametrize("quant_algo", ["FP8", "FP8_BLOCK_SCALES"]) def test_load_flux1_with_quantization(self, flux1_checkpoint_exists, quant_algo: str): """Test loading FLUX.1 with FP8 quantization and verify FP8 weights.""" - args = DiffusionArgs( + args = VisualGenArgs( checkpoint_path=FLUX1_CHECKPOINT_PATH, device="cuda", dtype="bfloat16", @@ -270,7 +270,7 @@ def test_load_flux1_with_quantization(self, flux1_checkpoint_exists, quant_algo: @pytest.mark.parametrize("quant_algo", ["FP8", "FP8_BLOCK_SCALES"]) def test_load_flux2_with_quantization(self, flux2_checkpoint_exists, quant_algo: str): """Test loading FLUX.2 with FP8 quantization and verify FP8 weights.""" - args = DiffusionArgs( + args = VisualGenArgs( checkpoint_path=FLUX2_CHECKPOINT_PATH, device="cuda", dtype="bfloat16", @@ -331,7 +331,7 @@ def test_fp8_vs_bf16_single_layer(self, flux1_checkpoint_exists, quant_algo: str """ # Load BF16 pipeline (reference) print(f"\n[Compare {quant_algo}] Loading BF16 pipeline...") - args_bf16 = DiffusionArgs( + args_bf16 = VisualGenArgs( checkpoint_path=FLUX1_CHECKPOINT_PATH, device="cuda", dtype="bfloat16", @@ -341,7 +341,7 @@ def test_fp8_vs_bf16_single_layer(self, flux1_checkpoint_exists, quant_algo: str # Load FP8 pipeline print(f"[Compare {quant_algo}] Loading {quant_algo} pipeline...") - args_fp8 = DiffusionArgs( + args_fp8 = VisualGenArgs( checkpoint_path=FLUX1_CHECKPOINT_PATH, device="cuda", dtype="bfloat16", @@ -413,7 +413,7 @@ def test_fp8_vs_bf16_full_transformer_e2e(self, flux1_checkpoint_exists, quant_a """ # Load BF16 transformer (reference) print("\n[E2E] Loading BF16 transformer...") - args_bf16 = DiffusionArgs( + args_bf16 = VisualGenArgs( checkpoint_path=FLUX1_CHECKPOINT_PATH, device="cuda", dtype="bfloat16", @@ -424,7 +424,7 @@ def test_fp8_vs_bf16_full_transformer_e2e(self, flux1_checkpoint_exists, quant_a # Load FP8 transformer print(f"[E2E] Loading {quant_algo} transformer...") - args_fp8 = DiffusionArgs( + args_fp8 = VisualGenArgs( checkpoint_path=FLUX1_CHECKPOINT_PATH, device="cuda", dtype="bfloat16", @@ -528,7 +528,7 @@ def get_module_memory_gb(module): torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() - args_bf16 = DiffusionArgs( + args_bf16 = VisualGenArgs( checkpoint_path=FLUX1_CHECKPOINT_PATH, device="cuda", dtype="bfloat16", @@ -548,7 +548,7 @@ def get_module_memory_gb(module): # Load FP8 torch.cuda.reset_peak_memory_stats() - args_fp8 = DiffusionArgs( + args_fp8 = VisualGenArgs( checkpoint_path=FLUX1_CHECKPOINT_PATH, device="cuda", dtype="bfloat16", @@ -596,7 +596,7 @@ def test_attention_backend_comparison(self, flux1_checkpoint_exists): # Run VANILLA first, save output, then free before loading TRTLLM # (two full transformers don't fit in GPU memory simultaneously) print("\n[Attention Backend Test] Loading baseline transformer (VANILLA)...") - args_baseline = DiffusionArgs( + args_baseline = VisualGenArgs( checkpoint_path=FLUX1_CHECKPOINT_PATH, device="cuda", dtype="bfloat16", @@ -619,7 +619,7 @@ def test_attention_backend_comparison(self, flux1_checkpoint_exists): # Load and run TRTLLM backend print("[Attention Backend Test] Loading TRTLLM transformer...") - args_trtllm = DiffusionArgs( + args_trtllm = VisualGenArgs( checkpoint_path=FLUX1_CHECKPOINT_PATH, device="cuda", dtype="bfloat16", @@ -705,7 +705,7 @@ def test_flux1_e2e_vs_hf(self, flux1_checkpoint_exists): torch.cuda.empty_cache() # 2. Load TRT-LLM pipeline (full, no skip_components) - args = DiffusionArgs( + args = VisualGenArgs( checkpoint_path=FLUX1_CHECKPOINT_PATH, device="cuda", dtype="bfloat16", @@ -758,7 +758,7 @@ def test_flux2_e2e_vs_hf(self, flux2_checkpoint_exists): torch.cuda.empty_cache() # 2. Load TRT-LLM pipeline (full, no skip_components) - args = DiffusionArgs( + args = VisualGenArgs( checkpoint_path=FLUX2_CHECKPOINT_PATH, device="cuda", dtype="bfloat16", @@ -821,11 +821,11 @@ def _run_ulysses_worker(rank, world_size, checkpoint_path, inputs_cpu, return_di try: _setup_distributed(rank, world_size) - from tensorrt_llm._torch.visual_gen.config import DiffusionArgs, ParallelConfig + from tensorrt_llm._torch.visual_gen.config import ParallelConfig, VisualGenArgs from tensorrt_llm._torch.visual_gen.pipeline_loader import PipelineLoader # Load pipeline with Ulysses parallelism - args = DiffusionArgs( + args = VisualGenArgs( checkpoint_path=checkpoint_path, device=f"cuda:{rank}", dtype="bfloat16", @@ -885,7 +885,7 @@ def test_ulysses_2gpu_correctness(self, flux1_checkpoint_exists): # Load single-GPU reference print("\n[1/3] Loading single-GPU reference (ulysses_size=1) on GPU 0...") - args_baseline = DiffusionArgs( + args_baseline = VisualGenArgs( checkpoint_path=FLUX1_CHECKPOINT_PATH, device="cuda:0", dtype="bfloat16", @@ -967,15 +967,15 @@ def _run_all_optimizations_worker(rank, world_size, checkpoint_path, inputs_cpu, from tensorrt_llm._torch.visual_gen.config import ( AttentionConfig, - DiffusionArgs, ParallelConfig, TeaCacheConfig, + VisualGenArgs, ) from tensorrt_llm._torch.visual_gen.pipeline_loader import PipelineLoader from tensorrt_llm.quantization.mode import QuantAlgo # Load pipeline with ALL optimizations - args = DiffusionArgs( + args = VisualGenArgs( checkpoint_path=checkpoint_path, device=f"cuda:{rank}", dtype="bfloat16", @@ -1062,7 +1062,7 @@ def test_all_optimizations_combined(self, flux1_checkpoint_exists): # Load baseline on GPU 0 (no optimizations) print("\n[1/3] Loading baseline on GPU 0 (BF16, no optimizations)...") - args_baseline = DiffusionArgs( + args_baseline = VisualGenArgs( checkpoint_path=FLUX1_CHECKPOINT_PATH, device="cuda:0", dtype="bfloat16", diff --git a/tests/unittest/_torch/visual_gen/test_model_loader.py b/tests/unittest/_torch/visual_gen/test_model_loader.py index 6502996f273..1003fee8431 100644 --- a/tests/unittest/_torch/visual_gen/test_model_loader.py +++ b/tests/unittest/_torch/visual_gen/test_model_loader.py @@ -1,4 +1,4 @@ -"""Test PipelineLoader with DiffusionArgs API.""" +"""Test PipelineLoader with VisualGenArgs API.""" import os from pathlib import Path @@ -50,12 +50,12 @@ def test_meta_init_mode_creates_meta_tensors(checkpoint_exists): pytest.skip("Checkpoint not available") from tensorrt_llm._torch.models.modeling_utils import MetaInitMode - from tensorrt_llm._torch.visual_gen import DiffusionArgs + from tensorrt_llm._torch.visual_gen import VisualGenArgs from tensorrt_llm._torch.visual_gen.config import DiffusionModelConfig from tensorrt_llm._torch.visual_gen.models import AutoPipeline # Load config directly - args = DiffusionArgs(checkpoint_path=CHECKPOINT_PATH) + args = VisualGenArgs(checkpoint_path=CHECKPOINT_PATH) config = DiffusionModelConfig.from_pretrained( CHECKPOINT_PATH, args=args, @@ -71,15 +71,15 @@ def test_meta_init_mode_creates_meta_tensors(checkpoint_exists): def test_load_wan_pipeline_basic(checkpoint_exists): - """Test basic loading without quantization using DiffusionArgs.""" + """Test basic loading without quantization using VisualGenArgs.""" if not checkpoint_exists: pytest.skip("Checkpoint not available") - from tensorrt_llm._torch.visual_gen import DiffusionArgs, PipelineLoader + from tensorrt_llm._torch.visual_gen import PipelineLoader, VisualGenArgs - # Simple one-liner with DiffusionArgs + # Simple one-liner with VisualGenArgs # Skip text_encoder/vae to speed up test (focus on transformer) - args = DiffusionArgs( + args = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH, skip_components=SKIP_HEAVY_COMPONENTS, ) @@ -100,7 +100,7 @@ def test_load_wan_pipeline_basic(checkpoint_exists): def test_load_wan_pipeline_with_fp8_dynamic_quant(checkpoint_exists): - """Test loading with FP8 dynamic quantization using DiffusionArgs. + """Test loading with FP8 dynamic quantization using VisualGenArgs. Verifies the dynamic quantization flow: 1. Config has dynamic_weight_quant=True when linear.type="trtllm-fp8-per-tensor" @@ -112,11 +112,11 @@ def test_load_wan_pipeline_with_fp8_dynamic_quant(checkpoint_exists): pytest.skip("Checkpoint not available") from tensorrt_llm._torch.modules.linear import Linear - from tensorrt_llm._torch.visual_gen import DiffusionArgs, PipelineLoader + from tensorrt_llm._torch.visual_gen import PipelineLoader, VisualGenArgs - # Use DiffusionArgs with FP8 quantization + # Use VisualGenArgs with FP8 quantization # Skip text_encoder/vae to speed up test (focus on transformer quantization) - args = DiffusionArgs( + args = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH, quant_config={"quant_algo": "FP8", "dynamic": True}, skip_components=SKIP_HEAVY_COMPONENTS, @@ -146,15 +146,15 @@ def test_load_wan_pipeline_with_fp8_dynamic_quant(checkpoint_exists): def test_load_wan_pipeline_with_fp8_blockwise(checkpoint_exists): - """Test loading with FP8 blockwise quantization using DiffusionArgs.""" + """Test loading with FP8 blockwise quantization using VisualGenArgs.""" if not checkpoint_exists: pytest.skip("Checkpoint not available") from tensorrt_llm._torch.modules.linear import Linear - from tensorrt_llm._torch.visual_gen import DiffusionArgs, PipelineLoader + from tensorrt_llm._torch.visual_gen import PipelineLoader, VisualGenArgs # Skip text_encoder/vae to speed up test (focus on transformer quantization) - args = DiffusionArgs( + args = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH, quant_config={"quant_algo": "FP8_BLOCK_SCALES", "dynamic": True}, skip_components=SKIP_HEAVY_COMPONENTS, @@ -172,16 +172,16 @@ def test_load_wan_pipeline_with_fp8_blockwise(checkpoint_exists): def test_diffusion_args_to_quant_config(): - """Test that DiffusionArgs correctly parses quant_config dict to QuantConfig.""" - from tensorrt_llm._torch.visual_gen import DiffusionArgs + """Test that VisualGenArgs correctly parses quant_config dict to QuantConfig.""" + from tensorrt_llm._torch.visual_gen import VisualGenArgs from tensorrt_llm.quantization.mode import QuantAlgo # Default - no quantization - args = DiffusionArgs(checkpoint_path="/fake/path") + args = VisualGenArgs(checkpoint_path="/fake/path") assert args.quant_config.quant_algo is None # FP8 per-tensor (dict is coerced to QuantConfig by model_validator) - args = DiffusionArgs( + args = VisualGenArgs( checkpoint_path="/fake/path", quant_config={"quant_algo": "FP8", "dynamic": True}, ) @@ -191,7 +191,7 @@ def test_diffusion_args_to_quant_config(): assert args.dynamic_weight_quant is True # FP8 blockwise - args = DiffusionArgs( + args = VisualGenArgs( checkpoint_path="/fake/path", quant_config={"quant_algo": "FP8_BLOCK_SCALES", "dynamic": True}, ) @@ -199,7 +199,7 @@ def test_diffusion_args_to_quant_config(): assert qc.quant_algo == QuantAlgo.FP8_BLOCK_SCALES # NVFP4 - args = DiffusionArgs( + args = VisualGenArgs( checkpoint_path="/fake/path", quant_config={"quant_algo": "NVFP4", "dynamic": True}, ) @@ -207,7 +207,7 @@ def test_diffusion_args_to_quant_config(): assert qc.quant_algo == QuantAlgo.NVFP4 # With ignore patterns (exclude_modules) - args = DiffusionArgs( + args = VisualGenArgs( checkpoint_path="/fake/path", quant_config={ "quant_algo": "FP8", @@ -228,14 +228,14 @@ def test_diffusion_args_to_quant_config(): def test_diffusion_args_to_mapping(): - """Test that DiffusionArgs correctly generates Mapping from ParallelConfig.""" - from tensorrt_llm._torch.visual_gen import DiffusionArgs, ParallelConfig + """Test that VisualGenArgs correctly generates Mapping from ParallelConfig.""" + from tensorrt_llm._torch.visual_gen import ParallelConfig, VisualGenArgs # ParallelConfig validator requires WORLD_SIZE >= total parallel (tp*cp = 4) old_world = os.environ.get("WORLD_SIZE") try: os.environ["WORLD_SIZE"] = "4" - args = DiffusionArgs( + args = VisualGenArgs( checkpoint_path="/fake/path", parallel=ParallelConfig(dit_tp_size=2, dit_cp_size=2), ) @@ -257,11 +257,11 @@ def test_load_without_quant_config_no_fp8(checkpoint_exists): pytest.skip("Checkpoint not available") from tensorrt_llm._torch.modules.linear import Linear - from tensorrt_llm._torch.visual_gen import DiffusionArgs, PipelineLoader + from tensorrt_llm._torch.visual_gen import PipelineLoader, VisualGenArgs # No quantization specified # Skip text_encoder/vae to speed up test (focus on transformer) - args = DiffusionArgs( + args = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH, skip_components=SKIP_HEAVY_COMPONENTS, ) @@ -283,8 +283,8 @@ def test_load_without_quant_config_no_fp8(checkpoint_exists): def test_diffusion_args_from_dict(): - """Test DiffusionArgs can be created from a dictionary.""" - from tensorrt_llm._torch.visual_gen import DiffusionArgs + """Test VisualGenArgs can be created from a dictionary.""" + from tensorrt_llm._torch.visual_gen import VisualGenArgs from tensorrt_llm.quantization.mode import QuantAlgo config_dict = { @@ -297,7 +297,7 @@ def test_diffusion_args_from_dict(): old_world = os.environ.get("WORLD_SIZE") try: os.environ["WORLD_SIZE"] = "2" - args = DiffusionArgs.from_dict(config_dict) + args = VisualGenArgs.from_dict(config_dict) assert args.checkpoint_path == "/path/to/model" assert args.quant_config.quant_algo == QuantAlgo.FP8 assert args.dynamic_weight_quant is True @@ -343,7 +343,7 @@ def test_fp8_vs_bf16_memory_comparison(checkpoint_exists): if not checkpoint_exists: pytest.skip("Checkpoint not available") - from tensorrt_llm._torch.visual_gen import DiffusionArgs, PipelineLoader + from tensorrt_llm._torch.visual_gen import PipelineLoader, VisualGenArgs # ========================================================================= # Test 1: Load BF16 (no quantization) @@ -351,7 +351,7 @@ def test_fp8_vs_bf16_memory_comparison(checkpoint_exists): torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() - args_bf16 = DiffusionArgs( + args_bf16 = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH, skip_components=SKIP_HEAVY_COMPONENTS, ) @@ -374,7 +374,7 @@ def test_fp8_vs_bf16_memory_comparison(checkpoint_exists): # ========================================================================= torch.cuda.reset_peak_memory_stats() - args_fp8 = DiffusionArgs( + args_fp8 = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH, quant_config={"quant_algo": "FP8", "dynamic": True}, skip_components=SKIP_HEAVY_COMPONENTS, @@ -430,7 +430,7 @@ def test_fp8_vs_bf16_memory_comparison(checkpoint_exists): # ========================================================================= torch.cuda.reset_peak_memory_stats() - args_fp8_block = DiffusionArgs( + args_fp8_block = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH, quant_config={"quant_algo": "FP8_BLOCK_SCALES", "dynamic": True}, skip_components=SKIP_HEAVY_COMPONENTS, diff --git a/tests/unittest/_torch/visual_gen/test_trtllm_serve_e2e.py b/tests/unittest/_torch/visual_gen/test_trtllm_serve_e2e.py index d44e5fd9378..3e641a586f9 100644 --- a/tests/unittest/_torch/visual_gen/test_trtllm_serve_e2e.py +++ b/tests/unittest/_torch/visual_gen/test_trtllm_serve_e2e.py @@ -183,7 +183,6 @@ def _ffmpeg_available() -> bool: def _make_visual_gen_options(**extra) -> dict: """Build the YAML dict passed via ``--extra_visual_gen_options``.""" config = { - "linear": {"type": "default"}, "parallel": {"dit_cfg_size": 1, "dit_ulysses_size": 1}, } config.update(extra) diff --git a/tests/unittest/_torch/visual_gen/test_visual_gen_args.py b/tests/unittest/_torch/visual_gen/test_visual_gen_args.py new file mode 100644 index 00000000000..008c5e5d22d --- /dev/null +++ b/tests/unittest/_torch/visual_gen/test_visual_gen_args.py @@ -0,0 +1,183 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for VisualGenArgs construction, validation, and serialization.""" + +import pickle + +import pytest +from pydantic import ValidationError + +from tensorrt_llm._torch.visual_gen.config import ( + AttentionConfig, + CudaGraphConfig, + ParallelConfig, + PipelineConfig, + TeaCacheConfig, + TorchCompileConfig, + VisualGenArgs, +) + + +class TestVisualGenArgsStrictValidation: + """extra='forbid' rejects unknown fields at every nesting level.""" + + def test_unknown_top_level_field_rejected(self): + with pytest.raises(ValidationError, match="Extra inputs are not permitted"): + VisualGenArgs(checkpoint_path="/tmp/model", unknown_field="bad") + + def test_typo_field_rejected(self): + with pytest.raises(ValidationError, match="Extra inputs are not permitted"): + VisualGenArgs(checkpoint_path="/tmp/model", chekpoint_path="/typo") + + def test_nested_parallel_unknown_field_rejected(self): + with pytest.raises(ValidationError, match="Extra inputs are not permitted"): + VisualGenArgs( + checkpoint_path="/tmp/model", + parallel={"dit_cfg_size": 1, "nonexistent_param": 42}, + ) + + def test_nested_attention_unknown_field_rejected(self): + with pytest.raises(ValidationError, match="Extra inputs are not permitted"): + AttentionConfig(backend="VANILLA", extra_key="bad") + + def test_nested_teacache_unknown_field_rejected(self): + with pytest.raises(ValidationError, match="Extra inputs are not permitted"): + TeaCacheConfig(enable_teacache=True, unknown_opt=True) + + def test_nested_torch_compile_unknown_field_rejected(self): + with pytest.raises(ValidationError, match="Extra inputs are not permitted"): + TorchCompileConfig(enable_torch_compile=True, bad_key=1) + + def test_nested_cuda_graph_unknown_field_rejected(self): + with pytest.raises(ValidationError, match="Extra inputs are not permitted"): + CudaGraphConfig(enable_cuda_graph=False, extra=True) + + def test_nested_pipeline_unknown_field_rejected(self): + with pytest.raises(ValidationError, match="Extra inputs are not permitted"): + PipelineConfig(fuse_qkv=True, invalid_flag=True) + + def test_legacy_linear_field_rejected(self): + """The removed 'linear' YAML field must now cause an error.""" + with pytest.raises(ValidationError, match="Extra inputs are not permitted"): + VisualGenArgs( + checkpoint_path="/tmp/model", + linear={"type": "default"}, + ) + + +class TestVisualGenArgsFromDict: + """from_dict now enforces strict validation (no silent drops).""" + + def test_valid_dict(self): + args = VisualGenArgs.from_dict( + { + "checkpoint_path": "/tmp/model", + "parallel": {"dit_cfg_size": 2, "dit_ulysses_size": 1}, + } + ) + assert args.checkpoint_path == "/tmp/model" + assert args.parallel.dit_cfg_size == 2 + + def test_unknown_field_raises(self): + with pytest.raises(ValidationError, match="Extra inputs are not permitted"): + VisualGenArgs.from_dict( + { + "checkpoint_path": "/tmp/model", + "bad_key": 123, + } + ) + + def test_nested_dict_auto_coerced(self): + args = VisualGenArgs.from_dict( + { + "checkpoint_path": "/tmp/model", + "attention": {"backend": "TRTLLM"}, + "teacache": {"enable_teacache": True, "teacache_thresh": 0.3}, + } + ) + assert isinstance(args.attention, AttentionConfig) + assert args.attention.backend == "TRTLLM" + assert args.teacache.enable_teacache is True + assert args.teacache.teacache_thresh == 0.3 + + def test_quant_config_dict_coerced(self): + args = VisualGenArgs.from_dict( + { + "checkpoint_path": "/tmp/model", + "quant_config": {"quant_algo": "FP8", "dynamic": True}, + } + ) + assert args.quant_config.quant_algo is not None + assert args.dynamic_weight_quant is True + + +class TestVisualGenArgsFromYaml: + """from_yaml round-trips through a YAML file.""" + + def test_from_yaml_basic(self, tmp_path): + yaml_path = tmp_path / "config.yml" + yaml_path.write_text( + "checkpoint_path: /tmp/model\nparallel:\n dit_cfg_size: 2\n dit_ulysses_size: 1\n" + ) + args = VisualGenArgs.from_yaml(yaml_path) + assert args.checkpoint_path == "/tmp/model" + assert args.parallel.dit_cfg_size == 2 + + def test_from_yaml_with_overrides(self, tmp_path): + yaml_path = tmp_path / "config.yml" + yaml_path.write_text("checkpoint_path: /tmp/model\ndtype: float16\n") + args = VisualGenArgs.from_yaml(yaml_path, dtype="bfloat16") + assert args.dtype == "bfloat16" + + def test_from_yaml_unknown_field_raises(self, tmp_path): + yaml_path = tmp_path / "bad.yml" + yaml_path.write_text("checkpoint_path: /tmp/model\nlinear:\n type: default\n") + with pytest.raises(ValidationError, match="Extra inputs are not permitted"): + VisualGenArgs.from_yaml(yaml_path) + + +class TestParallelConfigValidation: + """ParallelConfig no longer checks WORLD_SIZE at construction time.""" + + def test_large_parallel_no_env_check(self): + pc = ParallelConfig(dit_cfg_size=2, dit_ulysses_size=4) + assert pc.total_parallel_size == 8 + assert pc.n_workers == 8 + + def test_validate_world_size_passes(self): + pc = ParallelConfig(dit_cfg_size=2, dit_ulysses_size=2) + pc.validate_world_size(4) + + def test_validate_world_size_fails(self): + pc = ParallelConfig(dit_cfg_size=2, dit_ulysses_size=4) + with pytest.raises(ValueError, match="exceeds world_size"): + pc.validate_world_size(4) + + +class TestVisualGenArgsPickle: + """VisualGenArgs must survive pickle round-trip (mp.Process spawn).""" + + def test_pickle_roundtrip(self): + args = VisualGenArgs( + checkpoint_path="/tmp/model", + dtype="float16", + parallel=ParallelConfig(dit_cfg_size=2, dit_ulysses_size=1), + attention=AttentionConfig(backend="TRTLLM"), + quant_config={"quant_algo": "FP8", "dynamic": True}, + ) + data = pickle.dumps(args) + restored = pickle.loads(data) + + assert restored.checkpoint_path == args.checkpoint_path + assert restored.dtype == args.dtype + assert restored.parallel.dit_cfg_size == 2 + assert restored.attention.backend == "TRTLLM" + assert restored.quant_config.quant_algo is not None + assert restored.dynamic_weight_quant is True + + def test_model_copy_device_override(self): + args = VisualGenArgs(checkpoint_path="/tmp/model", device="cuda") + updated = args.model_copy(update={"device": "cuda:3"}) + assert updated.device == "cuda:3" + assert args.device == "cuda" diff --git a/tests/unittest/_torch/visual_gen/test_wan.py b/tests/unittest/_torch/visual_gen/test_wan.py index 4215cae8b33..ccc7da5287f 100644 --- a/tests/unittest/_torch/visual_gen/test_wan.py +++ b/tests/unittest/_torch/visual_gen/test_wan.py @@ -20,11 +20,11 @@ from tensorrt_llm._torch.modules.linear import Linear from tensorrt_llm._torch.visual_gen.config import ( AttentionConfig, - DiffusionArgs, DiffusionModelConfig, ParallelConfig, PipelineComponent, TeaCacheConfig, + VisualGenArgs, ) from tensorrt_llm._torch.visual_gen.models.wan.transformer_wan import WanTransformer3DModel from tensorrt_llm._torch.visual_gen.pipeline_loader import PipelineLoader @@ -149,11 +149,11 @@ def _run_cfg_worker(rank, world_size, checkpoint_path, inputs_list, return_dict) try: setup_distributed(rank, world_size) - from tensorrt_llm._torch.visual_gen.config import DiffusionArgs, ParallelConfig + from tensorrt_llm._torch.visual_gen.config import ParallelConfig, VisualGenArgs from tensorrt_llm._torch.visual_gen.pipeline_loader import PipelineLoader # Load pipeline with CFG parallel - args = DiffusionArgs( + args = VisualGenArgs( checkpoint_path=checkpoint_path, device=f"cuda:{rank}", dtype="bfloat16", @@ -253,7 +253,7 @@ def _run_all_optimizations_worker(rank, world_size, checkpoint_path, inputs_list setup_distributed(rank, world_size) # Load pipeline with ALL optimizations - args_full = DiffusionArgs( + args_full = VisualGenArgs( checkpoint_path=checkpoint_path, device=f"cuda:{rank}", dtype="bfloat16", @@ -768,7 +768,7 @@ def test_load_wan_pipeline_basic(self, checkpoint_exists): "This test requires Wan 2.1 checkpoint (single-stage). Use DIFFUSION_MODEL_PATH with '2.1' in the path." ) - args = DiffusionArgs( + args = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH, device="cuda", dtype="bfloat16", @@ -821,7 +821,7 @@ def test_load_wan_pipeline_with_quantization(self, checkpoint_exists, quant_algo "This test requires Wan 2.1 checkpoint. Use DIFFUSION_MODEL_PATH with '2.1' in the path." ) - args = DiffusionArgs( + args = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH, device="cuda", dtype="bfloat16", @@ -856,7 +856,7 @@ def test_load_wan_pipeline_with_nvfp4_quantization(self, checkpoint_exists): from tensorrt_llm.quantization.utils import fp4_utils - args = DiffusionArgs( + args = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH, device="cuda", dtype="bfloat16", @@ -910,7 +910,7 @@ def test_fp8_vs_bf16_numerical_correctness(self, checkpoint_exists, quant_algo): # ===================================================================== print(f"\n[Compare {quant_algo}] Loading BF16 pipeline...") - args_bf16 = DiffusionArgs( + args_bf16 = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH, device="cuda", dtype="bfloat16", @@ -922,7 +922,7 @@ def test_fp8_vs_bf16_numerical_correctness(self, checkpoint_exists, quant_algo): # ===================================================================== print(f"[Compare {quant_algo}] Loading {quant_algo} pipeline...") - args_fp8 = DiffusionArgs( + args_fp8 = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH, device="cuda", dtype="bfloat16", @@ -1028,7 +1028,7 @@ def get_module_memory_gb(module): torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() - args_bf16 = DiffusionArgs( + args_bf16 = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH, device="cuda", dtype="bfloat16", @@ -1047,7 +1047,7 @@ def get_module_memory_gb(module): # Load FP8 torch.cuda.reset_peak_memory_stats() - args_fp8 = DiffusionArgs( + args_fp8 = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH, device="cuda", dtype="bfloat16", @@ -1098,7 +1098,7 @@ def test_fp8_vs_bf16_full_transformer_e2e(self, checkpoint_exists, quant_algo): # ===================================================================== print("\n[E2E] Loading BF16 transformer...") - args_bf16 = DiffusionArgs( + args_bf16 = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH, device="cuda", dtype="bfloat16", @@ -1111,7 +1111,7 @@ def test_fp8_vs_bf16_full_transformer_e2e(self, checkpoint_exists, quant_algo): # ===================================================================== print(f"[E2E] Loading {quant_algo} transformer...") - args_fp8 = DiffusionArgs( + args_fp8 = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH, device="cuda", dtype="bfloat16", @@ -1268,7 +1268,7 @@ def test_attention_backend_comparison(self, checkpoint_exists): from tensorrt_llm._torch.visual_gen.config import AttentionConfig - args_baseline = DiffusionArgs( + args_baseline = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH, device="cuda", dtype="bfloat16", @@ -1282,7 +1282,7 @@ def test_attention_backend_comparison(self, checkpoint_exists): # ===================================================================== print("[Attention Backend Test] Loading VANILLA transformer (explicit)...") - args_vanilla = DiffusionArgs( + args_vanilla = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH, device="cuda", dtype="bfloat16", @@ -1395,7 +1395,7 @@ def test_attention_backend_comparison(self, checkpoint_exists): # ===================================================================== print("\n[Attention Backend Test] Loading TRTLLM transformer...") - args_trtllm = DiffusionArgs( + args_trtllm = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH, device="cuda", dtype="bfloat16", @@ -1505,7 +1505,7 @@ def test_fp8_mixed_quant_numerical_correctness(self, checkpoint_exists, quant_al # Load Models # ===================================================================== print("\n[Mixed Quant Accuracy] Loading BF16 model (reference)...") - args_bf16 = DiffusionArgs( + args_bf16 = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH, device="cuda", dtype="bfloat16", @@ -1514,7 +1514,7 @@ def test_fp8_mixed_quant_numerical_correctness(self, checkpoint_exists, quant_al pipeline_bf16 = PipelineLoader(args_bf16).load(skip_warmup=True) print(f"[Mixed Quant Accuracy] Loading mixed {quant_algo} model...") - args_fp8_mixed = DiffusionArgs( + args_fp8_mixed = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH, device="cuda", dtype="bfloat16", @@ -1630,7 +1630,7 @@ def test_fp8_vs_bf16_accuracy(self, wan22_both_checkpoints_exist): # Load BF16 reference model print(f"\n[BF16] Loading from {CHECKPOINT_PATH_WAN22_BF16}") - args_bf16 = DiffusionArgs( + args_bf16 = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH_WAN22_BF16, device="cuda", dtype="bfloat16", @@ -1640,7 +1640,7 @@ def test_fp8_vs_bf16_accuracy(self, wan22_both_checkpoints_exist): # Load FP8 static quantized model (from pre-quantized checkpoint) print(f"\n[FP8 Static] Loading from {CHECKPOINT_PATH_WAN22_FP8}") - args_fp8_static = DiffusionArgs( + args_fp8_static = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH_WAN22_FP8, device="cuda", dtype="bfloat16", @@ -1650,7 +1650,7 @@ def test_fp8_vs_bf16_accuracy(self, wan22_both_checkpoints_exist): # Load FP8 dynamic quantized model (from BF16 checkpoint with on-the-fly quant) print(f"\n[FP8 Dynamic] Loading from {CHECKPOINT_PATH_WAN22_BF16} with dynamic quant") - args_fp8_dynamic = DiffusionArgs( + args_fp8_dynamic = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH_WAN22_BF16, device="cuda", dtype="bfloat16", @@ -1839,7 +1839,7 @@ def test_nvfp4_vs_bf16_accuracy(self, wan22_nvfp4_bf16_checkpoints_exist): # Load BF16 reference model print(f"\n[BF16] Loading from {CHECKPOINT_PATH_WAN22_BF16}") - args_bf16 = DiffusionArgs( + args_bf16 = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH_WAN22_BF16, device="cuda", dtype="bfloat16", @@ -1849,7 +1849,7 @@ def test_nvfp4_vs_bf16_accuracy(self, wan22_nvfp4_bf16_checkpoints_exist): # Load NVFP4 static quantized model (from pre-quantized checkpoint) print(f"\n[NVFP4 Static] Loading from {CHECKPOINT_PATH_WAN22_NVFP4}") - args_nvfp4_static = DiffusionArgs( + args_nvfp4_static = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH_WAN22_NVFP4, device="cuda", dtype="bfloat16", @@ -1859,7 +1859,7 @@ def test_nvfp4_vs_bf16_accuracy(self, wan22_nvfp4_bf16_checkpoints_exist): # Load NVFP4 dynamic quantized model (from BF16 checkpoint with on-the-fly quant) print(f"\n[NVFP4 Dynamic] Loading from {CHECKPOINT_PATH_WAN22_BF16} with dynamic quant") - args_nvfp4_dynamic = DiffusionArgs( + args_nvfp4_dynamic = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH_WAN22_BF16, device="cuda", dtype="bfloat16", @@ -2076,7 +2076,7 @@ def test_nvfp4_vs_bf16_accuracy_mixed_quant(self, wan22_t2v_bf16_checkpoint_exis # Load BF16 reference model print(f"\n[BF16] Loading from {CHECKPOINT_PATH_WAN22_T2V}") - args_bf16 = DiffusionArgs( + args_bf16 = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH_WAN22_T2V, device="cuda", dtype="bfloat16", @@ -2090,7 +2090,7 @@ def test_nvfp4_vs_bf16_accuracy_mixed_quant(self, wan22_t2v_bf16_checkpoint_exis static_bf16_modules = 0 if have_nvfp4_static: print(f"\n[NVFP4 Static] Loading from {CHECKPOINT_PATH_WAN22_T2V_NVFP4}") - args_nvfp4_static = DiffusionArgs( + args_nvfp4_static = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH_WAN22_T2V_NVFP4, device="cuda", dtype="bfloat16", @@ -2124,7 +2124,7 @@ def test_nvfp4_vs_bf16_accuracy_mixed_quant(self, wan22_t2v_bf16_checkpoint_exis f"\n[NVFP4 Dynamic] Loading from {CHECKPOINT_PATH_WAN22_T2V} " f"with dynamic quant + ignore patterns" ) - args_nvfp4_dynamic = DiffusionArgs( + args_nvfp4_dynamic = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH_WAN22_T2V, device="cuda", dtype="bfloat16", @@ -2435,7 +2435,7 @@ def test_teacache_multi_step(self): # Load HuggingFace baseline print("\n[1/4] Loading HuggingFace baseline...") - args_trtllm = DiffusionArgs( + args_trtllm = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH, device="cuda", dtype="bfloat16", @@ -2479,7 +2479,7 @@ def test_teacache_multi_step(self): # Load TeaCache-enabled pipeline print("\n[2/4] Loading TeaCache-enabled TRT-LLM pipeline...") - args_teacache = DiffusionArgs( + args_teacache = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH, device="cuda", dtype="bfloat16", @@ -2625,7 +2625,7 @@ def test_cfg_2gpu_correctness(self): # Load standard CFG baseline on GPU 0 print("\n[1/3] Loading standard CFG baseline (cfg_size=1) on GPU 0...") - args_baseline = DiffusionArgs( + args_baseline = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH, device="cuda:0", dtype="bfloat16", @@ -2819,7 +2819,7 @@ def test_all_optimizations_combined(self): # Load baseline on GPU 0 (no optimizations, standard CFG) print("\n[1/3] Loading baseline on GPU 0 (standard CFG, no optimizations)...") - args_baseline = DiffusionArgs( + args_baseline = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH, device="cuda:0", dtype="bfloat16", @@ -2988,7 +2988,7 @@ def test_two_stage_pipeline_initialization(self): print("WAN 2.2 TWO-STAGE PIPELINE INITIALIZATION TEST") print("=" * 80) - args = DiffusionArgs( + args = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH_WAN22_T2V, device="cuda", dtype="bfloat16", @@ -3036,7 +3036,7 @@ def test_two_stage_transformer_selection_logic(self): print("WAN 2.2 TRANSFORMER SELECTION LOGIC TEST") print("=" * 80) - args = DiffusionArgs( + args = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH_WAN22_T2V, device="cuda", dtype="bfloat16", @@ -3132,7 +3132,7 @@ def test_two_stage_with_custom_boundary_ratio(self): print("WAN 2.2 CUSTOM BOUNDARY_RATIO TEST") print("=" * 80) - args = DiffusionArgs( + args = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH_WAN22_T2V, device="cuda", dtype="bfloat16", @@ -3183,7 +3183,7 @@ def test_two_stage_guidance_scale_2(self): print("WAN 2.2 GUIDANCE_SCALE_2 SUPPORT TEST") print("=" * 80) - args = DiffusionArgs( + args = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH_WAN22_T2V, device="cuda", dtype="bfloat16", @@ -3219,7 +3219,7 @@ def test_two_stage_with_fp8_quantization(self): print("WAN 2.2 TWO-STAGE + FP8 QUANTIZATION TEST") print("=" * 80) - args = DiffusionArgs( + args = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH_WAN22_T2V, device="cuda", dtype="bfloat16", @@ -3271,7 +3271,7 @@ def test_two_stage_with_trtllm_attention(self): print("WAN 2.2 TWO-STAGE + TRTLLM ATTENTION TEST") print("=" * 80) - args = DiffusionArgs( + args = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH_WAN22_T2V, device="cuda", dtype="bfloat16", @@ -3338,7 +3338,7 @@ def test_two_stage_all_optimizations(self): print("FP8 + TRTLLM Attention (TeaCache not supported for Wan 2.2)") print("=" * 80) - args = DiffusionArgs( + args = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH_WAN22_T2V, device="cuda", dtype="bfloat16", @@ -3422,7 +3422,7 @@ def tearDown(self): def test_invalid_quant_config(self): """Test that invalid quantization config raises appropriate error.""" with pytest.raises((ValueError, KeyError)): - args = DiffusionArgs( + args = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH, device="cuda", dtype="bfloat16", diff --git a/tests/unittest/_torch/visual_gen/test_wan_i2v.py b/tests/unittest/_torch/visual_gen/test_wan_i2v.py index 6a887315482..a6367e2adb9 100644 --- a/tests/unittest/_torch/visual_gen/test_wan_i2v.py +++ b/tests/unittest/_torch/visual_gen/test_wan_i2v.py @@ -33,10 +33,10 @@ from tensorrt_llm._torch.visual_gen.config import ( AttentionConfig, - DiffusionArgs, DiffusionModelConfig, ParallelConfig, TeaCacheConfig, + VisualGenArgs, ) from tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan_i2v import WanImageToVideoPipeline from tensorrt_llm._torch.visual_gen.pipeline_loader import PipelineLoader @@ -103,7 +103,7 @@ def wan21_i2v_pipeline_bf16(): if not is_wan21_checkpoint(): pytest.skip("This fixture requires Wan 2.1 checkpoint") - args = DiffusionArgs( + args = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH, device="cuda", dtype="bfloat16", @@ -123,7 +123,7 @@ def wan21_i2v_pipeline_fp8(): if not is_wan21_checkpoint(): pytest.skip("This fixture requires Wan 2.1 checkpoint") - args = DiffusionArgs( + args = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH, device="cuda", dtype="bfloat16", @@ -144,7 +144,7 @@ def wan21_i2v_pipeline_fp8_blockwise(): if not is_wan21_checkpoint(): pytest.skip("This fixture requires Wan 2.1 checkpoint") - args = DiffusionArgs( + args = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH, device="cuda", dtype="bfloat16", @@ -165,7 +165,7 @@ def wan21_i2v_pipeline_with_image_encoder(): if not is_wan21_checkpoint(): pytest.skip("This fixture requires Wan 2.1 checkpoint") - args = DiffusionArgs( + args = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH, device="cuda", dtype="bfloat16", @@ -185,7 +185,7 @@ def wan22_i2v_pipeline_bf16(): if not is_wan22_checkpoint(): pytest.skip("This fixture requires Wan 2.2 checkpoint") - args = DiffusionArgs( + args = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH, device="cuda", dtype="bfloat16", @@ -205,7 +205,7 @@ def wan22_i2v_pipeline_fp8(): if not is_wan22_checkpoint(): pytest.skip("This fixture requires Wan 2.2 checkpoint") - args = DiffusionArgs( + args = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH, device="cuda", dtype="bfloat16", @@ -269,11 +269,11 @@ def _run_cfg_worker_i2v(rank, world_size, checkpoint_path, inputs_list, return_d try: setup_distributed(rank, world_size) - from tensorrt_llm._torch.visual_gen.config import DiffusionArgs, ParallelConfig + from tensorrt_llm._torch.visual_gen.config import ParallelConfig, VisualGenArgs from tensorrt_llm._torch.visual_gen.pipeline_loader import PipelineLoader # Load I2V pipeline with CFG parallel - args = DiffusionArgs( + args = VisualGenArgs( checkpoint_path=checkpoint_path, device=f"cuda:{rank}", dtype="bfloat16", @@ -376,7 +376,7 @@ def _run_all_optimizations_worker_i2v(rank, world_size, checkpoint_path, inputs_ setup_distributed(rank, world_size) # Load I2V pipeline with ALL optimizations - args_full = DiffusionArgs( + args_full = VisualGenArgs( checkpoint_path=checkpoint_path, device=f"cuda:{rank}", dtype="bfloat16", @@ -641,7 +641,7 @@ def test_attention_backends(self, backend): if not is_wan21_checkpoint(): pytest.skip("This test requires Wan 2.1 checkpoint") - args = DiffusionArgs( + args = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH, device="cuda", dtype="bfloat16", @@ -693,7 +693,7 @@ def test_teacache(self): if not is_wan21_checkpoint(): pytest.skip("This test requires Wan 2.1 checkpoint") - args = DiffusionArgs( + args = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH, device="cuda", dtype="bfloat16", @@ -738,7 +738,7 @@ def test_all_optimizations_combined(self): if not is_wan21_checkpoint(): pytest.skip("This test requires Wan 2.1 checkpoint") - args = DiffusionArgs( + args = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH, device="cuda", dtype="bfloat16", @@ -973,8 +973,8 @@ def test_two_stage_with_all_optimizations(self, wan22_i2v_pipeline_fp8): ): pytest.skip("Not a two-stage checkpoint") - # Load pipeline with all supported optimizations (no TeaCache for Wan 2.2) - args = DiffusionArgs( + # Load pipeline with all optimizations + args = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH, device="cuda", dtype="bfloat16", @@ -1021,7 +1021,7 @@ class TestWanI2VRobustness: def test_invalid_quant_config(self): """Test that invalid quantization config raises appropriate error.""" with pytest.raises((ValueError, KeyError)): - args = DiffusionArgs( + args = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH, device="cuda", dtype="bfloat16", @@ -1036,7 +1036,7 @@ def test_mismatched_image_size(self, test_image): if not CHECKPOINT_PATH or not os.path.exists(CHECKPOINT_PATH): pytest.skip("DIFFUSION_MODEL_PATH not set") - args = DiffusionArgs( + args = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH, device="cuda", dtype="bfloat16", @@ -1115,7 +1115,7 @@ def test_cfg_2gpu_correctness(self): # Load standard CFG baseline on GPU 0 print("\n[1/3] Loading standard CFG I2V baseline (cfg_size=1) on GPU 0...") - args_baseline = DiffusionArgs( + args_baseline = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH, device="cuda:0", dtype="bfloat16", @@ -1326,7 +1326,7 @@ def test_all_optimizations_combined(self): # Load baseline on GPU 0 (no optimizations, standard CFG) print("\n[1/3] Loading I2V baseline on GPU 0 (standard CFG, no optimizations)...") - args_baseline = DiffusionArgs( + args_baseline = VisualGenArgs( checkpoint_path=CHECKPOINT_PATH, device="cuda:0", dtype="bfloat16", From 06ec49235a7620dbee5fdc0983249fabdd5e567a Mon Sep 17 00:00:00 2001 From: TensorRT LLM <90828364+tensorrt-cicd@users.noreply.github.com> Date: Mon, 9 Mar 2026 03:17:11 +0000 Subject: [PATCH 6/6] [None][infra] Check in most recent lock file from nightly pipeline Signed-off-by: TensorRT LLM <90828364+tensorrt-cicd@users.noreply.github.com> --- security_scanning/metadata.json | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/security_scanning/metadata.json b/security_scanning/metadata.json index d44eec7de5b..77a076768bf 100644 --- a/security_scanning/metadata.json +++ b/security_scanning/metadata.json @@ -1,4 +1,4 @@ { - "commit_hash": "6b049733311d552d507ecb4b04feda19066fc160", - "timestamp": "2026-03-08T02:47:25Z" + "commit_hash": "02c8a948208eca28fdec57bdd27be61563891ab0", + "timestamp": "2026-03-09T02:46:55Z" }