From 22007ca14db8c8f3cf3a7f1fcd87af05dc8d0e08 Mon Sep 17 00:00:00 2001 From: xinhe-nv <200704525+xinhe-nv@users.noreply.github.com> Date: Thu, 5 Mar 2026 14:18:27 +0800 Subject: [PATCH 1/7] [None][fix] remove leak check for kimi (#11825) Signed-off-by: Xin He (SW-GPU) <200704525+xinhe-nv@users.noreply.github.com> --- .../defs/accuracy/test_disaggregated_serving.py | 7 +++---- tests/integration/defs/accuracy/test_llm_api_pytorch.py | 1 - 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/tests/integration/defs/accuracy/test_disaggregated_serving.py b/tests/integration/defs/accuracy/test_disaggregated_serving.py index d74f829f380..3c6084e16e9 100644 --- a/tests/integration/defs/accuracy/test_disaggregated_serving.py +++ b/tests/integration/defs/accuracy/test_disaggregated_serving.py @@ -1563,13 +1563,12 @@ def test_mixed_ctx_gen_model(self, ctx_pp, gen_tp): @pytest.mark.timeout(10800) @skip_pre_blackwell class TestKimiK2(LlmapiAccuracyTestHarness): - MODEL_NAME = "moonshotai/Kimi-K2-Instruct" - MODEL_PATH = f"{llm_models_root()}/Kimi-K2-Instruct" + MODEL_NAME = "moonshotai/Kimi-K2-Thinking" + MODEL_PATH = f"{llm_models_root()}/Kimi-K2-Thinking-NVFP4" @pytest.mark.skip_less_device(8) @pytest.mark.skip_less_device_memory(200000) def test_nvfp4(self): - model_path = f"{llm_models_root()}/Kimi-K2-Thinking-NVFP4" ctx_server_config = { "max_batch_size": 16, "disable_overlap_scheduler": True, @@ -1611,7 +1610,7 @@ def test_nvfp4(self): } with launch_disaggregated_llm(disaggregated_server_config, ctx_server_config, gen_server_config, - model_path) as llm: + self.MODEL_PATH) as llm: run_accuracy_test(llm, self.MODEL_NAME, ["GSM8K"]) diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index c212a55a9ab..f9e6d231a1b 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -3223,7 +3223,6 @@ def test_nvfp4_2_model_mtp(self, tp_size, cuda_graph, overlap_scheduler, task.evaluate(llm) -@pytest.mark.threadleak(enabled=False) @pytest.mark.timeout(10800) @pytest.mark.skip_less_device_memory(100000) class TestKimiK2(LlmapiAccuracyTestHarness): From 12f2f3938d509293788155a7c30aefbf9c36dded Mon Sep 17 00:00:00 2001 From: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com> Date: Thu, 5 Mar 2026 14:23:22 +0800 Subject: [PATCH 2/7] [https://nvbugs/5907477][chore] unwaive test (#11896) Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com> --- tests/integration/test_lists/waives.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index 3107cb0b7eb..d2495def9e2 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -317,7 +317,6 @@ full:sm89/accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ full:sm89/accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_multi_instance[GSM8K] SKIP (https://nvbugs/5893116) full:sm89/accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_multi_instance[MMLU] SKIP (https://nvbugs/5893116) accuracy/test_llm_api_pytorch.py::TestMistralLarge3_675B::test_nvfp4_4gpus[latency_moe_trtllm_eagle] SKIP (https://nvbugspro.nvidia.com/bug/5896577) -accuracy/test_disaggregated_serving.py::TestDeepSeekV32Exp::test_auto_dtype[False] SKIP (https://nvbugspro.nvidia.com/bug/5907477) unittest/_torch/thop/serial/test_moe.py::TestMoeFp4::test_no_autotune[use_score_as_input-RoutingDSv3-swiglu-1024-1024-1] SKIP (https://nvbugspro.nvidia.com/bug/5908070) unittest/_torch/thop/serial/test_moe.py::TestMoeFp4::test_no_autotune[use_score_as_input-RoutingRenormalize_qwen_next-swiglu-1024-1024-150] SKIP (https://nvbugspro.nvidia.com/bug/5908070) unittest/_torch/thop/serial/test_moe.py::TestMoeFp4::test_no_autotune[use_score_as_input-RoutingRenormalize_topk_4-swiglu-1024-1024-150] SKIP (https://nvbugspro.nvidia.com/bug/5908070) From 17921f8d5e54ed3f4861cba079c27143e73254a0 Mon Sep 17 00:00:00 2001 From: Abby Wei Date: Thu, 5 Mar 2026 15:05:59 +0800 Subject: [PATCH 3/7] [TRTLLM-10956][infra] Support build-only mode for GenPostMergeBuilds job (#11895) Signed-off-by: Abby Wei --- jenkins/L0_MergeRequest.groovy | 33 +++++++++++++++++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/jenkins/L0_MergeRequest.groovy b/jenkins/L0_MergeRequest.groovy index 99d6ee68766..cd75b768c79 100644 --- a/jenkins/L0_MergeRequest.groovy +++ b/jenkins/L0_MergeRequest.groovy @@ -66,6 +66,8 @@ boolean enableFailFast = !(env.JOB_NAME ==~ /.*PostMerge.*/ || env.JOB_NAME ==~ boolean isReleaseCheckMode = (gitlabParamsFromBot.get("run_mode", "full") == "release_check") +GEN_POST_MERGE_BUILDS_ONLY = (env.JOB_NAME?.contains("GenPostMergeBuilds") ?: false) + BUILD_STATUS_NAME = isReleaseCheckMode ? "Jenkins Release Check" : "Jenkins Full Build" def trimForStageList(stageNameList) @@ -1048,6 +1050,10 @@ def launchStages(pipeline, reuseBuild, testFilter, enableFailFast, globalVars) stages = [ "Release-Check": { script { + if (GEN_POST_MERGE_BUILDS_ONLY) { + echo "Skipping Release-Check (GenPostMergeBuilds mode: builds only)" + return + } launchReleaseCheck(this) } }, @@ -1063,6 +1069,11 @@ def launchStages(pipeline, reuseBuild, testFilter, enableFailFast, globalVars) launchJob("/LLM/helpers/Build-x86_64", reuseBuild, enableFailFast, globalVars, "x86_64", additionalParameters) } + if (GEN_POST_MERGE_BUILDS_ONLY) { + echo "Skipping x86_64 tests (GenPostMergeBuilds mode: builds only)" + return + } + testStageName = "[Test-x86_64-Single-GPU] ${env.localJobCredentials ? "Remote Run" : "Run"}" def singleGpuTestFailed = false stage(testStageName) { @@ -1168,6 +1179,11 @@ def launchStages(pipeline, reuseBuild, testFilter, enableFailFast, globalVars) launchJob("/LLM/helpers/Build-SBSA", reuseBuild, enableFailFast, globalVars, "SBSA", additionalParameters) } + if (GEN_POST_MERGE_BUILDS_ONLY) { + echo "Skipping SBSA tests (GenPostMergeBuilds mode: builds only)" + return + } + testStageName = "[Test-SBSA-Single-GPU] ${env.localJobCredentials ? "Remote Run" : "Run"}" def singleGpuTestFailed = false stage(testStageName) { @@ -1278,10 +1294,10 @@ def launchStages(pipeline, reuseBuild, testFilter, enableFailFast, globalVars) } ] - if (env.JOB_NAME ==~ /.*PostMerge.*/) { + if (env.JOB_NAME ==~ /.*PostMerge.*/ && !GEN_POST_MERGE_BUILDS_ONLY) { stages += dockerBuildJob } - if (testFilter[(TEST_STAGE_LIST)]?.contains("Build-Docker-Images") || testFilter[(EXTRA_STAGE_LIST)]?.contains("Build-Docker-Images")) { + if (!GEN_POST_MERGE_BUILDS_ONLY && (testFilter[(TEST_STAGE_LIST)]?.contains("Build-Docker-Images") || testFilter[(EXTRA_STAGE_LIST)]?.contains("Build-Docker-Images"))) { stages += dockerBuildJob testFilter[(TEST_STAGE_LIST)]?.remove("Build-Docker-Images") testFilter[(EXTRA_STAGE_LIST)]?.remove("Build-Docker-Images") @@ -1374,5 +1390,18 @@ pipeline { } } } + stage("Upload Build Info") { + steps { + script { + def buildInfo = "commit=${env.gitlabCommit}\n" + + "branch=${env.gitlabTargetBranch ?: env.BRANCH_NAME ?: 'unknown'}\n" + + "date=${new Date().format('yyyy-MM-dd HH:mm:ss z', TimeZone.getTimeZone('UTC'))}\n" + + "jenkins_url=${env.BUILD_URL}" + writeFile file: 'build_info.txt', text: buildInfo + trtllm_utils.uploadArtifacts("build_info.txt", "${UPLOAD_PATH}/") + echo "Build info: https://urm.nvidia.com/artifactory/${UPLOAD_PATH}/build_info.txt" + } + } + } } // stages } // pipeline From 9da717a495a2fd4aa254007136f8cef03b4068eb Mon Sep 17 00:00:00 2001 From: Bala Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com> Date: Wed, 4 Mar 2026 23:11:39 -0800 Subject: [PATCH 4/7] [#11755][feat] AutoDeploy onboarding agent + Kimi K2.5 AD modeling code (#11780) Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com> --- .claude/agents/ad-debug-agent.md | 41 + .claude/agents/ad-onboard-reviewer.md | 123 +++ .claude/skills/ad-model-onboard/SKILL.md | 104 ++ .../model_registry/configs/kimi_k2.yaml | 22 + .../custom_ops/mla/flashinfer_mla.py | 24 +- .../custom_ops/mla/torch_backend_mla.py | 36 +- .../auto_deploy/models/custom/__init__.py | 3 + .../models/custom/modeling_kimi_k2.py | 989 ++++++++++++++++++ .../defs/accuracy/test_llm_api_autodeploy.py | 66 ++ .../_utils_test/_model_test_utils.py | 27 + .../singlegpu/models/test_kimi_k2_modeling.py | 764 ++++++++++++++ 11 files changed, 2190 insertions(+), 9 deletions(-) create mode 100644 .claude/agents/ad-debug-agent.md create mode 100644 .claude/agents/ad-onboard-reviewer.md create mode 100644 .claude/skills/ad-model-onboard/SKILL.md create mode 100644 examples/auto_deploy/model_registry/configs/kimi_k2.yaml create mode 100644 tensorrt_llm/_torch/auto_deploy/models/custom/modeling_kimi_k2.py create mode 100644 tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_kimi_k2_modeling.py diff --git a/.claude/agents/ad-debug-agent.md b/.claude/agents/ad-debug-agent.md new file mode 100644 index 00000000000..3d2f9525e4a --- /dev/null +++ b/.claude/agents/ad-debug-agent.md @@ -0,0 +1,41 @@ +--- +name: ad-debug-agent +description: Debug the AutoDeploy model onboarding process +tools: Read, Grep, Glob, Bash, Edit, Write +model: sonnet +--- + +Usually, we run a model with auto deploy using this command. If you are not given the model-id and config, ask the user first. + +And ask if you want to rerun it to get fresh log and IR. +Keep log and IR dump directory $PWD. + +Workflow: +1. Run the ad flow with the user given model-id and yaml using the below command. +How to run: +```bash +AD_DUMP_GRAPHS_DIR= python examples/auto_deploy/build_and_run_ad.py \ + --model \ + --args.yaml-extra examples/auto_deploy/model_registry/configs/ \ + 2>&1 | tee +``` +Where `AD_DUMP_GRAPHS_DIR=` is the directory where the graphs will be dumped (will be auto-created by the script), `` is the HF model-id of model we want to run (it can also be a local path to a model checkpoint), and `` is the configuration file for the model. + +If there's any error, we check the log file `` and IR files in the `AD_DUMP_GRAPHS_DIR` directory to see what went wrong. + +2. if you hit an error and notice something wrong, first inform the user what you observed. Then analyze the issue and think of possible rootcause. Don't jump to fixing anything yet. + +3. Based on the discussion with the user, implement the fix and run again and iterate. + + +Remember to use you your own tools - Read, Grep, Glob, Bash, Edit, Write + +Some common strategies to iterate faster and debug issues: +* use less hidden layers - can be done by updating the yaml file with model_kwargs. usually it'll be simple but it needs to match what model config expects - some models might have alternating layer patterns like - 1 full attention, 1 linear attention etc. Then update the yaml file with model_kwargs accordingly. +* enable / disable sharding - can be done by updating the yaml file with world_size = 1 or world_size >1 (say 2) + +Common pit-falls: +* weights in HF safetensors are not matching what AD custom modeling code expects. So weight loading will fail. Usually there'll be load hooks registered in ad modeling code, but you can verify that. HF safetensors json will be helpful refer. +* custom model has different module hierarchies than what the checkpoint safetensors expect. In that case we update the ad custom modeling code to match the expected hierarchy. + +Remember to use you your own tools - Read, Grep, Glob, Bash, Edit, Write diff --git a/.claude/agents/ad-onboard-reviewer.md b/.claude/agents/ad-onboard-reviewer.md new file mode 100644 index 00000000000..3749391428e --- /dev/null +++ b/.claude/agents/ad-onboard-reviewer.md @@ -0,0 +1,123 @@ +--- +name: onboard-reviewer +description: Independent reviewer for AutoDeploy model onboarding. Validates created model and test files against all onboarding requirements. Use after completing model onboarding work. +tools: Read, Grep, Glob +model: sonnet +--- + +You are an independent code reviewer for AutoDeploy model onboarding. + +**Your role is adversarial.** You exist because the implementing agent misses details. +Do NOT trust any claims from the caller. You will be given a model name and file paths. +Read every file yourself, line by line, and verify each checklist item with concrete evidence. + +## Inputs You Will Receive + +- `model_name`: The model being onboarded +- `model_file`: Path to the created `modeling_*.py` +- `test_file`: Path to the created `test_*_modeling.py` +- `init_file`: Always `tensorrt_llm/_torch/auto_deploy/models/custom/__init__.py` + +## Validation Checklist + +Read the actual source code for each check. Cite `file:line_number` for every PASS and FAIL. + + +### B. Self-Containment + +| # | Check | How to verify | +|---|-------|---------------| +| B1 | No imports from other AD custom models (`from .modeling_*`) | Grep for `from .modeling_` — only `from .` imports of non-model utilities are OK (e.g., `mla_rope_utils`) | +| B2 | Config class is defined in the file OR imported from transformers (not from another AD model) | Check where the config class comes from | +| B3 | If config not in installed transformers, file has `AutoConfig.register()` | Grep for `AutoConfig.register` | + +### BA Checkpoint compatibility +| BA1 | Make sure the custom modeling code nn.module hierarchy matches the model hierarchy that is expected in the checkpoint safetensor json. | +| BA2 | If our modeling code has expert-list style moe experts and the checkpoint has fused moe experts, add a load hook to load the safetensors correctly into our expert list weights. + +### C. Ops & Compatibility + +| # | Check | How to verify | +|---|-------|---------------| +| C1 | Only uses `torch_*` reference ops from `auto_deploy.custom_ops` or plain PyTorch | Grep for `torch.ops.` calls — only `torch.ops.auto_deploy.torch_*` allowed | +| C2 | No `triton_*`, `flashinfer_*`, `trtllm.*` ops (no exception for routers or router gemms all must be CPU compatible torch ops) | Grep for these prefixes | +| C3 | No KV cache logic (no `past_key_values`, no cache classes) | Grep for `past_key_value`, `cache`, `DynamicCache` | +| C4 | No training paths (no `self.training` checks, no `dropout`) | Grep for `self.training`, `dropout`, `Dropout` | +| C5 | No flash attention variants (`flash_attn`, `sdpa`, `_flash_attention`) | Grep for these strings | + +### D. RoPE & MoE Conventions + +| # | Check | How to verify | +|---|-------|---------------| +| D1 | RoPE buffers use `_ad_` prefix (`_ad_cos_cached`, `_ad_sin_cached`) | Grep for `register_buffer` calls with `_ad_` | +| D2 | RoPE `forward()` returns full table (not sliced by seq_len) | Read the RoPE forward method — should return full cached tensors | +| D3 | Position slicing happens downstream (in attention, by `position_ids`) | Check attention forward for `cos[position_ids]` or similar pattern | +| D4 | MoE experts use `nn.ModuleList` (not stacked tensor parameters) | Grep for `nn.ModuleList` in MoE class | +| D5 | Each expert has individual `gate_proj`, `up_proj`, `down_proj` weights | Check expert structure | + +Note: D1-D3 only apply if the model uses RoPE. D4-D5 only apply if the model has MoE. +Mark as N/A with justification if the model doesn't have the relevant component. + +### F. Test File — Structure + +| # | Check | How to verify | +|---|-------|---------------| +| F1 | Uses small config (hidden_size ~64, num_hidden_layers 2-3, vocab_size ~1000) | Read the test config creation | +| F2 | No smoke tests — every test has meaningful assertions (`assert_close`, `assert_rmse_close`, shape checks, finiteness checks) | Check each test for substantive assertions | +| F3 | Do not rely on only `isnan`/`isinf` checks; include functional equivalence assertions | Check tests use `assert_close` or `assert_rmse_close` against reference outputs | +| F4 | Test imports must be self-contained (transformers imports or copied reference classes only); no hardcoded local/temp path imports | Inspect imports and helper loaders | + +### G. Test File — Hierarchical Levels + +| # | Check | How to verify | +|---|-------|---------------| +| G1 | **Block equivalence**: Tests individual blocks (MLP, Attention, MoE, Norm) comparing AD output vs HF output. Blocks with identical math (plain MLP, Norm) should use `torch.testing.assert_close` with tight tolerance. Blocks with fused custom ops (Attention with MLA/RoPE, MoE with fused routing) must use `assert_rmse_close` from `_model_test_utils` with appropriate `rmse_ratio_tol` (attention: 0.10, MoE: 0.02). | Look for per-block test functions loading same weights into both implementations; verify correct comparison function and tolerance | +| G2 | **Layer equivalence**: Tests a full decoder layer (if model has heterogeneous layers like dense vs MoE, tests each type). Must use `assert_rmse_close` with `rmse_ratio_tol=0.05`. | Look for layer-level test with `assert_rmse_close` | +| G3 | **Full model equivalence**: End-to-end logits comparison AD vs HF with same weights with minimum number layers. Must use `assert_rmse_close` with `rmse_ratio_tol=0.05`. Also, need to be able to run on CPU. | Look for full model test with logits `assert_rmse_close` | +| G4 | **Export test**: Uses `torch_export_to_gm` with `Dim.DYNAMIC` for both batch and sequence dimensions | Grep for `torch_export_to_gm` and `Dim.DYNAMIC` | +| G6 | Export test runs a second forward with different shape to verify dynamic dims work | Look for a second input with different B, S values | + +### H. Test File — Weight Conversion + +| # | Check | How to verify | +|---|-------|---------------| +| H1 | If MoE model: has state_dict converter from HF stacked format to per-expert format | Look for conversion function | +| H2 | Equivalence tests load identical weights into both HF and AD models before comparing | Check that `load_state_dict` is called with converted weights | + +## Output Format + +```text +REVIEW RESULT: PASS | FAIL + +=== A. Structure & Hierarchy === +A1 PASS modeling_foo.py:45 — FooPreTrainedModel(PreTrainedModel) +A2 PASS modeling_foo.py:30 — @dataclass FooCausalLMOutput(ModelOutput) +A3 FAIL modeling_foo.py:120 — forward(self, input_ids, attention_mask, ...) — missing position_ids +A4 PASS modeling_foo.py:135 — returns FooCausalLMOutput(logits=logits) + +=== B. Self-Containment === +B1 PASS No `from .modeling_` imports found +B2 PASS modeling_foo.py:15 — FooConfig defined in file +B3 PASS modeling_foo.py:80 — AutoConfig.register("foo", FooConfig, exist_ok=True) + +=== C. Ops & Compatibility === +... + +=== Summary === +PASSED: 22/26 +FAILED: 4/26 + +Failed items requiring fixes: +1. A3 — Forward signature missing position_ids parameter (modeling_foo.py:120) +2. G2 — No layer equivalence test found +3. G4 — Export test missing Dim.DYNAMIC +4. H1 — No MoE weight converter despite model having MoE layers +``` + +## Rules + +1. Be strict. If something is ambiguous or borderline, mark it FAIL and explain why. +2. A PASS result means EVERY SINGLE item passed. Even one FAIL means overall FAIL. +3. Always cite file:line_number. No exceptions. +4. Read the actual files. Never infer or assume based on the caller's description. +5. If a check is not applicable (e.g., D4 for a non-MoE model), mark it N/A with justification. diff --git a/.claude/skills/ad-model-onboard/SKILL.md b/.claude/skills/ad-model-onboard/SKILL.md new file mode 100644 index 00000000000..90ed2f74045 --- /dev/null +++ b/.claude/skills/ad-model-onboard/SKILL.md @@ -0,0 +1,104 @@ +--- +name: ad-model-onboard +description: Translates a HuggingFace model into a prefill-only AutoDeploy custom model using reference custom ops, validates with hierarchical equivalence tests. +--- + +# AutoDeploy Model Onboarding + +**Input:** HuggingFace model ID. **Output:** prefill-only custom model file + hierarchical tests + summary report. + +## Phase 0 — Gather All Resources Upfront +Web/GitHub fetches require user approval and the user may leave. Do ALL network access now and save locally before proceeding. + +**Step 1 — Check local transformers install first:** +```bash +python -c "import transformers; print(transformers.__file__)" +``` +Look for `models/{model_type}/modeling_*.py` under that path. If found, use it directly — no network needed. + +**Step 2 — If not found, download the HF repo (code only, skip weights):** +```bash +huggingface-cli download {org}/{model} --exclude "*.safetensors" "*.bin" "*.pt" "*.gguf" +``` +This downloads config, code, and tokenizer files into the standard HF cache (`$HF_HOME` or `~/.cache/huggingface/`) while skipping large weight files. Files cached here are automatically found by `transformers.AutoConfig.from_pretrained` and similar APIs — no extra path wiring needed. Once downloaded you can work fully offline — read `config.json` and `modeling_*.py` from the cache snapshot directory printed by the command. + +## Phase 1 — Analyze HF Model +Study the locally-available `config.json` and `modeling_*.py` (NOT from `tensorrt_llm/_torch/models/`). Identify attention type (MHA/GQA/MLA), MoE config, RoPE variant, normalization, activation, and any data-dependent ops that break `torch.export` (e.g. `torch.nonzero`, data-conditioned `if`). + +## Phase 2 — Write Prefill-Only Model +Create `tensorrt_llm/_torch/auto_deploy/models/custom/modeling_{name}.py`. Use `modeling_glm4_moe_lite.py` as a **structural template only** (class layout, dataclass outputs, forward signature). Strip: KV cache, training paths, dropout, flash attention variants. Keep: `PreTrainedModel` hierarchy, `ModelOutput` dataclass, minimal forward `(input_ids, position_ids, inputs_embeds=None, **kwargs)`. + +**Critical** +Make sure the custom modeling code matches the model hierarchy that is expected in the checkpoint safetensor json. + +**Critical rule: Do NOT import or reuse existing AD custom model code** (e.g. `from .modeling_deepseek import ...`). Every `modeling_{name}.py` must be self-contained. Use the HF source (`$CLONE_DIR/modeling_*.py`) as the source of truth for the model's logic and translate it fresh — even if a structurally similar AD model already exists. This prevents hidden coupling, makes each model auditable on its own, and ensures model-specific quirks are captured correctly. + +## Phase 3 — Use Reference Custom Ops Only +Replace HF ops with `torch_*` prefixed AD reference ops. **Never** use `triton_*`/`flashinfer_*`/`trtllm_*` — backend selection happens later in AD transforms. Browse `tensorrt_llm/_torch/auto_deploy/custom_ops/` for all available reference ops and their exact signatures. For vanilla components (RMSNorm, MLP), plain PyTorch is also fine — AD fusion passes replace them. + +## Phase 4 — Register +1. Bottom of model file: `AutoModelForCausalLMFactory.register_custom_model_cls("ConfigClassName", ForCausalLM)`. +2. Add import + `__all__` entry in `models/custom/__init__.py`. +3. If config not in installed transformers, bundle config class and `AutoConfig.register(model_type, ConfigCls, exist_ok=True)`. + +## Phase 5 — Model Input Contract +The custom model's forward signature must follow these rules: + +1. **Always `input_ids`** — The top-level model always receives `input_ids`. A submodule graph may internally receive `inputs_embeds` (e.g., after the embedding layer), but the exported entry point takes token IDs. +2. **Always `position_ids`** — Vanilla sequential `position_ids` are always provided. If the model uses a non-standard RoPE variant or custom position encoding, the model must compute it internally on top of these vanilla `position_ids`. +3. **Multi-modal inputs** — If the model supports vision/audio/etc., those additional inputs are passed during prefill alongside `input_ids`. +4. **No attention mask, no cache inputs, no HF-runtime features** — Do not accept `attention_mask`, `past_key_values`, `use_cache`, or similar HF-runtime arguments. AD manages masking and caching via its own transforms and runtime. + +## Phase 6 — Hierarchical Tests +Create `tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_{name}_modeling.py`. Use `test_glm4_moe_lite_modeling.py` as template. **No smoke tests.** Small config (hidden=64, layers=2-3, vocab=1000). Use `pytest.skip` if HF class unavailable. + +**HF Reference Strategy:** Equivalence tests compare our custom implementation against the HF reference with identical weights and inputs. +- **If HF modules exist in the installed `transformers`**: import them directly (e.g., `from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3ForCausalLM`). Wrap imports in `_get_hf_*_class()` try/except helpers that return `None` on `ImportError`, and use `pytest.skip` when `None`. +- **If HF modules are NOT in the installed `transformers`**: copy the minimal module definitions from the HF `modeling_*.py` source into the test file as standalone reference classes. This keeps tests self-contained without requiring a specific `transformers` version. +- **Weight conversion helpers**: Write test-only helpers for any weight format differences between HF and custom (e.g., RoPE de-interleaving, stacked-to-per-expert MoE weights, gate weight key remapping). For full-model tests, prefer using `load_state_dict` pre-hooks already registered on the custom model. + +**Numerical comparison:** For equivalence tests comparing custom ops against HF reference, use the shared `assert_rmse_close` utility from `_model_test_utils`: +```python +from _model_test_utils import assert_rmse_close +``` +This computes `rmse(actual - expected) / rmse(expected)` — more robust than per-element `torch.testing.assert_close` since a few outlier elements won't fail the test. Use `torch.testing.assert_close` only for blocks with identical math (e.g., plain MLP with no custom ops). + +Recommended `rmse_ratio_tol` values for bfloat16: +- **Identical math** (MLP, Norm): use `torch.testing.assert_close` with tight rtol/atol (1e-3) +- **MoE block** (fused routing): `0.02` +- **Decoder layer / MoE layer / full model**: `0.05` +- **Attention**: `0.10` + +**Bottom-up levels (each must pass before next):** +1. **Block equivalence** — Test MLP, Attention, MoE, Norm individually: same weights + same input → `assert_rmse_close` (or `torch.testing.assert_close` for identical-math blocks). +2. **Layer equivalence** — Full decoder layer. If model has heterogeneous layers (dense vs MoE, attention vs SSM), test each type separately. +3. **Full model equivalence** — End-to-end logits comparison. Use a small config with <10 layers that covers the essence of the architecture (e.g., at least one of each layer type). +4. **Export test** — `torch_export_to_gm` with `Dim.DYNAMIC` for batch+seq, verify finite output, test a second shape. + +## Phase 7 — Independent Review (MANDATORY) + +Invoke the `ad-onboard-reviewer` subagent with ONLY the following information: +- Model name +- Path to the model file created +- Path to the test file created + +**Do NOT include your own assessment of correctness. Do NOT summarize what you did.** Let the reviewer read the files and judge independently. + +If the reviewer returns **FAIL** on any item: +1. Read the reviewer's specific failure reasons and file:line references +2. Fix each failed item +3. Invoke the reviewer again with the same minimal inputs +4. Repeat until you get a full **PASS** + +Do NOT proceed to Phase 8 until the reviewer returns PASS. + +## Phase 8 — Summary Report +Print (not file) after completion: (1) model overview + unique features, (2) tricky parts needing human review, (3) files created/modified, (4) test results table (name | validates | PASS/FAIL), (5) known limitations, (6) reviewer result (PASS + how many review iterations it took). + +## Key Gotchas +- **Self-contained files only**: Never import from other AD custom models. Each `modeling_{name}.py` is a standalone translation from HF source. +- RoPE buffers: `_ad_` prefix, return full table (not sliced), slice by `position_ids` downstream. +- MoE weights: use `nn.ModuleList` per-expert for checkpoint compatibility. Write test-only state_dict converters for HF stacked format. +- `noaux_tc` routers (DeepSeek-V3 style): use vanilla PyTorch (sigmoid + bias + group topk + normalize + scale). AD transforms can replace with fused `trtllm` kernels at deployment time. +- Vision towers are typically **not** exported. Keep vision logic in eager PyTorch and export only the text path unless explicitly requested otherwise. +- Model code and tests must run on CPU. Use only torch reference ops in AutoDeploy (e.g., `torch_rmsnorm`, `torch_mla`, `torch_moe`) and avoid CUDA-only kernels in the modeling path. diff --git a/examples/auto_deploy/model_registry/configs/kimi_k2.yaml b/examples/auto_deploy/model_registry/configs/kimi_k2.yaml new file mode 100644 index 00000000000..2d9cda2199b --- /dev/null +++ b/examples/auto_deploy/model_registry/configs/kimi_k2.yaml @@ -0,0 +1,22 @@ +# Configuration for Kimi-K2.5 VLM (moonshotai/Kimi-K2.5) +# Uses minimum layers for validation: 1 dense + 2 MoE = 3 total +runtime: trtllm +compile_backend: torch-cudagraph +max_seq_len: 4096 +max_num_tokens: 4096 +max_batch_size: 64 +world_size: 8 +enable_chunked_prefill: true +cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 32, 64] +kv_cache_config: + dtype: bfloat16 + enable_block_reuse: false + free_gpu_memory_fraction: 0.7 + tokens_per_block: 64 +model_kwargs: + torch_dtype: bfloat16 +transforms: + export_to_gm: + num_moe_experts_for_export: 2 + fuse_nvfp4_moe: + allow_different_input_scales: true diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mla/flashinfer_mla.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mla/flashinfer_mla.py index 961135ff9c6..06ac62ff5d4 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mla/flashinfer_mla.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mla/flashinfer_mla.py @@ -30,6 +30,7 @@ from .....llmapi.llm_args import KvCacheConfig from ...utils.cuda_graph import cuda_graph_state +from ...utils.logger import ad_logger from ..attention_interface import ( AttentionDescriptor, AttentionLayout, @@ -526,16 +527,15 @@ def flashinfer_mla_with_cache( compressed_kv_flat = compressed_kv.contiguous().view(bs, kv_lora_rank) kpe_flat = kpe.contiguous().view(bs, qk_rope_head_dim) - # Convert cache dtype if needed - if ckv_cache.dtype == torch.float8_e4m3fn: - compressed_kv_flat = compressed_kv_flat.to(torch.float8_e4m3fn) - kpe_flat = kpe_flat.to(torch.float8_e4m3fn) + # Cast to cache dtype for writes (no-op when dtypes already match). + compressed_kv_for_cache = compressed_kv_flat.to(ckv_cache.dtype) + kpe_for_cache = kpe_flat.to(kpe_cache.dtype) # Append to paged cache using FlashInfer's append function # Note: caches are guaranteed contiguous by CachedSequenceInterface._create_kv_cache_manager flashinfer.page.append_paged_mla_kv_cache( - compressed_kv_flat, - kpe_flat, + compressed_kv_for_cache, + kpe_for_cache, flashinfer_batch_indices, flashinfer_positions, ckv_cache, @@ -932,7 +932,17 @@ def get_cache_initializers( if qk_rope_head_dim != 64: raise ValueError("qk_rope_head_dim must be 64 for flashinfer_mla") - cache_dtype = cls.resolve_cache_dtype(cache_config.dtype, compressed_kv_fake.dtype) + model_dtype = compressed_kv_fake.dtype + cache_dtype = cls.resolve_cache_dtype(cache_config.dtype, model_dtype) + + # FlashInfer MLA kernels currently require BF16 cache dtype. + if cache_dtype != torch.bfloat16: + ad_logger.warning( + "FlashInfer MLA requires BF16 KV cache; overriding %s to %s.", + cache_dtype, + torch.bfloat16, + ) + cache_dtype = torch.bfloat16 # FlashInfer MLA uses two separate paged caches with no num_heads dimension return { diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mla/torch_backend_mla.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mla/torch_backend_mla.py index 28cda4cb0ef..a66ad63df68 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mla/torch_backend_mla.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mla/torch_backend_mla.py @@ -52,6 +52,12 @@ def _update_mla_cache( - First kv_lora_rank dims: compressed KV latent (before kv_b_proj) - Last qk_rope_head_dim dims: key positional encoding """ + cache_dtype = mla_cache.dtype + if compressed_kv.dtype != cache_dtype: + compressed_kv = compressed_kv.to(cache_dtype) + if kpe.dtype != cache_dtype: + kpe = kpe.to(cache_dtype) + for idx in range(seq_len.shape[0]): start = seq_start[idx].item() length = seq_len[idx].item() @@ -102,6 +108,13 @@ def _torch_mla_generate_with_absorption( compressed_kv_flat = compressed_kv.squeeze(1) # [B, kv_lora_rank] kpe_flat = kpe.squeeze(1).squeeze(1) # [B, qk_rope_head_dim] + # Cast to cache dtype if needed (e.g. BF16 -> FP8) + cache_dtype = mla_cache.dtype + if compressed_kv_flat.dtype != cache_dtype: + compressed_kv_flat = compressed_kv_flat.to(cache_dtype) + if kpe_flat.dtype != cache_dtype: + kpe_flat = kpe_flat.to(cache_dtype) + for i in range(b): cache_idx = slot_idx[i].item() pos = input_pos[i].item() @@ -122,6 +135,13 @@ def _torch_mla_generate_with_absorption( compressed_kv_cached = cached_data[:, :kv_lora_rank] # [seq_len, kv_lora_rank] kpe_cached = cached_data[:, kv_lora_rank:] # [seq_len, qk_rope_head_dim] + # Cast from cache dtype (e.g. FP8) to compute dtype + compute_dtype = q_nope.dtype + if compressed_kv_cached.dtype != compute_dtype: + compressed_kv_cached = compressed_kv_cached.to(compute_dtype) + if kpe_cached.dtype != compute_dtype: + kpe_cached = kpe_cached.to(compute_dtype) + # ===================================================================== # Weight absorption for Q_nope part # ===================================================================== @@ -230,6 +250,13 @@ def _torch_mla_context_with_expansion( compressed_kv_cached = cached_data[:, :kv_lora_rank] # [kv_seq_len, kv_lora_rank] kpe_cached = cached_data[:, kv_lora_rank:] # [kv_seq_len, qk_rope_head_dim] + # Cast from cache dtype (e.g. FP8) to compute dtype + compute_dtype = q_nope.dtype + if compressed_kv_cached.dtype != compute_dtype: + compressed_kv_cached = compressed_kv_cached.to(compute_dtype) + if kpe_cached.dtype != compute_dtype: + kpe_cached = kpe_cached.to(compute_dtype) + # ===================================================================== # Expand compressed_kv using kv_b_proj_weight for this sequence # ===================================================================== @@ -496,12 +523,17 @@ def get_cache_initializers( kv_lora_rank = compressed_kv_fake.shape[-1] qk_rope_head_dim = kpe_fake.shape[-1] - # FlashInfer MLA cache: [max_batch, max_seq, kv_lora_rank + qk_rope_head_dim] + model_dtype = compressed_kv_fake.dtype + cache_dtype = cls.resolve_cache_dtype(cache_config.dtype, model_dtype) + + # Torch MLA supports configured/model cache dtypes; no dtype override needed. + + # MLA cache: [max_batch, max_seq, kv_lora_rank + qk_rope_head_dim] # No num_heads dimension - this is the key MLA optimization return { "mla_cache": UnpagedResourceHandler( kv_lora_rank + qk_rope_head_dim, - dtype=cls.resolve_cache_dtype(cache_config.dtype, compressed_kv_fake.dtype), + dtype=cache_dtype, ), } diff --git a/tensorrt_llm/_torch/auto_deploy/models/custom/__init__.py b/tensorrt_llm/_torch/auto_deploy/models/custom/__init__.py index 6a383ad8768..5feae9dbc17 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/custom/__init__.py +++ b/tensorrt_llm/_torch/auto_deploy/models/custom/__init__.py @@ -1,5 +1,6 @@ from .modeling_deepseek import DeepSeekV3ForCausalLM from .modeling_glm4_moe_lite import Glm4MoeLiteForCausalLM +from .modeling_kimi_k2 import KimiK2ForCausalLM, KimiK25ForConditionalGeneration from .modeling_nemotron_flash import NemotronFlashForCausalLM, NemotronFlashPreTrainedTokenizerFast from .modeling_nemotron_h import NemotronHForCausalLM from .modeling_qwen3_5_moe import Qwen3_5MoeForCausalLM, Qwen3_5MoeForConditionalGeneration @@ -7,6 +8,8 @@ __all__ = ( "DeepSeekV3ForCausalLM", "Glm4MoeLiteForCausalLM", + "KimiK2ForCausalLM", + "KimiK25ForConditionalGeneration", "NemotronFlashForCausalLM", "NemotronFlashPreTrainedTokenizerFast", "NemotronHForCausalLM", diff --git a/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_kimi_k2.py b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_kimi_k2.py new file mode 100644 index 00000000000..3f1907f0c48 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_kimi_k2.py @@ -0,0 +1,989 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +"""Slimmed down PyTorch Kimi-K2.5 model implementation for auto_deploy export. + +Source: +https://huggingface.co/moonshotai/Kimi-K2.5 + +This implementation differs from the original HuggingFace version in the following ways: +* Bundled config classes (KimiK2Config, KimiK25Config) for transformers compatibility +* Simplified for prefill-only inference (no KV caching) +* Uses auto_deploy custom ops for export compatibility (torch_mla, torch_moe, + torch_rope_with_qk_interleaving) +* Vanilla PyTorch MoE routing (sigmoid + bias + group top-k); AD transforms + can replace with fused kernels at deployment time +* Removed flash attention variants +* Removed gradient checkpointing and training code paths +* Removed attention dropout (inference only) +* Vision tower kept in eager mode; only text path is exported + +The Kimi-K2.5 text model is a DeepSeek-V3-style architecture with: +* Multi-head Latent Attention (MLA) +* Mixture of Experts (MoE) with noaux_tc routing (sigmoid + bias + top-k) +* YaRN rotary position embeddings +* SwiGLU activation in FFN layers +""" + +import math +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn +from transformers import AutoConfig +from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig +from transformers.generation import GenerationMixin +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ModelOutput + +from tensorrt_llm._torch.auto_deploy.models.hf import AutoModelForCausalLMFactory +from tensorrt_llm._torch.utils import ActivationType + +# ============================================================================= +# Configuration +# ============================================================================= + + +class KimiK2Config(PretrainedConfig): + """Configuration class for Kimi-K2 text model (DeepSeek-V3 variant). + + This config class is bundled with the custom model implementation to enable + loading on transformers versions that don't natively have Kimi-K2 registered. + """ + + model_type = "kimi_k2" + + def __init__( + self, + vocab_size: int = 163840, + hidden_size: int = 7168, + intermediate_size: int = 18432, + moe_intermediate_size: int = 2048, + num_hidden_layers: int = 61, + num_attention_heads: int = 64, + num_key_value_heads: int = 64, + hidden_act: str = "silu", + max_position_embeddings: int = 262144, + rms_norm_eps: float = 1e-5, + # MLA parameters + q_lora_rank: int = 1536, + kv_lora_rank: int = 512, + qk_nope_head_dim: int = 128, + qk_rope_head_dim: int = 64, + v_head_dim: int = 128, + # MoE parameters + n_routed_experts: int = 384, + n_shared_experts: int = 1, + num_experts_per_tok: int = 8, + moe_layer_freq: int = 1, + first_k_dense_replace: int = 1, + n_group: int = 1, + topk_group: int = 1, + routed_scaling_factor: float = 2.827, + norm_topk_prob: bool = True, + scoring_func: str = "sigmoid", + topk_method: str = "noaux_tc", + # RoPE parameters + rope_theta: float = 50000.0, + rope_scaling: Optional[dict] = None, + # Other parameters + attention_bias: bool = False, + attention_dropout: float = 0.0, + tie_word_embeddings: bool = False, + pad_token_id: Optional[int] = None, + bos_token_id: int = 163584, + eos_token_id: int = 163585, + initializer_range: float = 0.02, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.moe_intermediate_size = moe_intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.max_position_embeddings = max_position_embeddings + self.rms_norm_eps = rms_norm_eps + # MLA + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + # MoE + self.n_routed_experts = n_routed_experts + self.n_shared_experts = n_shared_experts + self.num_experts_per_tok = num_experts_per_tok + self.moe_layer_freq = moe_layer_freq + self.first_k_dense_replace = first_k_dense_replace + self.n_group = n_group + self.topk_group = topk_group + self.routed_scaling_factor = routed_scaling_factor + self.norm_topk_prob = norm_topk_prob + self.scoring_func = scoring_func + self.topk_method = topk_method + # RoPE + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + # Other + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.initializer_range = initializer_range + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +class KimiK25Config(PretrainedConfig): + """Configuration class for Kimi-K2.5 (vision-language model wrapper). + + The text model config is stored as text_config and uses DeepSeek-V3 architecture. + Vision config is stored but not used for AD export (vision stays in eager mode). + """ + + model_type = "kimi_k25" + + def __init__( + self, + text_config=None, + vision_config=None, + ignore_index: int = -100, + media_placeholder_token_id: int = 163605, + pad_token_id: int = 163839, + use_unified_vision_chunk: bool = True, + video_placeholder: str = "<|kimi_k25_video_placeholder|>", + **kwargs, + ): + if isinstance(text_config, dict): + text_config = KimiK2Config(**text_config) + if text_config is None: + text_config = KimiK2Config() + self.text_config = text_config + self.vision_config = vision_config + self.ignore_index = ignore_index + self.media_placeholder_token_id = media_placeholder_token_id + self.use_unified_vision_chunk = use_unified_vision_chunk + self.video_placeholder = video_placeholder + + super().__init__(pad_token_id=pad_token_id, **kwargs) + + +# Register configs with transformers' AutoConfig +try: + AutoConfig.register("kimi_k2", KimiK2Config, exist_ok=True) +except TypeError: + try: + AutoConfig.register("kimi_k2", KimiK2Config) + except ValueError: + pass + +try: + AutoConfig.register("kimi_k25", KimiK25Config, exist_ok=True) +except TypeError: + try: + AutoConfig.register("kimi_k25", KimiK25Config) + except ValueError: + pass + + +# ============================================================================= +# Model Components +# ============================================================================= + + +class KimiK2RMSNorm(nn.Module): + """RMS Normalization for Kimi-K2.""" + + def __init__(self, hidden_size: int, eps: float = 1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +class KimiK2RotaryEmbedding(nn.Module): + """Rotary Position Embedding for Kimi-K2. + + Returns full cached cos/sin (not sliced by seq_len) to enable export. + Uses _ad_ prefix for buffer names for AutoDeploy lift_to_meta compatibility. + """ + + def __init__( + self, + dim: int, + max_position_embeddings: int = 2048, + base: float = 10000.0, + attention_scaling: float = 1.0, + ): + super().__init__() + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.attention_scaling = attention_scaling + + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + self._set_cos_sin_cache(max_position_embeddings) + + def _set_cos_sin_cache(self, seq_len: int): + self.max_seq_len_cached = seq_len + t = torch.arange(seq_len, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("_ad_cos_cached", emb.cos() * self.attention_scaling, persistent=False) + self.register_buffer("_ad_sin_cached", emb.sin() * self.attention_scaling, persistent=False) + + def forward( + self, x: torch.Tensor, seq_len: Optional[int] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + return ( + self._ad_cos_cached.to(dtype=x.dtype, device=x.device), + self._ad_sin_cached.to(dtype=x.dtype, device=x.device), + ) + + +class KimiK2YarnRotaryEmbedding(KimiK2RotaryEmbedding): + """YaRN-extended rotary embedding for Kimi-K2.""" + + def __init__( + self, + dim: int, + max_position_embeddings: int = 2048, + base: float = 10000.0, + scaling_factor: float = 1.0, + original_max_position_embeddings: int = 4096, + beta_fast: int = 32, + beta_slow: int = 1, + mscale: float = 1.0, + mscale_all_dim: float = 0.0, + attention_scaling: float = 1.0, + ): + self.scaling_factor = scaling_factor + self.original_max_position_embeddings = original_max_position_embeddings + self.beta_fast = beta_fast + self.beta_slow = beta_slow + self.mscale = mscale + self.mscale_all_dim = mscale_all_dim + super().__init__(dim, max_position_embeddings, base, attention_scaling) + + def _set_cos_sin_cache(self, seq_len: int): + self.max_seq_len_cached = seq_len + dim = self.dim + + freq_extra = 1.0 / (self.base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + freq_inter = 1.0 / ( + self.scaling_factor * self.base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim) + ) + + low, high = self._yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + dim, + self.base, + self.original_max_position_embeddings, + ) + inv_freq_mask = 1.0 - self._yarn_linear_ramp_mask(low, high, dim // 2) + inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(seq_len, dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + + _mscale = float( + self._yarn_get_mscale(self.scaling_factor, self.mscale) + / self._yarn_get_mscale(self.scaling_factor, self.mscale_all_dim) + ) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("_ad_cos_cached", (emb.cos() * _mscale), persistent=False) + self.register_buffer("_ad_sin_cached", (emb.sin() * _mscale), persistent=False) + + @staticmethod + def _yarn_find_correction_dim( + num_rotations: float, + dim: int, + base: float = 10000, + max_position_embeddings: int = 2048, + ) -> float: + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(base) + ) + + def _yarn_find_correction_range( + self, + low_rot: int, + high_rot: int, + dim: int, + base: float, + max_position_embeddings: int, + ) -> Tuple[int, int]: + low = math.floor( + self._yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) + ) + high = math.ceil( + self._yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) + ) + return max(low, 0), min(high, dim - 1) + + @staticmethod + def _yarn_get_mscale(scale: float = 1.0, mscale: float = 1.0) -> float: + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + @staticmethod + def _yarn_linear_ramp_mask(min_val: float, max_val: float, dim: int) -> torch.Tensor: + if min_val == max_val: + max_val += 0.001 + linear_func = (torch.arange(dim, dtype=torch.float32) - min_val) / (max_val - min_val) + return torch.clamp(linear_func, 0, 1) + + +# ============================================================================= +# MLP / MoE +# ============================================================================= + + +class KimiK2MLP(nn.Module): + """MLP layer (SwiGLU activation).""" + + def __init__( + self, + config, + hidden_size: Optional[int] = None, + intermediate_size: Optional[int] = None, + ): + super().__init__() + self.config = config + self.hidden_size = hidden_size or config.hidden_size + self.intermediate_size = intermediate_size or config.intermediate_size + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +class KimiK2MoEGate(nn.Module): + """MoE Gating with noaux_tc routing (sigmoid + bias + group top-k). + + Vanilla PyTorch implementation of DeepSeek-V3-style routing: + sigmoid scoring → bias-adjusted group selection → top-k → normalize → scale. + AutoDeploy transforms can replace this with fused kernels at deployment time. + """ + + def __init__(self, config): + super().__init__() + self.config = config + self.top_k = config.num_experts_per_tok + self.n_routed_experts = config.n_routed_experts + self.routed_scaling_factor = config.routed_scaling_factor + self.n_group = config.n_group + self.topk_group = config.topk_group + self.norm_topk_prob = getattr(config, "norm_topk_prob", True) + + self.weight = nn.Parameter( + torch.empty((self.n_routed_experts, config.hidden_size), dtype=torch.float32) + ) + self.register_buffer( + "e_score_correction_bias", + torch.zeros(self.n_routed_experts, dtype=torch.float32), + ) + self.reset_parameters() + + def reset_parameters(self) -> None: + torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + + @torch.no_grad() + def _get_topk_indices(self, scores: torch.Tensor) -> torch.Tensor: + """Select top-k expert indices using group-based routing with bias.""" + scores_for_choice = scores.view( + -1, self.n_routed_experts + ) + self.e_score_correction_bias.to(device=scores.device).unsqueeze(0) + group_scores = ( + scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group) + .topk(2, dim=-1)[0] + .sum(dim=-1) + ) + group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1) + score_mask = ( + group_mask.unsqueeze(-1) + .expand(-1, self.n_group, self.n_routed_experts // self.n_group) + .reshape(-1, self.n_routed_experts) + ) + scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) + topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1] + return topk_indices + + def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + bsz, seq_len, hidden_dim = hidden_states.shape + hidden_states_flat = hidden_states.view(-1, hidden_dim) + + # Router GEMM in float32 + router_logits = F.linear( + hidden_states_flat.to(torch.float32), self.weight.to(torch.float32) + ) + + # Sigmoid scoring + scores = router_logits.sigmoid() + + # Group-based top-k selection (uses bias-adjusted scores for selection) + topk_indices = self._get_topk_indices(scores) + + # Gather original scores (not bias-adjusted) for the selected experts + topk_weights = scores.gather(1, topk_indices) + + # Normalize + if self.norm_topk_prob: + denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 + topk_weights = topk_weights / denominator + + # Scale + topk_weights = topk_weights * self.routed_scaling_factor + + return topk_indices, topk_weights + + +class KimiK2MoE(nn.Module): + """Mixture of Experts layer for Kimi-K2.""" + + def __init__(self, config): + super().__init__() + self.config = config + self.num_experts_per_tok = config.num_experts_per_tok + + self.experts = nn.ModuleList( + [ + KimiK2MLP(config, intermediate_size=config.moe_intermediate_size) + for _ in range(config.n_routed_experts) + ] + ) + + self.gate = KimiK2MoEGate(config) + + if config.n_shared_experts is not None and config.n_shared_experts > 0: + intermediate_size = config.moe_intermediate_size * config.n_shared_experts + self.shared_experts = KimiK2MLP(config, intermediate_size=intermediate_size) + else: + self.shared_experts = None + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + identity = hidden_states + orig_shape = hidden_states.shape + + selected_experts, routing_weights = self.gate(hidden_states) + + final_hidden_states = torch.ops.auto_deploy.torch_moe( + hidden_states.view(-1, hidden_states.shape[-1]), + selected_experts, + routing_weights, + w1_weight=[expert.gate_proj.weight for expert in self.experts], + w2_weight=[expert.down_proj.weight for expert in self.experts], + w3_weight=[expert.up_proj.weight for expert in self.experts], + is_gated_mlp=True, + act_fn=int(ActivationType.Silu), + ) + + final_hidden_states = final_hidden_states.view(*orig_shape) + + if self.shared_experts is not None: + final_hidden_states = final_hidden_states + self.shared_experts(identity) + + return final_hidden_states + + +# ============================================================================= +# Attention +# ============================================================================= + + +class KimiK2Attention(nn.Module): + """Multi-head Latent Attention (MLA) for Kimi-K2. + + Uses compressed KV representation with latent projections, identical + to the DeepSeek-V3 MLA mechanism. + """ + + def __init__(self, config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.q_lora_rank = config.q_lora_rank + self.kv_lora_rank = config.kv_lora_rank + self.qk_nope_head_dim = config.qk_nope_head_dim + self.qk_rope_head_dim = config.qk_rope_head_dim + self.v_head_dim = config.v_head_dim + self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim + + # Q projection (with optional LoRA) + if self.q_lora_rank is None: + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.qk_head_dim, bias=False) + else: + self.q_a_proj = nn.Linear( + self.hidden_size, config.q_lora_rank, bias=config.attention_bias + ) + self.q_a_layernorm = KimiK2RMSNorm(config.q_lora_rank) + self.q_b_proj = nn.Linear( + config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False + ) + + # KV projection with MQA + self.kv_a_proj_with_mqa = nn.Linear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=config.attention_bias, + ) + self.kv_a_layernorm = KimiK2RMSNorm(self.kv_lora_rank) + self.kv_b_proj = nn.Linear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + ) + + # Output projection + self.o_proj = nn.Linear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=config.attention_bias, + ) + + # Softmax scale with mscale adjustment + self.softmax_scale = self.qk_head_dim ** (-0.5) + if ( + config.rope_scaling is not None + and isinstance(config.rope_scaling, dict) + and "factor" in config.rope_scaling + ): + mscale_all_dim = config.rope_scaling.get("mscale_all_dim", 0) + scaling_factor = config.rope_scaling["factor"] + if mscale_all_dim: + mscale = KimiK2YarnRotaryEmbedding._yarn_get_mscale(scaling_factor, mscale_all_dim) + self.softmax_scale = self.softmax_scale * mscale * mscale + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + ) -> torch.Tensor: + bsz, q_len, _ = hidden_states.size() + + # Q projection + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + + # Shape: [B, S, N, qk_head_dim] (BSND layout) + q = q.view(bsz, q_len, self.num_heads, self.qk_head_dim) + q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + # KV projection + kv_a_output = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, k_pe = torch.split( + kv_a_output, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + + compressed_kv = self.kv_a_layernorm(compressed_kv) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim) + + # Get cos/sin from position_embeddings + cos = position_embeddings[0] # Full table: [max_seq_len, head_dim] + sin = position_embeddings[1] # Full table: [max_seq_len, head_dim] + cos = cos[position_ids] # [B, S, head_dim] + sin = sin[position_ids] # [B, S, head_dim] + + # Apply RoPE on native interleaved q/k channels. + q_pe_rotated, kpe = torch.ops.auto_deploy.torch_rope_with_qk_interleaving( + q_pe, + k_pe, + cos, + sin, + 2, # unsqueeze_dim=2 for BSND layout + ) + + # MLA with compressed KV + attn_output = torch.ops.auto_deploy.torch_mla( + q_nope, # [B, S, N, qk_nope_head_dim] + q_pe_rotated, # [B, S, N, qk_rope_head_dim] + compressed_kv, # [B, S, kv_lora_rank] + kpe, # [B, S, 1, qk_rope_head_dim] + self.kv_b_proj.weight, # [N*(qk_nope+v), kv_lora_rank] + True, # is_causal + self.softmax_scale, + "bsnd", # layout + ) + + # Output: [B, S, N, v_head_dim] -> [B, S, N * v_head_dim] + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) + attn_output = self.o_proj(attn_output) + + return attn_output + + +# ============================================================================= +# Decoder Layer +# ============================================================================= + + +class KimiK2DecoderLayer(nn.Module): + """Transformer decoder layer for Kimi-K2.""" + + def __init__(self, config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + + self.self_attn = KimiK2Attention(config, layer_idx=layer_idx) + + # Layer 0 to first_k_dense_replace-1 use dense MLP, rest use MoE + use_moe = ( + config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0 + ) + if use_moe: + self.mlp = KimiK2MoE(config) + else: + self.mlp = KimiK2MLP(config) + + self.input_layernorm = KimiK2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = KimiK2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + ) -> torch.Tensor: + # Self attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn(hidden_states, position_ids, position_embeddings) + hidden_states = residual + hidden_states + + # MLP/MoE + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +# ============================================================================= +# Model Outputs +# ============================================================================= + + +@dataclass +class KimiK2ModelOutput(ModelOutput): + last_hidden_state: Optional[torch.FloatTensor] = None + + +@dataclass +class KimiK2CausalLMOutput(ModelOutput): + logits: Optional[torch.FloatTensor] = None + + +@dataclass +class KimiK25ConditionalOutput(ModelOutput): + logits: Optional[torch.FloatTensor] = None + + +# ============================================================================= +# Full Models +# ============================================================================= + + +class KimiK2PreTrainedModel(PreTrainedModel): + """Base class for Kimi-K2 models.""" + + config_class = KimiK2Config + base_model_prefix = "model" + _no_split_modules = ["KimiK2DecoderLayer"] + supports_gradient_checkpointing = False + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class KimiK2Model(KimiK2PreTrainedModel): + """Kimi-K2 transformer decoder model.""" + + def __init__(self, config): + super().__init__(config) + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [KimiK2DecoderLayer(config, layer_idx=idx) for idx in range(config.num_hidden_layers)] + ) + self.norm = KimiK2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + # Shared rotary embedding at model level + self.rotary_emb = self._init_rope(config) + + self.post_init() + + def _init_rope(self, config): + qk_rope_head_dim = config.qk_rope_head_dim + + attention_scaling = 1.0 + if ( + config.rope_scaling is not None + and isinstance(config.rope_scaling, dict) + and "factor" in config.rope_scaling + ): + mscale_all_dim = config.rope_scaling.get("mscale_all_dim", 0) + scaling_factor = config.rope_scaling["factor"] + if mscale_all_dim: + mscale = KimiK2YarnRotaryEmbedding._yarn_get_mscale(scaling_factor, mscale_all_dim) + attention_scaling = mscale + + use_yarn = ( + config.rope_scaling is not None + and isinstance(config.rope_scaling, dict) + and "factor" in config.rope_scaling + ) + + if not use_yarn: + return KimiK2RotaryEmbedding( + qk_rope_head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + attention_scaling=attention_scaling, + ) + else: + scaling_factor = config.rope_scaling["factor"] + kwargs = { + key: config.rope_scaling[key] + for key in [ + "original_max_position_embeddings", + "beta_fast", + "beta_slow", + "mscale", + "mscale_all_dim", + ] + if key in config.rope_scaling + } + return KimiK2YarnRotaryEmbedding( + qk_rope_head_dim, + max_position_embeddings=config.max_position_embeddings, + scaling_factor=scaling_factor, + base=config.rope_theta, + attention_scaling=attention_scaling, + **kwargs, + ) + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + **kwargs, + ) -> KimiK2ModelOutput: + if input_ids is not None and inputs_embeds is not None: + raise ValueError("Cannot specify both input_ids and inputs_embeds") + elif input_ids is None and inputs_embeds is None: + raise ValueError("Must specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_length = inputs_embeds.shape[:2] + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange(seq_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) + + # Compute position embeddings once from shared rotary embedding + position_embeddings = self.rotary_emb(inputs_embeds) + + hidden_states = inputs_embeds + + for decoder_layer in self.layers: + hidden_states = decoder_layer(hidden_states, position_ids, position_embeddings) + + hidden_states = self.norm(hidden_states) + + return KimiK2ModelOutput(last_hidden_state=hidden_states) + + +class KimiK2ForCausalLM(KimiK2PreTrainedModel, GenerationMixin): + """Kimi-K2 model with language modeling head. + + Weight layout matches DeepseekV3ForCausalLM from the HF checkpoint: + model.embed_tokens, model.layers.*, model.norm, lm_head + """ + + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config, **kwargs): + super().__init__(config) + self.model = KimiK2Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_decoder(self): + return self.model + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + **kwargs, + ) -> KimiK2CausalLMOutput: + outputs = self.model( + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + logits = self.lm_head(hidden_states).float() + + return KimiK2CausalLMOutput(logits=logits) + + +# ============================================================================= +# VLM Wrapper (Conditional Generation) +# ============================================================================= + + +class KimiK25PreTrainedModel(PreTrainedModel): + """Base class for Kimi-K2.5 VLM models.""" + + config_class = KimiK25Config + base_model_prefix = "language_model" + _no_split_modules = ["KimiK2DecoderLayer"] + supports_gradient_checkpointing = False + + def _init_weights(self, module): + std = getattr(self.config.text_config, "initializer_range", 0.02) + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class KimiK25ForConditionalGeneration(KimiK25PreTrainedModel): + """Kimi-K2.5 conditional generation model (text-only path for AD export). + + Weight layout matches HF checkpoint: + language_model.model.embed_tokens, language_model.model.layers.*, + language_model.model.norm, language_model.lm_head + Vision tower weights are ignored during export. + """ + + def __init__(self, config: KimiK25Config, **kwargs): + super().__init__(config) + self.language_model = KimiK2ForCausalLM(config.text_config) + + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.language_model.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + self.language_model.set_output_embeddings(new_embeddings) + + def get_decoder(self): + return self.language_model.get_decoder() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + **kwargs, + ) -> KimiK25ConditionalOutput: + outputs = self.language_model( + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + **kwargs, + ) + + return KimiK25ConditionalOutput(logits=outputs.logits) + + +# ============================================================================= +# Registration +# ============================================================================= + +# Register text model for direct text-only usage +AutoModelForCausalLMFactory.register_custom_model_cls("KimiK2Config", KimiK2ForCausalLM) + +# Register VLM wrapper for full KimiK25 config (used by HF's auto_map) +AutoModelForCausalLMFactory.register_custom_model_cls( + "KimiK25Config", KimiK25ForConditionalGeneration +) diff --git a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py index c844b3aad70..ecebe94126d 100644 --- a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py +++ b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py @@ -769,3 +769,69 @@ def test_autodeploy_from_registry(self, model_name, config_overrides, tasks, task.evaluate(llm, sampling_params=sampling_params) except (AssertionError, RuntimeError, ValueError) as e: raise type(e)(f"[{task_cls.__name__}] {e}") from None + + +class TestKimiK2_5(LlmapiAccuracyTestHarness): + """Accuracy regression tests for Kimi-K2.5 via AutoDeploy. + + Runs the model via AutoDeploy and verifies benchmark performance on MMLU and GSM8K. + Configuration derived from examples/auto_deploy/model_registry/configs/kimi_k2.yaml. + """ + + MODEL_NAME = "nvidia/Kimi-K2.5-NVFP4" + MAX_SEQ_LEN = max(MMLU.MAX_INPUT_LEN + MMLU.MAX_OUTPUT_LEN, + GSM8K.MAX_INPUT_LEN + GSM8K.MAX_OUTPUT_LEN) + + def get_default_kwargs(self): + return { + "skip_tokenizer_init": False, + "trust_remote_code": True, + "enable_chunked_prefill": True, + "compile_backend": "torch-cudagraph", + "max_batch_size": 64, + "max_seq_len": self.MAX_SEQ_LEN, + "max_num_tokens": self.MAX_SEQ_LEN, + "cuda_graph_batch_sizes": [1, 2, 4, 8, 16, 32, 64], + "kv_cache_config": { + "dtype": "bfloat16", + "enable_block_reuse": False, + "free_gpu_memory_fraction": 0.7, + "tokens_per_block": 64, + }, + "model_kwargs": { + "torch_dtype": "bfloat16", + }, + "transforms": { + "export_to_gm": { + "num_moe_experts_for_export": 2, + }, + "fuse_nvfp4_moe": { + "allow_different_input_scales": True, + }, + }, + } + + def get_default_sampling_params(self): + eos_id = -1 + beam_width = 1 + return SamplingParams(end_id=eos_id, + pad_id=eos_id, + n=beam_width, + use_beam_search=beam_width > 1) + + @pytest.mark.skip_less_device_memory(180000) + @pytest.mark.parametrize("world_size", [8]) + def test_nvfp4(self, world_size): + if get_device_count() < world_size: + pytest.skip("Not enough devices for world size, skipping test") + kwargs = self.get_default_kwargs() + sampling_params = self.get_default_sampling_params() + with AutoDeployLLM(model=self.MODEL_NAME, + tokenizer=self.MODEL_NAME, + dtype="bfloat16", + world_size=world_size, + **kwargs) as llm: + task = MMLU(self.MODEL_NAME) + task.evaluate(llm, sampling_params=sampling_params) + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) diff --git a/tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py b/tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py index e1922ae751b..ebbd6ad4c63 100644 --- a/tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py +++ b/tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py @@ -252,6 +252,33 @@ def forward(self, x): return torch.bmm(x, dynamic_weights) +def assert_rmse_close( + actual: torch.Tensor, + expected: torch.Tensor, + rmse_ratio_tol: float, + msg: str = "", +) -> None: + """Assert that the RMSE between two tensors is small relative to the reference signal. + + Computes: rmse(actual - expected) / rmse(expected) + This is more robust than per-element rtol/atol checks since a few outlier + elements won't fail the test if the overall signal is faithfully reproduced. + + Recommended tolerances for bfloat16 custom-op equivalence tests: + - Attention (fused MLA + RoPE de-interleaving): 0.10 + - Decoder layer / MoE layer / full model: 0.05 + - MoE block (fused routing): 0.02 + """ + diff = actual.float() - expected.float() + rmse_diff = torch.sqrt(torch.mean(diff**2)) + rmse_ref = torch.sqrt(torch.mean(expected.float() ** 2)) + ratio = (rmse_diff / rmse_ref).item() + assert ratio < rmse_ratio_tol, ( + f"{msg}RMSE ratio {ratio:.6f} exceeds tolerance {rmse_ratio_tol}. " + f"(rmse_diff={rmse_diff.item():.6f}, rmse_ref={rmse_ref.item():.6f})" + ) + + FP8_MAX = torch.finfo(torch.float8_e4m3fn).max diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_kimi_k2_modeling.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_kimi_k2_modeling.py new file mode 100644 index 00000000000..b58197f46bb --- /dev/null +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_kimi_k2_modeling.py @@ -0,0 +1,764 @@ +"""Tests for Kimi-K2 / Kimi-K2.5 custom model implementation. + +This module tests the custom Kimi-K2 model implementation (DeepSeek-V3 variant) +which uses auto_deploy custom ops (torch_mla, torch_moe, etc.) for export +compatibility. + +Hierarchical test levels: +1. Block equivalence — MLP, MoE, Attention individually +2. Layer equivalence — Full decoder layer (dense and MoE) +3. Full model equivalence — End-to-end logits comparison +4. Export test — torch_export_to_gm with dynamic shapes +""" + +import pytest +import torch +from _model_test_utils import assert_rmse_close +from torch.export import Dim + +from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm +from tensorrt_llm._torch.auto_deploy.models.custom.modeling_kimi_k2 import ( + KimiK2Attention, + KimiK2Config, + KimiK2DecoderLayer, + KimiK2ForCausalLM, + KimiK2MLP, + KimiK2MoE, + KimiK2RotaryEmbedding, + KimiK25Config, + KimiK25ForConditionalGeneration, +) +from tensorrt_llm._torch.auto_deploy.utils._graph import move_to_device + +_BATCH_AND_SEQUENCE_TEST_CASES = ((2, 6), (1, 8)) + + +@pytest.fixture(scope="function", autouse=True) +def set_seed(): + torch.manual_seed(42) + + +# ============================================================================= +# Helpers +# ============================================================================= + + +def _create_small_text_config() -> KimiK2Config: + """Create a small Kimi-K2 text config for testing.""" + return KimiK2Config( + vocab_size=1000, + hidden_size=64, + intermediate_size=128, + moe_intermediate_size=32, + num_hidden_layers=3, # Layer 0 dense, layers 1-2 MoE + num_attention_heads=4, + num_key_value_heads=4, + hidden_act="silu", + max_position_embeddings=512, + rms_norm_eps=1e-5, + # MLA params (scaled down) + q_lora_rank=32, + kv_lora_rank=32, + qk_nope_head_dim=8, + qk_rope_head_dim=8, + v_head_dim=16, + # MoE params (scaled down) + n_routed_experts=4, + n_shared_experts=1, + num_experts_per_tok=2, + moe_layer_freq=1, + first_k_dense_replace=1, + n_group=1, + topk_group=1, + routed_scaling_factor=1.0, + norm_topk_prob=True, + scoring_func="sigmoid", + topk_method="noaux_tc", + # RoPE + rope_theta=10000.0, + rope_scaling=None, + # Other + attention_bias=False, + attention_dropout=0.0, + pad_token_id=0, + ) + + +def _create_small_vlm_config() -> KimiK25Config: + """Create a small KimiK25 config wrapping the text config.""" + text_config = _create_small_text_config() + return KimiK25Config( + text_config=text_config, + vision_config=None, + pad_token_id=0, + ) + + +def _create_moe_layer(config: KimiK2Config) -> KimiK2MoE: + """Create a MoE layer from config with reproducible gate weights.""" + moe = KimiK2MoE(config) + moe.gate.weight = torch.nn.Parameter(torch.randn_like(moe.gate.weight)) + return moe + + +# ============================================================================= +# Export Tests +# ============================================================================= + + +@torch.no_grad() +def test_kimi_k2_text_model_can_be_exported(): + """Test that KimiK2ForCausalLM can be exported with torch_export_to_gm. + + Verifies: + 1. The model exports successfully without graph breaks + 2. The exported graph module produces numerically equivalent output to eager + 3. Dynamic shapes work correctly with different input sizes + """ + device = "cpu" + dtype = torch.bfloat16 + config = _create_small_text_config() + + model = KimiK2ForCausalLM(config) + model.to(device=device, dtype=dtype) + model.eval() + + B, S = 2, 8 + input_ids = torch.randint(0, config.vocab_size, (B, S), device=device) + position_ids = torch.arange(S, device=device).unsqueeze(0).expand(B, -1) + + batch_size_dynamic = Dim.DYNAMIC + seq_len_dynamic = Dim.DYNAMIC + dynamic_shapes = ( + {0: batch_size_dynamic, 1: seq_len_dynamic}, + {0: batch_size_dynamic, 1: seq_len_dynamic}, + ) + + gm = torch_export_to_gm( + model, + args=tuple(), + kwargs={"input_ids": input_ids, "position_ids": position_ids}, + dynamic_shapes=dynamic_shapes, + ) + + move_to_device(gm, device) + + # Eager reference output + eager_out = model(input_ids=input_ids, position_ids=position_ids) + + # Exported graph output + out_gm = gm(input_ids=input_ids, position_ids=position_ids) + + assert "logits" in out_gm, "Output should contain 'logits' key" + logits = out_gm["logits"] + assert logits.shape == (B, S, config.vocab_size), ( + f"Expected shape {(B, S, config.vocab_size)}, got {logits.shape}" + ) + torch.testing.assert_close(logits.float(), eager_out.logits.float(), rtol=1e-3, atol=1e-3) + + # Test with different input shape to verify dynamic shapes work + B2, S2 = 1, 4 + input_ids2 = torch.randint(0, config.vocab_size, (B2, S2), device=device) + position_ids2 = torch.arange(S2, device=device).unsqueeze(0).expand(B2, -1) + + eager_out2 = model(input_ids=input_ids2, position_ids=position_ids2) + out_gm2 = gm(input_ids=input_ids2, position_ids=position_ids2) + + logits2 = out_gm2["logits"] + expected_shape = (B2, S2, config.vocab_size) + assert logits2.shape == expected_shape, ( + f"Dynamic shape test failed: expected {expected_shape}, got {logits2.shape}" + ) + torch.testing.assert_close(logits2.float(), eager_out2.logits.float(), rtol=1e-3, atol=1e-3) + + +# ============================================================================= +# Structural Tests +# ============================================================================= + + +def test_kimi_k2_config_registration(): + """Test that configs are properly instantiated with correct model_type.""" + text_config = _create_small_text_config() + assert text_config.model_type == "kimi_k2" + assert hasattr(text_config, "hidden_size") + assert hasattr(text_config, "n_routed_experts") + assert hasattr(text_config, "kv_lora_rank") + assert hasattr(text_config, "qk_rope_head_dim") + assert hasattr(text_config, "moe_layer_freq") + + vlm_config = _create_small_vlm_config() + assert vlm_config.model_type == "kimi_k25" + assert isinstance(vlm_config.text_config, KimiK2Config) + + +def test_kimi_k2_layer_types(): + """Test that layer 0 uses dense MLP and later layers use MoE.""" + config = _create_small_text_config() + model = KimiK2ForCausalLM(config) + + layer0_mlp = model.model.layers[0].mlp + assert type(layer0_mlp).__name__ == "KimiK2MLP", ( + f"Layer 0 should use KimiK2MLP, got {type(layer0_mlp).__name__}" + ) + + for i in range(1, config.num_hidden_layers): + layer_mlp = model.model.layers[i].mlp + assert type(layer_mlp).__name__ == "KimiK2MoE", ( + f"Layer {i} should use KimiK2MoE, got {type(layer_mlp).__name__}" + ) + + +def test_kimi_k2_expert_structure(): + """Test that experts have correct structure for checkpoint loading.""" + config = _create_small_text_config() + moe = KimiK2MoE(config) + + assert isinstance(moe.experts, torch.nn.ModuleList), "experts should be nn.ModuleList" + assert len(moe.experts) == config.n_routed_experts, ( + f"Expected {config.n_routed_experts} experts, got {len(moe.experts)}" + ) + + for i, expert in enumerate(moe.experts): + assert hasattr(expert, "gate_proj"), f"Expert {i} missing gate_proj" + assert hasattr(expert, "up_proj"), f"Expert {i} missing up_proj" + assert hasattr(expert, "down_proj"), f"Expert {i} missing down_proj" + + state_dict = moe.state_dict() + expected_keys = [ + "experts.0.gate_proj.weight", + "experts.0.up_proj.weight", + "experts.0.down_proj.weight", + ] + for key in expected_keys: + assert key in state_dict, ( + f"Expected key '{key}' in state_dict, got keys: {list(state_dict.keys())[:10]}..." + ) + + +def test_kimi_k25_weight_layout(): + """Test that VLM wrapper has correct weight prefix for checkpoint compatibility.""" + config = _create_small_vlm_config() + model = KimiK25ForConditionalGeneration(config) + + state_dict = model.state_dict() + # Check that weights are under language_model.* prefix + assert any(k.startswith("language_model.model.") for k in state_dict), ( + "Expected weights under 'language_model.model.*' prefix" + ) + assert any(k.startswith("language_model.lm_head.") for k in state_dict), ( + "Expected weights under 'language_model.lm_head.*' prefix" + ) + + +def test_kimi_k2_shared_experts(): + """Test that shared experts are present when n_shared_experts > 0.""" + config = _create_small_text_config() + moe = KimiK2MoE(config) + + assert moe.shared_experts is not None, "shared_experts should be present" + assert isinstance(moe.shared_experts, KimiK2MLP), "shared_experts should be KimiK2MLP" + + # Shared expert intermediate size = moe_intermediate_size * n_shared_experts + expected_intermediate = config.moe_intermediate_size * config.n_shared_experts + assert moe.shared_experts.intermediate_size == expected_intermediate, ( + f"Expected shared expert intermediate_size={expected_intermediate}, " + f"got {moe.shared_experts.intermediate_size}" + ) + + +# ============================================================================= +# HF Reference Import Helpers (for numerical equivalence tests) +# ============================================================================= + + +def _get_hf_config_class(): + """Get the HF DeepseekV3Config class. + + Returns None if transformers doesn't have deepseek_v3 (requires v4.57+). + """ + try: + from transformers.models.deepseek_v3.configuration_deepseek_v3 import DeepseekV3Config + + return DeepseekV3Config + except ImportError: + return None + + +def _get_hf_model_class(): + """Get the HF DeepseekV3ForCausalLM class.""" + try: + from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3ForCausalLM + + return DeepseekV3ForCausalLM + except ImportError: + return None + + +def _get_hf_moe_class(): + """Get the HF DeepseekV3MoE class.""" + try: + from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE + + return DeepseekV3MoE + except ImportError: + return None + + +def _get_hf_mlp_class(): + """Get the HF DeepseekV3MLP class.""" + try: + from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MLP + + return DeepseekV3MLP + except ImportError: + return None + + +def _get_hf_attention_class(): + """Get the HF DeepseekV3Attention class.""" + try: + from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3Attention + + return DeepseekV3Attention + except ImportError: + return None + + +def _get_hf_decoder_layer_class(): + """Get the HF DeepseekV3DecoderLayer class.""" + try: + from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3DecoderLayer + + return DeepseekV3DecoderLayer + except ImportError: + return None + + +# ============================================================================= +# HF Config and Weight Conversion Helpers +# ============================================================================= + + +def _create_hf_config(): + """Create HF DeepseekV3Config matching our small test config. + + Returns None if DeepseekV3Config is not available. + """ + HFConfig = _get_hf_config_class() + if HFConfig is None: + return None + + config = HFConfig( + vocab_size=1000, + hidden_size=64, + intermediate_size=128, + moe_intermediate_size=32, + num_hidden_layers=3, + num_attention_heads=4, + num_key_value_heads=4, + hidden_act="silu", + max_position_embeddings=512, + rms_norm_eps=1e-5, + # MLA params + q_lora_rank=32, + kv_lora_rank=32, + qk_nope_head_dim=8, + qk_rope_head_dim=8, + v_head_dim=16, + # MoE params + n_routed_experts=4, + n_shared_experts=1, + num_experts_per_tok=2, + n_group=1, + topk_group=1, + routed_scaling_factor=1.0, + norm_topk_prob=True, + first_k_dense_replace=1, + # RoPE + rope_theta=10000.0, + rope_scaling=None, + rope_interleave=True, # HF default: interleaved RoPE format + # Other + attention_bias=False, + attention_dropout=0.0, + pad_token_id=0, + ) + + # Use eager attention for deterministic comparison + config._attn_implementation = "eager" + + return config + + +def _deinterleave_attention_weights(state_dict, config, prefix=""): + """De-interleave RoPE weight columns for attention weights. + + Applies the same transformation as mla_rope_utils._rope_deinterleave_load_hook + but for a single layer's attention weights. The prefix parameter allows reuse + for both block-level (prefix="") and layer-level (prefix="self_attn.") tests. + """ + d = config.qk_rope_head_dim + perm = torch.cat([torch.arange(0, d, 2), torch.arange(1, d, 2)]) + qk_head_dim = config.qk_nope_head_dim + d + + # --- q_b_proj.weight --- + q_key = f"{prefix}q_b_proj.weight" + if q_key in state_dict: + w = state_dict[q_key] + w = w.view(config.num_attention_heads, qk_head_dim, -1) + w_nope = w[:, : config.qk_nope_head_dim, :] + w_rope = w[:, config.qk_nope_head_dim :, :] + w_rope = w_rope[:, perm, :] + w = torch.cat([w_nope, w_rope], dim=1) + state_dict[q_key] = w.view(-1, w.shape[-1]) + + # --- kv_a_proj_with_mqa.weight --- + kv_key = f"{prefix}kv_a_proj_with_mqa.weight" + if kv_key in state_dict: + w = state_dict[kv_key] + w_kv = w[: config.kv_lora_rank, :] + w_pe = w[config.kv_lora_rank :, :] + w_pe = w_pe[perm, :] + state_dict[kv_key] = torch.cat([w_kv, w_pe], dim=0) + + # --- kv_a_proj_with_mqa.bias (if present) --- + kv_bias_key = f"{prefix}kv_a_proj_with_mqa.bias" + if kv_bias_key in state_dict: + b = state_dict[kv_bias_key] + b_kv = b[: config.kv_lora_rank] + b_pe = b[config.kv_lora_rank :] + b_pe = b_pe[perm] + state_dict[kv_bias_key] = torch.cat([b_kv, b_pe]) + + return state_dict + + +def _create_causal_mask(B, S, device, dtype): + """Create a 4D causal attention mask for HF eager attention. + + Returns a [B, 1, S, S] mask with 0 for attended positions and -inf for masked. + """ + mask = torch.full((S, S), torch.finfo(dtype).min, device=device, dtype=dtype) + mask = torch.triu(mask, diagonal=1) + return mask.unsqueeze(0).unsqueeze(0).expand(B, 1, S, S) + + +# ============================================================================= +# Numerical Equivalence Tests +# These tests compare our custom Kimi-K2 implementation against the HF +# DeepseekV3 reference with identical weights and inputs. +# ============================================================================= + +# --- Level 1: Block Equivalence --- + + +@pytest.mark.parametrize("B,S", _BATCH_AND_SEQUENCE_TEST_CASES) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@torch.no_grad() +def test_kimi_k2_mlp_numerical_equivalence(B, S, dtype): + """Test MLP produces numerically equivalent output to HF DeepseekV3MLP.""" + HFMLP = _get_hf_mlp_class() + if HFMLP is None: + pytest.skip("transformers doesn't have DeepseekV3 (requires v4.57+)") + + device = "cpu" + config = _create_small_text_config() + hf_config = _create_hf_config() + + # Create HF MLP + hf_mlp = HFMLP(hf_config) + hf_mlp.to(device=device, dtype=dtype) + hf_mlp.eval() + + # Create custom MLP and load same weights (identical structure) + custom_mlp = KimiK2MLP(config) + custom_mlp.to(device=device, dtype=dtype) + custom_mlp.load_state_dict(hf_mlp.state_dict()) + custom_mlp.eval() + + # Create input + H = config.hidden_size + x = torch.randn(B, S, H, device=device, dtype=dtype) + + # Run both + hf_out = hf_mlp(x) + custom_out = custom_mlp(x) + + # Compare — identical math, tight tolerance + rtol, atol = 1e-3, 1e-3 + torch.testing.assert_close(custom_out, hf_out, rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("B,S", _BATCH_AND_SEQUENCE_TEST_CASES) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@torch.no_grad() +def test_kimi_k2_moe_numerical_equivalence(B, S, dtype): + """Test MoE produces numerically equivalent output to HF DeepseekV3MoE.""" + HFMoE = _get_hf_moe_class() + if HFMoE is None: + pytest.skip("transformers doesn't have DeepseekV3 (requires v4.57+)") + + device = "cpu" + config = _create_small_text_config() + hf_config = _create_hf_config() + + # Create HF MoE and initialize gate weights for reproducibility + hf_moe = HFMoE(hf_config) + hf_moe.gate.weight = torch.nn.Parameter(torch.randn_like(hf_moe.gate.weight)) + hf_moe.to(device=device, dtype=dtype) + hf_moe.eval() + + # Create custom MoE and load same weights + # State dict keys match: gate.weight, gate.e_score_correction_bias, + # experts.{i}.{gate,up,down}_proj.weight, shared_experts.* + custom_moe = KimiK2MoE(config) + custom_moe.to(device=device, dtype=dtype) + custom_moe.load_state_dict(hf_moe.state_dict()) + custom_moe.eval() + + # Create input + H = config.hidden_size + x = torch.randn(B, S, H, device=device, dtype=dtype) + + # Run both + hf_out = hf_moe(x) + custom_out = custom_moe(x) + + # Compare — torch_moe expert computation vs HF Python loop + assert_rmse_close(custom_out, hf_out, rmse_ratio_tol=0.02, msg="MoE: ") + + +@pytest.mark.parametrize("B,S", _BATCH_AND_SEQUENCE_TEST_CASES) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@torch.no_grad() +def test_kimi_k2_attention_numerical_equivalence(B, S, dtype): + """Test attention produces numerically equivalent output to HF DeepseekV3Attention. + + HF uses rope_interleave=True (interleaved RoPE format in weights). + Our model uses NeoX format. Weights are de-interleaved before loading. + """ + HFAttn = _get_hf_attention_class() + if HFAttn is None: + pytest.skip("transformers doesn't have DeepseekV3 (requires v4.57+)") + + device = "cpu" + config = _create_small_text_config() + hf_config = _create_hf_config() + + # Create HF attention + hf_attn = HFAttn(hf_config, layer_idx=0) + hf_attn.to(device=device, dtype=dtype) + hf_attn.eval() + + # Create custom attention with de-interleaved weights + custom_attn = KimiK2Attention(config, layer_idx=0) + custom_attn.to(device=device, dtype=dtype) + + # Copy weights, de-interleaving RoPE dimensions + hf_state_dict = dict(hf_attn.state_dict()) + _deinterleave_attention_weights(hf_state_dict, config) + custom_attn.load_state_dict(hf_state_dict) + custom_attn.eval() + + # Create inputs + H = config.hidden_size + x = torch.randn(B, S, H, device=device, dtype=dtype) + position_ids = torch.arange(S, device=device).unsqueeze(0).expand(B, -1) + + # Create position embeddings using our rotary embedding (full table) + rotary_emb = KimiK2RotaryEmbedding( + config.qk_rope_head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + ) + full_cos, full_sin = rotary_emb(x) # [max_seq_len, head_dim] + + # HF expects position-indexed cos/sin: [B, S, head_dim] + hf_cos = full_cos[position_ids] + hf_sin = full_sin[position_ids] + + # Causal mask for HF eager attention: [B, 1, S, S] + causal_mask = _create_causal_mask(B, S, device, dtype) + + # Run HF attention + hf_out, _ = hf_attn( + hidden_states=x, + position_embeddings=(hf_cos, hf_sin), + attention_mask=causal_mask, + ) + + # Run custom attention (takes full table + position_ids) + custom_out = custom_attn(x, position_ids, (full_cos, full_sin)) + + # Compare — RoPE format conversion + fused MLA vs eager attention. + # Higher tolerance due to RoPE de-interleaving + fused MLA vs eager path. + assert_rmse_close(custom_out, hf_out, rmse_ratio_tol=0.10, msg="Attention: ") + + +# --- Level 2: Layer Equivalence --- + + +@pytest.mark.parametrize("B,S", _BATCH_AND_SEQUENCE_TEST_CASES) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@torch.no_grad() +def test_kimi_k2_dense_layer_numerical_equivalence(B, S, dtype): + """Test dense decoder layer (layer 0) matches HF DeepseekV3DecoderLayer.""" + HFLayer = _get_hf_decoder_layer_class() + if HFLayer is None: + pytest.skip("transformers doesn't have DeepseekV3 (requires v4.57+)") + + device = "cpu" + config = _create_small_text_config() + hf_config = _create_hf_config() + + # Create HF layer (layer 0 = dense MLP) + hf_layer = HFLayer(hf_config, layer_idx=0) + hf_layer.to(device=device, dtype=dtype) + hf_layer.eval() + + # Create custom layer with de-interleaved attention weights + custom_layer = KimiK2DecoderLayer(config, layer_idx=0) + custom_layer.to(device=device, dtype=dtype) + + hf_state_dict = dict(hf_layer.state_dict()) + _deinterleave_attention_weights(hf_state_dict, config, prefix="self_attn.") + custom_layer.load_state_dict(hf_state_dict) + custom_layer.eval() + + # Create inputs + H = config.hidden_size + x = torch.randn(B, S, H, device=device, dtype=dtype) + position_ids = torch.arange(S, device=device).unsqueeze(0).expand(B, -1) + + # Position embeddings + rotary_emb = KimiK2RotaryEmbedding( + config.qk_rope_head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + ) + full_cos, full_sin = rotary_emb(x) + hf_cos = full_cos[position_ids] + hf_sin = full_sin[position_ids] + + causal_mask = _create_causal_mask(B, S, device, dtype) + + # Run HF layer + hf_out = hf_layer( + hidden_states=x, + attention_mask=causal_mask, + position_embeddings=(hf_cos, hf_sin), + ) + + # Run custom layer + custom_out = custom_layer(x, position_ids, (full_cos, full_sin)) + + # Compare + assert_rmse_close(custom_out, hf_out, rmse_ratio_tol=0.05, msg="Dense layer: ") + + +@pytest.mark.parametrize("B,S", _BATCH_AND_SEQUENCE_TEST_CASES) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@torch.no_grad() +def test_kimi_k2_moe_layer_numerical_equivalence(B, S, dtype): + """Test MoE decoder layer (layer 1) matches HF DeepseekV3DecoderLayer.""" + HFLayer = _get_hf_decoder_layer_class() + if HFLayer is None: + pytest.skip("transformers doesn't have DeepseekV3 (requires v4.57+)") + + device = "cpu" + config = _create_small_text_config() + hf_config = _create_hf_config() + + # Create HF layer (layer 1 = MoE) + hf_layer = HFLayer(hf_config, layer_idx=1) + # Initialize gate weights for reproducibility + hf_layer.mlp.gate.weight = torch.nn.Parameter(torch.randn_like(hf_layer.mlp.gate.weight)) + hf_layer.to(device=device, dtype=dtype) + hf_layer.eval() + + # Create custom layer with de-interleaved attention weights + custom_layer = KimiK2DecoderLayer(config, layer_idx=1) + custom_layer.to(device=device, dtype=dtype) + + hf_state_dict = dict(hf_layer.state_dict()) + _deinterleave_attention_weights(hf_state_dict, config, prefix="self_attn.") + custom_layer.load_state_dict(hf_state_dict) + custom_layer.eval() + + # Create inputs + H = config.hidden_size + x = torch.randn(B, S, H, device=device, dtype=dtype) + position_ids = torch.arange(S, device=device).unsqueeze(0).expand(B, -1) + + rotary_emb = KimiK2RotaryEmbedding( + config.qk_rope_head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + ) + full_cos, full_sin = rotary_emb(x) + hf_cos = full_cos[position_ids] + hf_sin = full_sin[position_ids] + + causal_mask = _create_causal_mask(B, S, device, dtype) + + # Run HF layer + hf_out = hf_layer( + hidden_states=x, + attention_mask=causal_mask, + position_embeddings=(hf_cos, hf_sin), + ) + + # Run custom layer + custom_out = custom_layer(x, position_ids, (full_cos, full_sin)) + + # Compare — includes MoE routing differences + assert_rmse_close(custom_out, hf_out, rmse_ratio_tol=0.05, msg="MoE layer: ") + + +# --- Level 3: Full Model Equivalence --- + + +@pytest.mark.parametrize("B,S", _BATCH_AND_SEQUENCE_TEST_CASES) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@torch.no_grad() +def test_kimi_k2_full_model_numerical_equivalence(B, S, dtype): + """Test full model produces numerically equivalent logits to HF DeepseekV3ForCausalLM. + + Weight conversion: no RoPE de-interleaving is needed for this model path. + The attention path uses native interleaved q/k channels. + """ + HFModel = _get_hf_model_class() + if HFModel is None: + pytest.skip("transformers doesn't have DeepseekV3 (requires v4.57+)") + + device = "cpu" + config = _create_small_text_config() + hf_config = _create_hf_config() + + # Create HF model + hf_model = HFModel(hf_config) + # Initialize all gate weights for reproducibility + for module in hf_model.modules(): + if hasattr(module, "gate") and hasattr(module.gate, "weight"): + module.gate.weight = torch.nn.Parameter(torch.randn_like(module.gate.weight)) + hf_model.to(device=device, dtype=dtype) + hf_model.eval() + + # Create custom model and load matching HF weights directly. + custom_model = KimiK2ForCausalLM(config) + custom_model.to(device=device, dtype=dtype) + hf_state_dict = hf_model.state_dict() + custom_model.load_state_dict(hf_state_dict) + custom_model.eval() + + # Create input + input_ids = torch.randint(0, config.vocab_size, (B, S), device=device) + position_ids = torch.arange(S, device=device).unsqueeze(0).expand(B, -1) + + # Run both + hf_out = hf_model(input_ids=input_ids, position_ids=position_ids) + custom_out = custom_model(input_ids=input_ids, position_ids=position_ids) + + # Compare logits — cast to float32 since HF may return bfloat16 + assert_rmse_close(custom_out.logits, hf_out.logits, rmse_ratio_tol=0.02, msg="Full model: ") From 5b0e8a9290bda03d7cbdc761df28ae9a91018c5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EB=85=B8=EB=9E=80=ED=86=A0=EB=81=BC?= <83907395+Bias92@users.noreply.github.com> Date: Thu, 5 Mar 2026 17:36:35 +0900 Subject: [PATCH 5/7] [None][fix] Prevent RuntimeError from dict mutation during iteration in EXAONE MoE weight mapper (#11862) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 노란토끼 <83907395+Bias92@users.noreply.github.com> Co-authored-by: Zhenhua Wang <4936589+zhenhuaw-me@users.noreply.github.com> Co-authored-by: Yechan Kim <161688079+yechank-nvidia@users.noreply.github.com> --- .../_torch/models/checkpoints/hf/exaone_moe_weight_mapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/models/checkpoints/hf/exaone_moe_weight_mapper.py b/tensorrt_llm/_torch/models/checkpoints/hf/exaone_moe_weight_mapper.py index 072d73362a2..62ac71468ab 100644 --- a/tensorrt_llm/_torch/models/checkpoints/hf/exaone_moe_weight_mapper.py +++ b/tensorrt_llm/_torch/models/checkpoints/hf/exaone_moe_weight_mapper.py @@ -28,7 +28,7 @@ def __init__(self): def preprocess_weights(self, weights: dict): mtp_layer_offset = self.config.pretrained_config.num_hidden_layers - for name in weights.keys(): + for name in list(weights.keys()): if name.startswith("mtp.layers."): # mtp.layers.{idx}.* -> model.layers.{offset + idx}.* _, _, mtp_layer_idx, module_name = name.split(".", 3) From 2f4ed7dc845f11eadd882da7b1269037c6646cb5 Mon Sep 17 00:00:00 2001 From: Zhenhua Wang <4936589+zhenhuaw-me@users.noreply.github.com> Date: Thu, 5 Mar 2026 16:36:52 +0800 Subject: [PATCH 6/7] [TRTLLM-11101][feat] VisualGen benchmarking script (#11651) Signed-off-by: Zhenhua Wang --- examples/visual_gen/README.md | 118 +++- .../visual_gen/serve/benchmark_visual_gen.sh | 167 +++++ examples/visual_gen/serve/configs/flux1.yml | 2 +- examples/visual_gen/serve/configs/flux2.yml | 24 + tensorrt_llm/_torch/visual_gen/config.py | 4 + tensorrt_llm/bench/benchmark/visual_gen.py | 471 +++++++++++++ .../bench/benchmark/visual_gen_utils.py | 232 +++++++ tensorrt_llm/commands/bench.py | 2 + tensorrt_llm/commands/serve.py | 7 +- tensorrt_llm/commands/utils.py | 14 + tensorrt_llm/llmapi/visual_gen.py | 35 +- .../serve/scripts/benchmark_visual_gen.py | 616 ++++++++++++++++++ .../defs/examples/test_visual_gen.py | 22 +- tests/integration/defs/visual_gen/__init__.py | 0 .../visual_gen/test_visual_gen_benchmark.py | 418 ++++++++++++ 15 files changed, 2097 insertions(+), 35 deletions(-) create mode 100644 examples/visual_gen/serve/benchmark_visual_gen.sh create mode 100644 examples/visual_gen/serve/configs/flux2.yml create mode 100644 tensorrt_llm/bench/benchmark/visual_gen.py create mode 100644 tensorrt_llm/bench/benchmark/visual_gen_utils.py create mode 100644 tensorrt_llm/serve/scripts/benchmark_visual_gen.py create mode 100644 tests/integration/defs/visual_gen/__init__.py create mode 100644 tests/integration/defs/visual_gen/test_visual_gen_benchmark.py diff --git a/examples/visual_gen/README.md b/examples/visual_gen/README.md index 4dfd5e07e06..7b356d0f079 100644 --- a/examples/visual_gen/README.md +++ b/examples/visual_gen/README.md @@ -1,6 +1,6 @@ # Visual Generation Examples -Quick reference for running visual generation models (WAN). +Quick reference for running visual generation models (FLUX, WAN). ## Prerequisites @@ -34,6 +34,86 @@ cd examples/visual_gen --- +## FLUX (Text-to-Image) + +Supports both FLUX.1-dev and FLUX.2-dev. The pipeline type is auto-detected from the model checkpoint (`model_index.json`). + +### Basic Usage + +**FLUX.1-dev:** +```bash +python visual_gen_flux.py \ + --model_path ${MODEL_ROOT}/FLUX.1-dev \ + --prompt "A cat sitting on a windowsill" \ + --guidance_scale 3.5 \ + --output_path output.png +``` + +**FLUX.2-dev:** +```bash +python visual_gen_flux.py \ + --model_path ${MODEL_ROOT}/FLUX.2-dev \ + --prompt "A cat sitting on a windowsill" \ + --guidance_scale 4.0 \ + --output_path output.png +``` + +**With FP8 Quantization:** +```bash +python visual_gen_flux.py \ + --model_path ${MODEL_ROOT}/FLUX.2-dev \ + --prompt "A cat sitting on a windowsill" \ + --linear_type trtllm-fp8-per-tensor \ + --output_path output.png +``` + +**With TeaCache:** +```bash +python visual_gen_flux.py \ + --model_path ${MODEL_ROOT}/FLUX.1-dev \ + --prompt "A cat sitting on a windowsill" \ + --enable_teacache \ + --output_path output.png +``` + +### Batch Mode + +Generate multiple images from a prompts file (one prompt per line): + +```bash +python visual_gen_flux.py \ + --model_path ${MODEL_ROOT}/FLUX.1-dev \ + --prompts_file prompts.txt \ + --output_dir results/bf16/ \ + --seed 42 +``` + +```bash +# With FP8 quantization +python visual_gen_flux.py \ + --model_path ${MODEL_ROOT}/FLUX.2-dev \ + --prompts_file prompts.txt \ + --output_dir results/fp8/ \ + --linear_type trtllm-fp8-per-tensor +``` + +Images are saved as `00.png`, `01.png`, etc. with a `timing.json` summary. + +### Multi-GPU Parallelism + +FLUX supports CFG and Ulysses parallelism, same as WAN. + +**CFG + Ulysses (4 GPUs):** +```bash +python visual_gen_flux.py \ + --model_path ${MODEL_ROOT}/FLUX.1-dev \ + --prompts_file prompts.txt \ + --output_dir results/ \ + --cfg_size 2 --ulysses_size 2 +``` + +--- + ## WAN (Text-to-Video) ### Basic Usage @@ -116,25 +196,28 @@ GPU Layout: GPU 0-3 (positive) | GPU 4-7 (negative) ## Common Arguments -| Argument | WAN | Default | Description | -|----------|-----|---------|-------------| -| `--height` | ✓ | 720 | Output height | -| `--width` | ✓ | 1280 | Output width | -| `--num_frames` | ✓ | 81 | Number of frames | -| `--steps` | ✓ | 50 | Denoising steps | -| `--guidance_scale` | ✓ | 5.0 | CFG guidance strength | -| `--seed` | ✓ | 42 | Random seed | -| `--enable_teacache` | ✓ | False | Cache optimization | -| `--teacache_thresh` | ✓ | 0.2 | TeaCache similarity threshold | -| `--attention_backend` | ✓ | VANILLA | VANILLA or TRTLLM | -| `--cfg_size` | ✓ | 1 | CFG parallelism | -| `--ulysses_size` | ✓ | 1 | Sequence parallelism | -| `--linear_type` | ✓ | default | Quantization type | +| Argument | FLUX | WAN | Default | Description | +|----------|------|-----|---------|-------------| +| `--height` | ✓ | ✓ | 1024 / 720 | Output height | +| `--width` | ✓ | ✓ | 1024 / 1280 | Output width | +| `--num_frames` | | ✓ | 81 | Number of frames | +| `--steps` | ✓ | ✓ | 50 | Denoising steps | +| `--guidance_scale` | ✓ | ✓ | 3.5 / 5.0 | CFG guidance strength | +| `--seed` | ✓ | ✓ | 42 | Random seed | +| `--enable_teacache` | ✓ | ✓ | False | Cache optimization | +| `--teacache_thresh` | ✓ | ✓ | 0.2 | TeaCache similarity threshold | +| `--attention_backend` | ✓ | ✓ | VANILLA | VANILLA or TRTLLM | +| `--cfg_size` | ✓ | ✓ | 1 | CFG parallelism | +| `--ulysses_size` | ✓ | ✓ | 1 | Sequence parallelism | +| `--linear_type` | ✓ | ✓ | default | Quantization type | +| `--prompts_file` | ✓ | | — | Batch mode prompts file | +| `--output_dir` | ✓ | | — | Batch mode output directory | +| `--disable_torch_compile` | ✓ | ✓ | False | Disable torch.compile | ## Troubleshooting **Out of Memory:** -- Use quantization: `--linear_type trtllm-fp8-blockwise` +- Use quantization: `--linear_type trtllm-fp8-blockwise` (WAN) or `--linear_type trtllm-fp8-per-tensor` (FLUX) - Reduce resolution or frames - Enable TeaCache: `--enable_teacache` - Use Ulysses parallelism with more GPUs @@ -149,12 +232,13 @@ GPU Layout: GPU 0-3 (positive) | GPU 4-7 (negative) - Install necessary dependencies, e.g., `pip install -r requirements-dev.txt` **Ulysses Errors:** -- `ulysses_size` must divide 12 (WAN heads) +- `ulysses_size` must divide the model's head count (12 for WAN) - Total GPUs = `cfg_size × ulysses_size` - Sequence length must be divisible by `ulysses_size` ## Output Formats +- **FLUX**: `.png` (image) - **WAN**: `.mp4` (video), `.gif` (animated), `.png` (single frame) ## Baseline Validation diff --git a/examples/visual_gen/serve/benchmark_visual_gen.sh b/examples/visual_gen/serve/benchmark_visual_gen.sh new file mode 100644 index 00000000000..20664eb72a5 --- /dev/null +++ b/examples/visual_gen/serve/benchmark_visual_gen.sh @@ -0,0 +1,167 @@ +#!/bin/bash +# Benchmark VisualGen serving with trtllm-serve +# +# This script demonstrates how to: +# 1. Start a trtllm-serve server for VisualGen +# 2. Run the benchmark_visual_gen.py client against it +# +# Usage: +# # Set model path (HF model ID or local path) +# export MODEL=Wan-AI/Wan2.2-T2V-A14B-Diffusers +# +# # Optional: customize server config +# export SERVER_CONFIG=./configs/wan.yml +# +# # Run the benchmark +# ./benchmark_visual_gen.sh +# +# Requirements: +# pip install git+https://github.com/huggingface/diffusers.git +# pip install av + +set -euo pipefail + +# --------------------------------------------------------------------------- +# Configuration (override via environment variables) +# --------------------------------------------------------------------------- + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +PROJECT_ROOT=${PROJECT_ROOT:-"$(cd "${SCRIPT_DIR}/../../.." && pwd)"} + +MODEL=${MODEL:-"Wan-AI/Wan2.2-T2V-A14B-Diffusers"} +SERVER_CONFIG=${SERVER_CONFIG:-"${SCRIPT_DIR}/configs/wan.yml"} +BACKEND=${BACKEND:-"openai-videos"} +HOST=${HOST:-"127.0.0.1"} +PORT=${PORT:-8000} + +# Generation defaults +SIZE=${SIZE:-"720x1280"} +NUM_FRAMES=${NUM_FRAMES:-81} +FPS=${FPS:-16} +NUM_INFERENCE_STEPS=${NUM_INFERENCE_STEPS:-50} +GUIDANCE_SCALE=${GUIDANCE_SCALE:-5.0} +SEED=${SEED:-42} + +# Benchmark defaults +NUM_PROMPTS=${NUM_PROMPTS:-3} +MAX_CONCURRENCY=${MAX_CONCURRENCY:-1} +PROMPT=${PROMPT:-"A cat walks through a field of flowers, with the wind blowing gently"} + +# Output +RESULT_DIR=${RESULT_DIR:-"./benchmark_results"} + +# --------------------------------------------------------------------------- +# Helper functions +# --------------------------------------------------------------------------- + +wait_for_server() { + local url="http://${HOST}:${PORT}/health" + local max_wait=${SERVER_TIMEOUT:-3600} # 60 minutes for model loading + warmup on NFS + local elapsed=0 + local interval=5 + + echo "Waiting for server at ${url} ..." + while [ $elapsed -lt $max_wait ]; do + if curl -s -o /dev/null -w "%{http_code}" "$url" 2>/dev/null | grep -q "200"; then + echo "Server is ready (took ${elapsed}s)" + return 0 + fi + sleep $interval + elapsed=$((elapsed + interval)) + if [ $((elapsed % 30)) -eq 0 ]; then + echo " Still waiting... (${elapsed}s elapsed)" + fi + done + echo "ERROR: Server did not become ready within ${max_wait}s" + return 1 +} + +cleanup() { + if [ -n "${SERVER_PID:-}" ]; then + echo "Stopping server (PID: $SERVER_PID)..." + kill "$SERVER_PID" 2>/dev/null || true + wait "$SERVER_PID" 2>/dev/null || true + fi +} + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +echo "============================================" +echo "VisualGen Serving Benchmark" +echo "============================================" +echo "Model: $MODEL" +echo "Backend: $BACKEND" +echo "Server: http://${HOST}:${PORT}" +echo "Size: $SIZE" +if [ "$BACKEND" = "openai-videos" ]; then +echo "Num frames: $NUM_FRAMES" +echo "FPS: $FPS" +fi +echo "Inference steps: $NUM_INFERENCE_STEPS" +echo "Guidance scale: $GUIDANCE_SCALE" +echo "Num prompts: $NUM_PROMPTS" +echo "Max concurrency: $MAX_CONCURRENCY" +echo "Result dir: $RESULT_DIR" +echo "============================================" +echo "" + +# Step 1: Start server +SERVER_CMD="trtllm-serve ${MODEL} --host ${HOST} --port ${PORT}" +if [ -n "$SERVER_CONFIG" ]; then + SERVER_CMD="${SERVER_CMD} --extra_visual_gen_options ${SERVER_CONFIG}" +fi + +echo "Step 1: Starting server..." +echo " Command: ${SERVER_CMD}" + +SERVER_LOG="${RESULT_DIR}/server.log" +mkdir -p "${RESULT_DIR}" + +$SERVER_CMD > "$SERVER_LOG" 2>&1 & +SERVER_PID=$! +trap cleanup EXIT + +echo " Server PID: $SERVER_PID" +echo " Server log: $SERVER_LOG" + +wait_for_server + +# Step 2: Run benchmark +echo "" +echo "Step 2: Running benchmark..." + +BENCHMARK_CMD="python -m tensorrt_llm.serve.scripts.benchmark_visual_gen \ + --model ${MODEL} \ + --backend ${BACKEND} \ + --host ${HOST} \ + --port ${PORT} \ + --prompt \"${PROMPT}\" \ + --num-prompts ${NUM_PROMPTS} \ + --size ${SIZE} \ + --num-inference-steps ${NUM_INFERENCE_STEPS} \ + --guidance-scale ${GUIDANCE_SCALE} \ + --seed ${SEED} \ + --max-concurrency ${MAX_CONCURRENCY} \ + --save-result \ + --save-detailed \ + --result-dir ${RESULT_DIR} \ + --metric-percentiles 50,90,99" + +if [ "$BACKEND" = "openai-videos" ]; then + BENCHMARK_CMD="${BENCHMARK_CMD} --num-frames ${NUM_FRAMES} --fps ${FPS}" +fi + +BENCHMARK_LOG="${RESULT_DIR}/benchmark.log" + +echo " Command: ${BENCHMARK_CMD}" +echo " Benchmark log: ${BENCHMARK_LOG}" +echo "" + +eval $BENCHMARK_CMD 2>&1 | tee "${BENCHMARK_LOG}" + +echo "" +echo "============================================" +echo "Benchmark complete. Results in: ${RESULT_DIR}" +echo "============================================" diff --git a/examples/visual_gen/serve/configs/flux1.yml b/examples/visual_gen/serve/configs/flux1.yml index f97f0016e10..57aa695e46c 100644 --- a/examples/visual_gen/serve/configs/flux1.yml +++ b/examples/visual_gen/serve/configs/flux1.yml @@ -1,7 +1,7 @@ linear: type: default teacache: - enable_teacache: false + enable_teacache: true teacache_thresh: 0.2 attention: backend: VANILLA diff --git a/examples/visual_gen/serve/configs/flux2.yml b/examples/visual_gen/serve/configs/flux2.yml new file mode 100644 index 00000000000..d492e382512 --- /dev/null +++ b/examples/visual_gen/serve/configs/flux2.yml @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +linear: + type: default +teacache: + enable_teacache: true + teacache_thresh: 0.2 +attention: + backend: VANILLA +parallel: + dit_cfg_size: 1 + dit_ulysses_size: 2 diff --git a/tensorrt_llm/_torch/visual_gen/config.py b/tensorrt_llm/_torch/visual_gen/config.py index 4f655930b20..e177ae3451d 100644 --- a/tensorrt_llm/_torch/visual_gen/config.py +++ b/tensorrt_llm/_torch/visual_gen/config.py @@ -108,6 +108,10 @@ class ParallelConfig(BaseModel): t5_fsdp_size: int = 1 + @property + def n_workers(self) -> int: + return self.dit_cfg_size * self.dit_ulysses_size + def to_mapping(self) -> Mapping: """Convert to TRT-LLM Mapping.""" world_size = self.dit_tp_size * self.dit_cp_size diff --git a/tensorrt_llm/bench/benchmark/visual_gen.py b/tensorrt_llm/bench/benchmark/visual_gen.py new file mode 100644 index 00000000000..212b1cb31e0 --- /dev/null +++ b/tensorrt_llm/bench/benchmark/visual_gen.py @@ -0,0 +1,471 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Offline benchmark for VisualGen (image/video generation) models. + +Usage: + trtllm-bench --model Wan-AI/Wan2.2-T2V-A14B-Diffusers \ + --model_path /path/to/checkpoint \ + visual-gen --extra_visual_gen_options config.yaml +""" + +import json +import os +import time +from datetime import datetime +from typing import Optional + +import click + +from tensorrt_llm.bench.benchmark.visual_gen_utils import ( + VisualGenBenchmarkMetrics, + VisualGenRequestOutput, + build_visual_gen_result_dict, + calculate_metrics, + load_visual_gen_prompts, + print_visual_gen_results, +) +from tensorrt_llm.bench.dataclasses.general import BenchmarkEnvironment +from tensorrt_llm.logger import logger + + +def _parse_size(size_str: str) -> tuple[Optional[int], Optional[int]]: + """Parse WxH size string into (width, height). Returns (None, None) for 'auto'.""" + if size_str.lower() == "auto": + return None, None + parts = size_str.lower().split("x") + if len(parts) != 2: + raise click.BadParameter( + f"Size must be 'auto' or WxH format (e.g. 480x832), got '{size_str}'" + ) + return int(parts[0]), int(parts[1]) + + +@click.command(name="visual-gen", context_settings={"show_default": True}) +@click.option( + "--extra_visual_gen_options", + type=str, + default=None, + help="Path to a YAML file with extra VisualGen model options " + "(same format as trtllm-serve --extra_visual_gen_options).", +) +@click.option( + "--prompt", + type=str, + default=None, + help="Single text prompt (repeated --num_prompts times).", +) +@click.option( + "--prompt_file", + type=str, + default=None, + help="Path to prompt file. Supports plain text (one prompt per line) " + "or JSONL with 'text'/'prompt' field.", +) +@click.option( + "--num_prompts", + type=int, + default=5, + help="Number of prompts to benchmark.", +) +@click.option( + "--size", + type=str, + default="auto", + help="Output resolution in WxH format (e.g. 480x832) or 'auto'.", +) +@click.option( + "--seconds", + type=float, + default=4.0, + help="Video duration in seconds.", +) +@click.option( + "--fps", + type=int, + default=16, + help="Frames per second.", +) +@click.option( + "--num_frames", + type=int, + default=None, + help="Total frames to generate. Overrides --seconds (computed as num_frames / fps).", +) +@click.option( + "--num_inference_steps", + type=int, + default=None, + help="Number of diffusion denoising steps.", +) +@click.option( + "--guidance_scale", + type=float, + default=None, + help="Classifier-free guidance scale.", +) +@click.option( + "--seed", + type=int, + default=42, + help="Random seed for reproducibility.", +) +@click.option( + "--negative_prompt", + type=str, + default=None, + help="Negative prompt (concepts to avoid).", +) +@click.option( + "--max_concurrency", + type=int, + default=1, + help="Maximum concurrent generation requests.", +) +@click.option( + "--warmup", + type=int, + default=1, + help="Number of warmup requests before benchmarking.", +) +@click.option( + "--save_result", + is_flag=True, + default=False, + help="Save results to a JSON file.", +) +@click.option( + "--save_detailed", + is_flag=True, + default=False, + help="Include per-request details (latencies, errors) in saved results.", +) +@click.option( + "--result_dir", + type=str, + default=None, + help="Directory for result files.", +) +@click.option( + "--result_filename", + type=str, + default=None, + help="Custom result filename.", +) +@click.option( + "--metric_percentiles", + type=str, + default="50,90,99", + help="Comma-separated percentile values.", +) +@click.pass_obj +def visual_gen_command( + bench_env: BenchmarkEnvironment, + extra_visual_gen_options: Optional[str], + prompt: Optional[str], + prompt_file: Optional[str], + num_prompts: int, + size: str, + seconds: float, + fps: int, + num_frames: Optional[int], + num_inference_steps: Optional[int], + guidance_scale: Optional[float], + seed: int, + negative_prompt: Optional[str], + max_concurrency: int, + warmup: int, + save_result: bool, + save_detailed: bool, + result_dir: Optional[str], + result_filename: Optional[str], + metric_percentiles: str, +) -> None: + """Benchmark VisualGen (image/video generation) models offline.""" + import yaml + + from tensorrt_llm.commands.utils import get_visual_gen_model_type, get_visual_gen_num_gpus + from tensorrt_llm.llmapi.visual_gen import VisualGen, VisualGenParams + + if prompt is None and prompt_file is None: + raise click.UsageError("Either --prompt or --prompt_file must be specified.") + if prompt is not None and prompt_file is not None: + raise click.UsageError("--prompt and --prompt_file are mutually exclusive.") + + model = bench_env.model + model_path = str(bench_env.checkpoint_path or model) + + # Build diffusion config (same pattern as trtllm-serve _serve_visual_gen) + visual_gen_config: dict = { + "model": model_path, + "model_type": get_visual_gen_model_type(model_path), + } + if extra_visual_gen_options is not None: + with open(extra_visual_gen_options, "r") as f: + visual_gen_extra_args = yaml.safe_load(f) or {} + visual_gen_config.update(visual_gen_extra_args) + + n_workers = get_visual_gen_num_gpus(visual_gen_config) + parallel_config = visual_gen_config.get("parallel", {}) + if parallel_config: + logger.info(f"World size: {n_workers}") + logger.info(f"CFG size: {parallel_config.get('dit_cfg_size', 1)}") + logger.info(f"Ulysses size: {parallel_config.get('dit_ulysses_size', 1)}") + + # Parse generation parameters + width, height = _parse_size(size) + if num_frames is not None: + seconds = num_frames / fps + logger.info(f"Computed seconds={seconds:.3f} from num_frames={num_frames} / fps={fps}") + + gen_params_kwargs: dict = {"seed": seed, "frame_rate": float(fps)} + if height is not None: + gen_params_kwargs["height"] = height + if width is not None: + gen_params_kwargs["width"] = width + if num_frames is not None: + gen_params_kwargs["num_frames"] = num_frames + if num_inference_steps is not None: + gen_params_kwargs["num_inference_steps"] = num_inference_steps + if guidance_scale is not None: + gen_params_kwargs["guidance_scale"] = guidance_scale + + gen_params = VisualGenParams(**gen_params_kwargs) + + gen_params_for_report = { + "size": size, + "seconds": seconds, + "fps": fps, + } + if num_inference_steps is not None: + gen_params_for_report["num_inference_steps"] = num_inference_steps + if guidance_scale is not None: + gen_params_for_report["guidance_scale"] = guidance_scale + if negative_prompt is not None: + gen_params_for_report["negative_prompt"] = negative_prompt + gen_params_for_report["seed"] = seed + + # Load prompts + input_requests = load_visual_gen_prompts(prompt, prompt_file, num_prompts) + selected_percentiles = [float(p) for p in metric_percentiles.split(",")] + + # Initialize VisualGen + logger.info(f"Initializing VisualGen ({model_path})") + visual_gen = VisualGen( + model_path=model_path, + n_workers=n_workers, + diffusion_config=visual_gen_config, + ) + + try: + # Warmup + if warmup > 0: + logger.info(f"Running {warmup} warmup request(s)...") + for i in range(warmup): + warmup_prompt = input_requests[i % len(input_requests)].prompt + visual_gen.generate(inputs=warmup_prompt, params=gen_params) + logger.info("Warmup complete.") + + # Main benchmark + logger.info( + f"Starting benchmark: {len(input_requests)} requests, max_concurrency={max_concurrency}" + ) + + benchmark_start = time.perf_counter() + outputs = _run_benchmark( + visual_gen=visual_gen, + input_requests=input_requests, + gen_params=gen_params, + negative_prompt=negative_prompt, + max_concurrency=max_concurrency, + ) + benchmark_duration = time.perf_counter() - benchmark_start + + finally: + visual_gen.shutdown() + + metrics = calculate_metrics( + outputs=outputs, + dur_s=benchmark_duration, + selected_percentiles=selected_percentiles, + num_gpus=n_workers, + ) + + print_visual_gen_results( + backend="offline", + model_id=model_path, + benchmark_duration=benchmark_duration, + metrics=metrics, + ) + + if save_result: + _save_results( + backend="offline", + model_id=model_path, + benchmark_duration=benchmark_duration, + metrics=metrics, + outputs=outputs, + gen_params=gen_params_for_report, + num_prompts=num_prompts, + max_concurrency=max_concurrency, + num_gpus=n_workers, + save_detailed=save_detailed, + result_dir=result_dir, + result_filename=result_filename, + ) + + +def _run_benchmark( + visual_gen, + input_requests, + gen_params, + negative_prompt: Optional[str], + max_concurrency: int, +) -> list[VisualGenRequestOutput]: + """Run the benchmark loop, dispatching requests with concurrency control.""" + import asyncio + + outputs: list[VisualGenRequestOutput] = [] + + if max_concurrency <= 1: + outputs = _run_sequential(visual_gen, input_requests, gen_params, negative_prompt) + else: + outputs = asyncio.run( + _run_concurrent( + visual_gen, + input_requests, + gen_params, + negative_prompt, + max_concurrency, + ) + ) + + return outputs + + +def _run_sequential( + visual_gen, input_requests, gen_params, negative_prompt +) -> list[VisualGenRequestOutput]: + """Run requests one at a time, measuring per-request latency.""" + outputs = [] + + for req in input_requests: + output = VisualGenRequestOutput() + inputs = ( + {"prompt": req.prompt, "negative_prompt": negative_prompt} + if negative_prompt + else req.prompt + ) + st = time.perf_counter() + try: + visual_gen.generate(inputs=inputs, params=gen_params) + output.e2e_latency = time.perf_counter() - st + output.success = True + except Exception as e: + output.e2e_latency = time.perf_counter() - st + output.success = False + output.error = str(e) + output.exception_type = e.__class__.__name__ + logger.error(f"Request failed: {e}") + + outputs.append(output) + + return outputs + + +async def _run_concurrent( + visual_gen, input_requests, gen_params, negative_prompt, max_concurrency +) -> list[VisualGenRequestOutput]: + """Run requests concurrently using generate_async with a semaphore.""" + import asyncio + + semaphore = asyncio.Semaphore(max_concurrency) + outputs: list[VisualGenRequestOutput] = [VisualGenRequestOutput() for _ in input_requests] + + async def _generate_one(idx, req): + inputs = ( + {"prompt": req.prompt, "negative_prompt": negative_prompt} + if negative_prompt + else req.prompt + ) + async with semaphore: + output = outputs[idx] + st = time.perf_counter() + try: + future = visual_gen.generate_async(inputs=inputs, params=gen_params) + await future.result() + output.e2e_latency = time.perf_counter() - st + output.success = True + except Exception as e: + output.e2e_latency = time.perf_counter() - st + output.success = False + output.error = str(e) + output.exception_type = e.__class__.__name__ + logger.error(f"Request {idx} failed: {e}") + + tasks = [_generate_one(i, req) for i, req in enumerate(input_requests)] + await asyncio.gather(*tasks) + + return outputs + + +def _save_results( + backend: str, + model_id: str, + benchmark_duration: float, + metrics: VisualGenBenchmarkMetrics, + outputs: list[VisualGenRequestOutput], + gen_params: dict, + num_prompts: int, + max_concurrency: int, + num_gpus: int, + save_detailed: bool, + result_dir: Optional[str], + result_filename: Optional[str], +) -> None: + """Save benchmark results to a JSON file.""" + current_dt = datetime.now().strftime("%Y%m%d-%H%M%S") + + result_json = build_visual_gen_result_dict( + backend=backend, + model_id=model_id, + benchmark_duration=benchmark_duration, + metrics=metrics, + outputs=outputs, + gen_params=gen_params, + ) + + result_json["date"] = current_dt + result_json["num_prompts"] = num_prompts + result_json["max_concurrency"] = max_concurrency + result_json["num_gpus"] = num_gpus + + if not save_detailed: + for field_name in ["e2e_latencies", "errors"]: + result_json.pop(field_name, None) + + base_model_id = model_id.split("/")[-1] + concurrency_str = f"-concurrency{max_concurrency}" if max_concurrency is not None else "" + file_name = f"offline{concurrency_str}-{base_model_id}-{current_dt}.json" + if result_filename: + file_name = result_filename + if result_dir: + os.makedirs(result_dir, exist_ok=True) + file_name = os.path.join(result_dir, file_name) + + with open(file_name, "w", encoding="utf-8") as outfile: + json.dump(result_json, outfile, indent=2) + + print(f"Results saved to: {file_name}") diff --git a/tensorrt_llm/bench/benchmark/visual_gen_utils.py b/tensorrt_llm/bench/benchmark/visual_gen_utils.py new file mode 100644 index 00000000000..292de4805aa --- /dev/null +++ b/tensorrt_llm/bench/benchmark/visual_gen_utils.py @@ -0,0 +1,232 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Shared utilities for VisualGen benchmarking (online and offline).""" + +import json +import warnings +from dataclasses import dataclass +from typing import Optional + +import numpy as np + +SECONDS_TO_MILLISECONDS = 1000 + + +@dataclass +class VisualGenSampleRequest: + """A single prompt for visual generation benchmarking.""" + + prompt: str + + +@dataclass +class VisualGenRequestOutput: + """Timing and status result for a single visual generation request.""" + + success: bool = False + e2e_latency: float = 0.0 + ttff: float = -1.0 + gen_fps: float = -1.0 + error: str = "" + exception_type: Optional[str] = None + + +@dataclass +class VisualGenBenchmarkMetrics: + """Aggregated benchmark metrics across all requests.""" + + completed: int + total_requests: int + request_throughput: float + mean_e2e_latency_ms: float + median_e2e_latency_ms: float + std_e2e_latency_ms: float + min_e2e_latency_ms: float + max_e2e_latency_ms: float + percentiles_e2e_latency_ms: list[tuple[float, float]] + num_gpus: int = 1 + per_gpu_throughput: float = 0.0 + mean_ttff_ms: float = -1.0 + mean_gen_fps: float = -1.0 + + +def calculate_metrics( + outputs: list[VisualGenRequestOutput], + dur_s: float, + selected_percentiles: list[float], + num_gpus: int = 1, +) -> VisualGenBenchmarkMetrics: + """Compute aggregate metrics from per-request outputs.""" + e2e_latencies: list[float] = [] + error_counts: dict[str, int] = {} + completed = 0 + + for out in outputs: + if out.exception_type: + error_counts[out.exception_type] = error_counts.get(out.exception_type, 0) + 1 + if out.success: + e2e_latencies.append(out.e2e_latency) + completed += 1 + + total_error_count = sum(error_counts.values()) + for exception_type, count in error_counts.items(): + print(f"Error type: {exception_type}, Count: {count} requests") + if total_error_count: + print(f"Total failed requests: {total_error_count}") + + if completed == 0: + warnings.warn( + "All requests failed. This is likely due to a misconfiguration " + "on the benchmark arguments.", + stacklevel=2, + ) + + e2e_ms = [v * SECONDS_TO_MILLISECONDS for v in e2e_latencies] + + request_throughput = completed / dur_s if dur_s > 0 else 0 + return VisualGenBenchmarkMetrics( + completed=completed, + total_requests=len(outputs), + request_throughput=request_throughput, + mean_e2e_latency_ms=float(np.mean(e2e_ms)) if e2e_ms else 0, + median_e2e_latency_ms=float(np.median(e2e_ms)) if e2e_ms else 0, + std_e2e_latency_ms=float(np.std(e2e_ms)) if e2e_ms else 0, + min_e2e_latency_ms=float(np.min(e2e_ms)) if e2e_ms else 0, + max_e2e_latency_ms=float(np.max(e2e_ms)) if e2e_ms else 0, + percentiles_e2e_latency_ms=( + [(p, float(np.percentile(e2e_ms, p))) for p in selected_percentiles] + if e2e_ms + else [(p, 0.0) for p in selected_percentiles] + ), + num_gpus=num_gpus, + per_gpu_throughput=request_throughput / num_gpus, + ) + + +def print_visual_gen_results( + backend: str, + model_id: str, + benchmark_duration: float, + metrics: VisualGenBenchmarkMetrics, +) -> None: + """Print benchmark results to stdout.""" + print("{s:{c}^{n}}".format(s=" Benchmark Result (VisualGen) ", n=60, c="=")) + print("{:<40} {:<10}".format("Backend:", backend)) + print("{:<40} {:<10}".format("Model:", model_id)) + print("{:<40} {:<10}".format("Total requests:", metrics.total_requests)) + print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) + print("{:<40} {:<10}".format("Failed requests:", metrics.total_requests - metrics.completed)) + print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration)) + print("{:<40} {:<10.4f}".format("Request throughput (req/s):", metrics.request_throughput)) + print("{:<40} {:<10}".format("Number of GPUs:", metrics.num_gpus)) + print("{:<40} {:<10.4f}".format("Per-GPU throughput (req/s/GPU):", metrics.per_gpu_throughput)) + + if metrics.total_requests - metrics.completed > 0: + print("=" * 60) + print( + f" !!! {metrics.total_requests - metrics.completed} " + "FAILED REQUESTS - CHECK LOG FOR ERRORS !!!" + ) + print("=" * 60) + + print("{s:{c}^{n}}".format(s=" E2E Latency ", n=60, c="-")) + print("{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms)) + print("{:<40} {:<10.2f}".format("Median E2E Latency (ms):", metrics.median_e2e_latency_ms)) + print("{:<40} {:<10.2f}".format("Std Dev E2E Latency (ms):", metrics.std_e2e_latency_ms)) + print("{:<40} {:<10.2f}".format("Min E2E Latency (ms):", metrics.min_e2e_latency_ms)) + print("{:<40} {:<10.2f}".format("Max E2E Latency (ms):", metrics.max_e2e_latency_ms)) + for p, v in metrics.percentiles_e2e_latency_ms: + p_word = str(int(p)) if int(p) == p else str(p) + print("{:<40} {:<10.2f}".format(f"P{p_word} E2E Latency (ms):", v)) + + print("{s:{c}^{n}}".format(s=" Placeholder Metrics ", n=60, c="-")) + print("{:<40} {:<10}".format("TTFF (ms):", "N/A (placeholder)")) + print("{:<40} {:<10}".format("GenFPS:", "N/A (placeholder)")) + print("=" * 60) + + +def load_visual_gen_prompts( + prompt: Optional[str], + prompt_file: Optional[str], + num_prompts: int, +) -> list[VisualGenSampleRequest]: + """Load prompts from a single string or a file. + + Args: + prompt: Single text prompt (repeated to fill num_prompts). + prompt_file: Path to prompt file. Supports plain text (one per line) + or JSONL with ``text`` / ``prompt`` field. + num_prompts: Number of prompts to return. + + Returns: + List of ``VisualGenSampleRequest`` of length *num_prompts*. + """ + prompts: list[str] = [] + + if prompt_file: + with open(prompt_file, encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + try: + data = json.loads(line) + prompts.append(data.get("text", data.get("prompt", line))) + except json.JSONDecodeError: + prompts.append(line) + elif prompt: + prompts.append(prompt) + else: + raise ValueError("Either prompt or prompt_file must be specified.") + + if len(prompts) < num_prompts: + repeats = (num_prompts // len(prompts)) + 1 + prompts = (prompts * repeats)[:num_prompts] + else: + prompts = prompts[:num_prompts] + + return [VisualGenSampleRequest(prompt=p) for p in prompts] + + +def build_visual_gen_result_dict( + backend: str, + model_id: str, + benchmark_duration: float, + metrics: VisualGenBenchmarkMetrics, + outputs: list[VisualGenRequestOutput], + gen_params: dict, +) -> dict: + """Build the result dictionary for JSON serialization.""" + return { + "backend": backend, + "model": model_id, + "duration": benchmark_duration, + "num_gpus": metrics.num_gpus, + "total_requests": metrics.total_requests, + "completed": metrics.completed, + "request_throughput": metrics.request_throughput, + "per_gpu_throughput": metrics.per_gpu_throughput, + "mean_e2e_latency_ms": metrics.mean_e2e_latency_ms, + "median_e2e_latency_ms": metrics.median_e2e_latency_ms, + "std_e2e_latency_ms": metrics.std_e2e_latency_ms, + "min_e2e_latency_ms": metrics.min_e2e_latency_ms, + "max_e2e_latency_ms": metrics.max_e2e_latency_ms, + "percentiles_e2e_latency_ms": { + f"p{int(p) if int(p) == p else p}": v for p, v in metrics.percentiles_e2e_latency_ms + }, + "e2e_latencies": [out.e2e_latency for out in outputs], + "errors": [out.error for out in outputs], + "gen_params": gen_params, + } diff --git a/tensorrt_llm/commands/bench.py b/tensorrt_llm/commands/bench.py index ab4755082f6..9ba1a5258df 100644 --- a/tensorrt_llm/commands/bench.py +++ b/tensorrt_llm/commands/bench.py @@ -5,6 +5,7 @@ from tensorrt_llm.bench.benchmark.low_latency import latency_command from tensorrt_llm.bench.benchmark.throughput import throughput_command +from tensorrt_llm.bench.benchmark.visual_gen import visual_gen_command from tensorrt_llm.bench.build.build import build_command from tensorrt_llm.bench.dataclasses.general import BenchmarkEnvironment from tensorrt_llm.bench.dataset.prepare_dataset import prepare_dataset @@ -67,6 +68,7 @@ def main( main.add_command(throughput_command) main.add_command(latency_command) main.add_command(prepare_dataset) +main.add_command(visual_gen_command) if __name__ == "__main__": main() diff --git a/tensorrt_llm/commands/serve.py b/tensorrt_llm/commands/serve.py index 2cd47787559..eb20f513cd5 100644 --- a/tensorrt_llm/commands/serve.py +++ b/tensorrt_llm/commands/serve.py @@ -20,7 +20,8 @@ from tensorrt_llm._tensorrt_engine import LLM from tensorrt_llm._utils import mpi_rank from tensorrt_llm.commands.utils import (get_is_diffusion_model, - get_visual_gen_model_type) + get_visual_gen_model_type, + get_visual_gen_num_gpus) from tensorrt_llm.executor.utils import LlmLauncherEnvs from tensorrt_llm.inputs.multimodal import MultimodalServerConfig from tensorrt_llm.llmapi import (BuildConfig, CapacitySchedulerPolicy, @@ -465,11 +466,9 @@ def launch_visual_gen_server( model = visual_gen_config["model"] logger.info(f"Initializing VisualGen ({model})") - n_workers = 1 + n_workers = get_visual_gen_num_gpus(visual_gen_config) parallel_config = visual_gen_config.get("parallel", {}) if parallel_config: - n_workers = parallel_config.get( - "dit_cfg_size", 1) * parallel_config.get("dit_ulysses_size", 1) logger.info(f"World size: {n_workers}") logger.info(f"CFG size: {parallel_config.get('dit_cfg_size', 1)}") logger.info( diff --git a/tensorrt_llm/commands/utils.py b/tensorrt_llm/commands/utils.py index df1442c6e78..a85e152b9f3 100644 --- a/tensorrt_llm/commands/utils.py +++ b/tensorrt_llm/commands/utils.py @@ -3,6 +3,7 @@ import logging import os +from tensorrt_llm._torch.visual_gen.config import ParallelConfig from tensorrt_llm.llmapi.utils import download_hf_partial logger = logging.getLogger(__name__) @@ -115,6 +116,7 @@ def get_model_path(extra_argv): VISUAL_GEN_PARTIAL_MODEL_NAME_TO_MODEL_TYPE = { + "FLUX.1": "flux1", "FLUX.2": "flux2", "LTX-2": "ltx2", "Wan2": "wan2", @@ -130,3 +132,15 @@ def get_visual_gen_model_type(model_path: str): f"Unknown VISUAL_GEN model type for model path: {model_path}," f"available models: {VISUAL_GEN_PARTIAL_MODEL_NAME_TO_MODEL_TYPE.keys()}" ) + + +def get_visual_gen_num_gpus(diffusion_config: dict) -> int: + """Compute the number of GPUs from a visual_gen config. + + Uses ParallelConfig.model_construct (skips env validators) + so this is safe to call from non-worker processes. + """ + parallel = diffusion_config.get("parallel", {}) + if isinstance(parallel, dict): + parallel = ParallelConfig.model_construct(**parallel) + return parallel.n_workers diff --git a/tensorrt_llm/llmapi/visual_gen.py b/tensorrt_llm/llmapi/visual_gen.py index 2113b435484..8e742911cee 100644 --- a/tensorrt_llm/llmapi/visual_gen.py +++ b/tensorrt_llm/llmapi/visual_gen.py @@ -25,7 +25,6 @@ AWAIT_TIMEOUT = 0.05 THREAD_TIMEOUT = 5.0 WORKER_TIMEOUT = 2.0 -READY_TIMEOUT = 1200 # 20 minutes for large models (Wan 2.2 with transformer_2) def find_free_port() -> int: @@ -317,27 +316,43 @@ def shutdown(self): p.kill() p.join(timeout=WORKER_TIMEOUT) - def _wait_ready(self, timeout: float = READY_TIMEOUT): + def _wait_ready(self): """Wait for workers to be ready (sync wrapper for async operation).""" logger.info("DiffusionClient: Waiting for workers") - # Run the async wait in the background thread's event loop - future = asyncio.run_coroutine_threadsafe(self._wait_ready_async(timeout), self._event_loop) - return future.result(timeout=timeout) + future = asyncio.run_coroutine_threadsafe(self._wait_ready_async(), self._event_loop) + try: + future.result() + except Exception: + self.shutdown() + raise + + async def _wait_ready_async(self): + """Wait for workers to be ready (async version). - async def _wait_ready_async(self, timeout: float = READY_TIMEOUT): - """Wait for workers to be ready (async version).""" + Polls indefinitely for the ready signal. If any worker process dies + during initialization, raises RuntimeError immediately (LLM-style). + """ start_time = time.time() + last_log_time = start_time + log_interval = 300 while True: async with self.lock: if -1 in self.completed_responses: self.completed_responses.pop(-1) - logger.info("DiffusionClient: Workers ready") + elapsed = time.time() - start_time + logger.info(f"DiffusionClient: Workers ready ({elapsed:.1f}s)") return - if time.time() - start_time > timeout: - raise RuntimeError("DiffusionClient: Timeout waiting for workers") + if any(not p.is_alive() for p in self.worker_processes): + raise RuntimeError("DiffusionClient: Worker died during initialization") + + now = time.time() + if now - last_log_time >= log_interval: + elapsed = now - start_time + logger.info(f"DiffusionClient: Still waiting for workers ({elapsed:.0f}s elapsed)") + last_log_time = now try: await asyncio.wait_for(self.response_event.wait(), timeout=AWAIT_TIMEOUT) diff --git a/tensorrt_llm/serve/scripts/benchmark_visual_gen.py b/tensorrt_llm/serve/scripts/benchmark_visual_gen.py new file mode 100644 index 00000000000..abad5973454 --- /dev/null +++ b/tensorrt_llm/serve/scripts/benchmark_visual_gen.py @@ -0,0 +1,616 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Benchmark online serving throughput for VisualGen (image/video generation). + +On the server side, run: + trtllm-serve Wan-AI/Wan2.2-T2V-A14B-Diffusers --extra_visual_gen_options + +On the client side, run: + python -m tensorrt_llm.serve.scripts.benchmark_visual_gen \ + --model Wan-AI/Wan2.2-T2V-A14B-Diffusers \ + --backend openai-videos \ + --prompt "A cat playing in the park" \ + --num-prompts 5 \ + --size 480x832 \ + --num-frames 81 \ + --fps 16 \ + --num-inference-steps 50 \ + --max-concurrency 1 \ + --save-result +""" + +import argparse +import asyncio +import gc +import json +import os +import random +import sys +import time +import traceback +from argparse import ArgumentParser as FlexibleArgumentParser +from collections.abc import AsyncGenerator +from dataclasses import dataclass +from datetime import datetime +from typing import Any, Optional + +import aiohttp +import numpy as np +import yaml +from tqdm.asyncio import tqdm + +from tensorrt_llm.bench.benchmark.visual_gen_utils import ( + VisualGenRequestOutput, + VisualGenSampleRequest, + build_visual_gen_result_dict, + calculate_metrics, + load_visual_gen_prompts, + print_visual_gen_results, +) + +AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) + + +@dataclass +class VisualGenRequestInput: + """HTTP request payload for online (server) benchmarking.""" + + prompt: str + api_url: str + model: str + size: str = "auto" + seconds: float = 4.0 + fps: int = 24 + num_inference_steps: Optional[int] = None + guidance_scale: Optional[float] = None + negative_prompt: Optional[str] = None + seed: Optional[int] = None + extra_body: Optional[dict] = None + + +def _build_payload_common(request_input: VisualGenRequestInput) -> dict: + """Build common payload fields shared by image and video generation.""" + payload: dict[str, Any] = { + "model": request_input.model, + "prompt": request_input.prompt, + "size": request_input.size, + } + if request_input.num_inference_steps is not None: + payload["num_inference_steps"] = request_input.num_inference_steps + if request_input.guidance_scale is not None: + payload["guidance_scale"] = request_input.guidance_scale + if request_input.negative_prompt is not None: + payload["negative_prompt"] = request_input.negative_prompt + if request_input.seed is not None: + payload["seed"] = request_input.seed + if request_input.extra_body: + payload.update(request_input.extra_body) + return payload + + +def _get_headers() -> dict[str, str]: + return { + "Content-Type": "application/json", + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY', 'unused')}", + } + + +async def _do_post( + request_input: VisualGenRequestInput, + payload: dict[str, Any], + pbar: Optional[tqdm], + session: Optional[aiohttp.ClientSession], +) -> VisualGenRequestOutput: + """Execute HTTP POST, measure E2E latency, return output.""" + request_session = session or aiohttp.ClientSession( + trust_env=True, + timeout=AIOHTTP_TIMEOUT, + connector=aiohttp.TCPConnector(limit=0, limit_per_host=0), + ) + + output = VisualGenRequestOutput() + st = time.perf_counter() + try: + async with request_session.post( + url=request_input.api_url, json=payload, headers=_get_headers() + ) as response: + if response.status == 200: + await response.read() + output.success = True + output.e2e_latency = time.perf_counter() - st + else: + body = await response.text() + output.error = f"HTTP {response.status}: {body}" + output.success = False + except Exception as e: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + output.exception_type = e.__class__.__name__ + finally: + if session is None: + await request_session.close() + + if pbar: + pbar.update(1) + return output + + +async def async_request_image_generation( + request_input: VisualGenRequestInput, + pbar: Optional[tqdm] = None, + session: Optional[aiohttp.ClientSession] = None, +) -> VisualGenRequestOutput: + """POST /v1/images/generations and measure E2E latency.""" + payload = _build_payload_common(request_input) + payload["response_format"] = "b64_json" + payload["n"] = 1 + return await _do_post(request_input, payload, pbar, session) + + +async def async_request_video_generation( + request_input: VisualGenRequestInput, + pbar: Optional[tqdm] = None, + session: Optional[aiohttp.ClientSession] = None, +) -> VisualGenRequestOutput: + """POST /v1/videos/generations (sync endpoint) and measure E2E latency.""" + payload = _build_payload_common(request_input) + payload["seconds"] = request_input.seconds + payload["fps"] = request_input.fps + return await _do_post(request_input, payload, pbar, session) + + +VISUAL_GEN_REQUEST_FUNCS = { + "openai-images": async_request_image_generation, + "openai-videos": async_request_video_generation, +} + + +async def get_request( + input_requests: list[VisualGenSampleRequest], + request_rate: float, + burstiness: float = 1.0, +) -> AsyncGenerator[VisualGenSampleRequest, None]: + """Asynchronously generates requests at a specified rate with optional burstiness.""" + assert burstiness > 0, f"A positive burstiness factor is expected, but given {burstiness}." + theta = 1.0 / (request_rate * burstiness) + for request in input_requests: + yield request + if request_rate == float("inf"): + continue + interval = np.random.gamma(shape=burstiness, scale=theta) + await asyncio.sleep(interval) + + +async def benchmark( + backend: str, + api_url: str, + model_id: str, + input_requests: list[VisualGenSampleRequest], + request_rate: float, + burstiness: float, + disable_tqdm: bool, + selected_percentiles: list[float], + max_concurrency: Optional[int], + gen_params: dict[str, Any], + extra_body: Optional[dict], + no_test_input: bool = False, + request_timeout: float = 6 * 60 * 60, + num_gpus: int = 1, +) -> dict[str, Any]: + if backend not in VISUAL_GEN_REQUEST_FUNCS: + raise ValueError( + f"Unknown backend: {backend}. Available: {list(VISUAL_GEN_REQUEST_FUNCS.keys())}" + ) + + request_func = VISUAL_GEN_REQUEST_FUNCS[backend] + + def _make_request_input(prompt: str) -> VisualGenRequestInput: + return VisualGenRequestInput( + prompt=prompt, + api_url=api_url, + model=model_id, + size=gen_params.get("size", "auto"), + seconds=gen_params.get("seconds", 4.0), + fps=gen_params.get("fps", 24), + num_inference_steps=gen_params.get("num_inference_steps"), + guidance_scale=gen_params.get("guidance_scale"), + negative_prompt=gen_params.get("negative_prompt"), + seed=gen_params.get("seed"), + extra_body=extra_body, + ) + + if not no_test_input: + print("Starting initial single prompt test run...") + test_input = _make_request_input(input_requests[0].prompt) + test_output = await request_func(request_input=test_input) + if not test_output.success: + raise ValueError( + "Initial test run failed - Please make sure benchmark " + "arguments are correctly specified. " + f"Error: {test_output.error}" + ) + else: + print("Initial test run completed. Starting main benchmark run...") + else: + print("Skipping initial test run. Starting main benchmark run...") + + if burstiness == 1.0: + distribution = "Poisson process" + else: + distribution = "Gamma distribution" + + print(f"Traffic request rate: {request_rate}") + print(f"Burstiness factor: {burstiness} ({distribution})") + print(f"Maximum request concurrency: {max_concurrency}") + + pbar = None if disable_tqdm else tqdm(total=len(input_requests), desc="Benchmarking") + + semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None + + async def limited_request_func(req_input, pbar_ref, sess): + if semaphore is None: + return await request_func(request_input=req_input, pbar=pbar_ref, session=sess) + async with semaphore: + return await request_func(request_input=req_input, pbar=pbar_ref, session=sess) + + timeout = aiohttp.ClientTimeout(total=request_timeout) + benchmark_start_time = time.perf_counter() + tasks: list[asyncio.Task] = [] + async with aiohttp.ClientSession( + trust_env=True, + timeout=timeout, + connector=aiohttp.TCPConnector(limit=0, limit_per_host=0, force_close=True), + ) as session: + async for request in get_request(input_requests, request_rate, burstiness): + request_input = _make_request_input(request.prompt) + tasks.append(asyncio.create_task(limited_request_func(request_input, pbar, session))) + + outputs: list[VisualGenRequestOutput] = await asyncio.gather(*tasks) + + if pbar is not None: + pbar.close() + + benchmark_duration = time.perf_counter() - benchmark_start_time + + metrics = calculate_metrics( + outputs=outputs, + dur_s=benchmark_duration, + selected_percentiles=selected_percentiles, + num_gpus=num_gpus, + ) + + print_visual_gen_results(backend, model_id, benchmark_duration, metrics) + + result = build_visual_gen_result_dict( + backend=backend, + model_id=model_id, + benchmark_duration=benchmark_duration, + metrics=metrics, + outputs=outputs, + gen_params=gen_params, + ) + + return result + + +def load_prompts(args: argparse.Namespace) -> list[VisualGenSampleRequest]: + """Load prompts from --prompt or --prompt-file (delegates to shared util).""" + return load_visual_gen_prompts(args.prompt, args.prompt_file, args.num_prompts) + + +def _resolve_num_gpus(args: argparse.Namespace) -> int: + """Determine the number of GPUs from explicit arg or server config YAML. + + Priority: --num-gpus (explicit) > --extra-visual-gen-options YAML > default 1. + """ + if args.num_gpus is not None: + return args.num_gpus + + if args.extra_visual_gen_options is not None: + with open(args.extra_visual_gen_options, encoding="utf-8") as f: + config = yaml.safe_load(f) or {} + from tensorrt_llm.commands.utils import get_visual_gen_num_gpus + + return get_visual_gen_num_gpus(config) + + return 1 + + +def main(args: argparse.Namespace): + print(args) + random.seed(args.seed) + np.random.seed(args.seed) + + backend = args.backend + model_id = args.model + + endpoint_map = { + "openai-images": "/v1/images/generations", + "openai-videos": "/v1/videos/generations", + } + endpoint = args.endpoint or endpoint_map.get(backend) + if endpoint is None: + raise ValueError( + f"Cannot resolve endpoint for backend '{backend}'. " + "Please specify --endpoint explicitly." + ) + + if args.base_url is not None: + api_url = f"{args.base_url}{endpoint}" + else: + api_url = f"http://{args.host}:{args.port}{endpoint}" + + input_requests = load_prompts(args) + + seconds = args.seconds + if args.num_frames is not None: + seconds = args.num_frames / args.fps + print(f"Computed seconds={seconds:.3f} from num_frames={args.num_frames} / fps={args.fps}") + + gen_params: dict[str, Any] = { + "size": args.size, + "seconds": seconds, + "fps": args.fps, + } + if args.num_inference_steps is not None: + gen_params["num_inference_steps"] = args.num_inference_steps + if args.guidance_scale is not None: + gen_params["guidance_scale"] = args.guidance_scale + if args.negative_prompt is not None: + gen_params["negative_prompt"] = args.negative_prompt + if args.seed is not None: + gen_params["seed"] = args.seed + + extra_body = None + if args.extra_body: + try: + extra_body = json.loads(args.extra_body) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON in --extra-body: {e}") from e + + num_gpus = _resolve_num_gpus(args) + + gc.disable() + + benchmark_result = asyncio.run( + benchmark( + backend=backend, + api_url=api_url, + model_id=model_id, + input_requests=input_requests, + request_rate=args.request_rate, + burstiness=args.burstiness, + disable_tqdm=args.disable_tqdm, + selected_percentiles=[float(p) for p in args.metric_percentiles.split(",")], + max_concurrency=args.max_concurrency, + gen_params=gen_params, + extra_body=extra_body, + no_test_input=args.no_test_input, + request_timeout=args.request_timeout, + num_gpus=num_gpus, + ) + ) + + if args.save_result: + result_json: dict[str, Any] = {} + + current_dt = datetime.now().strftime("%Y%m%d-%H%M%S") + result_json["date"] = current_dt + result_json["backend"] = backend + result_json["model_id"] = model_id + result_json["num_prompts"] = args.num_prompts + + if args.metadata: + for item in args.metadata: + if "=" in item: + key, value = item.split("=", 1) + result_json[key.strip()] = value.strip() + else: + raise ValueError("Invalid metadata format. Please use KEY=VALUE format.") + + result_json = {**result_json, **benchmark_result} + + if not args.save_detailed: + for field_name in ["e2e_latencies", "errors"]: + result_json.pop(field_name, None) + + result_json["request_rate"] = ( + args.request_rate if args.request_rate < float("inf") else "inf" + ) + result_json["burstiness"] = args.burstiness + result_json["max_concurrency"] = args.max_concurrency + result_json["num_gpus"] = num_gpus + + base_model_id = model_id.split("/")[-1] + max_concurrency_str = ( + f"-concurrency{args.max_concurrency}" if args.max_concurrency is not None else "" + ) + file_name = ( + f"{backend}-{args.request_rate}qps" + f"{max_concurrency_str}-{base_model_id}" + f"-{current_dt}.json" + ) + if args.result_filename: + file_name = args.result_filename + if args.result_dir: + os.makedirs(args.result_dir, exist_ok=True) + file_name = os.path.join(args.result_dir, file_name) + + with open(file_name, "w", encoding="utf-8") as outfile: + json.dump(result_json, outfile, indent=2) + + print(f"Results saved to: {file_name}") + + +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description="Benchmark VisualGen (image/video generation) serving." + ) + + parser.add_argument( + "--backend", + type=str, + default="openai-videos", + choices=list(VISUAL_GEN_REQUEST_FUNCS.keys()), + help="Backend API type.", + ) + parser.add_argument( + "--model", + type=str, + required=True, + help="HuggingFace model ID (e.g. Wan-AI/Wan2.1-T2V-14B).", + ) + parser.add_argument("--host", type=str, default="127.0.0.1", help="Server host.") + parser.add_argument("--port", type=int, default=8000, help="Server port.") + parser.add_argument( + "--base-url", type=str, default=None, help="Full base URL (overrides --host/--port)." + ) + parser.add_argument( + "--endpoint", + type=str, + default=None, + help="API endpoint path (auto-resolved from backend if not specified).", + ) + + prompt_group = parser.add_mutually_exclusive_group() + prompt_group.add_argument( + "--prompt", + type=str, + default=None, + help="Single text prompt (repeated --num-prompts times).", + ) + prompt_group.add_argument( + "--prompt-file", + type=str, + default=None, + help="Path to prompt file. Supports plain text (one prompt " + "per line) or JSONL with 'text'/'prompt' field.", + ) + parser.add_argument( + "--num-prompts", type=int, default=5, help="Number of prompts to benchmark." + ) + + gen_group = parser.add_argument_group("Generation Parameters") + gen_group.add_argument( + "--size", + type=str, + default="auto", + help="Output resolution in WxH format (e.g. 480x832) or 'auto'.", + ) + gen_group.add_argument("--seconds", type=float, default=4.0, help="Video duration in seconds.") + gen_group.add_argument("--fps", type=int, default=16, help="Frames per second.") + gen_group.add_argument( + "--num-frames", + type=int, + default=None, + help="Total frames to generate. Overrides --seconds (computed as num_frames / fps).", + ) + gen_group.add_argument( + "--num-inference-steps", type=int, default=None, help="Number of diffusion denoising steps." + ) + gen_group.add_argument( + "--guidance-scale", type=float, default=None, help="Classifier-free guidance scale." + ) + gen_group.add_argument( + "--seed", type=int, default=None, help="Random seed for reproducibility." + ) + gen_group.add_argument( + "--negative-prompt", type=str, default=None, help="Negative prompt (concepts to avoid)." + ) + gen_group.add_argument( + "--extra-body", + type=str, + default=None, + help="JSON string of extra request body parameters (e.g. '{\"guidance_rescale\": 0.7}').", + ) + + traffic_group = parser.add_argument_group("Traffic Control") + traffic_group.add_argument( + "--request-rate", + type=float, + default=float("inf"), + help="Request rate (req/s). Default inf sends all at once.", + ) + traffic_group.add_argument( + "--burstiness", + type=float, + default=1.0, + help="Burstiness factor for request generation. 1.0 = Poisson process.", + ) + traffic_group.add_argument( + "--max-concurrency", type=int, default=None, help="Maximum concurrent requests." + ) + traffic_group.add_argument( + "--request-timeout", + type=float, + default=6 * 60 * 60, + help="Request timeout in seconds (default: 6 hours).", + ) + + parser.add_argument( + "--extra-visual-gen-options", + type=str, + default=None, + help="Path to the server config YAML (same file passed to trtllm-serve " + "via --extra_visual_gen_options). Parallelism settings are read to " + "automatically determine the number of GPUs.", + ) + parser.add_argument( + "--num-gpus", + type=int, + default=None, + help="Number of GPUs used by the server. Overrides the value inferred " + "from --extra-visual-gen-options. Defaults to 1 if neither is given.", + ) + + output_group = parser.add_argument_group("Output") + output_group.add_argument( + "--save-result", action="store_true", help="Save results to JSON file." + ) + output_group.add_argument( + "--save-detailed", action="store_true", help="Include per-request details in saved results." + ) + output_group.add_argument( + "--result-dir", type=str, default=None, help="Directory for result files." + ) + output_group.add_argument( + "--result-filename", type=str, default=None, help="Custom result filename." + ) + output_group.add_argument( + "--metric-percentiles", + type=str, + default="50,90,99", + help="Comma-separated percentile values (default: '50,90,99').", + ) + output_group.add_argument( + "--metadata", + type=str, + nargs="*", + default=None, + help="Key=value pairs to add to result metadata.", + ) + + parser.add_argument("--disable-tqdm", action="store_true", help="Disable progress bar.") + parser.add_argument( + "--no-test-input", action="store_true", help="Skip the initial single-prompt test run." + ) + + args = parser.parse_args() + + if args.prompt is None and args.prompt_file is None: + parser.error("Either --prompt or --prompt-file must be specified.") + + main(args) diff --git a/tests/integration/defs/examples/test_visual_gen.py b/tests/integration/defs/examples/test_visual_gen.py index 7b876fee3d0..65bdb2bedff 100644 --- a/tests/integration/defs/examples/test_visual_gen.py +++ b/tests/integration/defs/examples/test_visual_gen.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -20,8 +20,8 @@ import pytest import torch +from defs import conftest from defs.common import venv_check_call -from defs.conftest import llm_models_root from defs.trt_test_alternative import check_call WAN_T2V_MODEL_SUBPATH = "Wan2.1-T2V-1.3B-Diffusers" @@ -163,7 +163,7 @@ def _generate_wan_video(llm_venv, llm_root, model_subpath, output_subdir): Returns the path to the generated .mp4, or calls pytest.skip if the model is not found under LLM_MODELS_ROOT. """ - scratch_space = llm_models_root() + scratch_space = conftest.llm_models_root() model_path = os.path.join(scratch_space, model_subpath) if not os.path.isdir(model_path): pytest.skip( @@ -373,3 +373,19 @@ def test_vbench_dimension_score_wan22_a14b_nvfp4( golden_scores=VBENCH_WAN22_A14B_NVFP4_GOLDEN_SCORES, max_score_diff=0.05, ) + + +def test_visual_gen_benchmark_serving(llm_venv): + """Run benchmark_visual_gen.py against a live trtllm-serve visual-gen server.""" + test_root = conftest.unittest_path() / "_torch" / "visual_gen" + llm_venv.run_cmd( + [ + "-m", + "pytest", + "-v", + str( + test_root / "_test_trtllm_serve_visual_gen_benchmark.py" + "::test_visual_gen_benchmark_video[openai-videos]" + ), + ] + ) diff --git a/tests/integration/defs/visual_gen/__init__.py b/tests/integration/defs/visual_gen/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/integration/defs/visual_gen/test_visual_gen_benchmark.py b/tests/integration/defs/visual_gen/test_visual_gen_benchmark.py new file mode 100644 index 00000000000..19cae81bd24 --- /dev/null +++ b/tests/integration/defs/visual_gen/test_visual_gen_benchmark.py @@ -0,0 +1,418 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""E2E tests for VisualGen benchmarking (online serving and offline trtllm-bench). + +Online tests launch a trtllm-serve server and run benchmark_visual_gen.py against it. +Offline tests run trtllm-bench visual-gen directly (no server). +Both require GPU and model weights. +""" + +import json +import os +import subprocess +import sys +import tempfile +import time +from pathlib import Path +from typing import List, Optional + +import pytest +import requests +import yaml + +from defs import conftest +from tensorrt_llm._utils import get_free_port + +# --------------------------------------------------------------------------- +# Model discovery +# --------------------------------------------------------------------------- + +_WAN_T2V_MODEL = "Wan2.1-T2V-1.3B-Diffusers" + + +def _wan_t2v_path() -> Path: + """Resolve the Wan T2V model path, or call pytest.skip if unavailable.""" + root = Path(conftest.llm_models_root()) + model_path = root / _WAN_T2V_MODEL + if not model_path.is_dir(): + pytest.skip( + f"Wan T2V model not found: {model_path} " + f"(set LLM_MODELS_ROOT or place {_WAN_T2V_MODEL} under scratch)" + ) + return model_path + + +# Common small-scale generation params for fast CI +_SMALL_GEN_PARAMS = { + "size": "480x320", + "num_frames": "9", + "fps": "8", + "num_inference_steps": "4", + "seed": "42", +} + + +def _make_visual_gen_options(**extra) -> dict: + """Build a minimal VisualGen YAML config dict.""" + config = { + "linear": {"type": "default"}, + "parallel": {"dit_cfg_size": 1, "dit_ulysses_size": 1}, + } + config.update(extra) + return config + + +def _write_config_file(config: dict, tmp_dir: Path) -> str: + """Write config dict to a temp YAML file and return the path.""" + config_file = tmp_dir / "visual_gen_config.yml" + with open(config_file, "w") as f: + yaml.dump(config, f) + return str(config_file) + + +# --------------------------------------------------------------------------- +# Remote server helper (for online benchmark tests) +# --------------------------------------------------------------------------- + + +class RemoteVisualGenServer: + MAX_SERVER_START_WAIT_S = 1200 + + def __init__( + self, + model: str, + extra_visual_gen_options: Optional[dict] = None, + cli_args: Optional[List[str]] = None, + host: str = "localhost", + port: Optional[int] = None, + ) -> None: + self.host = host + self.port = port if port is not None else get_free_port() + self._config_file: Optional[str] = None + self.proc: Optional[subprocess.Popen] = None + + args = ["--host", self.host, "--port", str(self.port)] + if cli_args: + args += cli_args + + if extra_visual_gen_options: + fd, self._config_file = tempfile.mkstemp(suffix=".yml", prefix="vg_bench_cfg_") + with os.fdopen(fd, "w") as f: + yaml.dump(extra_visual_gen_options, f) + args += ["--extra_visual_gen_options", self._config_file] + + launch_cmd = ["trtllm-serve", model] + args + self.proc = subprocess.Popen( + launch_cmd, + env=os.environ.copy(), + stdout=sys.stdout, + stderr=sys.stderr, + ) + self._wait_for_server(timeout=self.MAX_SERVER_START_WAIT_S) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.terminate() + + def terminate(self): + if self.proc is None: + return + self.proc.terminate() + try: + self.proc.wait(timeout=30) + except subprocess.TimeoutExpired: + self.proc.kill() + self.proc.wait(timeout=30) + self.proc = None + if self._config_file: + try: + os.remove(self._config_file) + except OSError: + pass + self._config_file = None + + def _wait_for_server(self, timeout: float): + url = f"http://{self.host}:{self.port}/health" + start = time.time() + while True: + try: + if requests.get(url, timeout=5).status_code == 200: + return + except requests.RequestException as err: + result = self.proc.poll() + if result is not None and result != 0: + raise RuntimeError("Visual-gen server exited unexpectedly.") from err + time.sleep(2) + if time.time() - start > timeout: + self.terminate() + raise RuntimeError(f"Visual-gen server failed to start within {timeout}s.") + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="module") +def server(): + model_path = _wan_t2v_path() + with RemoteVisualGenServer( + model=str(model_path), + extra_visual_gen_options=_make_visual_gen_options(), + ) as srv: + yield srv + + +@pytest.fixture(scope="module") +def benchmark_script(): + llm_root = os.getenv("LLM_ROOT") + if llm_root is None: + llm_root = str(Path(__file__).resolve().parents[4]) + return os.path.join( + llm_root, + "tensorrt_llm", + "serve", + "scripts", + "benchmark_visual_gen.py", + ) + + +# =========================================================================== +# Online benchmark tests (trtllm-serve + benchmark_visual_gen.py) +# =========================================================================== + + +@pytest.mark.parametrize("backend", ["openai-videos"]) +def test_online_benchmark_video( + server: RemoteVisualGenServer, + benchmark_script: str, + backend: str, +): + """Run benchmark_visual_gen.py for video generation and validate output.""" + cmd = [ + sys.executable, + benchmark_script, + "--backend", + backend, + "--model", + _WAN_T2V_MODEL, + "--host", + server.host, + "--port", + str(server.port), + "--prompt", + "A cat walking in a garden", + "--num-prompts", + "2", + "--size", + _SMALL_GEN_PARAMS["size"], + "--num-frames", + _SMALL_GEN_PARAMS["num_frames"], + "--fps", + _SMALL_GEN_PARAMS["fps"], + "--num-inference-steps", + _SMALL_GEN_PARAMS["num_inference_steps"], + "--seed", + _SMALL_GEN_PARAMS["seed"], + "--max-concurrency", + "1", + "--disable-tqdm", + ] + + result = subprocess.run( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + check=True, + ) + + assert result.returncode == 0 + assert "Benchmark Result (VisualGen)" in result.stdout + + +@pytest.mark.parametrize("backend", ["openai-videos"]) +def test_online_benchmark_save_result( + server: RemoteVisualGenServer, + benchmark_script: str, + backend: str, + tmp_path, +): + """Verify online benchmark --save-result produces a valid JSON file.""" + result_dir = str(tmp_path / "results") + cmd = [ + sys.executable, + benchmark_script, + "--backend", + backend, + "--model", + _WAN_T2V_MODEL, + "--host", + server.host, + "--port", + str(server.port), + "--prompt", + "A bird flying over the ocean", + "--num-prompts", + "1", + "--size", + _SMALL_GEN_PARAMS["size"], + "--num-frames", + _SMALL_GEN_PARAMS["num_frames"], + "--fps", + _SMALL_GEN_PARAMS["fps"], + "--num-inference-steps", + _SMALL_GEN_PARAMS["num_inference_steps"], + "--seed", + _SMALL_GEN_PARAMS["seed"], + "--max-concurrency", + "1", + "--save-result", + "--result-dir", + result_dir, + "--disable-tqdm", + ] + + result = subprocess.run( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + check=True, + ) + + assert result.returncode == 0 + assert "Benchmark Result (VisualGen)" in result.stdout + + result_files = list(Path(result_dir).glob("*.json")) + assert len(result_files) >= 1, f"No JSON result file found in {result_dir}" + + with open(result_files[0]) as f: + data = json.load(f) + assert "completed" in data + assert data["completed"] >= 1 + assert "mean_e2e_latency_ms" in data + + +# =========================================================================== +# Offline benchmark tests (trtllm-bench visual-gen) +# =========================================================================== + + +def test_offline_benchmark(tmp_path): + """Run trtllm-bench visual-gen and validate output.""" + model_path = _wan_t2v_path() + config_file = _write_config_file(_make_visual_gen_options(), tmp_path) + + cmd = [ + "trtllm-bench", + "--model", + str(model_path), + "--model_path", + str(model_path), + "visual-gen", + "--extra_visual_gen_options", + config_file, + "--prompt", + "A cat walking in a garden", + "--num_prompts", + "2", + "--size", + _SMALL_GEN_PARAMS["size"], + "--num_frames", + _SMALL_GEN_PARAMS["num_frames"], + "--fps", + _SMALL_GEN_PARAMS["fps"], + "--num_inference_steps", + _SMALL_GEN_PARAMS["num_inference_steps"], + "--seed", + _SMALL_GEN_PARAMS["seed"], + "--max_concurrency", + "1", + "--warmup", + "1", + ] + + result = subprocess.run( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + check=True, + ) + + assert result.returncode == 0 + assert "Benchmark Result (VisualGen)" in result.stdout + + +def test_offline_benchmark_save_result(tmp_path): + """Verify trtllm-bench visual-gen --save_result produces valid JSON.""" + model_path = _wan_t2v_path() + config_file = _write_config_file(_make_visual_gen_options(), tmp_path) + result_dir = str(tmp_path / "results") + + cmd = [ + "trtllm-bench", + "--model", + str(model_path), + "--model_path", + str(model_path), + "visual-gen", + "--extra_visual_gen_options", + config_file, + "--prompt", + "A bird flying over the ocean", + "--num_prompts", + "1", + "--size", + _SMALL_GEN_PARAMS["size"], + "--num_frames", + _SMALL_GEN_PARAMS["num_frames"], + "--fps", + _SMALL_GEN_PARAMS["fps"], + "--num_inference_steps", + _SMALL_GEN_PARAMS["num_inference_steps"], + "--seed", + _SMALL_GEN_PARAMS["seed"], + "--max_concurrency", + "1", + "--warmup", + "0", + "--save_result", + "--result_dir", + result_dir, + ] + + result = subprocess.run( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + check=True, + ) + + assert result.returncode == 0 + assert "Benchmark Result (VisualGen)" in result.stdout + + result_files = list(Path(result_dir).glob("*.json")) + assert len(result_files) >= 1, f"No JSON result file found in {result_dir}" + + with open(result_files[0]) as f: + data = json.load(f) + assert "completed" in data + assert data["completed"] >= 1 + assert "mean_e2e_latency_ms" in data From 2ee7dbae9d36818705654d81b9a874e90be730b0 Mon Sep 17 00:00:00 2001 From: Jin Li <59594262+liji-nv@users.noreply.github.com> Date: Thu, 5 Mar 2026 17:18:33 +0800 Subject: [PATCH 7/7] [None][feat] Run extra general warmup to warm up memory pool (#10340) Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com> --- .../fla/fused_sigmoid_gating_recurrent.py | 4 +- .../_torch/pyexecutor/model_engine.py | 60 ++++++++++++------- tests/integration/test_lists/waives.txt | 1 - 3 files changed, 39 insertions(+), 26 deletions(-) diff --git a/tensorrt_llm/_torch/modules/fla/fused_sigmoid_gating_recurrent.py b/tensorrt_llm/_torch/modules/fla/fused_sigmoid_gating_recurrent.py index 70589b762de..87902a68fe5 100644 --- a/tensorrt_llm/_torch/modules/fla/fused_sigmoid_gating_recurrent.py +++ b/tensorrt_llm/_torch/modules/fla/fused_sigmoid_gating_recurrent.py @@ -44,7 +44,7 @@ def fused_sigmoid_gating_delta_rule_update_kernel( """ Fused kernel that combines sigmoid gating computation with recurrent delta rule update. """ - i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_nh, i_v, i_k = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_n, i_hv = i_nh // HV, i_nh % HV i_h = i_hv // (HV // H) @@ -189,7 +189,7 @@ def fused_sigmoid_gating_delta_rule_update( assert scale > 0, "scale must be positive" o = q.new_empty(NK, *v.shape) - grid = (NK, NV, N * HV) + grid = (N * HV, NV, NK) fused_sigmoid_gating_delta_rule_update_kernel[grid]( A_log=A_log, diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 1bedaffccf3..f1f2174adbc 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -677,13 +677,18 @@ def warmup(self, resource_manager: ResourceManager) -> None: if not self.mapping.has_cp_helix(): self._run_autotuner_warmup(resource_manager) self._run_cuda_graph_warmup(resource_manager) - - # Set the value back to the original value after all warmups are complete - self.enable_spec_decode = self.is_spec_decode + if not self.is_draft_model and not self.mapping.has_cp_helix( + ) and self.guided_decoder is None: + # Run extra general warmup to warmup memory pool before running real requests to reduce memory fragmentation. + self._general_warmup(resource_manager, reverse=True) def _general_warmup(self, resource_manager: ResourceManager, reverse: bool = False): + """ + A General warmup to warmup with several different requests. + It is used to warmup torch.compile path and warmup memory pool before running real requests. + """ kv_cache_manager = resource_manager.get_resource_manager( self.kv_cache_manager_key) token_num_upper_bound = min(self.max_num_tokens, @@ -692,8 +697,8 @@ def _general_warmup(self, token_num_upper_bound=token_num_upper_bound, max_num_draft_tokens=self.original_max_draft_len) max_batch_size = min( - self.batch_size, - curr_max_num_tokens // (1 + self.runtime_draft_len)) + self.batch_size, curr_max_num_tokens // + (1 + self.runtime_draft_len) // self.max_beam_width) warmup_requests_configs = { (1, 1), # Specialize for 1 token. @@ -706,19 +711,28 @@ def _general_warmup(self, reverse=reverse) for num_tokens, num_gen_tokens in warmup_requests_configs: - with self._release_batch_context( - self._create_warmup_request(resource_manager, num_tokens, - num_gen_tokens), - resource_manager) as batch: - if batch is None: - continue # Not enough KV cache space - logger.info( - f"Run warmup with {num_tokens} tokens, include {num_gen_tokens} generation tokens" - ) - self.forward(batch, - new_tensors_device=None, - resource_manager=resource_manager) - torch.cuda.synchronize() + # Helix CP does not support warmup with context requests. + if self.mapping.has_cp_helix() and num_tokens != num_gen_tokens: + continue + try: + with self._release_batch_context( + self._create_warmup_request(resource_manager, + num_tokens, num_gen_tokens), + resource_manager) as batch: + if batch is None: + continue # Not enough KV cache space + logger.info( + f"Run warmup with {num_tokens} tokens, include {num_gen_tokens} generation tokens" + ) + self.forward(batch, + new_tensors_device=None, + resource_manager=resource_manager) + torch.cuda.synchronize() + except torch.OutOfMemoryError: + logger.warning( + f"OOM during general warmup with {num_tokens} tokens, " + f"{num_gen_tokens} generation tokens. Skipping.") + torch.cuda.empty_cache() def _run_torch_compile_warmup(self, resource_manager: ResourceManager): """Runs warmup iterations to specialize torch.compile kernels.""" @@ -868,6 +882,8 @@ def _capture_generation_cuda_graphs(self, new_tensors_device=None, resource_manager=resource_manager) torch.cuda.synchronize() + # Set the value back to the original value after cuda graph warmups are complete + self.enable_spec_decode = self.is_spec_decode def _capture_piecewise_cuda_graphs(self, resource_manager: ResourceManager): """Captures piecewise CUDA graphs for context/prefill steps via torch.compile.""" @@ -1025,8 +1041,8 @@ def _create_warmup_request( blocks_to_use = num_full_seqs * math.ceil( max_seq_len / kv_cache_manager.tokens_per_block) + math.ceil( - num_left_over_tokens / - kv_cache_manager.tokens_per_block) + num_gen_requests + num_left_over_tokens / kv_cache_manager.tokens_per_block + ) + num_gen_requests * self.max_beam_width if blocks_to_use > available_blocks and isinstance( kv_cache_manager, KVCacheManager): @@ -2782,8 +2798,6 @@ def previous_seq_slots_device(): num_generation_requests = len(gen_request_seq_slots) # Cache indirection is only used for beam search on generation requests if self.use_beam_search and num_generation_requests > 0: - # CUDA Graph needs to set beam width during warmup (where the graph is captured), to ensure that cache indirection buffer is correctly picked up by the CUDA graph - is_cuda_graph_during_warmup = self.is_warmup and attn_metadata.is_cuda_graph if cache_indirection_buffer is not None: #Copy cache indirection to local buffer with offsets changing: seq_slots[i] -> i # Convert to GPU tensor to avoid implicit sync @@ -2794,7 +2808,7 @@ def previous_seq_slots_device(): non_blocking=True) self.cache_indirection_attention[:num_generation_requests].copy_( cache_indirection_buffer[gen_request_seq_slots_tensor]) - if cache_indirection_buffer is not None or is_cuda_graph_during_warmup: + if cache_indirection_buffer is not None or self.is_warmup: attn_metadata.beam_width = self.max_beam_width else: attn_metadata.beam_width = 1 diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index d2495def9e2..e48f2ac30b9 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -253,7 +253,6 @@ accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_nvfp4_tp4[torch accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-ep4-trtllm-auto] SKIP (https://nvbugs/5651865) accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-ep4-trtllm-auto] SKIP (https://nvbugs/5651865) accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=2-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False-sampler_async_worker=False] SKIP (https://nvbugs/5701445) -accuracy/test_llm_api_pytorch.py::TestLlama4ScoutInstruct::test_fp4[tp4-cuda_graph=True] SKIP (https://nvbugs/5820734) accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-ep4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5800672) accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-tp2pp2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5800672) accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[tp4-attn_backend=TRTLLM-torch_compile=True] SKIP (https://nvbugs/5826604)