From 958c5839706383688f7058e28e95a6b25a2944ad Mon Sep 17 00:00:00 2001 From: xinhe-nv <200704525+xinhe-nv@users.noreply.github.com> Date: Mon, 20 Apr 2026 17:43:10 +0800 Subject: [PATCH 1/8] [None][infra] Waive 9 failed cases for main in post-merge (#13204) Signed-off-by: xinhe-nv <200704525+xinhe-nv@users.noreply.github.com> --- tests/integration/test_lists/waives.txt | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index d85dff82bd24..62debb3a13ef 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -370,6 +370,15 @@ accuracy/test_llm_api_pytorch.py::TestNemotronNas::test_auto_dtype_tp8 SKIP (htt accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_bf16_4gpu[tp4ep4_cudagraph_overlap_adp_on] SKIP (https://nvbugs/6094068) accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.9-fp8kv=True] SKIP (https://nvbugs/6094066) accuracy/test_llm_api_autodeploy.py::TestModelRegistryAccuracy::test_autodeploy_from_registry[nvidia_Llama-3.1-8B-Instruct-NVFP4-True] SKIP (https://nvbugs/6093715) +disaggregated/test_auto_scaling.py::test_service_discovery[etcd-kv_cache_aware] SKIP (https://nvbugs/6094100) +disaggregated/test_disaggregated.py::test_disaggregated_chat_completion_tool_calls[TinyLlama-1.1B-Chat-v1.0] SKIP (https://nvbugs/6094100) +disaggregated/test_auto_scaling.py::test_service_discovery[http-load_balancing] SKIP (https://nvbugs/6094100) +disaggregated/test_auto_scaling.py::test_minimal_instances[http-round_robin] SKIP (https://nvbugs/6094100) +disaggregated/test_disaggregated.py::test_disaggregated_load_balance[TinyLlama-1.1B-Chat-v1.0] SKIP (https://nvbugs/6094100) +disaggregated/test_auto_scaling.py::test_worker_restart[http-round_robin] SKIP (https://nvbugs/6094100) +disaggregated/test_auto_scaling.py::test_service_discovery[etcd-load_balancing] SKIP (https://nvbugs/6094100) +disaggregated/test_auto_scaling.py::test_disagg_server_restart[etcd-round_robin] SKIP (https://nvbugs/6094100) +disaggregated/test_disaggregated.py::test_disaggregated_ctxpp2_gentp2[TinyLlama-1.1B-Chat-v1.0] SKIP (https://nvbugs/6094100) disaggregated/test_disaggregated.py::test_disaggregated_multi_gpu[TinyLlama-1.1B-Chat-v1.0] SKIP (https://nvbugs/6094100) disaggregated/test_auto_scaling.py::test_worker_restart[etcd-round_robin] SKIP (https://nvbugs/6094100) disaggregated/test_disaggregated.py::test_disaggregated_single_gpu_trt_backend[TinyLlama-1.1B-Chat-v1.0] SKIP (https://nvbugs/6094100) From b1a3b6632731498c029aa21eeb463bce63e50183 Mon Sep 17 00:00:00 2001 From: xinhe-nv <200704525+xinhe-nv@users.noreply.github.com> Date: Mon, 20 Apr 2026 17:58:56 +0800 Subject: [PATCH 2/8] [None][infra] Waive 6 failed cases for main in post-merge (#13195) Signed-off-by: xinhe-nv <200704525+xinhe-nv@users.noreply.github.com> --- tests/integration/test_lists/waives.txt | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index 62debb3a13ef..2de8dfab584f 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -367,6 +367,12 @@ perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-r1-fp4_ perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-r1-fp4_1k1k_con1024_ctx1_dep4_gen1_dep32_eplb0_mtp3_ccb-UCX] SKIP (https://nvbugs/6088149) perf/test_perf_sanity.py::test_e2e[aggr_upload-k25_thinking_fp4_2_nodes_grace_blackwell-k25_thinking_fp4_dep8_32k8k] SKIP (https://nvbugs/6088149) accuracy/test_llm_api_pytorch.py::TestNemotronNas::test_auto_dtype_tp8 SKIP (https://nvbugs/6070857) +accuracy/test_llm_api_autodeploy.py::TestNemotronNanoV3::test_accuracy[fp8-1-trtllm] SKIP (https://nvbugs/6094208) +accuracy/test_llm_api_autodeploy.py::TestNemotronNanoV3::test_accuracy[bf16-1-trtllm] SKIP (https://nvbugs/6094208) +accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[torch-True-1] SKIP (https://nvbugs/6093714) +accuracy/test_llm_api_autodeploy.py::TestGLM4Flash::test_auto_dtype[trtllm-True] SKIP (https://nvbugs/6093713) +accuracy/test_llm_api_autodeploy.py::TestGLM4Flash::test_auto_dtype[trtllm-False] SKIP (https://nvbugs/6093713) +accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_gen_first[ctx_tp1pp1-gen_tp1pp1] SKIP (https://nvbugs/6093712) accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_bf16_4gpu[tp4ep4_cudagraph_overlap_adp_on] SKIP (https://nvbugs/6094068) accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.9-fp8kv=True] SKIP (https://nvbugs/6094066) accuracy/test_llm_api_autodeploy.py::TestModelRegistryAccuracy::test_autodeploy_from_registry[nvidia_Llama-3.1-8B-Instruct-NVFP4-True] SKIP (https://nvbugs/6093715) From bae6e6def26a3f9fec6e2e9ea2f9231000f06468 Mon Sep 17 00:00:00 2001 From: Stanley Sun <190317771+StanleySun639@users.noreply.github.com> Date: Mon, 20 Apr 2026 17:59:39 +0800 Subject: [PATCH 3/8] [None][test] Add doc test (#13152) Signed-off-by: Stanley Sun --- tests/integration/defs/test_doc.py | 223 ++++++++++++++++++ .../test_lists/test-db/l0_h100.yml | 2 + 2 files changed, 225 insertions(+) create mode 100644 tests/integration/defs/test_doc.py diff --git a/tests/integration/defs/test_doc.py b/tests/integration/defs/test_doc.py new file mode 100644 index 000000000000..e9a4a78775e9 --- /dev/null +++ b/tests/integration/defs/test_doc.py @@ -0,0 +1,223 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import concurrent.futures +import os +import re +from collections import defaultdict +from urllib.parse import urlparse + +import pytest +import requests +from requests.adapters import HTTPAdapter +from requests.packages.urllib3.util.retry import Retry + +requests.packages.urllib3.disable_warnings( + requests.packages.urllib3.exceptions.InsecureRequestWarning +) + +# Markdown discovery filters. The walker prunes any directory whose name is in +# SKIP_DIR_NAMES or starts with a prefix in SKIP_DIR_PREFIXES, and drops any +# file in SKIP_FILENAMES (e.g., auto-generated attribution files). +SKIP_DIR_NAMES = {"3rdparty", "_deps", "build", "include", "node_modules", ".git"} +SKIP_DIR_PREFIXES = (".venv", "venv") +SKIP_FILENAMES = { + "ATTRIBUTIONS-Python.md", + "ATTRIBUTIONS-CPP-x86_64.md", + "ATTRIBUTIONS-CPP-aarch64.md", +} + +# URLs that return 404 at HTTP level but are valid in a browser +# (e.g., GitHub Pages sites using JS redirects) +EXCEPTION_URLS = [ + "https://nvidia.github.io/", +] + +HTML_LINK_PATTERN = re.compile(r']*?\s+)?href="([^"]*)"') + + +def _get_session(): + session = requests.Session() + retry = Retry(total=3, backoff_factor=0.5, status_forcelist=[429, 500, 502, 503, 504]) + adapter = HTTPAdapter(max_retries=retry) + session.mount("http://", adapter) + session.mount("https://", adapter) + return session + + +def _extract_markdown_links(text): + """Extract markdown links handling nested parentheses.""" + links = [] + i = 0 + while i < len(text): + start_bracket = text.find("[", i) + if start_bracket == -1: + break + close_bracket = text.find("]", start_bracket) + if close_bracket == -1 or close_bracket + 1 >= len(text) or text[close_bracket + 1] != "(": + i = start_bracket + 1 + continue + + open_paren = close_bracket + 1 + depth = 1 + j = open_paren + 1 + close_paren = -1 + while j < len(text) and depth > 0: + if text[j] == "(": + depth += 1 + elif text[j] == ")": + depth -= 1 + if depth == 0: + close_paren = j + j += 1 + + if close_paren != -1: + url = text[open_paren + 1 : close_paren] + links.append(url) + i = close_paren + 1 + else: + i = open_paren + 1 + return links + + +def _clean_url(url): + if url.startswith("<") and url.endswith(">"): + url = url[1:-1] + open_count = url.count("(") + close_count = url.count(")") + if open_count != close_count: + if close_count > open_count and url.endswith(")"): + while close_count > open_count and url.endswith(")"): + url = url[:-1] + close_count -= 1 + while url and url[-1] in ".,;:'\"]": + url = url[:-1] + return url.strip() + + +def _find_markdown_files(root_dir): + markdown_files = [] + for dirpath, dirnames, filenames in os.walk(root_dir): + # Prune in-place so os.walk doesn't descend into skipped dirs. + dirnames[:] = [ + d for d in dirnames if d not in SKIP_DIR_NAMES and not d.startswith(SKIP_DIR_PREFIXES) + ] + for filename in filenames: + if filename.lower().endswith(".md"): + if filename in SKIP_FILENAMES: + continue + markdown_files.append(os.path.join(dirpath, filename)) + return markdown_files + + +def _extract_urls(file_path): + """Extract and normalize URLs from a markdown file.""" + with open(file_path, "r", encoding="utf-8", errors="replace") as f: + lines = f.read().split("\n") + + url_info_list = [] + + for line_num, line in enumerate(lines, 1): + for url in _extract_markdown_links(line): + url_info_list.append((_clean_url(url), line_num)) + for match in HTML_LINK_PATTERN.finditer(line): + url_info_list.append((_clean_url(match.group(1)), line_num)) + + normalized = [] + for url, line_num in url_info_list: + if url.startswith("www."): + url = "https://" + url + if not url.startswith(("http://", "https://")): + continue + normalized.append((url, line_num)) + return normalized + + +def _check_url(url_info): + """Return (is_valid, url, line_num, reason).""" + url, line_num = url_info + + if url in EXCEPTION_URLS: + return True, url, line_num, "Known exception URL (skipped validation)" + + parsed = urlparse(url) + if not all([parsed.scheme, parsed.netloc]): + return False, url, line_num, "Invalid URL format" + if parsed.netloc in ("localhost",) or parsed.netloc.startswith("127.0.0."): + return True, url, line_num, "local" + if "drive.google.com" in parsed.netloc: + return True, url, line_num, "Google Drive (auth required)" + if parsed.netloc == "github.com" and ("/blob/" in parsed.path or "/tree/" in parsed.path): + return True, url, line_num, "GitHub repo-internal ref" + + session = _get_session() + try: + resp = session.head(url, timeout=10, allow_redirects=True, verify=False) + if resp.status_code == 404: + resp = session.get(url, timeout=10, allow_redirects=True, verify=False, stream=True) + resp.close() + if resp.status_code == 404: + return False, url, line_num, "404 Not Found" + return True, url, line_num, f"HTTP {resp.status_code}" + except requests.exceptions.RequestException as e: + if "Connection" in str(e): + return True, url, line_num, "connection issue (transient)" + return False, url, line_num, str(e) + except Exception as e: + return False, url, line_num, f"Error: {e}" + + +def test_url_validity(llm_root): + """Scan all markdown files in the repo and assert no URLs return 404.""" + md_files = _find_markdown_files(llm_root) + assert md_files, f"No markdown files found under {llm_root}" + + all_urls = [] + for md_file in md_files: + for url, line_num in _extract_urls(md_file): + all_urls.append((url, line_num, md_file)) + + if not all_urls: + pytest.skip("No URLs found in any markdown file") + + # De-duplicate URLs (check each unique URL once, keep all locations for reporting) + unique_urls = {} + for url, line_num, md_file in all_urls: + if url not in unique_urls: + unique_urls[url] = [] + unique_urls[url].append((md_file, line_num)) + + url_items = [(url, 0) for url in unique_urls] # line_num=0 placeholder + + invalid = [] + with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: + futures = {executor.submit(_check_url, item): item for item in url_items} + for future in concurrent.futures.as_completed(futures): + is_valid, url, _, reason = future.result() + if not is_valid: + for md_file, line_num in unique_urls[url]: + invalid.append((md_file, line_num, url, reason)) + + if invalid: + invalid.sort() + by_file = defaultdict(list) + for md_file, line_num, url, reason in invalid: + by_file[md_file].append((line_num, url, reason)) + report_lines = [f"Found {len(invalid)} invalid URL(s) in {len(by_file)} file(s):"] + for md_file, entries in sorted(by_file.items()): + report_lines.append(f"{md_file}:") + for line_num, url, reason in entries: + report_lines.append(f" L{line_num} [{reason}] {url}") + pytest.fail("\n".join(report_lines)) diff --git a/tests/integration/test_lists/test-db/l0_h100.yml b/tests/integration/test_lists/test-db/l0_h100.yml index 63efbd1aa88a..c920c6ef7b83 100644 --- a/tests/integration/test_lists/test-db/l0_h100.yml +++ b/tests/integration/test_lists/test-db/l0_h100.yml @@ -363,6 +363,8 @@ l0_h100: - test_e2e.py::test_draft_token_tree_quickstart_advanced_eagle3_depth_1_tree[Llama-3.1-8b-Instruct-llama-3.1-model/Llama-3.1-8B-Instruct-EAGLE3-LLaMA3.1-Instruct-8B] # https://nvbugs/5563469: Disable Nemotron-Nano-8B-v1 test due to non-deterministic failures, revisit as part of TRTLLM-7885 # - examples/test_nemotron_nas.py::test_nemotron_nano_8b_lora_torch[Llama-3.1-Nemotron-Nano-8B-v1] + # Documentation URL validation (CPU-only, no GPU needed) + - test_doc.py::test_url_validity - condition: ranges: system_gpu_count: From d08817c595b386311cede29bb402b58e70df71f6 Mon Sep 17 00:00:00 2001 From: chenfeiz0326 Date: Mon, 20 Apr 2026 18:01:28 +0800 Subject: [PATCH 4/8] [https://nvbugs/6071070][fix] Add K2.5 DISAGG Gen Only EPLB Cases into CI (#13185) Signed-off-by: Chenfei Zhang --- jenkins/L0_Test.groovy | 6 +++--- .../test-db/l0_gb200_multi_gpus_perf_sanity.yml | 8 ++++---- ..._nodes_perf_sanity_ctx1_node1_gpu4_gen1_node2_gpu8.yml | 2 +- ...nodes_perf_sanity_ctx1_node1_gpu4_gen1_node4_gpu16.yml | 4 ++-- ...nodes_perf_sanity_ctx1_node1_gpu4_gen1_node8_gpu32.yml | 8 ++++---- ..._con2048_ctx1_dep4_gen1_dep32_eplb0_mtp0_ccb-UCX.yaml} | 3 --- ..._con1024_ctx1_dep4_gen1_dep32_eplb0_mtp3_ccb-UCX.yaml} | 3 --- ..._con4096_ctx1_dep4_gen1_dep16_eplb0_mtp0_ccb-UCX.yaml} | 3 --- ..._con2048_ctx1_dep4_gen1_dep32_eplb0_mtp0_ccb-UCX.yaml} | 3 --- ..._con1024_ctx1_dep4_gen1_dep32_eplb0_mtp3_ccb-UCX.yaml} | 3 --- ..._con4096_ctx1_dep4_gen1_dep16_eplb0_mtp0_ccb-UCX.yaml} | 3 --- 11 files changed, 14 insertions(+), 32 deletions(-) rename tests/scripts/perf-sanity/disaggregated/{gb200_kimi-k2-thinking-fp4_1k1k_con2048_ctx1_dep4_gen1_dep32_eplb384_mtp0_ccb-UCX.yaml => gb200_kimi-k2-thinking-fp4_1k1k_con2048_ctx1_dep4_gen1_dep32_eplb0_mtp0_ccb-UCX.yaml} (97%) rename tests/scripts/perf-sanity/disaggregated/{gb200_kimi-k2-thinking-fp4_8k1k_con1024_ctx1_dep4_gen1_dep32_eplb416_mtp3_ccb-UCX.yaml => gb200_kimi-k2-thinking-fp4_8k1k_con1024_ctx1_dep4_gen1_dep32_eplb0_mtp3_ccb-UCX.yaml} (97%) rename tests/scripts/perf-sanity/disaggregated/{gb200_kimi-k2-thinking-fp4_8k1k_con4096_ctx1_dep4_gen1_dep16_eplb384_mtp0_ccb-UCX.yaml => gb200_kimi-k2-thinking-fp4_8k1k_con4096_ctx1_dep4_gen1_dep16_eplb0_mtp0_ccb-UCX.yaml} (97%) rename tests/scripts/perf-sanity/disaggregated/{gb200_kimi-k25-thinking-fp4_1k1k_con2048_ctx1_dep4_gen1_dep32_eplb384_mtp0_ccb-UCX.yaml => gb200_kimi-k25-thinking-fp4_1k1k_con2048_ctx1_dep4_gen1_dep32_eplb0_mtp0_ccb-UCX.yaml} (97%) rename tests/scripts/perf-sanity/disaggregated/{gb200_kimi-k25-thinking-fp4_8k1k_con1024_ctx1_dep4_gen1_dep32_eplb416_mtp3_ccb-UCX.yaml => gb200_kimi-k25-thinking-fp4_8k1k_con1024_ctx1_dep4_gen1_dep32_eplb0_mtp3_ccb-UCX.yaml} (97%) rename tests/scripts/perf-sanity/disaggregated/{gb200_kimi-k25-thinking-fp4_8k1k_con4096_ctx1_dep4_gen1_dep16_eplb384_mtp0_ccb-UCX.yaml => gb200_kimi-k25-thinking-fp4_8k1k_con4096_ctx1_dep4_gen1_dep16_eplb0_mtp0_ccb-UCX.yaml} (97%) diff --git a/jenkins/L0_Test.groovy b/jenkins/L0_Test.groovy index 665fda0f6c6a..221822d09838 100644 --- a/jenkins/L0_Test.groovy +++ b/jenkins/L0_Test.groovy @@ -3376,7 +3376,7 @@ def launchTestJobs(pipeline, testFilter) "GB200-12_GPUs-3_Nodes-PyTorch-Disagg-PerfSanity-CTX1-NODE1-GPU4-GEN1-NODE2-GPU8-Post-Merge", "auto:gb200-flex", "l0_gb200_multi_nodes_perf_sanity_ctx1_node1_gpu4_gen1_node2_gpu8", - 6, + 7, 12, 3 ) @@ -3394,7 +3394,7 @@ def launchTestJobs(pipeline, testFilter) "GB200-20_GPUs-5_Nodes-PyTorch-Disagg-PerfSanity-CTX1-NODE1-GPU4-GEN1-NODE4-GPU16-Post-Merge", "auto:gb200-flex", "l0_gb200_multi_nodes_perf_sanity_ctx1_node1_gpu4_gen1_node4_gpu16", - 1, + 2, 20, 5 ) @@ -3420,7 +3420,7 @@ def launchTestJobs(pipeline, testFilter) "GB200-36_GPUs-9_Nodes-PyTorch-Disagg-PerfSanity-CTX1-NODE1-GPU4-GEN1-NODE8-GPU32-Post-Merge", "auto:gb200-flex", "l0_gb200_multi_nodes_perf_sanity_ctx1_node1_gpu4_gen1_node8_gpu32", - 7, + 8, 36, 9 ) diff --git a/tests/integration/test_lists/test-db/l0_gb200_multi_gpus_perf_sanity.yml b/tests/integration/test_lists/test-db/l0_gb200_multi_gpus_perf_sanity.yml index 292d8633f2ab..39db52739d90 100644 --- a/tests/integration/test_lists/test-db/l0_gb200_multi_gpus_perf_sanity.yml +++ b/tests/integration/test_lists/test-db/l0_gb200_multi_gpus_perf_sanity.yml @@ -85,12 +85,12 @@ l0_gb200_multi_gpus_perf_sanity: - perf/test_perf_sanity.py::test_e2e[aggr_upload-ctx_only-gb200_gpt-oss-120b-fp4_8k1k_con4_ctx1_tp1_gen1_tp4_eplb0_mtp0_ccb-UCX] TIMEOUT (120) - perf/test_perf_sanity.py::test_e2e[aggr_upload-ctx_only-gb200_gpt-oss-120b-fp4_8k1k_con512_ctx1_tp1_gen1_dep2_eplb0_mtp0_ccb-UCX] TIMEOUT (120) # kimi-k25-thinking-fp4 - # - perf/test_perf_sanity.py::test_e2e[aggr_upload-ctx_only-gb200_kimi-k25-thinking-fp4_1k1k_con2048_ctx1_dep4_gen1_dep32_eplb384_mtp0_ccb-UCX] TIMEOUT (120) + # - perf/test_perf_sanity.py::test_e2e[aggr_upload-ctx_only-gb200_kimi-k25-thinking-fp4_1k1k_con2048_ctx1_dep4_gen1_dep32_eplb0_mtp0_ccb-UCX] TIMEOUT (120) - perf/test_perf_sanity.py::test_e2e[aggr_upload-ctx_only-gb200_kimi-k25-thinking-fp4_1k1k_con4096_ctx1_dep4_gen1_dep8_eplb0_mtp0_ccb-UCX] TIMEOUT (120) - perf/test_perf_sanity.py::test_e2e[aggr_upload-ctx_only-gb200_kimi-k25-thinking-fp4_1k1k_con4_ctx1_dep4_gen1_tep4_eplb0_mtp0_ccb-UCX] TIMEOUT (120) - # - perf/test_perf_sanity.py::test_e2e[aggr_upload-ctx_only-gb200_kimi-k25-thinking-fp4_8k1k_con1024_ctx1_dep4_gen1_dep32_eplb416_mtp3_ccb-UCX] TIMEOUT (120) - - perf/test_perf_sanity.py::test_e2e[aggr_upload-ctx_only-gb200_kimi-k25-thinking-fp4_8k1k_con4096_ctx1_dep4_gen1_dep16_eplb384_mtp0_ccb-UCX] TIMEOUT (120) - # - perf/test_perf_sanity.py::test_e2e[aggr_upload-ctx_only-gb200_kimi-k25-thinking-fp4_8k1k_con4_ctx1_dep4_gen1_tep8_eplb0_mtp3_ccb-UCX] TIMEOUT (120) + # - perf/test_perf_sanity.py::test_e2e[aggr_upload-ctx_only-gb200_kimi-k25-thinking-fp4_8k1k_con1024_ctx1_dep4_gen1_dep32_eplb0_mtp3_ccb-UCX] TIMEOUT (120) + - perf/test_perf_sanity.py::test_e2e[aggr_upload-ctx_only-gb200_kimi-k25-thinking-fp4_8k1k_con4096_ctx1_dep4_gen1_dep16_eplb0_mtp0_ccb-UCX] TIMEOUT (120) + - perf/test_perf_sanity.py::test_e2e[aggr_upload-ctx_only-gb200_kimi-k25-thinking-fp4_8k1k_con4_ctx1_dep4_gen1_tep8_eplb0_mtp3_ccb-UCX] TIMEOUT (120) # qwen3-235b-fp4 - perf/test_perf_sanity.py::test_e2e[aggr_upload-ctx_only-gb200_qwen3-235b-fp4_8k1k_con1024_ctx1_tp1_gen1_dep8_eplb0_mtp0_ccb-UCX] TIMEOUT (120) - perf/test_perf_sanity.py::test_e2e[aggr_upload-ctx_only-gb200_qwen3-235b-fp4_8k1k_con64_ctx1_tp1_gen1_tep4_eplb0_mtp0_ccb-UCX] TIMEOUT (120) diff --git a/tests/integration/test_lists/test-db/l0_gb200_multi_nodes_perf_sanity_ctx1_node1_gpu4_gen1_node2_gpu8.yml b/tests/integration/test_lists/test-db/l0_gb200_multi_nodes_perf_sanity_ctx1_node1_gpu4_gen1_node2_gpu8.yml index 32fdf1be3c61..30d6500886c3 100644 --- a/tests/integration/test_lists/test-db/l0_gb200_multi_nodes_perf_sanity_ctx1_node1_gpu4_gen1_node2_gpu8.yml +++ b/tests/integration/test_lists/test-db/l0_gb200_multi_nodes_perf_sanity_ctx1_node1_gpu4_gen1_node2_gpu8.yml @@ -20,7 +20,7 @@ l0_gb200_multi_nodes_perf_sanity_ctx1_node1_gpu4_gen1_node2_gpu8: - perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-v32-fp4_1k1k_con1_ctx1_dep4_gen1_tep8_eplb0_mtp3_ccb-UCX] TIMEOUT (120) - perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-v32-fp4_32k4k_con1_ctx1_dep4_gen1_tep8_eplb0_mtp3_ccb-UCX] TIMEOUT (120) - perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-v32-fp4_8k1k_con1_ctx1_dep4_gen1_tep8_eplb0_mtp3_ccb-UCX] TIMEOUT (120) - # - perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_kimi-k25-thinking-fp4_1k1k_con4096_ctx1_dep4_gen1_dep8_eplb0_mtp0_ccb-UCX] TIMEOUT (120) # Failed requests + - perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_kimi-k25-thinking-fp4_1k1k_con4096_ctx1_dep4_gen1_dep8_eplb0_mtp0_ccb-UCX] TIMEOUT (120) # Failed requests # - perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_kimi-k25-thinking-fp4_8k1k_con4_ctx1_dep4_gen1_tep8_eplb0_mtp3_ccb-UCX] TIMEOUT (120) # - perf/test_perf_sanity.py::test_e2e[disagg_upload-e2e-gb200_deepseek-r1-fp4_1k1k_con1024_ctx1_dep4_gen1_dep8_eplb0_mtp0_ccb-UCX] TIMEOUT (120) # - perf/test_perf_sanity.py::test_e2e[disagg_upload-e2e-gb200_deepseek-r1-fp4_1k1k_con1_ctx1_dep4_gen1_tep8_eplb0_mtp3_ccb-UCX] TIMEOUT (120) diff --git a/tests/integration/test_lists/test-db/l0_gb200_multi_nodes_perf_sanity_ctx1_node1_gpu4_gen1_node4_gpu16.yml b/tests/integration/test_lists/test-db/l0_gb200_multi_nodes_perf_sanity_ctx1_node1_gpu4_gen1_node4_gpu16.yml index ac69996882b3..0e4e68522b1e 100644 --- a/tests/integration/test_lists/test-db/l0_gb200_multi_nodes_perf_sanity_ctx1_node1_gpu4_gen1_node4_gpu16.yml +++ b/tests/integration/test_lists/test-db/l0_gb200_multi_nodes_perf_sanity_ctx1_node1_gpu4_gen1_node4_gpu16.yml @@ -15,6 +15,6 @@ l0_gb200_multi_nodes_perf_sanity_ctx1_node1_gpu4_gen1_node4_gpu16: backend: pytorch tests: - perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-r1-fp4_8k1k_con4096_ctx1_dep4_gen1_dep16_eplb0_mtp1_ccb-UCX] TIMEOUT (120) - # - perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_kimi-k25-thinking-fp4_8k1k_con4096_ctx1_dep4_gen1_dep16_eplb384_mtp0_ccb-UCX] TIMEOUT (120) + - perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_kimi-k25-thinking-fp4_8k1k_con4096_ctx1_dep4_gen1_dep16_eplb0_mtp0_ccb-UCX] TIMEOUT (120) # - perf/test_perf_sanity.py::test_e2e[disagg_upload-e2e-gb200_deepseek-r1-fp4_8k1k_con4096_ctx1_dep4_gen1_dep16_eplb0_mtp1_ccb-UCX] TIMEOUT (120) - # - perf/test_perf_sanity.py::test_e2e[disagg_upload-e2e-gb200_kimi-k25-thinking-fp4_8k1k_con4096_ctx1_dep4_gen1_dep16_eplb384_mtp0_ccb-UCX] TIMEOUT (120) + # - perf/test_perf_sanity.py::test_e2e[disagg_upload-e2e-gb200_kimi-k25-thinking-fp4_8k1k_con4096_ctx1_dep4_gen1_dep16_eplb0_mtp0_ccb-UCX] TIMEOUT (120) diff --git a/tests/integration/test_lists/test-db/l0_gb200_multi_nodes_perf_sanity_ctx1_node1_gpu4_gen1_node8_gpu32.yml b/tests/integration/test_lists/test-db/l0_gb200_multi_nodes_perf_sanity_ctx1_node1_gpu4_gen1_node8_gpu32.yml index 1e95f41898eb..9260d25ca87f 100644 --- a/tests/integration/test_lists/test-db/l0_gb200_multi_nodes_perf_sanity_ctx1_node1_gpu4_gen1_node8_gpu32.yml +++ b/tests/integration/test_lists/test-db/l0_gb200_multi_nodes_perf_sanity_ctx1_node1_gpu4_gen1_node8_gpu32.yml @@ -22,8 +22,8 @@ l0_gb200_multi_nodes_perf_sanity_ctx1_node1_gpu4_gen1_node8_gpu32: - perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-v32-fp4_32k4k_con256_ctx1_dep4_gen1_dep32_eplb0_mtp3_ccb-UCX] TIMEOUT (120) - perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-v32-fp4_8k1k_con1024_ctx1_dep4_gen1_dep32_eplb256_mtp3_ccb-UCX] TIMEOUT (120) - perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-v32-fp4_8k1k_con4096_ctx1_dep4_gen1_dep32_eplb256_mtp0_ccb-UCX] TIMEOUT (120) - # - perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_kimi-k25-thinking-fp4_1k1k_con2048_ctx1_dep4_gen1_dep32_eplb384_mtp0_ccb-UCX] TIMEOUT (120) - # - perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_kimi-k25-thinking-fp4_8k1k_con1024_ctx1_dep4_gen1_dep32_eplb416_mtp3_ccb-UCX] TIMEOUT (120) + - perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_kimi-k25-thinking-fp4_1k1k_con2048_ctx1_dep4_gen1_dep32_eplb0_mtp0_ccb-UCX] TIMEOUT (120) + # - perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_kimi-k25-thinking-fp4_8k1k_con1024_ctx1_dep4_gen1_dep32_eplb0_mtp3_ccb-UCX] TIMEOUT (120) # - perf/test_perf_sanity.py::test_e2e[disagg_upload-e2e-gb200_deepseek-r1-fp4_1k1k_con1024_ctx1_dep4_gen1_dep32_eplb0_mtp3_ccb-UCX] TIMEOUT (120) # - perf/test_perf_sanity.py::test_e2e[disagg_upload-e2e-gb200_deepseek-r1-fp4_8k1k_con1024_ctx1_dep4_gen1_dep32_eplb0_mtp3_ccb-UCX] TIMEOUT (120) # - perf/test_perf_sanity.py::test_e2e[disagg_upload-e2e-gb200_deepseek-r1-fp4_1k1k_con1024_ctx1_dep4_gen1_dep32_eplb0_mtp0_ccb-DEFAULT] TIMEOUT (120) @@ -32,5 +32,5 @@ l0_gb200_multi_nodes_perf_sanity_ctx1_node1_gpu4_gen1_node8_gpu32: # - perf/test_perf_sanity.py::test_e2e[disagg_upload-e2e-gb200_deepseek-v32-fp4_32k4k_con256_ctx1_dep4_gen1_dep32_eplb0_mtp3_ccb-UCX] TIMEOUT (120) # - perf/test_perf_sanity.py::test_e2e[disagg_upload-e2e-gb200_deepseek-v32-fp4_8k1k_con1024_ctx1_dep4_gen1_dep32_eplb256_mtp3_ccb-UCX] TIMEOUT (120) # - perf/test_perf_sanity.py::test_e2e[disagg_upload-e2e-gb200_deepseek-v32-fp4_8k1k_con4096_ctx1_dep4_gen1_dep32_eplb256_mtp0_ccb-UCX] TIMEOUT (120) - # - perf/test_perf_sanity.py::test_e2e[disagg_upload-e2e-gb200_kimi-k25-thinking-fp4_1k1k_con2048_ctx1_dep4_gen1_dep32_eplb384_mtp0_ccb-UCX] TIMEOUT (120) - # - perf/test_perf_sanity.py::test_e2e[disagg_upload-e2e-gb200_kimi-k25-thinking-fp4_8k1k_con1024_ctx1_dep4_gen1_dep32_eplb416_mtp3_ccb-UCX] TIMEOUT (120) + # - perf/test_perf_sanity.py::test_e2e[disagg_upload-e2e-gb200_kimi-k25-thinking-fp4_1k1k_con2048_ctx1_dep4_gen1_dep32_eplb0_mtp0_ccb-UCX] TIMEOUT (120) + # - perf/test_perf_sanity.py::test_e2e[disagg_upload-e2e-gb200_kimi-k25-thinking-fp4_8k1k_con1024_ctx1_dep4_gen1_dep32_eplb0_mtp3_ccb-UCX] TIMEOUT (120) diff --git a/tests/scripts/perf-sanity/disaggregated/gb200_kimi-k2-thinking-fp4_1k1k_con2048_ctx1_dep4_gen1_dep32_eplb384_mtp0_ccb-UCX.yaml b/tests/scripts/perf-sanity/disaggregated/gb200_kimi-k2-thinking-fp4_1k1k_con2048_ctx1_dep4_gen1_dep32_eplb0_mtp0_ccb-UCX.yaml similarity index 97% rename from tests/scripts/perf-sanity/disaggregated/gb200_kimi-k2-thinking-fp4_1k1k_con2048_ctx1_dep4_gen1_dep32_eplb384_mtp0_ccb-UCX.yaml rename to tests/scripts/perf-sanity/disaggregated/gb200_kimi-k2-thinking-fp4_1k1k_con2048_ctx1_dep4_gen1_dep32_eplb0_mtp0_ccb-UCX.yaml index 03983b003bdb..6d65a89bf440 100644 --- a/tests/scripts/perf-sanity/disaggregated/gb200_kimi-k2-thinking-fp4_1k1k_con2048_ctx1_dep4_gen1_dep32_eplb384_mtp0_ccb-UCX.yaml +++ b/tests/scripts/perf-sanity/disaggregated/gb200_kimi-k2-thinking-fp4_1k1k_con2048_ctx1_dep4_gen1_dep32_eplb0_mtp0_ccb-UCX.yaml @@ -65,9 +65,6 @@ worker_config: moe_config: backend: CUTEDSL use_low_precision_moe_combine: true - load_balancer: - num_slots: 384 - layer_updates_per_iter: 1 cache_transceiver_config: max_tokens_in_buffer: 16384 backend: UCX diff --git a/tests/scripts/perf-sanity/disaggregated/gb200_kimi-k2-thinking-fp4_8k1k_con1024_ctx1_dep4_gen1_dep32_eplb416_mtp3_ccb-UCX.yaml b/tests/scripts/perf-sanity/disaggregated/gb200_kimi-k2-thinking-fp4_8k1k_con1024_ctx1_dep4_gen1_dep32_eplb0_mtp3_ccb-UCX.yaml similarity index 97% rename from tests/scripts/perf-sanity/disaggregated/gb200_kimi-k2-thinking-fp4_8k1k_con1024_ctx1_dep4_gen1_dep32_eplb416_mtp3_ccb-UCX.yaml rename to tests/scripts/perf-sanity/disaggregated/gb200_kimi-k2-thinking-fp4_8k1k_con1024_ctx1_dep4_gen1_dep32_eplb0_mtp3_ccb-UCX.yaml index 7d5a3e3ae85b..a9292d239bf5 100644 --- a/tests/scripts/perf-sanity/disaggregated/gb200_kimi-k2-thinking-fp4_8k1k_con1024_ctx1_dep4_gen1_dep32_eplb416_mtp3_ccb-UCX.yaml +++ b/tests/scripts/perf-sanity/disaggregated/gb200_kimi-k2-thinking-fp4_8k1k_con1024_ctx1_dep4_gen1_dep32_eplb0_mtp3_ccb-UCX.yaml @@ -65,9 +65,6 @@ worker_config: moe_config: backend: CUTEDSL use_low_precision_moe_combine: true - load_balancer: - num_slots: 416 - layer_updates_per_iter: 1 cache_transceiver_config: max_tokens_in_buffer: 16384 backend: UCX diff --git a/tests/scripts/perf-sanity/disaggregated/gb200_kimi-k2-thinking-fp4_8k1k_con4096_ctx1_dep4_gen1_dep16_eplb384_mtp0_ccb-UCX.yaml b/tests/scripts/perf-sanity/disaggregated/gb200_kimi-k2-thinking-fp4_8k1k_con4096_ctx1_dep4_gen1_dep16_eplb0_mtp0_ccb-UCX.yaml similarity index 97% rename from tests/scripts/perf-sanity/disaggregated/gb200_kimi-k2-thinking-fp4_8k1k_con4096_ctx1_dep4_gen1_dep16_eplb384_mtp0_ccb-UCX.yaml rename to tests/scripts/perf-sanity/disaggregated/gb200_kimi-k2-thinking-fp4_8k1k_con4096_ctx1_dep4_gen1_dep16_eplb0_mtp0_ccb-UCX.yaml index ea656cde3dcb..d0d612be80ab 100644 --- a/tests/scripts/perf-sanity/disaggregated/gb200_kimi-k2-thinking-fp4_8k1k_con4096_ctx1_dep4_gen1_dep16_eplb384_mtp0_ccb-UCX.yaml +++ b/tests/scripts/perf-sanity/disaggregated/gb200_kimi-k2-thinking-fp4_8k1k_con4096_ctx1_dep4_gen1_dep16_eplb0_mtp0_ccb-UCX.yaml @@ -65,9 +65,6 @@ worker_config: moe_config: backend: CUTEDSL use_low_precision_moe_combine: true - load_balancer: - num_slots: 384 - layer_updates_per_iter: 1 cache_transceiver_config: max_tokens_in_buffer: 16384 backend: UCX diff --git a/tests/scripts/perf-sanity/disaggregated/gb200_kimi-k25-thinking-fp4_1k1k_con2048_ctx1_dep4_gen1_dep32_eplb384_mtp0_ccb-UCX.yaml b/tests/scripts/perf-sanity/disaggregated/gb200_kimi-k25-thinking-fp4_1k1k_con2048_ctx1_dep4_gen1_dep32_eplb0_mtp0_ccb-UCX.yaml similarity index 97% rename from tests/scripts/perf-sanity/disaggregated/gb200_kimi-k25-thinking-fp4_1k1k_con2048_ctx1_dep4_gen1_dep32_eplb384_mtp0_ccb-UCX.yaml rename to tests/scripts/perf-sanity/disaggregated/gb200_kimi-k25-thinking-fp4_1k1k_con2048_ctx1_dep4_gen1_dep32_eplb0_mtp0_ccb-UCX.yaml index 93e7ba0535ce..4e3eb972d195 100644 --- a/tests/scripts/perf-sanity/disaggregated/gb200_kimi-k25-thinking-fp4_1k1k_con2048_ctx1_dep4_gen1_dep32_eplb384_mtp0_ccb-UCX.yaml +++ b/tests/scripts/perf-sanity/disaggregated/gb200_kimi-k25-thinking-fp4_1k1k_con2048_ctx1_dep4_gen1_dep32_eplb0_mtp0_ccb-UCX.yaml @@ -65,9 +65,6 @@ worker_config: moe_config: backend: CUTEDSL use_low_precision_moe_combine: true - load_balancer: - num_slots: 384 - layer_updates_per_iter: 1 cache_transceiver_config: max_tokens_in_buffer: 16384 backend: UCX diff --git a/tests/scripts/perf-sanity/disaggregated/gb200_kimi-k25-thinking-fp4_8k1k_con1024_ctx1_dep4_gen1_dep32_eplb416_mtp3_ccb-UCX.yaml b/tests/scripts/perf-sanity/disaggregated/gb200_kimi-k25-thinking-fp4_8k1k_con1024_ctx1_dep4_gen1_dep32_eplb0_mtp3_ccb-UCX.yaml similarity index 97% rename from tests/scripts/perf-sanity/disaggregated/gb200_kimi-k25-thinking-fp4_8k1k_con1024_ctx1_dep4_gen1_dep32_eplb416_mtp3_ccb-UCX.yaml rename to tests/scripts/perf-sanity/disaggregated/gb200_kimi-k25-thinking-fp4_8k1k_con1024_ctx1_dep4_gen1_dep32_eplb0_mtp3_ccb-UCX.yaml index 5394bbd77324..d4ab81988a60 100644 --- a/tests/scripts/perf-sanity/disaggregated/gb200_kimi-k25-thinking-fp4_8k1k_con1024_ctx1_dep4_gen1_dep32_eplb416_mtp3_ccb-UCX.yaml +++ b/tests/scripts/perf-sanity/disaggregated/gb200_kimi-k25-thinking-fp4_8k1k_con1024_ctx1_dep4_gen1_dep32_eplb0_mtp3_ccb-UCX.yaml @@ -65,9 +65,6 @@ worker_config: moe_config: backend: CUTEDSL use_low_precision_moe_combine: true - load_balancer: - num_slots: 416 - layer_updates_per_iter: 1 cache_transceiver_config: max_tokens_in_buffer: 16384 backend: UCX diff --git a/tests/scripts/perf-sanity/disaggregated/gb200_kimi-k25-thinking-fp4_8k1k_con4096_ctx1_dep4_gen1_dep16_eplb384_mtp0_ccb-UCX.yaml b/tests/scripts/perf-sanity/disaggregated/gb200_kimi-k25-thinking-fp4_8k1k_con4096_ctx1_dep4_gen1_dep16_eplb0_mtp0_ccb-UCX.yaml similarity index 97% rename from tests/scripts/perf-sanity/disaggregated/gb200_kimi-k25-thinking-fp4_8k1k_con4096_ctx1_dep4_gen1_dep16_eplb384_mtp0_ccb-UCX.yaml rename to tests/scripts/perf-sanity/disaggregated/gb200_kimi-k25-thinking-fp4_8k1k_con4096_ctx1_dep4_gen1_dep16_eplb0_mtp0_ccb-UCX.yaml index dc849424ba01..d84d38d742b7 100644 --- a/tests/scripts/perf-sanity/disaggregated/gb200_kimi-k25-thinking-fp4_8k1k_con4096_ctx1_dep4_gen1_dep16_eplb384_mtp0_ccb-UCX.yaml +++ b/tests/scripts/perf-sanity/disaggregated/gb200_kimi-k25-thinking-fp4_8k1k_con4096_ctx1_dep4_gen1_dep16_eplb0_mtp0_ccb-UCX.yaml @@ -65,9 +65,6 @@ worker_config: moe_config: backend: CUTEDSL use_low_precision_moe_combine: true - load_balancer: - num_slots: 384 - layer_updates_per_iter: 1 cache_transceiver_config: max_tokens_in_buffer: 16384 backend: UCX From bf769afe9b82808c8111ee1a8add9c9faa308f3c Mon Sep 17 00:00:00 2001 From: Zhanrui Sun <184402041+ZhanruiSunCh@users.noreply.github.com> Date: Mon, 20 Apr 2026 19:04:40 +0800 Subject: [PATCH 5/8] [None][infra] Waive 2 failed cases for main in post-merge 2663 (#13216) Signed-off-by: ZhanruiSunCh <184402041+ZhanruiSunCh@users.noreply.github.com> --- tests/integration/test_lists/waives.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index 2de8dfab584f..db92eadd862c 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -405,3 +405,5 @@ disaggregated/test_disaggregated.py::test_disaggregated_perf_metrics[TinyLlama-1 disaggregated/test_auto_scaling.py::test_worker_restart[http-load_balancing] SKIP (https://nvbugs/6094100) disaggregated/test_auto_scaling.py::test_worker_restart[http-kv_cache_aware] SKIP (https://nvbugs/6094100) disaggregated/test_auto_scaling.py::test_service_discovery[http-round_robin] SKIP (https://nvbugs/6094100) +perf/test_perf_sanity.py::test_e2e[aggr_upload-deepseek_r1_fp4_v2_2_nodes_grace_blackwell-r1_fp4_v2_tep8_mtp3] SKIP (https://nvbugs/6095700) +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False-enable_chunked_prefill=False-v2_kv_cache=True] SKIP (https://nvbugs/6095851) From a8bd7b36a0574101d548193b17e2b78a15cd8c7e Mon Sep 17 00:00:00 2001 From: Grzegorz Kwasniewski <213329731+greg-kwasniewski1@users.noreply.github.com> Date: Mon, 20 Apr 2026 13:17:53 +0200 Subject: [PATCH 6/8] [TRTLLM-12291][feat] New sharding infrastructure (#12419) Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> --- .claude/skills/ad-model-onboard/SKILL.md | 137 + .claude/skills/ci-failure-retrieval/SKILL.md | 139 +- examples/auto_deploy/.gitignore | 1 + .../model_registry/enable_sharder_ir.yaml | 9 + .../auto_deploy/model_registry/models.yaml | 8 + .../nemotron/nemotron_fp8_ir_test.yaml | 34 + .../_torch/auto_deploy/config/default.yaml | 7 + .../config/export_edgellm_onnx.yaml | 2 + .../custom_ops/attention/torch_attention.py | 5 + .../custom_ops/fla/fla_gated_delta.py | 47 +- .../custom_ops/fused_moe/mxfp4_moe.py | 4 + .../custom_ops/fused_moe/torch_moe.py | 18 +- .../custom_ops/fused_moe/trtllm_moe.py | 11 +- .../auto_deploy/custom_ops/linear/linear.py | 57 +- .../custom_ops/mamba/torch_causal_conv.py | 46 +- .../custom_ops/mamba/torch_mamba.py | 44 + .../auto_deploy/custom_ops/mla/torch_mla.py | 52 +- .../custom_ops/normalization/rms_norm.py | 10 + .../custom_ops/quantization/quant.py | 26 +- .../custom_ops/quantization/torch_quant.py | 70 +- .../auto_deploy/custom_ops/sharding_ops.py | 152 + tensorrt_llm/_torch/auto_deploy/llm_args.py | 33 +- .../auto_deploy/models/custom/__init__.py | 10 + .../models/custom/mla_rope_utils.py | 12 +- .../models/custom/modeling_deepseek_ir.py | 768 +++++ .../models/custom/modeling_llama3_ir.py | 446 +++ .../models/custom/modeling_nemotron_h_ir.py | 822 +++++ .../models/custom/modeling_qwen3_5_moe_ir.py | 3050 +++++++++++++++++ .../models/custom/modeling_qwen3_ir.py | 456 +++ .../_torch/auto_deploy/shim/ad_executor.py | 36 +- .../_torch/auto_deploy/transform/interface.py | 3 +- .../transform/library/adapt_to_edgellm.py | 4 +- .../transform/library/attention.py | 1 + .../transform/library/collectives.py | 13 +- .../library/fuse_rmsnorm_quant_fp8.py | 3 +- .../transform/library/fused_moe.py | 14 +- .../auto_deploy/transform/library/fusion.py | 9 + .../transform/library/mxfp4_moe.py | 1 + .../transform/library/quantization.py | 15 +- .../auto_deploy/transform/library/rms_norm.py | 1 + .../auto_deploy/transform/library/rope.py | 9 +- .../auto_deploy/transform/library/sharding.py | 171 +- .../transform/library/sharding_ir.py | 1123 ++++++ .../library/short_reshape_attention_output.py | 7 +- .../_torch/auto_deploy/transform/optimizer.py | 13 +- .../_torch/auto_deploy/utils/dist_config.py | 152 + .../_torch/auto_deploy/utils/mapping_utils.py | 34 - .../_torch/auto_deploy/utils/node_utils.py | 433 +-- .../defs/accuracy/test_llm_api_autodeploy.py | 129 + .../library/test_apply_sharding_hints.py | 143 + .../library/test_ep_sharding.py | 8 +- .../library/test_tp_sharding.py | 2 +- .../singlegpu/custom_ops/test_sharding_ops.py | 95 + .../singlegpu/utils/test_dist_config.py | 115 + .../utils/test_node_utils_sharding.py | 168 + 55 files changed, 8733 insertions(+), 445 deletions(-) create mode 100644 examples/auto_deploy/model_registry/enable_sharder_ir.yaml create mode 100644 examples/auto_deploy/nemotron/nemotron_fp8_ir_test.yaml create mode 100644 tensorrt_llm/_torch/auto_deploy/custom_ops/sharding_ops.py create mode 100644 tensorrt_llm/_torch/auto_deploy/models/custom/modeling_deepseek_ir.py create mode 100644 tensorrt_llm/_torch/auto_deploy/models/custom/modeling_llama3_ir.py create mode 100644 tensorrt_llm/_torch/auto_deploy/models/custom/modeling_nemotron_h_ir.py create mode 100644 tensorrt_llm/_torch/auto_deploy/models/custom/modeling_qwen3_5_moe_ir.py create mode 100644 tensorrt_llm/_torch/auto_deploy/models/custom/modeling_qwen3_ir.py create mode 100644 tensorrt_llm/_torch/auto_deploy/transform/library/sharding_ir.py create mode 100644 tensorrt_llm/_torch/auto_deploy/utils/dist_config.py delete mode 100644 tensorrt_llm/_torch/auto_deploy/utils/mapping_utils.py create mode 100644 tests/unittest/auto_deploy/multigpu/transformations/library/test_apply_sharding_hints.py create mode 100644 tests/unittest/auto_deploy/singlegpu/custom_ops/test_sharding_ops.py create mode 100644 tests/unittest/auto_deploy/singlegpu/utils/test_dist_config.py create mode 100644 tests/unittest/auto_deploy/singlegpu/utils/test_node_utils_sharding.py diff --git a/.claude/skills/ad-model-onboard/SKILL.md b/.claude/skills/ad-model-onboard/SKILL.md index c37fd3ed07e9..ab6e61c2ec3a 100644 --- a/.claude/skills/ad-model-onboard/SKILL.md +++ b/.claude/skills/ad-model-onboard/SKILL.md @@ -303,6 +303,143 @@ GH_CONFIG_DIR= gh pr view --json reviews,state **Do NOT stop polling prematurely.** The loop must continue until the PR is approved or a clear termination signal is received. If polling has been running for an extended period (e.g., >2 hours) with no new activity, inform the user that you are still monitoring and ask if they want you to continue or stop. +## Sharding-aware IR model porting (`modeling_*_ir.py`) + +Use this when porting an existing AutoDeploy custom model (`tensorrt_llm/_torch/auto_deploy/models/custom/modeling_*.py`) to explicit sharding hint ops in `modeling_*_ir.py` **in the same directory** (no separate `new_sharding/` tree). The exported FX graph must fully specify how the model should be sharded: the `apply_sharding_hints` transform combines hints with a runtime `DistConfig` for deterministic, node-local sharding. + +**Argument reference:** Do not duplicate operator tables here. Refer to the custom op docstrings in `tensorrt_llm/_torch/auto_deploy/custom_ops/` for the complete argument reference (including sharding hints, `tp_mode`, `layer_type`, and which ops accept hints). + +### Reference examples (study before porting) + +| Original | IR / sharding-aware | Layer types | +|----------|---------------------|-------------| +| `modeling_nemotron_h.py` | `modeling_nemotron_h_ir.py` | Mamba SSM, MHA, SwiGLU MLP, MoE | +| `modeling_qwen3_5_moe.py` | `modeling_qwen3_5_moe_ir.py` | GatedDeltaNet, Gated MHA, SwiGLU MLP, MoE | +| `modeling_mistral.py` | `modeling_mistral_ir.py` | MHA, SwiGLU MLP (simplest) | +| `modeling_deepseek_v2.py` | `modeling_deepseek_v2_ir.py` | MLA, SwiGLU MLP, MoE | + +### Step-by-step porting procedure + +#### Step 1: Copy the source file + +```bash +cp tensorrt_llm/_torch/auto_deploy/models/custom/modeling_foo.py \ + tensorrt_llm/_torch/auto_deploy/models/custom/modeling_foo_ir.py +``` + +#### Step 2: Update the module docstring and add imports + +At the top of the IR file: + +```python +import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401 -- register all ops +``` + +Do **not** add global `SHARD_*` flags. Layer-level control uses the `layer_type` hint on each op and `shard_layers` in YAML. + +#### Step 3: Replace linear projections + +For every `self.proj(x)` or `nn.Linear` call, use `torch.ops.auto_deploy.torch_linear_simple` with explicit `tp_mode` and `layer_type`. Always set `tp_mode` unconditionally (no `if _s else "none"`). **Rules:** opening projections (Q/K/V/gate/up/in_proj) → `"colwise"`; closing (O/down/out_proj) → `"rowwise"`; tiny outputs (e.g. `shared_expert_gate` dim 1) → `"none"`; MLA latent projections (q_a, kv_a) → `"none"`. For fused weights split later, pass `output_sizes=[...]`. For GQA, use `tp_min_local_shape=self.head_dim` on K/V colwise lines. + +#### Step 4: Replace split / chunk after fused colwise projections + +Use `torch.ops.auto_deploy.split_with_sizes` with `shardable` / `layer_type` where sizes scale with TP. + +#### Step 5: Replace view / reshape with concrete head counts + +During `torch.export`, `-1` becomes concrete; after TP, wrong values break. Any reshape whose dimension is a head count that scales with TP must use `torch.ops.auto_deploy.view` with `tp_scaled_dim` set appropriately. Safe cases: flat-to-2D, or `[B,S,-1]` when the input is already correctly sharded. + +#### Step 6: Insert `all_reduce` + +After every rowwise projection, add `torch.ops.auto_deploy.all_reduce(..., layer_type=...)`. **Parallel branch rule:** when branches merge by addition, use a **single** `all_reduce` after the sum (e.g. MoE routed + shared expert; parallel attention + MLP residual branches). + +#### Step 7: Special ops (Conv1d, SSM, GatedDeltaNet, gated RMSNorm) + +Add sharding hints on `torch_causal_conv1d`, `torch_ssm`, `torch_gated_delta_rule`, `torch_rmsnorm_gated` per docstrings—typically `shardable` / `output_sizes` / `tp_mode` as required. + +#### Step 8: MoE + +Pass `layer_type="moe"` into `torch_moe`; `apply_sharding_hints` handles EP/TP. + +#### Step 9: Register the IR model + +1. Bottom of the IR file: `AutoModelForCausalLMFactory.register_custom_model_cls("ConfigClassName", ForCausalLM)` (same pattern as Phase 4). +2. Add a **side-effect import** in `tensorrt_llm/_torch/auto_deploy/models/custom/__init__.py` (e.g. `from . import modeling_foo_ir # noqa: F401`) and extend `__all__` if you export symbols. Without this import, worker processes may not load your class and `apply_sharding_hints` can report **0 nodes processed**. Do **not** use a separate `register_sharded_models.py` indirection. + +#### Step 10: YAML — composable registry pattern + +Prefer the model registry (`examples/auto_deploy/model_registry/models.yaml`) and **compose** shared fragments under `examples/auto_deploy/model_registry/configs/`, same as other models: list `dashboard_default.yaml`, the right `world_size_N.yaml`, then a dedicated fragment (e.g. `enable_sharder_ir.yaml`) that holds IR sharding transforms. That fragment should disable legacy sharding passes and enable hint-driven sharding. Registry fragments are deep-merged in `yaml_extra` order (see `DynamicYamlMixInForSettings` in `tensorrt_llm/_torch/auto_deploy/utils/_config.py`); place transform keys under `transforms:` so they merge with `dashboard_default.yaml`. Standalone experiment YAMLs for `build_and_run_ad` may wrap the same fields under a top-level `args:` block matching `LlmArgs`. + +Example transform block: + +```yaml +# Typical contents for enable_sharder_ir.yaml (registry composable fragment) +transforms: + export_to_gm: + num_moe_experts_for_export: 2 # often required when expert count is large (>64) + detect_sharding: + stage: sharding + enabled: false + sharding_transform_executor: + stage: sharding + enabled: false + apply_sharding_hints: + stage: sharding + enabled: true + run_shape_prop: true + allreduce_strategy: NCCL + # shard_layers: ['mha', 'mlp'] # optional selective sharding + gather_logits_before_lm_head: + enabled: true +``` + +Use `world_size: 8` when validating TP head-divisibility. Optional `shard_layers` limits which `layer_type` hints are processed; unset means shard all shardable nodes. + +#### Step 11: Validate + +Do not report success until a run completes successfully. + +1. Prefer `python examples/auto_deploy/build_and_run_ad.py --model --use-registry` after adding/updating the registry entry and composable YAMLs (Phase 8–9 style). +2. `apply_sharding_hints` logs should show **`N nodes processed` with N > 0**. +3. If validation fails with infrastructure limits (e.g. head count not divisible by `world_size`), document the assert and compatible sizes; do not “fix” core `sharding.py` / custom op schemas without owner review. +4. If blocked by missing infrastructure support, rename artifacts to `broken_modeling_*_ir.py` / broken YAML and file a short error report for humans (do not silently patch core transforms). + +**Layer type strings** (for `layer_type` / `shard_layers`): use `"mha"`, `"mla"`, `"mlp"`, `"moe"`, `"ssm"`, `"delta"`, or `"unknown"` (default; skipped when `shard_layers` is set). Match the conventions used in `apply_sharding_hints` and project enums. + +### Layer-specific sharding patterns + +**MHA (standard or gated):** `layer_type="mha"`: q/k/v colwise (GQA: `tp_min_local_shape`), `view` with `tp_scaled_dim` for head dim, o rowwise + `all_reduce`. Fused Q+gate interleaved per head: colwise without `output_sizes`; contiguous Q|K|V fused blocks need `output_sizes`. + +**SwiGLU MLP:** `layer_type="mlp"`: gate/up colwise, down rowwise + `all_reduce`. + +**Mamba / SSM:** `layer_type="ssm"`: in_proj colwise + `output_sizes`, splits shardable, conv1d shardable + `output_sizes`, views, `torch_ssm` shardable, norm gated colwise if weight scales, out rowwise + `all_reduce`. + +**GatedDeltaNet:** `layer_type="delta"`: in_proj_qkv with `output_sizes`, other in_projs colwise, conv1d/splits/views as above, `torch_gated_delta_rule` shardable, out rowwise + `all_reduce`. + +**MoE + shared expert:** `layer_type="moe"`: router replicated; one `all_reduce` after `routed + shared`, not two. + +**MLA (DeepSeek):** `layer_type="mla"`: keep `torch_mla` intact with `shardable=True`—do **not** decompose into separate linears + `torch_attention` (introduces bad `expand`/`view` with concrete head counts). q_a/kv_a latent: `tp_mode="none"`; q_b colwise; `o_proj` rowwise + `all_reduce`. + +### Common pitfalls (sharding IR) + +1. **Missing `auto_deploy::view` for head reshapes** — concrete shapes from export break after sharding. +2. **Sharding tiny projections** — dim-1 gates: `tp_mode="none"`. +3. **Double `all_reduce` in MoE** — one merge-point reduction for routed + shared. +4. **Cross-layer parameter contamination** — in `_apply_hint_*` handlers using `get_source_nodes()`, restrict with `allowed_ops` so residual links do not pull weights from other layers. +5. **Missing `num_moe_experts_for_export`** for very large expert counts — export can hang. +6. **Decomposing ops that absorb weights** (e.g. `torch_mla`) — use `shardable` + handler instead of splitting into plain linears. +7. **Interleaved vs contiguous fused weights** — interleaved per-head groups: colwise only; contiguous Q|K|V blocks: require `output_sizes`. +8. **Omitting `layer_type` when using `shard_layers`** — `"unknown"` nodes are skipped; set hints explicitly on sharding-aware ops. +9. **`layer_type` on non-hint ops** — do **not** pass `layer_type` to ops that are not designed for sharding hints (e.g. `torch_attention`, `torch_l2norm`, `torch_rope_*`); extra positional args break calls. Confirm in `custom_ops/` docstrings which ops accept hints. +10. **Conditional hint values** — no `if _s else "none"`; use unconditional hints and rely on `shard_layers` / transform config. + +### Sharding IR validation checklist (human review) + +- `world_size=1`: unsharded path; hints should not break correctness. +- `world_size=2` and `8`: shape checks and coherent output. +- `apply_sharding_hints` node count vs expectation. +- Optional: `shard_layers: ['moe']` to verify selective sharding. + ## Key Gotchas - **Canonical ops first:** Always use `torch.ops.auto_deploy.torch_*` canonical ops whenever one exists for the operation. This is how AD knows what to optimize. Writing manual attention, MoE, RoPE, or normalization in plain PyTorch instead of using the canonical op will prevent AD transforms from working. - **No `repeat_interleave`:** AD attention ops handle GQA natively. Never repeat K/V heads manually. diff --git a/.claude/skills/ci-failure-retrieval/SKILL.md b/.claude/skills/ci-failure-retrieval/SKILL.md index 9ac0a58cf321..6ce60ffa61f4 100644 --- a/.claude/skills/ci-failure-retrieval/SKILL.md +++ b/.claude/skills/ci-failure-retrieval/SKILL.md @@ -10,56 +10,136 @@ metadata: **Input:** a PR number or a request to check CI failures. **Auth requirement:** requires corporate network access to resolve the Jenkins base URL. **Output:** a summary of failed tests with error details, and optionally full stdout/stderr for specific failures. -## Phase 1 — Get the Jenkins Build Number +## Important: SSL and Authentication + +Jenkins requires SSL with certificate verification disabled. Always use `ssl` context bypass in Python or `-sk` flags in curl: +```python +import ssl +ctx = ssl.create_default_context() +ctx.check_hostname = False +ctx.verify_mode = ssl.CERT_NONE +``` +The `curl -s` approach often returns HTML login pages; prefer the Python `urllib` approach with SSL bypass. + +## Phase 0 — Get the Latest CI Run Info + +First, determine the latest CI run commit, build number, and high-level pass/fail counts: -The CI bot (`tensorrt-cicd`) posts comments with links to the Jenkins build. Extract the `L0_MergeRequest_PR` build number: ```bash +source ~/utils/github/set_github_token.sh + PR_NUM= -BUILD_NUM=$(gh api "repos/NVIDIA/TensorRT-LLM/issues/${PR_NUM}/comments" --jq \ + +# Get the latest CI bot comment (contains build number and commit) +gh api "repos/NVIDIA/TensorRT-LLM/issues/${PR_NUM}/comments" --paginate --jq \ + '[.[] | select(.user.login == "tensorrt-cicd") | select(.body | test("L0_MergeRequest_PR"))] | last | .body' + +# Get the PR HEAD commit and its blossom-ci status (high-level pass/fail counts) +HEAD_SHA=$(gh api "repos/NVIDIA/TensorRT-LLM/pulls/${PR_NUM}" --jq '.head.sha') +gh api "repos/NVIDIA/TensorRT-LLM/commits/${HEAD_SHA}/statuses" --jq \ + '[.[] | select(.context == "blossom-ci")] | first | {state, description}' +``` + +The `description` field shows aggregate counts like `"23969 passed, 1 failed, 8962 skipped"`. + +## Phase 1 — Get the Jenkins Build Number + +Extract the `L0_MergeRequest_PR` build number from the CI bot comment: +```bash +BUILD_NUM=$(gh api "repos/NVIDIA/TensorRT-LLM/issues/${PR_NUM}/comments" --paginate --jq \ '[.[] | select(.user.login == "tensorrt-cicd") | select(.body | test("L0_MergeRequest_PR"))] | last | .body' \ | grep -oP 'L0_MergeRequest_PR/\K\d+') ``` -## Phase 2 — Query the Jenkins testReport API for Failures +## Phase 1.5 — Check Pipeline Stage Failures (before diving into test details) -Resolve the Jenkins base URL dynamically from the internal shortcut (requires corporate network): -```bash -JENKINS_BASE="$(curl -skI 'https://nv/trt-llm-cicd' 2>/dev/null | grep -i '^location:' | sed 's/^[Ll]ocation: *//;s/[[:space:]]*$//')job/main/job/L0_MergeRequest_PR" +Many CI failures are **infrastructure-level** (Slurm node issues, pipeline aborts, resource exhaustion) where no test code executes at all. Always check the pipeline stages first: + +```python +import json, ssl, urllib.request + +ctx = ssl.create_default_context() +ctx.check_hostname = False +ctx.verify_mode = ssl.CERT_NONE + +JENKINS_BASE = "https://prod.blsm.nvidia.com/sw-tensorrt-top-1/job/LLM/job/main/job/L0_MergeRequest_PR" +BUILD_NUM = + +# Get pipeline stage overview +url = f"{JENKINS_BASE}/{BUILD_NUM}/wfapi/describe" +resp = urllib.request.urlopen(urllib.request.Request(url), context=ctx, timeout=30) +data = json.loads(resp.read()) + +print(f"Pipeline status: {data.get('status')}") +for stage in data.get('stages', []): + status = stage.get('status', '') + if status not in ('SUCCESS', 'SKIPPED', 'NOT_EXECUTED'): + name = stage.get('name', '') + print(f" [{status}] {name}") + if 'error' in stage: + print(f" Error: {stage['error']}") ``` -```bash -curl -s "${JENKINS_BASE}/${BUILD_NUM}/testReport/api/json" | python3 -c " -import json, sys -data = json.load(sys.stdin) -print(f'Summary: {data[\"passCount\"]} passed, {data[\"failCount\"]} failed, {data[\"skipCount\"]} skipped') +## Phase 1.6 — Read Console Log Analysis (Most Valuable for Infrastructure Failures) + +The Jenkins console log contains a **CI failure analysis summary** with sections like `## Recommended Actions` and `## Infrastructure Notes`. This is the single most valuable source for understanding infrastructure failures: + +```python +url = f"{JENKINS_BASE}/{BUILD_NUM}/consoleText" +resp = urllib.request.urlopen(urllib.request.Request(url), context=ctx, timeout=30) +text = resp.read().decode('utf-8', errors='replace') + +# Extract failure-related lines from the end of the log +for line in text[-8000:].split('\n'): + lo = line.lower() + if any(kw in lo for kw in ['fail', 'error', 'abort', 'likely cause', + 'recommended action', 'infrastructure', + 'no test code', 'stage result']): + print(line.strip()[:300]) +``` + +Key sections to look for in the console log: +- **`Failing job`** / **`Failed stage`**: which Jenkins sub-job and stage failed +- **`Likely cause`**: automated root cause analysis (Slurm issues, pipeline timeouts, etc.) +- **`No test code was executed`**: confirms infrastructure-only failure (no code fix needed) +- **`Recommended Actions`**: whether to re-trigger CI or investigate code changes + +## Phase 2 — Query the Jenkins testReport API for Test Failures + +Only proceed here if Phase 1.5/1.6 indicate actual test failures (not infrastructure issues): + +```python +url = f"{JENKINS_BASE}/{BUILD_NUM}/testReport/api/json" +resp = urllib.request.urlopen(urllib.request.Request(url), context=ctx, timeout=30) +data = json.loads(resp.read()) + +print(f'Summary: {data["passCount"]} passed, {data["failCount"]} failed, {data["skipCount"]} skipped') + failed = [] for suite in data.get('suites', []): for case in suite.get('cases', []): if case.get('status') in ('FAILED', 'REGRESSION'): failed.append(case) + if not failed: - print('No test failures!') + print('No test failures in testReport!') else: print(f'Failed tests ({len(failed)}):') for f in failed: - print(f' - {f[\"className\"]}.{f[\"name\"]}') + print(f' - {f["className"]}.{f["name"]}') err = (f.get('errorDetails') or '')[:200] if err: print(f' Error: {err}') -" ``` -## Phase 3 — Get Full stdout/stderr for a Specific Failure +## Phase 3 — Get Full stdout/stderr for a Specific Test Failure -The `errorStackTrace` can be incomplete when errors originate from subprocesses. In that case, fetch `stdout` and `stderr` for the specific test case to find the real error: -```bash -curl -s "${JENKINS_BASE}/${BUILD_NUM}/testReport/api/json" | python3 -c " -import json, sys -data = json.load(sys.stdin) +The `errorStackTrace` can be incomplete when errors originate from subprocesses. Fetch `stdout` and `stderr` for the specific test case to find the real error: +```python for suite in data.get('suites', []): for case in suite.get('cases', []): if case.get('status') in ('FAILED', 'REGRESSION'): - name = f'{case[\"className\"]}.{case[\"name\"]}' + name = f'{case["className"]}.{case["name"]}' if '' in name: print(f'=== {name} ===') print('--- Error ---') @@ -71,7 +151,6 @@ for suite in data.get('suites', []): print('--- Stderr (last 3000 chars) ---') print((case.get('stderr') or '')[-3000:]) break -" ``` ## Available Fields per Failed Test Case (Jenkins testReport API) @@ -82,8 +161,20 @@ for suite in data.get('suites', []): - `errorStackTrace`: full stack trace (may be incomplete for subprocess errors) - `stdout`, `stderr`: full test output (can be large, check these when stack trace is insufficient) +## Common Failure Patterns + +| Pattern | Diagnosis | Action | +|---------|-----------|--------| +| `No test code was executed` + Slurm errors | Infrastructure: Slurm node resource exhaustion | Re-trigger CI | +| `ABORTED` stage + `Downstream job did not succeed` | Cascading failure from fail-fast policy | Fix root cause stage, re-trigger | +| `newosproc` / `errno=11` / `fork/exec` | Kernel process table exhaustion on login node | Wait and re-trigger | +| `testReport: 0 failed` but `blossom-ci: N failed` | Stage-level failures, not test failures | Check Phase 1.5/1.6 | +| `testReport: N failed` with real test names | Actual test code failures | Investigate test errors in Phase 3 | + ## Anti-Patterns -- Do not guess Jenkins URLs; always resolve dynamically via the internal shortcut. +- Do not guess Jenkins URLs; always use the known base `https://prod.blsm.nvidia.com/sw-tensorrt-top-1/job/LLM/job/main/job/L0_MergeRequest_PR`. +- Do not use `curl -s` for Jenkins API; it returns HTML login pages. Use Python `urllib` with SSL bypass. +- Do not jump to testReport (Phase 2) before checking pipeline stages (Phase 1.5) — many failures are infrastructure-only with zero test failures. - Do not stop at `errorStackTrace` if it mentions generic wrapper failures like `Process exited with status 1`; check `stdout` and `stderr` for the real error. - Do not fetch all test cases when looking for a specific failure; use the `` filter in Phase 3. diff --git a/examples/auto_deploy/.gitignore b/examples/auto_deploy/.gitignore index 0999a4ed7619..bb5a36a5efc3 100644 --- a/examples/auto_deploy/.gitignore +++ b/examples/auto_deploy/.gitignore @@ -7,3 +7,4 @@ benchmark_results.json !nano_v3.yaml !nemotron_flash.yaml !model_registry/configs/*.yaml +!model_registry/enable_sharder_ir.yaml diff --git a/examples/auto_deploy/model_registry/enable_sharder_ir.yaml b/examples/auto_deploy/model_registry/enable_sharder_ir.yaml new file mode 100644 index 000000000000..e7b48d0fea56 --- /dev/null +++ b/examples/auto_deploy/model_registry/enable_sharder_ir.yaml @@ -0,0 +1,9 @@ +# Enable the hint-driven IR sharding system and disable the legacy heuristic sharding. +# Use via yaml_extra in models.yaml or --yaml-extra on the command line. +transforms: + detect_sharding: + enabled: false + sharding_transform_executor: + enabled: false + apply_sharding_hints: + enabled: true diff --git a/examples/auto_deploy/model_registry/models.yaml b/examples/auto_deploy/model_registry/models.yaml index cae73b1b941c..4bec39c2abb7 100644 --- a/examples/auto_deploy/model_registry/models.yaml +++ b/examples/auto_deploy/model_registry/models.yaml @@ -505,3 +505,11 @@ models: yaml_extra: ['dashboard_default.yaml', 'world_size_2.yaml', 'multimodal.yaml'] - name: CohereLabs/aya-vision-32b yaml_extra: ['dashboard_default.yaml', 'world_size_4.yaml', 'multimodal.yaml'] +# ============================================================================= +# IR sharding (hint-driven) models +# These use enable_sharder_ir.yaml to opt in to the new sharding system. +# ============================================================================= +- name: nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8-IR + yaml_extra: ['dashboard_default.yaml', 'world_size_4.yaml', 'nano_v3.yaml', 'enable_sharder_ir.yaml'] +- name: deepseek-ai/DeepSeek-R1-IR + yaml_extra: ['dashboard_default.yaml', 'world_size_8.yaml', 'num_hidden_layers_5.yaml', 'enable_sharder_ir.yaml'] diff --git a/examples/auto_deploy/nemotron/nemotron_fp8_ir_test.yaml b/examples/auto_deploy/nemotron/nemotron_fp8_ir_test.yaml new file mode 100644 index 000000000000..47e2ae8818fe --- /dev/null +++ b/examples/auto_deploy/nemotron/nemotron_fp8_ir_test.yaml @@ -0,0 +1,34 @@ +# Test config: Nemotron Nano FP8 with IR sharding, strip_sharding_hints DISABLED +model: nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8 +args: + world_size: 2 + runtime: trtllm + compile_backend: torch-simple + max_seq_len: 512 + max_num_tokens: 512 + max_batch_size: 128 + enable_chunked_prefill: true + model_factory: AutoModelForCausalLM + skip_loading_weights: false + kv_cache_config: + free_gpu_memory_fraction: 0.88 + mamba_ssm_cache_dtype: auto + transforms: + detect_sharding: + enabled: false + sharding_transform_executor: + enabled: false + apply_sharding_hints: + enabled: true + stage: sharding + run_shape_prop: true + allreduce_strategy: NCCL + strip_sharding_hints: + enabled: false + gather_logits_before_lm_head: + enabled: true + fuse_mamba_a_log: + stage: post_load_fusion + enabled: true + insert_cached_ssm_attention: + backend: flashinfer_ssm diff --git a/tensorrt_llm/_torch/auto_deploy/config/default.yaml b/tensorrt_llm/_torch/auto_deploy/config/default.yaml index fd1ac47f9e0f..721f908d2761 100644 --- a/tensorrt_llm/_torch/auto_deploy/config/default.yaml +++ b/tensorrt_llm/_torch/auto_deploy/config/default.yaml @@ -126,9 +126,16 @@ transforms: sharding_transform_executor: stage: sharding run_shape_prop: true + apply_sharding_hints: + enabled: false + stage: sharding + run_shape_prop: true + allreduce_strategy: NCCL ############################################################################################ # MOVE MODEL AND LOAD WEIGHTS ############################################################################################ + strip_sharding_hints: + stage: weight_load load_weights: stage: weight_load run_per_gm: false diff --git a/tensorrt_llm/_torch/auto_deploy/config/export_edgellm_onnx.yaml b/tensorrt_llm/_torch/auto_deploy/config/export_edgellm_onnx.yaml index 24db07b847c6..bec43b44df1d 100644 --- a/tensorrt_llm/_torch/auto_deploy/config/export_edgellm_onnx.yaml +++ b/tensorrt_llm/_torch/auto_deploy/config/export_edgellm_onnx.yaml @@ -74,6 +74,8 @@ transforms: ############################################################################################ # MOVE MODEL AND LOAD WEIGHTS ############################################################################################ + strip_sharding_hints: + stage: weight_load load_weights: stage: weight_load run_per_gm: false diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_attention.py index 77b5cfc9820a..205ed4120573 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_attention.py @@ -121,6 +121,7 @@ def torch_attention( layout: str = "bnsd", # "bnsd" or "bsnd" layer_idx: Optional[int] = None, shared_kv_source_layer_idx: Optional[int] = None, + layer_type: str = "mha", ) -> torch.Tensor: """ SDPA attention (with optional GQA) that supports two memory layouts via `layout`: @@ -130,6 +131,9 @@ def torch_attention( The `attn_mask` is always interpreted as [b, n, s_q, s_k]. Returns a tensor in the SAME layout as inputs specified by `layout`. + + ``layer_type`` is graph metadata for ``apply_sharding_hints`` and does not + affect the numeric result. """ # `layer_idx` and `shared_kv_source_layer_idx` are graph metadata used by the KV-cache # transform; the eager attention kernel itself does not need them. @@ -245,5 +249,6 @@ def torch_attention_fake( layout: str = "bnsd", layer_idx: Optional[int] = None, shared_kv_source_layer_idx: Optional[int] = None, + layer_type: str = "mha", ): return query.new_empty(*query.shape[:-1], value.shape[-1]).contiguous() diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_gated_delta.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_gated_delta.py index 75db3d4ed459..557224d8ec59 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_gated_delta.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_gated_delta.py @@ -148,25 +148,46 @@ def torch_gated_delta_rule( A_log: torch.Tensor, dt_bias: torch.Tensor, scale: Optional[float] = None, + enable_sharding: bool = False, + layer_type: str = "delta", ) -> torch.Tensor: - """Gated Delta Rule custom op for linear attention (torch reference implementation). + """Gated Delta Rule (GDN) custom op for linear attention (torch reference). Performs L2 normalization, GQA repeat-interleave, gating computation, and the - gated delta rule recurrence internally. All inputs use the autodeploy [B, S, H, D] - (bsnd) layout convention. + chunked gated-delta recurrence internally. All tensor arguments use the AutoDeploy + ``[B, S, H, D]`` (``bsnd``) layout convention. Args: - q: [B, S, H_k, K] - raw query states (un-normalized, un-expanded) - k: [B, S, H_k, K] - raw key states (un-normalized, un-expanded) - v: [B, S, HV, V] - value states - a: [B, S, HV] - raw gating projection (before softplus) - b: [B, S, HV] - raw beta projection (before sigmoid) - A_log: [HV] - log of decay base per value head - dt_bias: [HV] - bias added to gating projection - scale: optional query scaling factor (defaults to K^-0.5) + q: Raw query states (not L2-normalized, not GQA-expanded). Shape + ``[B, S, H_k, K]`` where ``H_k`` is the number of key/query heads. + k: Raw key states, shape ``[B, S, H_k, K]``. + v: Value states, shape ``[B, S, H_v, V]`` where ``H_v`` may exceed ``H_k`` + for GQA (value heads are the parallel dimension for the recurrence). + a: Raw gating projection logits (before ``softplus``), shape ``[B, S, H_v]``. + b: Raw beta projection (before ``sigmoid``), shape ``[B, S, H_v]``. + A_log: Logarithm of the per-head decay base, shape ``[H_v]``. Combined with + ``a`` and ``dt_bias`` to form ``g = -exp(A_log) * softplus(a + dt_bias)``. + dt_bias: Bias added to ``a`` inside ``softplus`` for the gating path, shape + ``[H_v]``. + scale: Optional query scale; default ``K ** -0.5`` when ``None``. + enable_sharding: When ``True``, ``apply_sharding_hints`` shards ``A_log`` and + ``dt_bias`` along the **head** dimension (each rank holds the slice for + its local value heads). When ``False``, those 1D parameters are not + head-sharded by the hint pass. + layer_type: Layer classification for selective sharding via ``shard_layers`` + config. Values: ``"mha"``, ``"mla"``, ``"mlp"``, ``"moe"``, ``"ssm"``, + ``"delta"``, ``"unknown"``. + + Sharding hint arguments (graph-level metadata for ``apply_sharding_hints``): + ``enable_sharding``: When ``True``, ``apply_sharding_hints`` will shard the op's + weight/parameter ancestors along the head dimension (here: ``A_log`` and + ``dt_bias``). + ``layer_type``: Layer classification for selective sharding via + ``shard_layers`` config. Returns: - output: [B, S, HV, V] + Linear-attention output of shape ``[B, S, H_v, V]`` (same head/count layout as + ``v``). """ H_k = q.shape[2] HV = v.shape[2] @@ -206,6 +227,8 @@ def torch_gated_delta_rule_fake( A_log: torch.Tensor, dt_bias: torch.Tensor, scale: Optional[float] = None, + enable_sharding: bool = False, + layer_type: str = "delta", ) -> torch.Tensor: # Output shape is [B, S, H, V] matching v (not q/k which may have fewer heads) return torch.empty_like(v) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/mxfp4_moe.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/mxfp4_moe.py index 6bf291ef7299..6da2b61a8411 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/mxfp4_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/mxfp4_moe.py @@ -172,6 +172,7 @@ def triton_mxfp4_moe( down_blocks: torch.Tensor, # [E, H, I//32, 16] in uint8 down_bias: torch.Tensor, # [E, H] down_scales: torch.Tensor, # [E, H, I//32] in uint8 + layer_type: str = "moe", ) -> torch.Tensor: def _global_route_fn(logits: torch.Tensor): # routing() removed in triton_kernels 3.6.0 @@ -208,6 +209,7 @@ def _mxfp4_mlp_fake( down_blocks: torch.Tensor, down_bias: torch.Tensor, down_scales: torch.Tensor, + layer_type: str = "moe", ): return torch.empty_like(hidden_states) @@ -231,6 +233,7 @@ def triton_mxfp4_moe_ep( # EP topology ep_size: int, ep_rank: int, + layer_type: str = "moe", ) -> torch.Tensor: triton_ep_router = TritonEPRouter() @@ -269,5 +272,6 @@ def _mxfp4_mlp_ep_fake( down_scales: torch.Tensor, ep_size: int, ep_rank: int, + layer_type: str = "moe", ): return torch.empty_like(hidden_states) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py index d231de351364..cbc463ceee9e 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py @@ -20,9 +20,8 @@ import torch.nn.functional as F from tensorrt_llm._torch.auto_deploy.distributed import common as dist_common -from tensorrt_llm._torch.auto_deploy.utils.mapping_utils import deserialize_mapping +from tensorrt_llm._torch.auto_deploy.utils.dist_config import DistConfig from tensorrt_llm._torch.utils import ActivationType -from tensorrt_llm.mapping import Mapping def _template_moe_alltoall( @@ -31,7 +30,7 @@ def _template_moe_alltoall( routing_weights: torch.Tensor, mlps: List[Callable[[torch.Tensor], torch.Tensor]], apply_routing_on_input: bool, - mapping: Mapping, + mapping: DistConfig, max_num_tokens: int = 0, ) -> torch.Tensor: """ @@ -215,7 +214,7 @@ def _template_moe( """ # Check if all-to-all mode is enabled - mapping = deserialize_mapping(mapping_config) if mapping_config else None + mapping = DistConfig.deserialize(mapping_config) if mapping_config else None enable_alltoall = ( mapping is not None and mapping.enable_attention_dp and mapping.moe_ep_size > 1 ) @@ -287,6 +286,7 @@ def torch_moe( mapping_config: str = "", max_num_tokens: int = 0, apply_routing_on_input: bool = False, + layer_type: str = "moe", ) -> torch.Tensor: """ Unified Mixture-of-Experts (MoE) operator that uses a Mixtral-style dispatch @@ -311,6 +311,9 @@ def torch_moe( This means: silu(input) * routing_weight Returns: torch.Tensor: Output tensor with the same shape as the input x. + + ``layer_type`` is graph metadata for ``apply_sharding_hints`` and does not + affect the numeric result. """ torch_act_fn = _resolve_torch_fn(act_fn) @@ -360,6 +363,7 @@ def torch_moe_fake( mapping_config: str = "", max_num_tokens: int = 0, apply_routing_on_input: bool = False, + layer_type: str = "moe", ) -> torch.Tensor: return torch.empty_like(x) @@ -449,6 +453,7 @@ def torch_quant_fp8_moe( mapping_config: str = "", max_num_tokens: int = 0, apply_routing_on_input: bool = False, + layer_type: str = "moe", ) -> torch.Tensor: """ FP8 MoE op using quantized linear operations. Computes a Mixture-of-Experts layer similar to the reference @@ -568,6 +573,7 @@ def torch_quant_fp8_moe_fake( mapping_config: str = "", max_num_tokens: int = 0, apply_routing_on_input: bool = False, + layer_type: str = "moe", ) -> torch.Tensor: return torch.empty_like(x) @@ -594,6 +600,7 @@ def torch_quant_nvfp4_moe( mapping_config: str = "", max_num_tokens: int = 0, apply_routing_on_input: bool = False, + layer_type: str = "moe", ) -> torch.Tensor: """ FP4 MoE op using quantized linear operations. @@ -729,6 +736,7 @@ def torch_quant_nvfp4_moe_fake( mapping_config: str = "", max_num_tokens: int = 0, apply_routing_on_input: bool = False, + layer_type: str = "moe", ) -> torch.Tensor: return torch.empty_like(x) @@ -799,6 +807,7 @@ def torch_quant_finegrained_fp8_moe( mapping_config: str = "", max_num_tokens: int = 0, apply_routing_on_input: bool = False, + layer_type: str = "moe", ) -> torch.Tensor: """ FineGrainedFP8 MoE op using block-wise FP8 quantized linear operations. @@ -912,5 +921,6 @@ def torch_quant_finegrained_fp8_moe_fake( mapping_config: str = "", max_num_tokens: int = 0, apply_routing_on_input: bool = False, + layer_type: str = "moe", ) -> torch.Tensor: return torch.empty_like(x) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py index 0060cb2e78a4..4d053297fcf2 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py @@ -20,7 +20,7 @@ from tensorrt_llm._torch.auto_deploy.custom_ops.quantization.quant import ( TRTLLM_NVFP4_SCALING_VECTOR_SIZE, ) -from tensorrt_llm._torch.auto_deploy.utils.mapping_utils import deserialize_mapping +from tensorrt_llm._torch.auto_deploy.utils.dist_config import DistConfig from tensorrt_llm._torch.distributed.moe_alltoall import MoeAlltoAll from tensorrt_llm._torch.modules.fused_moe.routing import RoutingMethodType from tensorrt_llm._torch.utils import ActivationType @@ -36,10 +36,11 @@ def _check_moe_alltoall(mapping_config: str, max_num_tokens: int) -> Tuple[Mappi Returns: (mapping, enable_alltoall) — mapping is None when mapping_config is empty. """ - mapping = deserialize_mapping(mapping_config) if mapping_config else None - enable_alltoall = ( - mapping is not None and mapping.enable_attention_dp and mapping.moe_ep_size > 1 - ) + if not mapping_config: + return None, False + dc = DistConfig.deserialize(mapping_config) + mapping = dc.to_mapping() + enable_alltoall = dc.enable_attention_dp and dc.moe_ep_size > 1 if enable_alltoall and max_num_tokens <= 0: raise ValueError("max_num_tokens must be > 0 when enable_alltoall is True") return mapping, enable_alltoall diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/linear/linear.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/linear/linear.py index 4a7a04adef19..254410e6ba47 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/linear/linear.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/linear/linear.py @@ -15,25 +15,68 @@ """Custom ops for linear layers.""" -from typing import Optional +from typing import List, Optional import torch @torch.library.custom_op("auto_deploy::torch_linear_simple", mutates_args=()) -def simple(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: +def simple( + input: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + tp_mode: str = "none", + output_sizes: Optional[List[int]] = None, + tp_min_local_shape: int = 1, + layer_type: str = "unknown", +) -> torch.Tensor: """A wrapper for the linear functional to control how it is exposed. - By default F.linear (used in linear layers) will be represented as a call to - torch.ops.aten.linear.default wrapped with two view ops to flatten/unflatten multiple batch - dimensions into one batch dimension. + By default ``F.linear`` (used in linear layers) is represented as a call to + ``torch.ops.aten.linear.default`` wrapped with two ``view`` ops to flatten/unflatten + multiple batch dimensions into one batch dimension. This wrapper avoids exposing + that reshape pattern during export. - This wrapper avoids exposing this view op during the export graph. + Args: + input: Input activations passed to ``torch.nn.functional.linear`` + (``input @ weight.T + bias``). Shape is typically ``(..., in_features)``. + weight: Weight matrix of shape ``(out_features, in_features)``. + bias: Optional bias vector of shape ``(out_features,)``. If ``None``, no bias + is applied. + tp_mode: TP sharding mode hint (see "Sharding hint arguments" below). + output_sizes: Fused-weight group sizes hint (see below). + tp_min_local_shape: Minimum per-rank output width hint (see below). + layer_type: Layer classification hint for selective sharding (see below). + + Sharding hint arguments (graph-level metadata for ``apply_sharding_hints``): + ``tp_mode``: TP sharding mode. ``"colwise"`` shards weight dim 0, + ``"rowwise"`` shards weight dim 1, ``"none"`` skips sharding. + ``output_sizes``: Group sizes for fused-weight proportional column sharding + (e.g., ``[q_dim, kv_dim, kv_dim]`` for fused QKV). + ``tp_min_local_shape``: Minimum output size per rank after sharding. Used for + GQA where ``num_kv_heads < tp_size`` (set to ``head_dim``). + ``layer_type``: Layer classification for selective sharding via + ``shard_layers`` config. Values: ``"mha"``, ``"mla"``, ``"mlp"``, + ``"moe"``, ``"ssm"``, ``"delta"``, ``"unknown"``. + + These hint arguments do not change the numeric result of the linear; they only + guide graph transforms when tensor-parallel sharding is applied. + + Returns: + Output tensor of shape ``(..., out_features)``. """ return torch.ops.aten.linear(input, weight, bias) @simple.register_fake -def simple_fake(input, weight, bias): +def simple_fake( + input, + weight, + bias, + tp_mode="none", + output_sizes=None, + tp_min_local_shape=1, + layer_type="unknown", +): """Fake implementation of simple_linear.""" return torch.ops.aten.linear(input, weight, bias) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_causal_conv.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_causal_conv.py index 28b15388fb9d..6bded30fb78c 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_causal_conv.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_causal_conv.py @@ -15,7 +15,7 @@ """Custom op collection for uncached causal conv (sliding window with 1d).""" -from typing import Optional +from typing import List, Optional import torch import torch.nn.functional as F @@ -31,7 +31,48 @@ def _torch_causal_conv1d( dilation: int = 1, groups: int = 1, padding_mode: str = "zeros", + enable_sharding: bool = False, + output_sizes: Optional[List[int]] = None, + layer_type: str = "ssm", ) -> torch.Tensor: + """Causal 1D convolution along the sequence axis (sliding window over time). + + The input layout is ``[batch, seq_len, channels]``. The implementation transposes + to ``[batch, channels, seq_len]``, applies :func:`torch.nn.functional.conv1d`, + then trims to the original sequence length so the receptive field is causal. + + Args: + input: Activations of shape ``[batch, seq_len, in_channels]`` (or compatible + channel-last layout consumed by the transpose into conv1d). + weight: Conv1d kernel of shape + ``(out_channels, in_channels / groups, kernel_size)``. + bias: Optional bias of shape ``(out_channels,)``. + stride: Conv1d stride (default ``1``). + padding: Conv1d padding (default ``0``). + dilation: Conv1d dilation (default ``1``). + groups: Conv1d groups (default ``1``). + padding_mode: Must be ``"zeros"``; other modes raise. + enable_sharding: When ``True``, ``apply_sharding_hints`` shards the conv1d + ``weight`` along its **output channel** dimension (head-parallel conv + weights). When ``False``, sharding passes leave weights unchanged. + output_sizes: Optional group sizes for fused-weight proportional column + sharding (same convention as linear fused projections; consumed by + ``apply_sharding_hints`` when applicable). + layer_type: Layer classification for selective sharding via ``shard_layers`` + config. Values: ``"mha"``, ``"mla"``, ``"mlp"``, ``"moe"``, ``"ssm"``, + ``"delta"``, ``"unknown"``. + + Sharding hint arguments (graph-level metadata for ``apply_sharding_hints``): + ``enable_sharding``: When ``True``, ``apply_sharding_hints`` will shard the op's + weight ancestors along the conv output-channel dimension (per-rank conv). + ``output_sizes``: Group sizes for fused-weight proportional column sharding + when the surrounding graph uses fused projections. + ``layer_type``: Layer classification for selective sharding via + ``shard_layers`` config. + + Returns: + Tensor of the same batch/sequence layout as ``input`` after causal conv. + """ assert padding_mode == "zeros", "padding_mode must be zeros" batch_size, seq_len, _ = input.shape @@ -61,5 +102,8 @@ def _torch_causal_conv1d_meta( dilation: int = 1, groups: int = 1, padding_mode: str = "zeros", + enable_sharding: bool = False, + output_sizes: Optional[List[int]] = None, + layer_type: str = "ssm", ) -> torch.Tensor: return torch.empty_like(input) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_mamba.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_mamba.py index dbec15699eee..227285b6783d 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_mamba.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_mamba.py @@ -190,7 +190,49 @@ def _torch_ssm( float ], # NOTE: `torch` custom ops do not like `Tuple` inputs. Using `List` is the suggested WAR. chunk_size: int, + enable_sharding: bool = False, + layer_type: str = "ssm", ) -> torch.Tensor: + """Mamba state-space model (SSM) mixer forward (prefill / uncached path). + + Implements the chunked SSM recurrence in ``_torch_ssm_prefill``: discretizes + states with ``dt`` / ``dt_bias``, applies intra- and inter-chunk updates, and + returns the hidden output. ``time_step_limit`` clamps the softplus-scaled + time step. + + Args: + hidden_states: Input hidden states, shape + ``[batch, seq_len, num_heads, head_dim]``. + A: SSM decay/state matrix term, broadcast-compatible with chunked layout + (see internal reshape into chunks). + B: SSM input matrix (low-rank / group structure as in Mamba). + C: SSM output matrix (low-rank / group structure as in Mamba). + D: Residual skip scale, applied per head/feature as in the reference. + dt: Raw delta / time-step logits before softplus and clamping. + dt_bias: Bias added inside the softplus path for ``dt``. + time_step_limit: Two-element list ``[min, max]`` for clamping the effective + time step after ``softplus(dt + dt_bias)``. (Passed as ``List`` because + Torch custom ops avoid tuple inputs.) + chunk_size: Chunk length for blocked SSM computation along the sequence axis. + enable_sharding: When ``True``, ``apply_sharding_hints`` shards parameters such as + ``A``, ``D``, and ``dt_bias`` along the **head** dimension (per-rank head + slices). When ``False``, those parameter nodes are not head-sharded by the + hint pass. + layer_type: Layer classification for selective sharding via ``shard_layers`` + config. Values: ``"mha"``, ``"mla"``, ``"mlp"``, ``"moe"``, ``"ssm"``, + ``"delta"``, ``"unknown"``. + + Sharding hint arguments (graph-level metadata for ``apply_sharding_hints``): + ``enable_sharding``: When ``True``, ``apply_sharding_hints`` will shard the op's + weight/parameter ancestors along the head dimension (e.g., ``A``, ``D``, + ``dt_bias``). + ``layer_type``: Layer classification for selective sharding via + ``shard_layers`` config. + + Returns: + SSM output tensor, same shape as ``hidden_states`` (float32 compute may be + used internally; see fake/meta for export). + """ y, _ = _torch_ssm_prefill(hidden_states, A, B, C, D, dt, dt_bias, time_step_limit, chunk_size) return y @@ -206,5 +248,7 @@ def _torch_ssm_meta( dt_bias: torch.Tensor, time_step_limit: List[float], chunk_size: int, + enable_sharding: bool = False, + layer_type: str = "ssm", ) -> torch.Tensor: return torch.empty_like(hidden_states, dtype=torch.float32) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mla/torch_mla.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mla/torch_mla.py index 5be00a764eb6..164899a0eb5d 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mla/torch_mla.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mla/torch_mla.py @@ -22,25 +22,51 @@ def torch_mla( is_causal: bool = True, scale: Optional[float] = None, layout: str = "bsnd", + enable_sharding: bool = False, + layer_type: str = "mla", ) -> torch.Tensor: """Multi-head Latent Attention (MLA) with FlashInfer-compatible compressed KV. - This op expands compressed_kv using kv_b_proj_weight and computes attention. - For prefill, this is the standard formulation. For the cached version, - weight absorption is used for efficiency. + This op expands ``compressed_kv`` with ``kv_b_proj_weight`` and computes + standard dot-product attention. For prefill, this is the direct matmul/softmax + formulation; a separate cached path may use weight absorption elsewhere. Args: - q_nope: Query non-positional component [B, S, N, qk_nope_head_dim] (bsnd) - q_pe: Query positional component with RoPE applied [B, S, N, qk_rope_head_dim] (bsnd) - compressed_kv: Compressed KV latent [B, S, kv_lora_rank] (before kv_b_proj) - kpe: Key positional encoding with RoPE applied [B, S, 1, qk_rope_head_dim] (bsnd) - kv_b_proj_weight: Projection weights [num_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank] - is_causal: Whether to apply causal masking (default: True) - scale: Softmax scale factor (default: 1/sqrt(qk_head_dim)) - layout: Input/output layout, either "bsnd" or "bnsd" (default: "bsnd") + q_nope: Query non-positional component. Shape ``[B, S, N, qk_nope_head_dim]`` + when ``layout == "bsnd"``, or ``[B, N, S, qk_nope_head_dim]`` when + ``layout == "bnsd"``. + q_pe: Query positional (RoPE) component. Shape ``[B, S, N, qk_rope_head_dim]`` + or ``[B, N, S, qk_rope_head_dim]`` matching ``layout``. + compressed_kv: Compressed KV latent ``[B, S, kv_lora_rank]`` **before** + ``kv_b_proj`` expansion. + kpe: Key positional (RoPE) encodings ``[B, S, 1, qk_rope_head_dim]`` (or the + ``bnsd`` transpose consistent with ``layout``). + kv_b_proj_weight: Unpacked projection weights of shape + ``[num_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank]``. This is + argument index 4 (the fifth argument). + is_causal: If ``True`` and ``s_q == s_k``, apply a causal upper-triangular + mask to attention logits. + scale: Softmax temperature; default ``1 / sqrt(qk_nope_head_dim + qk_rope_head_dim)``. + layout: ``"bsnd"`` or ``"bnsd"`` for batch/sequence/head dimension ordering. + enable_sharding: When ``True``, ``apply_sharding_hints`` shards ``kv_b_proj_weight`` + **column-wise along the head dimension**: the weight is treated as a + stacked per-head projection, so each TP rank keeps the slice of rows + corresponding to its local heads (out_features grouped by head). When + ``False``, the hint pass does not apply that head-parallel rewrite to + ``kv_b_proj_weight``. + layer_type: Layer classification for selective sharding via ``shard_layers`` + config. Values: ``"mha"``, ``"mla"``, ``"mlp"``, ``"moe"``, ``"ssm"``, + ``"delta"``, ``"unknown"``. + + Sharding hint arguments (graph-level metadata for ``apply_sharding_hints``): + ``enable_sharding``: When ``True``, ``apply_sharding_hints`` shards ``kv_b_proj_weight`` + (arg ``kv_b_proj_weight`` / arg[4]) columnwise along the head dimension. + ``layer_type``: Selects whether MLA nodes are rewritten for a given + ``shard_layers`` configuration. Returns: - Attention output with shape [B, S, N, v_head_dim] (bsnd) + Attention output: ``[B, S, N, v_head_dim]`` for ``bsnd``, or ``[B, N, S, v_head_dim]`` + for ``bnsd``, consistent with ``layout``. """ if layout not in ("bnsd", "bsnd"): raise ValueError(f"layout must be 'bnsd' or 'bsnd', got {layout!r}") @@ -135,6 +161,8 @@ def torch_mla_fake( is_causal: bool = True, scale: Optional[float] = None, layout: str = "bsnd", + enable_sharding: bool = False, + layer_type: str = "mla", ) -> torch.Tensor: """Fake implementation for torch_mla.""" # Infer v_head_dim from kv_b_proj_weight diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/normalization/rms_norm.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/normalization/rms_norm.py index d1a74ea29096..9f32079d91e5 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/normalization/rms_norm.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/normalization/rms_norm.py @@ -112,6 +112,8 @@ def torch_rmsnorm_gated( eps: float, group_size: int, norm_before_gate: bool = False, + tp_mode: str = "none", + layer_type: str = "unknown", ) -> torch.Tensor: """Custom operator for Torch gated RMSNorm implementation. @@ -124,6 +126,8 @@ def torch_rmsnorm_gated( eps: Small constant for numerical stability. group_size: Size of groups for grouped normalization. H must be divisible by group_size. norm_before_gate: If True, apply gating after normalization. If False, apply before. + tp_mode: Tensor-parallel sharding hint for transforms. + layer_type: Layer id hint for selective sharding (e.g. ``shard_layers``). Returns: Normalized and optionally gated tensor of shape like x. @@ -158,6 +162,8 @@ def _( eps: float, group_size: int, norm_before_gate: bool = False, + tp_mode: str = "none", + layer_type: str = "unknown", ) -> torch.Tensor: """Fake implementation for the custom operator during tracing.""" return x.new_empty(x.shape, dtype=x.dtype) @@ -171,6 +177,8 @@ def triton_rmsnorm_gated( eps: float, group_size: int, norm_before_gate: bool = False, + tp_mode: str = "none", + layer_type: str = "unknown", ) -> torch.Tensor: """ Group RMSNorm with optional SiLU gating, using Triton kernel `_layer_norm_fwd`. @@ -227,6 +235,8 @@ def _triton_rmsnorm_gated_meta( eps: float, group_size: int, norm_before_gate: bool = False, + tp_mode: str = "none", + layer_type: str = "unknown", ): assert x.dim() >= 2, "x must be at least 2D" H = x.shape[-1] diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/quantization/quant.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/quantization/quant.py index 03ba73954f5f..92e235a9b1c8 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/quantization/quant.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/quantization/quant.py @@ -16,7 +16,7 @@ """Definition of the quant module that can be used for PTQ.""" import warnings -from typing import Optional, Tuple +from typing import List, Optional, Tuple import torch from flashinfer import bmm_fp8 @@ -143,6 +143,10 @@ def trtllm_quant_fp8_linear( input_scale: Optional[torch.Tensor] = None, weight_scale: Optional[torch.Tensor] = None, out_dtype: Optional[str] = None, + tp_mode: str = "none", + output_sizes: Optional[List[int]] = None, + tp_min_local_shape: int = 1, + layer_type: str = "unknown", ) -> torch.Tensor: """FP8 linear op similar to torch.nn.linear using TensorRT-LLM FP8 operations. @@ -193,6 +197,10 @@ def trtllm_quant_fp8_linear_fake( input_scale: Optional[torch.Tensor] = None, weight_scale: Optional[torch.Tensor] = None, out_dtype: Optional[str] = None, + tp_mode: str = "none", + output_sizes: Optional[List[int]] = None, + tp_min_local_shape: int = 1, + layer_type: str = "unknown", ) -> torch.Tensor: # Match real op behavior: FP8 input requires explicit output dtype. if input.dtype == torch.float8_e4m3fn: @@ -258,6 +266,10 @@ def fp8_linear( bias: Optional[torch.Tensor] = None, input_scale: Optional[torch.Tensor] = None, weight_scale: Optional[torch.Tensor] = None, + tp_mode: str = "none", + output_sizes: Optional[List[int]] = None, + tp_min_local_shape: int = 1, + layer_type: str = "unknown", ) -> torch.Tensor: """FP8 linear op similar to torch.nn.linear. @@ -333,6 +345,10 @@ def fp8_linear_fake( bias: Optional[torch.Tensor] = None, input_scale: Optional[torch.Tensor] = None, weight_scale: Optional[torch.Tensor] = None, + tp_mode: str = "none", + output_sizes: Optional[List[int]] = None, + tp_min_local_shape: int = 1, + layer_type: str = "unknown", ) -> torch.Tensor: return torch.ops.aten.linear(input, weight_fp8.to(input.dtype), bias) @@ -419,6 +435,10 @@ def nvfp4_linear( input_scale: Optional[torch.Tensor] = None, weight_scale: Optional[torch.Tensor] = None, alpha: Optional[torch.Tensor] = None, + tp_mode: str = "none", + output_sizes: Optional[List[int]] = None, + tp_min_local_shape: int = 1, + layer_type: str = "unknown", ) -> torch.Tensor: """FP4 linear op similar to torch.nn.linear. @@ -487,6 +507,10 @@ def fp4_linear_fake( input_scale: Optional[torch.Tensor] = None, weight_scale: Optional[torch.Tensor] = None, alpha: Optional[torch.Tensor] = None, + tp_mode: str = "none", + output_sizes: Optional[List[int]] = None, + tp_min_local_shape: int = 1, + layer_type: str = "unknown", ) -> torch.Tensor: return torch.ops.aten.linear(input, weight_fp4.repeat(1, 2).to(input.dtype), bias) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/quantization/torch_quant.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/quantization/torch_quant.py index 6e129a920619..5046d87065d3 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/quantization/torch_quant.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/quantization/torch_quant.py @@ -184,6 +184,10 @@ def torch_fake_quant_fp8_linear( weight_scale: List[torch.Tensor], input_zp: List[torch.Tensor], weight_zp: List[torch.Tensor], + tp_mode: str = "none", + output_sizes: Optional[List[int]] = None, + tp_min_local_shape: int = 1, + layer_type: str = "unknown", ) -> torch.Tensor: """ Reference (eager) implementation for multiple quant formats via `format_type`. @@ -219,6 +223,10 @@ def torch_fake_quant_fp8_linear( weight_scale: List[torch.Tensor], input_zp: List[torch.Tensor], weight_zp: List[torch.Tensor], + tp_mode: str = "none", + output_sizes: Optional[List[int]] = None, + tp_min_local_shape: int = 1, + layer_type: str = "unknown", ) -> torch.Tensor: w = weight_quantized.to(input.dtype) return torch.ops.aten.linear(input, w, bias) @@ -233,6 +241,10 @@ def torch_fake_quant_nvfp4_linear( weight_scale: List[torch.Tensor], input_zp: List[torch.Tensor], weight_zp: List[torch.Tensor], + tp_mode: str = "none", + output_sizes: Optional[List[int]] = None, + tp_min_local_shape: int = 1, + layer_type: str = "unknown", ) -> torch.Tensor: """ Reference (eager) implementation for multiple quant formats via `format_type`. @@ -295,6 +307,10 @@ def torch_fake_quant_nvfp4_linear( weight_scale: List[torch.Tensor], input_zp: List[torch.Tensor], weight_zp: List[torch.Tensor], + tp_mode: str = "none", + output_sizes: Optional[List[int]] = None, + tp_min_local_shape: int = 1, + layer_type: str = "unknown", ) -> torch.Tensor: return torch.ops.aten.linear(input, weight_quantized.repeat(1, 2).to(input.dtype), bias) @@ -308,6 +324,10 @@ def torch_fake_quant_int4_linear( weight_scale: List[torch.Tensor], # [ weight_scale ] input_zp: List[torch.Tensor], weight_zp: List[torch.Tensor], + tp_mode: str = "none", + output_sizes: Optional[List[int]] = None, + tp_min_local_shape: int = 1, + layer_type: str = "unknown", ) -> torch.Tensor: BLOCK_SIZE = 128 # activation pre-scale @@ -323,7 +343,15 @@ def torch_fake_quant_int4_linear( # Dequantize w_deq = (q_int4.to(torch.float32) / scale_full).to(input.dtype) - return torch.ops.auto_deploy.torch_linear_simple.default(x_scaled, w_deq, bias) + return torch.ops.auto_deploy.torch_linear_simple.default( + x_scaled, + w_deq, + bias, + tp_mode=tp_mode, + output_sizes=output_sizes, + tp_min_local_shape=tp_min_local_shape, + layer_type=layer_type, + ) @torch_fake_quant_int4_linear.register_fake @@ -335,6 +363,10 @@ def _fake( weight_scale: List[torch.Tensor], input_zp: List[torch.Tensor], weight_zp: List[torch.Tensor], + tp_mode: str = "none", + output_sizes: Optional[List[int]] = None, + tp_min_local_shape: int = 1, + layer_type: str = "unknown", ) -> torch.Tensor: N_half = weight_quantized.shape[-2] N = N_half * 2 @@ -350,6 +382,10 @@ def torch_fake_quant_int4_gptq_linear( weight_scale: List[torch.Tensor], # GPTQ scales [G, N] input_zp: List[torch.Tensor], # unused for GPTQ weight_zp: List[torch.Tensor], # GPTQ qzeros [G, N/8] int32 + tp_mode: str = "none", + output_sizes: Optional[List[int]] = None, + tp_min_local_shape: int = 1, + layer_type: str = "unknown", ) -> torch.Tensor: """ GPTQ INT4 linear with compatible signature to other quant ops. @@ -431,6 +467,10 @@ def torch_fake_quant_int4_gptq_linear_fake( weight_scale: List[torch.Tensor], input_zp: List[torch.Tensor], weight_zp: List[torch.Tensor], + tp_mode: str = "none", + output_sizes: Optional[List[int]] = None, + tp_min_local_shape: int = 1, + layer_type: str = "unknown", ) -> torch.Tensor: N = weight_quantized.size(1) return torch.empty((*input.shape[:-1], N), dtype=input.dtype, device=input.device) @@ -492,6 +532,10 @@ def torch_fake_quant_finegrained_fp8_linear( weight_scale: List[torch.Tensor], # [weight_scale_inv] input_zp: List[torch.Tensor], # unused weight_zp: List[torch.Tensor], # unused + tp_mode: str = "none", + output_sizes: Optional[List[int]] = None, + tp_min_local_shape: int = 1, + layer_type: str = "unknown", ) -> torch.Tensor: """FineGrainedFP8 linear operation. - weight_scale[0] = weight_scale_inv (per-block weight scale) @@ -503,11 +547,11 @@ def torch_fake_quant_finegrained_fp8_linear( weight_scale_inv = weight_scale[0] # Infer block_size from weight and weight_scale_inv shapes - # weight shape: [N, K], weight_scale_inv shape: [N/block_n, K/block_k] + # weight shape: [N, K], weight_scale_inv shape: [ceil(N/block_n), ceil(K/block_k)] N, K = weight_quantized.shape scale_n, scale_k = weight_scale_inv.shape - block_n = N // scale_n - block_k = K // scale_k + block_n = triton.cdiv(N, scale_n) + block_k = triton.cdiv(K, scale_k) block_size = [block_n, block_k] qinput, scale = _safe_act_quant(input, block_size[1]) @@ -535,6 +579,10 @@ def _torch_fake_quant_finegrained_fp8_linear_fake( weight_scale: List[torch.Tensor], input_zp: List[torch.Tensor], weight_zp: List[torch.Tensor], + tp_mode: str = "none", + output_sizes: Optional[List[int]] = None, + tp_min_local_shape: int = 1, + layer_type: str = "unknown", ) -> torch.Tensor: """Fake implementation for torch.export tracing.""" out_features = weight_quantized.shape[0] @@ -547,6 +595,10 @@ def trtllm_finegrained_fp8_linear( weight: torch.Tensor, # [N, K] float8_e4m3fn bias: Optional[torch.Tensor], # [N] or None weight_scale: torch.Tensor, # [N/128, K/128] per-block weight scale + tp_mode: str = "none", + output_sizes: Optional[List[int]] = None, + tp_min_local_shape: int = 1, + layer_type: str = "unknown", ) -> torch.Tensor: """TRT-LLM optimized FineGrainedFP8 linear operation. @@ -576,8 +628,10 @@ def trtllm_finegrained_fp8_linear( f"(shape={weight_scale.shape}), weight shape={weight.shape}. " f"This usually means scale tensor sharding produced an empty tensor." ) - block_n = N // scale_n - block_k = K // scale_k + # Ceiling division is required because the weight dimension may not be + # evenly divisible by the number of scale blocks (e.g. after TP sharding). + block_n = triton.cdiv(N, scale_n) + block_k = triton.cdiv(K, scale_k) # TRT-LLM fp8_block_scaling_gemm requires exact 128x128 blocks. # For small layers where a dimension < 128 (e.g. N=64), the derived block @@ -617,6 +671,10 @@ def _trtllm_finegrained_fp8_linear_fake( weight: torch.Tensor, bias: Optional[torch.Tensor], weight_scale: torch.Tensor, + tp_mode: str = "none", + output_sizes: Optional[List[int]] = None, + tp_min_local_shape: int = 1, + layer_type: str = "unknown", ) -> torch.Tensor: """Fake implementation for torch.export tracing.""" out_features = weight.shape[0] diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/sharding_ops.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/sharding_ops.py new file mode 100644 index 000000000000..5f5c71102687 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/sharding_ops.py @@ -0,0 +1,152 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sharding-aware custom ops for the new hint-driven sharding architecture. + +These ops encode sharding intent as metadata kwargs. At graph level they behave +identically to their torch.ops.aten counterparts. The ``apply_sharding_hints`` +transform reads the hint kwargs together with a runtime ``Mapping`` to apply +deterministic, node-local sharding transformations. +""" + +from typing import List + +import torch + + +@torch.library.custom_op("auto_deploy::view", mutates_args=()) +def view( + x: torch.Tensor, + shape: List[int], + tp_scaled_dim: int = -1, + layer_type: str = "unknown", +) -> torch.Tensor: + """Sharding-aware view/reshape. + + At runtime this behaves like ``x.reshape(shape).clone()``. When tensor-parallel + sharding is enabled, ``apply_sharding_hints`` may scale one dimension of + ``shape`` by ``1 / tp_size`` so the reshaped tensor matches per-rank shapes. + + Args: + x: Input tensor to reshape. + shape: Target shape, same semantics as :meth:`torch.Tensor.reshape`. + tp_scaled_dim: Index into ``shape`` for TP scaling. When non-negative + (``0``, ``1``, ``2``, ...), ``apply_sharding_hints`` divides + ``shape[tp_scaled_dim]`` by ``tp_size``. ``-1`` means this dimension is + not scaled for TP. + layer_type: Layer classification for selective sharding via ``shard_layers`` + config. Values: ``"mha"``, ``"mla"``, ``"mlp"``, ``"moe"``, ``"ssm"``, + ``"delta"``, ``"unknown"``. + + Sharding hint arguments (graph-level metadata for ``apply_sharding_hints``): + ``tp_scaled_dim`` and ``layer_type`` are hints only; they do not change the + unsharded reshape result. + + Returns: + Reshaped tensor, same values as ``x.reshape(shape)`` up to clone semantics. + """ + return x.reshape(shape).clone() + + +@view.register_fake +def _view_fake( + x: torch.Tensor, + shape: List[int], + tp_scaled_dim: int = -1, + layer_type: str = "unknown", +) -> torch.Tensor: + return x.reshape(shape).clone() + + +@torch.library.custom_op("auto_deploy::split_with_sizes", mutates_args=()) +def split_with_sizes( + x: torch.Tensor, + split_sizes: List[int], + dim: int = -1, + enable_sharding: bool = False, + layer_type: str = "unknown", +) -> List[torch.Tensor]: + """Sharding-aware :func:`torch.split` with explicit chunk sizes. + + At runtime this behaves like ``torch.split(x, split_sizes, dim=dim)``, with each + chunk cloned. When ``enable_sharding`` is ``True`` and TP sharding is applied, + ``apply_sharding_hints`` scales ``split_sizes`` so each rank splits its local + activation width consistently. + + Args: + x: Tensor to split along ``dim``. + split_sizes: Size of each chunk along ``dim`` (same as PyTorch + ``split_with_sizes``). + dim: Dimension along which to split. May be negative (same semantics as + PyTorch). + enable_sharding: When ``True``, ``apply_sharding_hints`` divides every element of + ``split_sizes`` by ``tp_size`` so splits match per-rank tensor shapes. + layer_type: Layer classification for selective sharding via ``shard_layers`` + config. Values: ``"mha"``, ``"mla"``, ``"mlp"``, ``"moe"``, ``"ssm"``, + ``"delta"``, ``"unknown"``. + + Sharding hint arguments (graph-level metadata for ``apply_sharding_hints``): + ``enable_sharding``: When ``True``, ``apply_sharding_hints`` scales ``split_sizes`` + for TP (each chunk size is divided by ``tp_size`` when applicable). + ``layer_type``: Selects whether this node is rewritten for a given + ``shard_layers`` configuration. + + Returns: + List of tensors, one per chunk, same as :func:`torch.split`. + """ + + return [t.clone() for t in torch.split(x, split_sizes, dim=dim)] + + +@split_with_sizes.register_fake +def _split_with_sizes_fake( + x: torch.Tensor, + split_sizes: List[int], + dim: int = -1, + enable_sharding: bool = False, + layer_type: str = "unknown", +) -> List[torch.Tensor]: + return [t.clone() for t in torch.split(x, split_sizes, dim=dim)] + + +@torch.library.custom_op("auto_deploy::all_reduce", mutates_args=()) +def all_reduce(x: torch.Tensor, layer_type: str = "unknown") -> torch.Tensor: + """Sharding-aware all-reduce placeholder. + + At runtime this returns ``x.clone()``. After ``apply_sharding_hints``, the node + may become a real ``dist.all_reduce`` when ``tp_size > 1`` and attention + data-parallel replication is disabled; otherwise it remains an identity on the + local tensor. + + Args: + x: Activation tensor to combine across TP ranks when an all-reduce is + inserted (e.g., partial attention outputs that must be summed). + layer_type: Layer classification for selective sharding via ``shard_layers`` + config. Values: ``"mha"``, ``"mla"``, ``"mlp"``, ``"moe"``, ``"ssm"``, + ``"delta"``, ``"unknown"``. + + Sharding hint arguments (graph-level metadata for ``apply_sharding_hints``): + ``layer_type`` gates whether this placeholder is eligible for replacement by + a collective in a given configuration. + + Returns: + Tensor with the same shape and dtype as ``x`` (clone of input when unsharded). + """ + return x.clone() + + +@all_reduce.register_fake +def _all_reduce_fake(x: torch.Tensor, layer_type: str = "unknown") -> torch.Tensor: + return x.clone() diff --git a/tensorrt_llm/_torch/auto_deploy/llm_args.py b/tensorrt_llm/_torch/auto_deploy/llm_args.py index b56079a23db1..4f3a3e568ca5 100644 --- a/tensorrt_llm/_torch/auto_deploy/llm_args.py +++ b/tensorrt_llm/_torch/auto_deploy/llm_args.py @@ -17,6 +17,7 @@ ) from .models import ModelFactory, ModelFactoryRegistry from .utils._config import DynamicYamlMixInForSettings +from .utils.dist_config import DistConfig from .utils.logger import ad_logger PathLike = Union[str, Path] @@ -382,28 +383,38 @@ def create_factory(self) -> ModelFactory: def is_cuda_graph_enabled(self) -> bool: return self.compile_backend in ["torch-cudagraph", "torch-opt"] - def init_mapping_from_config(self, rank: int, world_size: int) -> Mapping: - sharding_config = self.transforms.get("detect_sharding", {}) + def init_dist_config(self, rank: int, world_size: int) -> DistConfig: + """Build DistConfig from YAML transform config and runtime MPI info. + + Reads ``dist_mapping`` from ``apply_sharding_hints`` (preferred) or + ``detect_sharding`` (fallback). Runtime ``rank`` and ``world_size`` + come from MPI, not from YAML. + + Note: AutoDeploy blocks direct parallelism fields (tensor_parallel_size, + etc.) via ``ensure_no_custom_parallel_config``. Users configure MoE + topology exclusively through YAML ``dist_mapping`` blocks. If that + restriction is lifted in the future, a Tier-1 path deriving DistConfig + from ``self.parallel_config.to_mapping()`` should be added here. + """ + ash = self.transforms.get("apply_sharding_hints", {}) + sharding_config = ( + ash if ash.get("enabled", False) else self.transforms.get("detect_sharding", {}) + ) dist_mapping_config = sharding_config.get("dist_mapping", {}) enable_attention_dp = sharding_config.get("enable_attention_dp", False) - # Determine MoE parallelism dimensions if enable_attention_dp: - # EP + TP 2D parallelism is currently NOT supported with attention-DP. - # EP-only: experts sharded across GPUs, use all-to-all dispatch/combine moe_ep_size = self.world_size moe_tp_size = 1 ad_logger.info( f"Attention-DP with EP-only MoE: moe_ep_size={moe_ep_size}, moe_tp_size={moe_tp_size}" ) else: - # No attention-DP: use dist_mapping config or defaults moe_tp_size = dist_mapping_config.get("moe_tp", 1) moe_ep_size = dist_mapping_config.get("moe_ep", self.world_size) - # Create Mapping with proper distributed configuration try: - mapping = Mapping( + dc = DistConfig( world_size=world_size, rank=rank, tp_size=dist_mapping_config.get("tp", self.world_size), @@ -418,7 +429,11 @@ def init_mapping_from_config(self, rank: int, world_size: int) -> Mapping: f"Please check your dist_mapping configuration: {dist_mapping_config}" ) from e - return mapping + return dc + + def init_mapping_from_config(self, rank: int, world_size: int) -> Mapping: + """Build a Mapping for external APIs that still require it.""" + return self.init_dist_config(rank, world_size).to_mapping() ### PRIVATE METHODS ############################################################################ @classmethod diff --git a/tensorrt_llm/_torch/auto_deploy/models/custom/__init__.py b/tensorrt_llm/_torch/auto_deploy/models/custom/__init__.py index 35a485ba2d99..148c07532c7b 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/custom/__init__.py +++ b/tensorrt_llm/_torch/auto_deploy/models/custom/__init__.py @@ -1,3 +1,5 @@ +import os + from .modeling_deepseek import DeepSeekV3ForCausalLM from .modeling_gemma3n import Gemma3nForCausalLM, Gemma3nForConditionalGeneration from .modeling_gemma4 import Gemma4ForCausalLM, Gemma4ForConditionalGeneration @@ -8,6 +10,14 @@ from .modeling_nemotron_h import NemotronHForCausalLM from .modeling_qwen3_5_moe import Qwen3_5MoeForCausalLM, Qwen3_5MoeForConditionalGeneration +if os.environ.get("AD_USE_IR_MODELS"): + from .modeling_deepseek_ir import DeepSeekV3ForCausalLM # noqa: F811 + from .modeling_nemotron_h_ir import NemotronHForCausalLM # noqa: F811 + from .modeling_qwen3_5_moe_ir import ( # noqa: F811 + Qwen3_5MoeForCausalLM, + Qwen3_5MoeForConditionalGeneration, + ) + __all__ = ( "DeepSeekV3ForCausalLM", "Gemma3nForCausalLM", diff --git a/tensorrt_llm/_torch/auto_deploy/models/custom/mla_rope_utils.py b/tensorrt_llm/_torch/auto_deploy/models/custom/mla_rope_utils.py index 226eecab2da1..298ca399c116 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/custom/mla_rope_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/models/custom/mla_rope_utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. """Shared MLA RoPE utilities for auto_deploy custom models. @@ -63,21 +63,27 @@ def _rope_deinterleave_load_hook( q_key = layer_prefix + "q_b_proj.weight" if q_key in state_dict: w = state_dict[q_key] + orig_dtype = w.dtype + if not w.is_floating_point() or w.dtype == torch.float8_e4m3fn: + w = w.to(torch.bfloat16) w = w.view(num_heads, qk_head_dim, -1) w_nope = w[:, :qk_nope_head_dim, :] w_rope = w[:, qk_nope_head_dim:, :] w_rope = _index_select_with_float8_cpu_workaround(w_rope, 1, perm) w = torch.cat([w_nope, w_rope], dim=1) - state_dict[q_key] = w.view(-1, w.shape[-1]) + state_dict[q_key] = w.view(-1, w.shape[-1]).to(orig_dtype) # --- kv_a_proj_with_mqa.weight --- kv_key = layer_prefix + "kv_a_proj_with_mqa.weight" if kv_key in state_dict: w = state_dict[kv_key] + orig_dtype = w.dtype + if not w.is_floating_point() or w.dtype == torch.float8_e4m3fn: + w = w.to(torch.bfloat16) w_kv = w[:kv_lora_rank, :] w_pe = w[kv_lora_rank:, :] w_pe = _index_select_with_float8_cpu_workaround(w_pe, 0, perm) - state_dict[kv_key] = torch.cat([w_kv, w_pe], dim=0) + state_dict[kv_key] = torch.cat([w_kv, w_pe], dim=0).to(orig_dtype) # --- kv_a_proj_with_mqa.bias (if present) --- kv_bias_key = layer_prefix + "kv_a_proj_with_mqa.bias" diff --git a/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_deepseek_ir.py b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_deepseek_ir.py new file mode 100644 index 000000000000..009813d3d050 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_deepseek_ir.py @@ -0,0 +1,768 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DeepSeekV3 model with explicit sharding hint ops. + +WARNING: tested only on up to 8 layers (8xH100 machine can't fit the full model, +the output is not fully verified). The sharding pipeline processes all nodes +correctly (41 nodes at 4 layers, 59+ at 8 layers) but coherent end-to-end output +has not been validated due to memory constraints. Additionally, the DeepSeek-V3/R1 +FP8 checkpoints use dsv3_router_gemm_op which crashes on H100 (pre-existing bug). + +Based on the original modeling_deepseek.py. All enable_sharding operations use +AutoDeploy custom ops with sharding hint kwargs. The graph produced by this +model is a complete, self-contained specification of "how this model should be +sharded." The ``apply_sharding_hints`` transform reads the hints together with +a runtime ``DistConfig`` to apply deterministic, node-local sharding. + +Shardable custom ops used: + - torch.ops.auto_deploy.torch_linear_simple (tp_mode) + - torch.ops.auto_deploy.view (tp_scaled_dim) + - torch.ops.auto_deploy.all_reduce (identity / dist.all_reduce) + - torch.ops.auto_deploy.torch_mla (enable_sharding) + - torch.ops.auto_deploy.torch_moe (sharded by apply_sharding_hints) +""" + +import math +from dataclasses import dataclass +from functools import partial +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn +from transformers.activations import ACT2FN +from transformers.generation import GenerationMixin +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ModelOutput + +import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401 -- register all ops +from tensorrt_llm._torch.auto_deploy.models.custom import mla_rope_utils +from tensorrt_llm._torch.auto_deploy.models.hf import AutoModelForCausalLMFactory +from tensorrt_llm._torch.utils import ActivationType + + +class DeepSeekV3RMSNorm(nn.Module): + """RMS Normalization for DeepSeekV3.""" + + def __init__(self, hidden_size: int, eps: float = 1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return torch.ops.auto_deploy.triton_rms_norm( + hidden_states, self.weight, self.variance_epsilon + ).to(hidden_states.dtype) + + +class DeepSeekV3RotaryEmbedding(nn.Module): + """Rotary Position Embedding for DeepSeekV3. + + Simplified version that precomputes and caches cos/sin values. + Returns full cached values (not sliced by seq_len) to enable export. + """ + + def __init__(self, dim: int, max_position_embeddings: int = 2048, base: float = 10000.0): + super().__init__() + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build cos/sin cache + self._set_cos_sin_cache(max_position_embeddings) + + def _set_cos_sin_cache(self, seq_len: int): + self.max_seq_len_cached = seq_len + t = torch.arange(seq_len, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos(), persistent=False) + self.register_buffer("sin_cached", emb.sin(), persistent=False) + + def forward( + self, x: torch.Tensor, seq_len: Optional[int] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + return ( + self.cos_cached.to(dtype=x.dtype, device=x.device), + self.sin_cached.to(dtype=x.dtype, device=x.device), + ) + + +class DeepSeekV3YarnRotaryEmbedding(DeepSeekV3RotaryEmbedding): + """YaRN-extended rotary embedding for DeepSeekV3.""" + + def __init__( + self, + dim: int, + max_position_embeddings: int = 2048, + base: float = 10000.0, + scaling_factor: float = 1.0, + original_max_position_embeddings: int = 4096, + beta_fast: int = 32, + beta_slow: int = 1, + mscale: float = 1.0, + mscale_all_dim: float = 0.0, + ): + self.scaling_factor = scaling_factor + self.original_max_position_embeddings = original_max_position_embeddings + self.beta_fast = beta_fast + self.beta_slow = beta_slow + self.mscale = mscale + self.mscale_all_dim = mscale_all_dim + super().__init__(dim, max_position_embeddings, base) + + def _set_cos_sin_cache(self, seq_len: int): + self.max_seq_len_cached = seq_len + dim = self.dim + + freq_extra = 1.0 / (self.base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + freq_inter = 1.0 / ( + self.scaling_factor * self.base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim) + ) + + low, high = self._yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + dim, + self.base, + self.original_max_position_embeddings, + ) + inv_freq_mask = 1.0 - self._yarn_linear_ramp_mask(low, high, dim // 2) + inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(seq_len, dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + + _mscale = float( + self._yarn_get_mscale(self.scaling_factor, self.mscale) + / self._yarn_get_mscale(self.scaling_factor, self.mscale_all_dim) + ) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", (emb.cos() * _mscale), persistent=False) + self.register_buffer("sin_cached", (emb.sin() * _mscale), persistent=False) + + @staticmethod + def _yarn_find_correction_dim( + num_rotations: float, dim: int, base: float = 10000, max_position_embeddings: int = 2048 + ) -> float: + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(base) + ) + + def _yarn_find_correction_range( + self, low_rot: int, high_rot: int, dim: int, base: float, max_position_embeddings: int + ) -> Tuple[int, int]: + low = math.floor( + self._yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) + ) + high = math.ceil( + self._yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) + ) + return max(low, 0), min(high, dim - 1) + + @staticmethod + def _yarn_get_mscale(scale: float = 1.0, mscale: float = 1.0) -> float: + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + @staticmethod + def _yarn_linear_ramp_mask(min_val: float, max_val: float, dim: int) -> torch.Tensor: + if min_val == max_val: + max_val += 0.001 + linear_func = (torch.arange(dim, dtype=torch.float32) - min_val) / (max_val - min_val) + return torch.clamp(linear_func, 0, 1) + + +class DeepSeekV3MLP(nn.Module): + """MLP layer for DeepSeekV3 (SwiGLU activation) with sharding hints. + + When used as a shared expert inside MoE, ``add_all_reduce`` is set to False + and the all_reduce is deferred to the MoE merge point. + """ + + def __init__( + self, + config, + hidden_size: Optional[int] = None, + intermediate_size: Optional[int] = None, + add_all_reduce: bool = True, + layer_type: str = "mlp", + ): + super().__init__() + self.config = config + self.hidden_size = hidden_size or config.hidden_size + self.intermediate_size = intermediate_size or config.intermediate_size + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + self.add_all_reduce = add_all_reduce + self.layer_type = layer_type + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate = torch.ops.auto_deploy.torch_linear_simple( + x, + self.gate_proj.weight, + self.gate_proj.bias, + tp_mode="colwise", + layer_type=self.layer_type, + ) + up = torch.ops.auto_deploy.torch_linear_simple( + x, + self.up_proj.weight, + self.up_proj.bias, + tp_mode="colwise", + layer_type=self.layer_type, + ) + down = torch.ops.auto_deploy.torch_linear_simple( + self.act_fn(gate) * up, + self.down_proj.weight, + self.down_proj.bias, + tp_mode="rowwise", + layer_type=self.layer_type, + ) + if self.add_all_reduce: + down = torch.ops.auto_deploy.all_reduce(down, layer_type=self.layer_type) + return down + + +class DeepSeekV3MoEGate(nn.Module): + """MoE Gating for DeepSeekV3 with noaux_tc top-k selection.""" + + def __init__(self, config): + super().__init__() + self.config = config + self.top_k = config.num_experts_per_tok + self.n_routed_experts = config.n_routed_experts + self.routed_scaling_factor = config.routed_scaling_factor + self.n_group = config.n_group + self.topk_group = config.topk_group + + self.weight = nn.Parameter( + torch.empty((self.n_routed_experts, config.hidden_size), dtype=torch.float32) + ) + self.register_buffer( + "e_score_correction_bias", + torch.zeros(self.n_routed_experts, dtype=torch.float32), + ) + self.reset_parameters() + + def reset_parameters(self) -> None: + """Initialize gate weights using kaiming uniform (matches original DeepSeek implementation).""" + torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + + def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward pass returning (selected_experts, routing_weights).""" + bsz, seq_len, hidden_dim = hidden_states.shape + hidden_states_flat = hidden_states.view(-1, hidden_dim) + + if self.weight.dtype == torch.float32: + router_logits = F.linear(hidden_states_flat.float(), self.weight) + else: + router_logits = torch.ops.trtllm.dsv3_router_gemm_op( + hidden_states_flat, self.weight.t(), bias=None, out_dtype=torch.float32 + ) + + topk_weights, topk_indices = torch.ops.trtllm.noaux_tc_op( + router_logits, + self.e_score_correction_bias, + self.n_group, + self.topk_group, + self.top_k, + self.routed_scaling_factor, + ) + + return topk_indices, topk_weights + + +class DeepSeekV3MoE(nn.Module): + """Mixture of Experts layer for DeepSeekV3 with sharding hints. + + Routed experts are handled by torch_moe (sharded by apply_sharding_hints). + Shared expert uses TP-sharded MLP with deferred all_reduce. + Single all_reduce at the merge point (routed + shared). + """ + + def __init__(self, config): + super().__init__() + self.config = config + self.num_experts_per_tok = config.num_experts_per_tok + + self.experts = nn.ModuleList( + [ + DeepSeekV3MLP(config, intermediate_size=config.moe_intermediate_size) + for _ in range(config.n_routed_experts) + ] + ) + + self.gate = DeepSeekV3MoEGate(config) + + if config.n_shared_experts is not None: + intermediate_size = config.moe_intermediate_size * config.n_shared_experts + self.shared_experts = DeepSeekV3MLP( + config, + intermediate_size=intermediate_size, + add_all_reduce=False, + layer_type="moe", + ) + else: + self.shared_experts = None + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + identity = hidden_states + orig_shape = hidden_states.shape + + selected_experts, routing_weights = self.gate(hidden_states) + + final_hidden_states = torch.ops.auto_deploy.torch_moe( + hidden_states.view(-1, hidden_states.shape[-1]), + selected_experts, + routing_weights, + w1_weight=[expert.gate_proj.weight for expert in self.experts], + w2_weight=[expert.down_proj.weight for expert in self.experts], + w3_weight=[expert.up_proj.weight for expert in self.experts], + is_gated_mlp=True, + act_fn=int(ActivationType.Silu), + layer_type="moe", + ) + + final_hidden_states = final_hidden_states.view(*orig_shape) + + if self.shared_experts is not None: + final_hidden_states = final_hidden_states + self.shared_experts(identity) + + final_hidden_states = torch.ops.auto_deploy.all_reduce( + final_hidden_states, layer_type="moe" + ) + + return final_hidden_states.to(hidden_states.dtype) + + +class DeepSeekV3Attention(nn.Module): + """Multi-head Latent Attention (MLA) for DeepSeekV3 with sharding hints. + + MLA sharding strategy (from porting instructions): + q_a_proj -> tp_mode="none" (replicated latent projection) + q_a_layernorm -> unchanged + q_b_proj -> tp_mode="colwise" (shard by num_heads) + kv_a_proj -> tp_mode="none" (replicated latent projection) + kv_a_layernorm -> unchanged + torch_mla -> enable_sharding=True (kv_b_proj_weight sharded by _apply_hint_mla) + view -> tp_scaled_dim=2 for num_heads (Q reshape only) + o_proj -> tp_mode="rowwise" + all_reduce + """ + + def __init__(self, config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.q_lora_rank = config.q_lora_rank + self.kv_lora_rank = config.kv_lora_rank + self.qk_nope_head_dim = config.qk_nope_head_dim + self.qk_rope_head_dim = config.qk_rope_head_dim + self.v_head_dim = config.v_head_dim + self.q_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim + + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + + if self.q_lora_rank is None: + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.q_head_dim, bias=False) + else: + self.q_a_proj = nn.Linear( + self.hidden_size, config.q_lora_rank, bias=config.attention_bias + ) + self.q_a_layernorm = DeepSeekV3RMSNorm(config.q_lora_rank) + self.q_b_proj = nn.Linear( + config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False + ) + + self.kv_a_proj_with_mqa = nn.Linear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=config.attention_bias, + ) + self.kv_a_layernorm = DeepSeekV3RMSNorm(self.kv_lora_rank) + self.kv_b_proj = nn.Linear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + ) + + self.o_proj = nn.Linear( + self.num_heads * self.v_head_dim, self.hidden_size, bias=config.attention_bias + ) + + self._init_rope() + + self.softmax_scale = self.q_head_dim ** (-0.5) + if config.rope_scaling is not None: + mscale_all_dim = config.rope_scaling.get("mscale_all_dim", 0) + scaling_factor = config.rope_scaling["factor"] + if mscale_all_dim: + mscale = DeepSeekV3YarnRotaryEmbedding._yarn_get_mscale( + scaling_factor, mscale_all_dim + ) + self.softmax_scale = self.softmax_scale * mscale * mscale + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = DeepSeekV3RotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + + if scaling_type == "yarn": + kwargs = { + key: self.config.rope_scaling[key] + for key in [ + "original_max_position_embeddings", + "beta_fast", + "beta_slow", + "mscale", + "mscale_all_dim", + ] + if key in self.config.rope_scaling + } + self.rotary_emb = DeepSeekV3YarnRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + **kwargs, + ) + else: + self.rotary_emb = DeepSeekV3RotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + ) -> torch.Tensor: + bsz, q_len, _ = hidden_states.size() + + # Q projection: latent projections are replicated, q_b_proj is colwise + if self.q_lora_rank is None: + q = torch.ops.auto_deploy.torch_linear_simple( + hidden_states, + self.q_proj.weight, + self.q_proj.bias, + tp_mode="colwise", + layer_type="mla", + ) + else: + q = torch.ops.auto_deploy.torch_linear_simple( + hidden_states, + self.q_a_proj.weight, + self.q_a_proj.bias, + tp_mode="none", + layer_type="mla", + ) + q = self.q_a_layernorm(q) + q = torch.ops.auto_deploy.torch_linear_simple( + q, + self.q_b_proj.weight, + self.q_b_proj.bias, + tp_mode="colwise", + layer_type="mla", + ) + + # Shape: [B, S, N, q_head_dim] -- num_heads at dim 2 scales with TP + q = torch.ops.auto_deploy.view( + q, + [bsz, q_len, self.num_heads, self.q_head_dim], + tp_scaled_dim=2, + layer_type="mla", + ) + # Split on last dim (head_dim) -- does NOT scale with TP, plain torch.split + q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + # KV projection -- replicated (latent compression, not per-head) + kv_a_output = torch.ops.auto_deploy.torch_linear_simple( + hidden_states, + self.kv_a_proj_with_mqa.weight, + self.kv_a_proj_with_mqa.bias, + tp_mode="none", + layer_type="mla", + ) + compressed_kv, k_pe = torch.split( + kv_a_output, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + + compressed_kv = self.kv_a_layernorm(compressed_kv) + + # k_pe: [B, S, 1, qk_rope_head_dim] -- shared across heads, dim 2 is always 1 + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim) + + kv_seq_len = q_len + + cos, sin = self.rotary_emb(hidden_states, seq_len=kv_seq_len) + cos = cos[position_ids] + sin = sin[position_ids] + + q_pe_rotated, kpe = torch.ops.auto_deploy.torch_rope_with_explicit_cos_sin( + q_pe, + k_pe, + cos, + sin, + 2, + ) + + # MLA: enable_sharding=True lets _apply_hint_mla shard kv_b_proj_weight colwise + attn_output = torch.ops.auto_deploy.torch_mla( + q_nope, + q_pe_rotated, + compressed_kv, + kpe, + self.kv_b_proj.weight, + True, + self.softmax_scale, + "bsnd", + enable_sharding=True, + layer_type="mla", + ) + + # Output: [B, S, N, v_head_dim] -> [B, S, N * v_head_dim] + attn_output = torch.ops.auto_deploy.view( + attn_output, + [bsz, q_len, self.num_heads * self.v_head_dim], + tp_scaled_dim=2, + layer_type="mla", + ) + attn_output = torch.ops.auto_deploy.torch_linear_simple( + attn_output, + self.o_proj.weight, + self.o_proj.bias, + tp_mode="rowwise", + layer_type="mla", + ) + attn_output = torch.ops.auto_deploy.all_reduce(attn_output, layer_type="mla") + + return attn_output + + +class DeepSeekV3DecoderLayer(nn.Module): + """Transformer decoder layer for DeepSeekV3.""" + + def __init__(self, config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + + self.self_attn = DeepSeekV3Attention(config, layer_idx=layer_idx) + + use_moe = ( + config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0 + ) + if use_moe: + self.mlp = DeepSeekV3MoE(config) + else: + self.mlp = DeepSeekV3MLP(config) + + self.input_layernorm = DeepSeekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = DeepSeekV3RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn(hidden_states, position_ids) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +@dataclass +class DeepSeekV3Output(ModelOutput): + """Output for DeepSeekV3Model.""" + + last_hidden_state: Optional[torch.FloatTensor] = None + + +@dataclass +class DeepSeekV3CausalLMOutput(ModelOutput): + """Output for DeepSeekV3ForCausalLM.""" + + logits: Optional[torch.FloatTensor] = None + + +class DeepSeekV3PreTrainedModel(PreTrainedModel): + """Base class for DeepSeekV3 models.""" + + base_model_prefix = "model" + _no_split_modules = ["DeepSeekV3DecoderLayer"] + supports_gradient_checkpointing = False + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class DeepSeekV3Model(DeepSeekV3PreTrainedModel): + """DeepSeekV3 transformer decoder model.""" + + def __init__(self, config): + super().__init__(config) + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [ + DeepSeekV3DecoderLayer(config, layer_idx=idx) + for idx in range(config.num_hidden_layers) + ] + ) + self.norm = DeepSeekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + **kwargs, + ) -> DeepSeekV3Output: + if input_ids is not None and inputs_embeds is not None: + raise ValueError("Cannot specify both input_ids and inputs_embeds") + elif input_ids is None and inputs_embeds is None: + raise ValueError("Must specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_length = inputs_embeds.shape[:2] + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange(seq_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) + + hidden_states = inputs_embeds + + for decoder_layer in self.layers: + hidden_states = decoder_layer(hidden_states, position_ids) + + hidden_states = self.norm(hidden_states) + + return DeepSeekV3Output(last_hidden_state=hidden_states) + + +class DeepSeekV3ForCausalLM(DeepSeekV3PreTrainedModel, GenerationMixin): + """DeepSeekV3 model with language modeling head.""" + + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = DeepSeekV3Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self._register_load_state_dict_pre_hook( + partial( + mla_rope_utils._rope_deinterleave_load_hook, + qk_rope_head_dim=config.qk_rope_head_dim, + qk_nope_head_dim=config.qk_nope_head_dim, + num_heads=config.num_attention_heads, + kv_lora_rank=config.kv_lora_rank, + num_layers=config.num_hidden_layers, + ) + ) + + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_decoder(self): + return self.model + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + **kwargs, + ) -> DeepSeekV3CausalLMOutput: + outputs = self.model( + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + logits = self.lm_head(hidden_states).float() + + return DeepSeekV3CausalLMOutput(logits=logits) + + +AutoModelForCausalLMFactory.register_custom_model_cls("DeepseekV3Config", DeepSeekV3ForCausalLM) diff --git a/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_llama3_ir.py b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_llama3_ir.py new file mode 100644 index 000000000000..89ae18350f21 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_llama3_ir.py @@ -0,0 +1,446 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Llama 3 model with explicit sharding hint ops. + +This is a rewrite of modeling_llama3.py where all sharding-enabled operations use +AutoDeploy custom ops with sharding hint kwargs. The graph produced by this +model is a complete, self-contained specification of how this model should be +sharded. The ``apply_sharding_hints`` transform reads the hints together with a +runtime ``DistConfig`` to apply deterministic, node-local sharding. + +Shardable custom ops used: + - torch.ops.auto_deploy.torch_linear_simple (tp_mode, tp_min_local_shape, layer_type) + - torch.ops.auto_deploy.view (tp_scaled_dim, layer_type) + - torch.ops.auto_deploy.all_reduce (identity / dist.all_reduce, layer_type) +""" + +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +from torch import nn +from transformers.activations import ACT2FN +from transformers.generation import GenerationMixin +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.modeling_utils import PreTrainedModel +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.utils import ModelOutput + +import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401 -- register all ops +from tensorrt_llm._torch.auto_deploy.models.hf import AutoModelForCausalLMFactory + + +class Llama3RMSNorm(nn.Module): + """RMS Normalization for Llama using AutoDeploy torch_rmsnorm reference op.""" + + def __init__(self, hidden_size: int, eps: float = 1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return torch.ops.auto_deploy.torch_rmsnorm( + hidden_states, self.weight, self.variance_epsilon + ) + + +class Llama3RotaryEmbedding(nn.Module): + """Rotary Position Embedding for Llama 3 family. + + Supports all rope types (default, llama3, linear, dynamic, etc.) via + transformers ROPE_INIT_FUNCTIONS. Precomputes and caches cos/sin values. + Slices by position_ids once and returns pre-sliced cos/sin to all layers. + + Uses _ad_ prefix for buffer names to work with AutoDeploy's lift_to_meta. + """ + + def __init__(self, config: LlamaConfig): + super().__init__() + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + rope_type = config.rope_scaling.get( + "rope_type", config.rope_scaling.get("type", "default") + ) + else: + rope_type = "default" + + inv_freq, self.attention_scaling = ROPE_INIT_FUNCTIONS[rope_type](config, device=None) + + max_pos = config.max_position_embeddings + t = torch.arange(max_pos, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("_ad_cos_cached", emb.cos() * self.attention_scaling, persistent=False) + self.register_buffer("_ad_sin_cached", emb.sin() * self.attention_scaling, persistent=False) + + def forward( + self, x: torch.Tensor, position_ids: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + cos = self._ad_cos_cached.to(dtype=x.dtype, device=x.device) + sin = self._ad_sin_cached.to(dtype=x.dtype, device=x.device) + return cos[position_ids], sin[position_ids] + + +class Llama3MLP(nn.Module): + """MLP layer for Llama 3 (SwiGLU) with sharding hints. + + Sharding strategy: + gate_proj -> colwise + up_proj -> colwise + down_proj -> rowwise + all_reduce + """ + + def __init__(self, config: LlamaConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate = torch.ops.auto_deploy.torch_linear_simple( + x, + self.gate_proj.weight, + self.gate_proj.bias, + tp_mode="colwise", + layer_type="mlp", + ) + up = torch.ops.auto_deploy.torch_linear_simple( + x, + self.up_proj.weight, + self.up_proj.bias, + tp_mode="colwise", + layer_type="mlp", + ) + down = torch.ops.auto_deploy.torch_linear_simple( + self.act_fn(gate) * up, + self.down_proj.weight, + self.down_proj.bias, + tp_mode="rowwise", + layer_type="mlp", + ) + down = torch.ops.auto_deploy.all_reduce(down, layer_type="mlp") + return down + + +class Llama3Attention(nn.Module): + """Grouped Query Attention for Llama 3 with sharding hints. + + Uses AD canonical ops for attention and RoPE. GQA is handled natively + by torch_attention — no repeat_kv needed. + + Sharding strategy: + q_proj -> colwise (+ tp_min_local_shape for GQA) + k_proj -> colwise (+ tp_min_local_shape for GQA) + v_proj -> colwise (+ tp_min_local_shape for GQA) + view -> tp_scaled_dim=2 (head count dimension) + o_proj -> rowwise + all_reduce + """ + + def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads + self.head_dim = ( + getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + ) + self.scaling = self.head_dim ** (-0.5) + + self.q_proj = nn.Linear( + self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + self.hidden_size, + self.num_kv_heads * self.head_dim, + bias=config.attention_bias, + ) + self.v_proj = nn.Linear( + self.hidden_size, + self.num_kv_heads * self.head_dim, + bias=config.attention_bias, + ) + self.o_proj = nn.Linear( + self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + ) -> torch.Tensor: + bsz, q_len, _ = hidden_states.size() + + q = torch.ops.auto_deploy.torch_linear_simple( + hidden_states, + self.q_proj.weight, + self.q_proj.bias, + tp_mode="colwise", + tp_min_local_shape=self.head_dim, + layer_type="mha", + ) + k = torch.ops.auto_deploy.torch_linear_simple( + hidden_states, + self.k_proj.weight, + self.k_proj.bias, + tp_mode="colwise", + tp_min_local_shape=self.head_dim, + layer_type="mha", + ) + v = torch.ops.auto_deploy.torch_linear_simple( + hidden_states, + self.v_proj.weight, + self.v_proj.bias, + tp_mode="colwise", + tp_min_local_shape=self.head_dim, + layer_type="mha", + ) + + q = torch.ops.auto_deploy.view( + q, + [bsz, q_len, self.num_heads, self.head_dim], + tp_scaled_dim=2, + layer_type="mha", + ) + k = torch.ops.auto_deploy.view( + k, + [bsz, q_len, self.num_kv_heads, self.head_dim], + tp_scaled_dim=2, + layer_type="mha", + ) + v = torch.ops.auto_deploy.view( + v, + [bsz, q_len, self.num_kv_heads, self.head_dim], + tp_scaled_dim=2, + layer_type="mha", + ) + + cos, sin = position_embeddings + q, k = torch.ops.auto_deploy.torch_rope_with_explicit_cos_sin( + q, + k, + cos, + sin, + 2, + ) + + attn_output = torch.ops.auto_deploy.torch_attention( + q, + k, + v, + None, + 0.0, + True, + self.scaling, + None, + None, + None, + "bsnd", + ) + + attn_output = torch.ops.auto_deploy.view( + attn_output, + [bsz, q_len, self.num_heads * self.head_dim], + tp_scaled_dim=2, + layer_type="mha", + ) + + attn_output = torch.ops.auto_deploy.torch_linear_simple( + attn_output, + self.o_proj.weight, + self.o_proj.bias, + tp_mode="rowwise", + layer_type="mha", + ) + attn_output = torch.ops.auto_deploy.all_reduce(attn_output, layer_type="mha") + + return attn_output + + +class Llama3DecoderLayer(nn.Module): + """Transformer decoder layer for Llama 3.""" + + def __init__(self, config: LlamaConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Llama3Attention(config, layer_idx=layer_idx) + self.mlp = Llama3MLP(config) + self.input_layernorm = Llama3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Llama3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn(hidden_states, position_embeddings) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +@dataclass +class Llama3Output(ModelOutput): + """Output for Llama3Model.""" + + last_hidden_state: Optional[torch.FloatTensor] = None + + +@dataclass +class Llama3CausalLMOutput(ModelOutput): + """Output for Llama3ForCausalLM.""" + + logits: Optional[torch.FloatTensor] = None + + +class Llama3PreTrainedModel(PreTrainedModel): + """Base class for Llama 3 models.""" + + config_class = LlamaConfig + base_model_prefix = "model" + _no_split_modules = ["Llama3DecoderLayer"] + supports_gradient_checkpointing = False + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class Llama3Model(Llama3PreTrainedModel): + """Llama 3 transformer decoder model.""" + + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Llama3DecoderLayer(config, layer_idx=idx) for idx in range(config.num_hidden_layers)] + ) + self.norm = Llama3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.rotary_emb = Llama3RotaryEmbedding(config) + + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + **kwargs, + ) -> Llama3Output: + if input_ids is not None and inputs_embeds is not None: + raise ValueError("Cannot specify both input_ids and inputs_embeds") + elif input_ids is None and inputs_embeds is None: + raise ValueError("Must specify either input_ids or inputs_embeds") + + assert position_ids is not None, "position_ids must be provided for AD export" + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + inputs_embeds = inputs_embeds.to(self.norm.weight.dtype) + + position_embeddings = self.rotary_emb(inputs_embeds, position_ids) + + hidden_states = inputs_embeds + + for decoder_layer in self.layers: + hidden_states = decoder_layer(hidden_states, position_embeddings) + + hidden_states = self.norm(hidden_states) + + return Llama3Output(last_hidden_state=hidden_states) + + +class Llama3ForCausalLM(Llama3PreTrainedModel, GenerationMixin): + """Llama 3 model with language modeling head.""" + + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config, **kwargs): + super().__init__(config) + self.model = Llama3Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_decoder(self): + return self.model + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + **kwargs, + ) -> Llama3CausalLMOutput: + assert position_ids is not None, "position_ids must be provided for AD export" + outputs = self.model( + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + logits = self.lm_head(hidden_states).float() + + return Llama3CausalLMOutput(logits=logits) + + +AutoModelForCausalLMFactory.register_custom_model_cls("LlamaConfig", Llama3ForCausalLM) diff --git a/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_nemotron_h_ir.py b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_nemotron_h_ir.py new file mode 100644 index 000000000000..fac4767c33d4 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_nemotron_h_ir.py @@ -0,0 +1,822 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sharding-aware NemotronH model for AutoDeploy IR sharding. + +Derived from modeling_nemotron_h.py with explicit sharding hints on all +custom ops. ``apply_sharding_hints`` reads these hints to apply +deterministic, node-local TP/EP sharding. + +Shardable custom ops used: + - torch.ops.auto_deploy.torch_linear_simple (tp_mode, output_sizes, layer_type) + - torch.ops.auto_deploy.view (tp_scaled_dim, layer_type) + - torch.ops.auto_deploy.split_with_sizes (enable_sharding, layer_type) + - torch.ops.auto_deploy.all_reduce (layer_type) + - torch.ops.auto_deploy.torch_causal_conv1d (enable_sharding, layer_type) + - torch.ops.auto_deploy.torch_ssm (enable_sharding, layer_type) + - torch.ops.auto_deploy.torch_rmsnorm_gated (tp_mode, layer_type) + - torch.ops.auto_deploy.torch_moe (layer_type) + - torch.ops.auto_deploy.torch_attention (sharding-invariant, no hints needed) +""" + +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from transformers.activations import ACT2FN +from transformers.generation import GenerationMixin +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ModelOutput + +import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401, I001 -- register all ops +from tensorrt_llm._torch.auto_deploy.models.hf import AutoModelForCausalLMFactory +from tensorrt_llm._torch.utils import ActivationType + + +class MambaRMSNormGated(torch.nn.Module): + def __init__(self, hidden_size, group_size, eps=1e-5): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + self.group_size = group_size + + def forward(self, hidden_states, gate=None): + return torch.ops.auto_deploy.torch_rmsnorm_gated( + hidden_states, + self.weight, + gate, + self.variance_epsilon, + self.group_size, + tp_mode="colwise", + layer_type="ssm", + ) + + +class NemotronHMamba2Mixer(nn.Module): + """ + Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. + A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective) + ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4, + and is why Mamba is called **selective** state spaces) + """ + + def __init__(self, config, layer_idx: int): + super().__init__() + self.num_heads = config.mamba_num_heads + self.hidden_size = config.hidden_size + self.ssm_state_size = config.ssm_state_size + self.conv_kernel_size = config.conv_kernel + self.intermediate_size = config.mamba_num_heads * config.mamba_head_dim + self.layer_idx = layer_idx + self.use_conv_bias = config.use_conv_bias + self.activation = config.mamba_hidden_act + self.act = ACT2FN[config.mamba_hidden_act] + + self.layer_norm_epsilon = config.layer_norm_epsilon + + self.n_groups = config.n_groups + self.head_dim = config.mamba_head_dim + self.chunk_size = config.chunk_size + + self.time_step_limit = config.time_step_limit + self.time_step_min = config.time_step_min + self.time_step_max = config.time_step_max + + self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + bias=config.use_conv_bias, + kernel_size=config.conv_kernel, + groups=self.conv_dim, + padding=config.conv_kernel - 1, + ) + + projection_size = self.intermediate_size + self.conv_dim + self.num_heads + self.in_proj = nn.Linear( + self.hidden_size, + projection_size, + bias=config.use_bias, + ) + + self.dt_bias = nn.Parameter(torch.ones(self.num_heads)) + + A = torch.arange(1, self.num_heads + 1) + self.A_log = nn.Parameter(torch.log(A)) + self.A_log._no_weight_decay = True + self.norm = MambaRMSNormGated( + self.intermediate_size, + eps=self.layer_norm_epsilon, + group_size=self.intermediate_size // self.n_groups, + ) + self.D = nn.Parameter(torch.ones(self.num_heads)) + self.D._no_weight_decay = True + + self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) + self.use_bias = config.use_bias + + # Fused output sizes for in_proj sharding: + # in_proj output = [gate | hidden | B | C | dt] + self._in_proj_output_sizes = [ + self.intermediate_size, + self.intermediate_size, + self.n_groups * self.ssm_state_size, + self.n_groups * self.ssm_state_size, + self.num_heads, + ] + + # Fused output sizes for conv1d sharding: + # conv_dim = [hidden | B | C] + self._conv1d_output_sizes = [ + self.intermediate_size, + self.n_groups * self.ssm_state_size, + self.n_groups * self.ssm_state_size, + ] + + def torch_forward(self, input_states): + batch_size, seq_len, _ = input_states.shape + dtype = input_states.dtype + + # 1. Gated MLP's linear projection (colwise with fused output sizes) + projected_states = torch.ops.auto_deploy.torch_linear_simple( + input_states, + self.in_proj.weight, + self.in_proj.bias, + tp_mode="colwise", + output_sizes=self._in_proj_output_sizes, + layer_type="ssm", + ) + gate, hidden_states_B_C, dt = torch.ops.auto_deploy.split_with_sizes( + projected_states, + [self.intermediate_size, self.conv_dim, self.num_heads], + dim=-1, + enable_sharding=True, + layer_type="ssm", + ) + + # 2. Convolution sequence transformation + hidden_states_B_C = self.act( + torch.ops.auto_deploy.torch_causal_conv1d( + hidden_states_B_C, + self.conv1d.weight, + self.conv1d.bias, + self.conv1d.stride[0], + self.conv1d.padding[0], + self.conv1d.dilation[0], + self.conv1d.groups, + self.conv1d.padding_mode, + enable_sharding=True, + output_sizes=self._conv1d_output_sizes, + layer_type="ssm", + ) + ) + + hidden_states, B, C = torch.ops.auto_deploy.split_with_sizes( + hidden_states_B_C, + [ + self.intermediate_size, + self.n_groups * self.ssm_state_size, + self.n_groups * self.ssm_state_size, + ], + dim=-1, + enable_sharding=True, + layer_type="ssm", + ) + + # 3. SSM transformation + A = -torch.exp(self.A_log.float()) + y = torch.ops.auto_deploy.torch_ssm( + hidden_states=torch.ops.auto_deploy.view( + hidden_states, + [batch_size, seq_len, -1, self.head_dim], + tp_scaled_dim=2, + layer_type="ssm", + ), + A=A, + B=torch.ops.auto_deploy.view( + B, + [batch_size, seq_len, -1, self.ssm_state_size], + tp_scaled_dim=2, + layer_type="ssm", + ), + C=torch.ops.auto_deploy.view( + C, + [batch_size, seq_len, -1, self.ssm_state_size], + tp_scaled_dim=2, + layer_type="ssm", + ), + D=self.D, + dt=dt, + dt_bias=self.dt_bias, + time_step_limit=list(self.time_step_limit), + chunk_size=self.chunk_size, + enable_sharding=True, + layer_type="ssm", + ) + y = y.reshape(batch_size, seq_len, -1) + + scan_output = self.norm(y, gate) + + # 4. Final linear projection (rowwise) + all_reduce + contextualized_states = torch.ops.auto_deploy.torch_linear_simple( + scan_output.to(dtype), + self.out_proj.weight, + self.out_proj.bias, + tp_mode="rowwise", + layer_type="ssm", + ) + contextualized_states = torch.ops.auto_deploy.all_reduce( + contextualized_states, layer_type="ssm" + ) + return contextualized_states + + def forward(self, hidden_states): + return self.torch_forward(hidden_states) + + +class NemotronHRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + NemotronHRMSNorm is equivalent to T5LayerNorm and LlamaRMSNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + # Weights are in float32 + return (self.weight.to(torch.float32) * hidden_states).to(input_dtype) + + +class NemotronHBlock(nn.Module): + def __init__(self, config, layer_idx): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.residual_in_fp32 = config.residual_in_fp32 + self.norm = NemotronHRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + + # M: Mamba2, *: Attention, -: MLP + self.block_type = config.layers_block_type[layer_idx] + if self.block_type == "mamba": + self.mixer = NemotronHMamba2Mixer(config, layer_idx=layer_idx) + elif self.block_type == "attention": + self.mixer = NemotronHAttention(config, layer_idx=layer_idx) + elif self.block_type == "mlp": + self.mixer = NemotronHMLP(config, layer_idx=layer_idx) + elif self.block_type == "moe": + self.mixer = NemotronHMOE(config, layer_idx=layer_idx) + else: + raise ValueError(f"Invalid layer pattern {config.hybrid_override_pattern[layer_idx]}") + + def forward(self, hidden_states): + residual = hidden_states + hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype)) + if self.residual_in_fp32: + residual = residual.to(torch.float32) + + hidden_states = self.mixer(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class NemotronHMLP(nn.Module): + def __init__( + self, + config, + layer_idx: int, + intermediate_size: Optional[int] = None, + is_expert: bool = False, + add_all_reduce: bool = True, + sharding_layer_type: str = "mlp", + ): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.intermediate_size = intermediate_size or config.intermediate_size + use_latent_size = (getattr(self.config, "moe_latent_size", None) is not None) and is_expert + input_size = self.config.moe_latent_size if use_latent_size else self.hidden_size + self.up_proj = nn.Linear(input_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, input_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.mlp_hidden_act] + self.add_all_reduce = add_all_reduce + self.sharding_layer_type = sharding_layer_type + + def forward(self, x): + _lt = self.sharding_layer_type + up = torch.ops.auto_deploy.torch_linear_simple( + x, + self.up_proj.weight, + self.up_proj.bias, + tp_mode="colwise", + layer_type=_lt, + ) + down = torch.ops.auto_deploy.torch_linear_simple( + self.act_fn(up), + self.down_proj.weight, + self.down_proj.bias, + tp_mode="rowwise", + layer_type=_lt, + ) + if self.add_all_reduce: + down = torch.ops.auto_deploy.all_reduce(down, layer_type=_lt) + return down + + +class NemotronHMOE(nn.Module): + def __init__(self, config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.experts = nn.ModuleList( + [ + NemotronHMLP( + config, + layer_idx=layer_idx, + intermediate_size=config.moe_intermediate_size, + is_expert=True, + add_all_reduce=False, + sharding_layer_type="moe", + ) + for _ in range(config.n_routed_experts) + ] + ) + self.gate = NemotronHTopkRouter(config) + self.shared_experts = NemotronHMLP( + config=config, + intermediate_size=config.moe_shared_expert_intermediate_size, + layer_idx=layer_idx, + is_expert=False, + add_all_reduce=False, + sharding_layer_type="moe", + ) + if getattr(config, "moe_latent_size", None) is not None: + self.fc1_latent_proj = nn.Linear( + config.hidden_size, config.moe_latent_size, bias=config.mlp_bias + ) + self.fc2_latent_proj = nn.Linear( + config.moe_latent_size, config.hidden_size, bias=config.mlp_bias + ) + else: + self.fc1_latent_proj = nn.Identity() + self.fc2_latent_proj = nn.Identity() + + def forward(self, hidden_states: torch.Tensor): + residuals = hidden_states + orig_shape = hidden_states.shape + topk_indices, topk_weights = self.gate(hidden_states) + x_flat = hidden_states.view(-1, hidden_states.shape[-1]) + + # Shared expert first (dispatch order matches exported graph node order) + shared_out = self.shared_experts(residuals) + + has_latent_proj = hasattr(self, "fc1_latent_proj") and hasattr(self, "fc2_latent_proj") + + if has_latent_proj: + x_flat = self.fc1_latent_proj(x_flat) + + out_flat = torch.ops.auto_deploy.torch_moe( + x_flat, + topk_indices, + topk_weights, + w1_weight=[e.up_proj.weight for e in self.experts], + w2_weight=[e.down_proj.weight for e in self.experts], + w3_weight=[], + act_fn=ActivationType.Relu2, + is_gated_mlp=False, + layer_type="moe", + ) + + if has_latent_proj: + out_flat = self.fc2_latent_proj(out_flat) + + routed_out = out_flat.view(*orig_shape) + out = shared_out + routed_out + out = torch.ops.auto_deploy.all_reduce(out, layer_type="moe") + return out + + +class NemotronHTopkRouter(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.top_k = config.num_experts_per_tok + self.n_routed_experts = config.n_routed_experts + self.routed_scaling_factor = config.routed_scaling_factor + self.n_group = config.n_group + self.topk_group = config.topk_group + self.norm_topk_prob = config.norm_topk_prob + + self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size))) + self.register_buffer( + "e_score_correction_bias", torch.zeros(self.n_routed_experts, dtype=torch.float32) + ) + + def forward(self, hidden_states): + hidden_states = hidden_states.view(-1, self.config.hidden_size) + if self.weight.dtype == torch.float32: + router_logits = F.linear(hidden_states.type(torch.float32), self.weight) + else: + router_logits = torch.ops.trtllm.dsv3_router_gemm_op( + hidden_states, self.weight.t(), bias=None, out_dtype=torch.float32 + ) + + topk_weights, topk_indices = torch.ops.trtllm.noaux_tc_op( + router_logits, + self.e_score_correction_bias, + self.n_group, + self.topk_group, + self.top_k, + self.routed_scaling_factor, + ) + + return topk_indices, topk_weights + + +class NemotronHAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + raise ValueError("Please make sure to provide a `layer_idx` when creating this class.") + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + + if hasattr(config, "head_dim"): + head_dim = config.head_dim + elif hasattr(config, "attention_head_dim"): + head_dim = config.attention_head_dim + else: + raise AttributeError( + "Expected either `head_dim` or `attention_head_dim` to be present in the config " + "class, found neither." + ) + + if head_dim is not None: + self.head_dim = head_dim + else: + self.head_dim = config.hidden_size // config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.is_causal = True + + self.q_proj = nn.Linear( + self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + self.head_dim * self.num_heads, self.hidden_size, bias=config.attention_bias + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + bsz, q_len, _ = hidden_states.size() + + query_states = torch.ops.auto_deploy.torch_linear_simple( + hidden_states, + self.q_proj.weight, + self.q_proj.bias, + tp_mode="colwise", + layer_type="mha", + ) + key_states = torch.ops.auto_deploy.torch_linear_simple( + hidden_states, + self.k_proj.weight, + self.k_proj.bias, + tp_mode="colwise", + tp_min_local_shape=self.head_dim, + layer_type="mha", + ) + value_states = torch.ops.auto_deploy.torch_linear_simple( + hidden_states, + self.v_proj.weight, + self.v_proj.bias, + tp_mode="colwise", + tp_min_local_shape=self.head_dim, + layer_type="mha", + ) + + query_states = torch.ops.auto_deploy.view( + query_states, + [bsz, q_len, -1, self.head_dim], + tp_scaled_dim=2, + layer_type="mha", + ) + key_states = torch.ops.auto_deploy.view( + key_states, + [bsz, q_len, -1, self.head_dim], + tp_scaled_dim=2, + layer_type="mha", + ) + value_states = torch.ops.auto_deploy.view( + value_states, + [bsz, q_len, -1, self.head_dim], + tp_scaled_dim=2, + layer_type="mha", + ) + + attn_output = torch.ops.auto_deploy.torch_attention( + query_states, + key_states, + value_states, + attn_mask=None, + dropout_p=0.0, + is_causal=True, + layout="bsnd", + ) + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = torch.ops.auto_deploy.torch_linear_simple( + attn_output, + self.o_proj.weight, + self.o_proj.bias, + tp_mode="rowwise", + layer_type="mha", + ) + attn_output = torch.ops.auto_deploy.all_reduce(attn_output, layer_type="mha") + + return attn_output + + +class NemotronHPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + base_model_prefix = "backbone" + _no_split_modules = ["NemotronHBlock"] + supports_gradient_checkpointing = True + _is_stateful = True + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, NemotronHMamba2Mixer): + module.A_log._no_weight_decay = True + module.D._no_weight_decay = True + + dt = torch.exp( + torch.rand(self.config.mamba_num_heads) + * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min)) + + math.log(self.config.time_step_min) + ).clamp(min=self.config.time_step_floor) + + inv_dt = dt + torch.log(-torch.expm1(-dt)) + with torch.no_grad(): + module.dt_bias.copy_(inv_dt) + module.dt_bias._no_reinit = True + + if isinstance(module, nn.Linear): + if module.bias is not None: + if not getattr(module.bias, "_no_reinit", False): + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, std=self.config.initializer_range) + + if self.config.rescale_prenorm_residual: + for name, p in module.named_parameters(): + if name in ["out_proj.weight"]: + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(self.config.num_hidden_layers) + + +@dataclass +class NemotronHOutput(ModelOutput): + """ + Class for the NemotronH model outputs. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + + +@dataclass +class NemotronHCausalLMOutput(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs. + + Args: + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + """ + + logits: Optional[torch.FloatTensor] = None + + +class NemotronHModel(NemotronHPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = nn.ModuleList( + [NemotronHBlock(config, layer_idx=idx) for idx in range(config.num_hidden_layers)] + ) + + self.norm_f = NemotronHRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self._register_load_state_dict_pre_hook(self.load_hook) + self.post_init() + + def load_hook(self, state_dict, prefix, *args): + for k in state_dict: + if "embedding." in k: + state_dict[k.replace("embedding.", "embeddings.")] = state_dict.pop(k) + break + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, new_embeddings): + self.embeddings = new_embeddings + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Union[Tuple, NemotronHOutput]: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + + hidden_states = inputs_embeds + + for mixer_block in self.layers: + hidden_states = mixer_block(hidden_states) + + hidden_states = self.norm_f(hidden_states) + + return NemotronHOutput(last_hidden_state=hidden_states) + + +class NemotronHForCausalLM(NemotronHPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.backbone = NemotronHModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.post_init() + + def get_input_embeddings(self): + return self.backbone.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + return self.backbone.set_input_embeddings(new_embeddings) + + def get_final_normalization(self): + return self.backbone.norm_f + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Union[Tuple, NemotronHCausalLMOutput]: + nemotron_h_outputs = self.backbone(input_ids, inputs_embeds=inputs_embeds) + hidden_states = nemotron_h_outputs[0] + + logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float() + + return NemotronHCausalLMOutput(logits) + + +AutoModelForCausalLMFactory.register_custom_model_cls("NemotronHConfig", NemotronHForCausalLM) + + +# ============================================================================= +# Eagle Layer Builder for NemotronH MTP (Multi-Token Prediction) +# ============================================================================= + + +class NemotronHEagleLayer(nn.Module): + """Eagle layer for NemotronH models. + + NemotronH does not use RoPE, so position_ids is accepted but ignored. + The layer implements the MTP (Multi-Token Prediction) architecture: + - First layer fuses embeds + hidden_states via start projections (enorm, hnorm, eh_proj) + - All layers have pre-norm residual block with mixer (Attention or MoE) + - Last layer applies final_layernorm + + Supported layers are * (Attention with start projections) and E (MoE with final_layernorm) + """ + + def __init__( + self, + config, + layer_idx: int, + layer_type: str, + has_start_projections: bool, + has_end_norm: bool, + ): + super().__init__() + eps = getattr(config, "layer_norm_epsilon") + if eps is None: + raise ValueError("layer_norm_epsilon is not set in the config") + self.residual_in_fp32 = config.residual_in_fp32 + self.has_start_projections = has_start_projections + self.has_end_norm = has_end_norm + + if has_start_projections: + self.enorm = NemotronHRMSNorm(config.hidden_size, eps=eps) + self.hnorm = NemotronHRMSNorm(config.hidden_size, eps=eps) + self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False) + + self.norm = NemotronHRMSNorm(config.hidden_size, eps=eps) + + if layer_type == "*": + self.mixer = NemotronHAttention(config, layer_idx=layer_idx) + elif layer_type == "E": + self.mixer = NemotronHMOE(config, layer_idx=layer_idx) + else: + raise ValueError( + f"Unsupported MTP layer type in NemotronHEagleLayer. Only * and E are currently supported." + f"Layer type: {layer_type}" + ) + + if has_end_norm: + self.final_layernorm = NemotronHRMSNorm(config.hidden_size, eps=eps) + + def forward( + self, + hidden_states: torch.Tensor, + inputs_embeds: torch.Tensor, + position_ids: torch.LongTensor, + ) -> torch.Tensor: + if self.has_start_projections: + e_normed = self.enorm(inputs_embeds) + h_normed = self.hnorm(hidden_states) + hidden_states = self.eh_proj(torch.cat([e_normed, h_normed], dim=-1)) + + residual = hidden_states + hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype)) + if self.residual_in_fp32: + residual = residual.to(torch.float32) + hidden_states = self.mixer(hidden_states) + hidden_states = residual + hidden_states + + if self.has_end_norm: + hidden_states = self.final_layernorm(hidden_states) + + return hidden_states + + +def build_nemotron_eagle_layers(config) -> list[nn.Module]: + """Build NemotronH MTP layers for Eagle drafter.""" + pattern = getattr(config, "mtp_hybrid_override_pattern", None) + if pattern is None: + raise ValueError("mtp_hybrid_override_pattern is not set in the config") + + return [ + NemotronHEagleLayer( + config, + layer_idx=i, + layer_type=char, + has_start_projections=(i == 0), + has_end_norm=(i == len(pattern) - 1), + ) + for i, char in enumerate(pattern) + ] diff --git a/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_qwen3_5_moe_ir.py b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_qwen3_5_moe_ir.py new file mode 100644 index 000000000000..fe82de08f226 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_qwen3_5_moe_ir.py @@ -0,0 +1,3050 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sharding-aware Qwen3.5 MoE model for AutoDeploy IR sharding (text + vision). + +Reference HF modeling file (not yet in a released transformers version): + transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py + +This implementation differs from the HuggingFace original in the following ways: + * External kernel dependencies (flash-linear-attention, causal_conv1d) are replaced with + autodeploy custom ops. + * Cache-related code paths have been removed (prefill-only). + * Training-related code paths have been removed. + * Unnecessary output fields have been removed. + * The GatedDeltaNet forward is adapted from the Qwen3Next GDN patch + (tensorrt_llm/_torch/auto_deploy/models/patches/qwen3_next.py). + * The MoE implementation uses expert lists (individual nn.Linear layers per expert) + that directly match the checkpoint structure, dispatched via torch_moe op. + * The VLM wrapper passes 3D ``position_ids (3, B, S)`` to the text model, + which computes mRoPE cos/sin internally via its own ``rotary_emb``. + For text-only inputs the wrapper expands the executor's 2D positions to 3D; + for multimodal inputs it computes spatial (T, H, W) positions via ``get_rope_index``. + +This allows us to have a "pytorch" native reference implementation decoupled from bugs and +dependency issues in the source, while remaining weight-compatible with HF checkpoints. +""" + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +import torch +import torch.nn.functional as F +from PIL import Image +from torch import nn +from torch.export import Dim +from transformers import AutoConfig +from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig +from transformers.generation import GenerationMixin +from transformers.modeling_outputs import BaseModelOutputWithPooling +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ModelOutput + +import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401, I001 -- register all ops +from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import BatchInfo +from tensorrt_llm._torch.auto_deploy.models.factory import ModelFactoryRegistry +from tensorrt_llm._torch.auto_deploy.models.hf import ( + AutoModelForCausalLMFactory, + AutoModelForImageTextToTextFactory, + TextModelExportInfo, +) +from tensorrt_llm.inputs.multimodal import MultimodalInput, apply_mm_hashes, hexdigest_to_int32 +from tensorrt_llm.inputs.utils import VideoData + +# ============================================================================= +# Configuration +# ============================================================================= + + +class Qwen3_5MoeTextConfig(PretrainedConfig): + """Minimal config class for Qwen3.5 MoE text model. + + Mirrors the attributes of the upstream Qwen3_5MoeTextConfig. Only attributes + needed by the slimmed-down prefill model are included. + """ + + model_type = "qwen3_5_moe_text" + + def __init__( + self, + vocab_size=248320, + hidden_size=2048, + num_hidden_layers=40, + num_attention_heads=16, + num_key_value_heads=2, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + tie_word_embeddings=False, + rope_parameters=None, + attention_bias=False, + attention_dropout=0.0, + head_dim=256, + # linear attention (GatedDeltaNet) params + linear_conv_kernel_dim=4, + linear_key_head_dim=128, + linear_value_head_dim=128, + linear_num_key_heads=16, + linear_num_value_heads=32, + # MoE params + moe_intermediate_size=512, + shared_expert_intermediate_size=512, + num_experts_per_tok=8, + num_experts=256, + # layer types + layer_types=None, + pad_token_id=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.max_position_embeddings = max_position_embeddings + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.head_dim = head_dim + + if rope_parameters is None: + rope_parameters = { + "rope_type": "default", + "rope_theta": 1000000.0, + "partial_rotary_factor": 0.25, + "mrope_section": [11, 11, 10], + } + self.rope_parameters = rope_parameters + + self.linear_conv_kernel_dim = linear_conv_kernel_dim + self.linear_key_head_dim = linear_key_head_dim + self.linear_value_head_dim = linear_value_head_dim + self.linear_num_key_heads = linear_num_key_heads + self.linear_num_value_heads = linear_num_value_heads + + self.moe_intermediate_size = moe_intermediate_size + self.shared_expert_intermediate_size = shared_expert_intermediate_size + self.num_experts_per_tok = num_experts_per_tok + self.num_experts = num_experts + + self.layer_types = layer_types + if self.layer_types is None: + # Default pattern: every 4th layer is full_attention, rest are linear_attention + interval_pattern = kwargs.pop("full_attention_interval", 4) + self.layer_types = [ + "linear_attention" if bool((i + 1) % interval_pattern) else "full_attention" + for i in range(self.num_hidden_layers) + ] + + super().__init__( + pad_token_id=pad_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +# ============================================================================= +# Normalization +# ============================================================================= + + +class Qwen3_5MoeRMSNorm(nn.Module): + """RMSNorm with weight scaling. + + The HF checkpoint stores weights in ``(1 + w)`` parameterisation (zeros + init). A load-time pre-hook adds 1.0 so that the forward can use a plain + ``weight * x`` multiply, which matches the autodeploy RMSNorm pattern and + gets fused into a single ``flashinfer_rms_norm`` kernel. + """ + + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + self._register_load_state_dict_pre_hook(self._offset_weight) + + @staticmethod + def _offset_weight(state_dict, prefix, *args): + key = prefix + "weight" + assert key in state_dict, f"RMSNorm: Key {key} not found in state_dict" + state_dict[key] = state_dict[key] + 1.0 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + input_dtype = x.dtype + output = x.to(torch.float32) + output = output * torch.rsqrt(output.pow(2).mean(-1, keepdim=True) + self.eps) + return (self.weight.to(torch.float32) * output).to(input_dtype) + + +class Qwen3_5MoeRMSNormGated(nn.Module): + """Gated RMSNorm: norm(x) * weight * silu(gate). Weight is initialized to ones.""" + + def __init__(self, hidden_size: int, eps: float = 1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + hidden_states = self.weight * hidden_states.to(input_dtype) + hidden_states = hidden_states * F.silu(gate.to(torch.float32)) + return hidden_states.to(input_dtype) + + +# ============================================================================= +# Rotary Position Embedding (mRoPE) +# ============================================================================= + + +def rotate_half(x: torch.Tensor) -> torch.Tensor: + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + unsqueeze_dim: int = 2, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Apply partial RoPE to query and key tensors. + + Supports partial rotary where only the first `rotary_dim` dimensions are rotated. + Default unsqueeze_dim=2 is for bsnd layout (B, S, N, D). + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + rotary_dim = cos.shape[-1] + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + + q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) + k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) + + q_embed = torch.cat([q_embed, q_pass], dim=-1) + k_embed = torch.cat([k_embed, k_pass], dim=-1) + return q_embed, k_embed + + +class Qwen3_5MoeTextRotaryEmbedding(nn.Module): + """Simplified mRoPE for text-only prefill. Supports only the "default" rope type.""" + + def __init__(self, config: Qwen3_5MoeTextConfig): + super().__init__() + rope_params = config.rope_parameters + base = rope_params["rope_theta"] + partial_rotary_factor = rope_params.get("partial_rotary_factor", 1.0) + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) + + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.mrope_section = rope_params.get("mrope_section", [11, 11, 10]) + + @torch.no_grad() + def forward( + self, x: torch.Tensor, position_ids: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute cos/sin embeddings. + + Args: + x: Hidden states tensor, used only for dtype/device. + position_ids: Shape (3, B, S) for mRoPE or (B, S) for plain. + """ + if position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + # inv_freq: (dim/2,) -> (3, B, dim/2, 1) + inv_freq_expanded = ( + self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) + ) + # position_ids: (3, B, S) -> (3, B, 1, S) + position_ids_expanded = position_ids[:, :, None, :].float() + + # freqs: (3, B, S, dim/2) + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) + # Apply interleaved mRoPE: (3, B, S, dim/2) -> (B, S, dim/2) + freqs = self._apply_interleaved_mrope(freqs) + # Double for cos/sin: (B, S, dim) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + def _apply_interleaved_mrope(self, freqs: torch.Tensor) -> torch.Tensor: + """Apply interleaved mRoPE. Merges T/H/W frequency channels into one tensor.""" + freqs_t = freqs[0].clone() + for dim_idx, offset in enumerate((1, 2), start=1): + length = self.mrope_section[dim_idx] * 3 + idx = slice(offset, length, 3) + freqs_t[..., idx] = freqs[dim_idx, ..., idx] + return freqs_t + + +# ============================================================================= +# GatedDeltaNet (Linear Attention) +# ============================================================================= +# Adapted from the Qwen3Next GDN patch: +# tensorrt_llm/_torch/auto_deploy/models/patches/qwen3_next.py +# Uses autodeploy custom ops: torch_causal_conv1d, torch_gated_delta_rule + + +class Qwen3_5MoeGatedDeltaNet(nn.Module): + """Prefill-only GatedDeltaNet using autodeploy custom ops.""" + + def __init__(self, config: Qwen3_5MoeTextConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.num_v_heads = config.linear_num_value_heads + self.num_k_heads = config.linear_num_key_heads + self.head_k_dim = config.linear_key_head_dim + self.head_v_dim = config.linear_value_head_dim + self.key_dim = self.head_k_dim * self.num_k_heads + self.value_dim = self.head_v_dim * self.num_v_heads + + self.conv_kernel_size = config.linear_conv_kernel_dim + self.layer_idx = layer_idx + + # QKV convolution + self.conv_dim = self.key_dim * 2 + self.value_dim + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + bias=False, + kernel_size=self.conv_kernel_size, + groups=self.conv_dim, + padding=self.conv_kernel_size - 1, + ) + + # dt_bias and A_log for gated delta rule + self.dt_bias = nn.Parameter(torch.ones(self.num_v_heads)) + A = torch.empty(self.num_v_heads).uniform_(0, 16) + self.A_log = nn.Parameter(torch.log(A)) + + # Gated RMSNorm (per head_v_dim) + self.norm = Qwen3_5MoeRMSNormGated(self.head_v_dim, eps=config.rms_norm_eps) + + # Projections + self.in_proj_qkv = nn.Linear( + self.hidden_size, self.key_dim * 2 + self.value_dim, bias=False + ) + self.in_proj_z = nn.Linear(self.hidden_size, self.value_dim, bias=False) + self.in_proj_b = nn.Linear(self.hidden_size, self.num_v_heads, bias=False) + self.in_proj_a = nn.Linear(self.hidden_size, self.num_v_heads, bias=False) + self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False) + + self._conv_split_sizes = [self.key_dim, self.key_dim, self.value_dim] + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, seq_len, _ = hidden_states.shape + + # 1. Projections with sharding hints + mixed_qkv = torch.ops.auto_deploy.torch_linear_simple( + hidden_states, + self.in_proj_qkv.weight, + self.in_proj_qkv.bias, + tp_mode="colwise", + output_sizes=self._conv_split_sizes, + layer_type="delta", + ) + z = torch.ops.auto_deploy.torch_linear_simple( + hidden_states, + self.in_proj_z.weight, + self.in_proj_z.bias, + tp_mode="colwise", + layer_type="delta", + ) + z = torch.ops.auto_deploy.view( + z, + [batch_size, seq_len, -1, self.head_v_dim], + tp_scaled_dim=2, + layer_type="delta", + ) + b = torch.ops.auto_deploy.torch_linear_simple( + hidden_states, + self.in_proj_b.weight, + self.in_proj_b.bias, + tp_mode="colwise", + layer_type="delta", + ) + a = torch.ops.auto_deploy.torch_linear_simple( + hidden_states, + self.in_proj_a.weight, + self.in_proj_a.bias, + tp_mode="colwise", + layer_type="delta", + ) + + # 2. Causal Conv1d with sharding hint + mixed_qkv = torch.ops.auto_deploy.torch_causal_conv1d( + mixed_qkv, + self.conv1d.weight, + self.conv1d.bias, + self.conv1d.stride[0], + self.conv1d.padding[0], + self.conv1d.dilation[0], + self.conv1d.groups, + self.conv1d.padding_mode, + enable_sharding=True, + output_sizes=self._conv_split_sizes, + layer_type="delta", + ) + mixed_qkv = F.silu(mixed_qkv) + + # Split into Q, K, V with enable_sharding hint + query, key, value = torch.ops.auto_deploy.split_with_sizes( + mixed_qkv, + [self.key_dim, self.key_dim, self.value_dim], + dim=-1, + enable_sharding=True, + layer_type="delta", + ) + + # Reshape to per-head: [B, S, num_heads, head_dim] with -1 at head dim + query = torch.ops.auto_deploy.view( + query, + [batch_size, seq_len, -1, self.head_k_dim], + tp_scaled_dim=2, + layer_type="delta", + ) + key = torch.ops.auto_deploy.view( + key, + [batch_size, seq_len, -1, self.head_k_dim], + tp_scaled_dim=2, + layer_type="delta", + ) + value = torch.ops.auto_deploy.view( + value, + [batch_size, seq_len, -1, self.head_v_dim], + tp_scaled_dim=2, + layer_type="delta", + ) + + # 3. Gated Delta Rule with enable_sharding hint + core_attn_out = torch.ops.auto_deploy.torch_gated_delta_rule( + query, + key, + value, + a, + b, + self.A_log, + self.dt_bias, + enable_sharding=True, + layer_type="delta", + ) + + # 4. Gated RMSNorm (norm weight is replicated -- constant head_v_dim) + core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) + z_flat = z.reshape(-1, z.shape[-1]) + core_attn_out = self.norm(core_attn_out, z_flat) + core_attn_out = torch.ops.auto_deploy.view( + core_attn_out, + [batch_size, seq_len, -1, self.head_v_dim], + tp_scaled_dim=2, + layer_type="delta", + ) + core_attn_out = core_attn_out.reshape(batch_size, seq_len, -1) + + # 5. Output projection (rowwise) + all_reduce + output = torch.ops.auto_deploy.torch_linear_simple( + core_attn_out, + self.out_proj.weight, + self.out_proj.bias, + tp_mode="rowwise", + layer_type="delta", + ) + output = torch.ops.auto_deploy.all_reduce(output, layer_type="delta") + return output + + +# ============================================================================= +# Attention +# ============================================================================= + + +class Qwen3_5MoeAttention(nn.Module): + """Multi-headed attention with gating, Q/K norms, and partial RoPE. + + Key differences from standard attention: + - q_proj outputs 2x (query + gate), gating applied to attention output. + - q_norm / k_norm applied per-head before RoPE. + - Partial RoPE: only first rotary_dim dimensions are rotated. + """ + + def __init__(self, config: Qwen3_5MoeTextConfig, layer_idx: int): + super().__init__() + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self.num_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + + # q_proj outputs 2x for query + gate + self.q_proj = nn.Linear( + config.hidden_size, + config.num_attention_heads * self.head_dim * 2, + bias=config.attention_bias, + ) + self.k_proj = nn.Linear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.v_proj = nn.Linear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, + config.hidden_size, + bias=config.attention_bias, + ) + + # Per-head Q/K norms + self.q_norm = Qwen3_5MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = Qwen3_5MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + ) -> torch.Tensor: + bsz, q_len, _ = hidden_states.size() + + # Q projection with gate (interleaved per-head [q_h0,g_h0,...] -- plain colwise is correct) + qg = torch.ops.auto_deploy.torch_linear_simple( + hidden_states, + self.q_proj.weight, + self.q_proj.bias, + tp_mode="colwise", + layer_type="mha", + ) + qg = torch.ops.auto_deploy.view( + qg, + [bsz, q_len, -1, self.head_dim * 2], + tp_scaled_dim=2, + layer_type="mha", + ) + query_states, gate = torch.chunk(qg, 2, dim=-1) + gate = gate.reshape(bsz, q_len, -1) + + # K, V projections with tp_min_local_shape for GQA + key_states = torch.ops.auto_deploy.torch_linear_simple( + hidden_states, + self.k_proj.weight, + self.k_proj.bias, + tp_mode="colwise", + tp_min_local_shape=self.head_dim, + layer_type="mha", + ) + key_states = torch.ops.auto_deploy.view( + key_states, + [bsz, q_len, -1, self.head_dim], + tp_scaled_dim=2, + layer_type="mha", + ) + value_states = torch.ops.auto_deploy.torch_linear_simple( + hidden_states, + self.v_proj.weight, + self.v_proj.bias, + tp_mode="colwise", + tp_min_local_shape=self.head_dim, + layer_type="mha", + ) + value_states = torch.ops.auto_deploy.view( + value_states, + [bsz, q_len, -1, self.head_dim], + tp_scaled_dim=2, + layer_type="mha", + ) + + # Per-head Q/K norms (norm on last dim = head_dim, no sharding) + query_states = self.q_norm(query_states) + key_states = self.k_norm(key_states) + + # Partial RoPE in bsnd layout + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, unsqueeze_dim=2 + ) + + # Attention via autodeploy op (bsnd layout) + attn_output = torch.ops.auto_deploy.torch_attention( + query_states, + key_states, + value_states, + attn_mask=None, + dropout_p=0.0, + is_causal=True, + layout="bsnd", + ) + attn_output = attn_output.view(bsz, q_len, -1) + + # Gated output + attn_output = attn_output * torch.sigmoid(gate) + + # Output projection (rowwise) + all_reduce + attn_output = torch.ops.auto_deploy.torch_linear_simple( + attn_output, + self.o_proj.weight, + self.o_proj.bias, + tp_mode="rowwise", + layer_type="mha", + ) + attn_output = torch.ops.auto_deploy.all_reduce(attn_output, layer_type="mha") + return attn_output + + +# ============================================================================= +# MLP and MoE +# ============================================================================= + + +class Qwen3_5MoeMLP(nn.Module): + """SwiGLU MLP used for the shared expert and standalone MLP layers.""" + + def __init__( + self, + config: Qwen3_5MoeTextConfig, + intermediate_size: int, + add_all_reduce: bool = True, + sharding_layer_type: str = "mlp", + ): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + self.add_all_reduce = add_all_reduce + self.sharding_layer_type = sharding_layer_type + + def forward(self, x: torch.Tensor) -> torch.Tensor: + _lt = self.sharding_layer_type + gate = torch.ops.auto_deploy.torch_linear_simple( + x, + self.gate_proj.weight, + self.gate_proj.bias, + tp_mode="colwise", + layer_type=_lt, + ) + up = torch.ops.auto_deploy.torch_linear_simple( + x, + self.up_proj.weight, + self.up_proj.bias, + tp_mode="colwise", + layer_type=_lt, + ) + down = torch.ops.auto_deploy.torch_linear_simple( + self.act_fn(gate) * up, + self.down_proj.weight, + self.down_proj.bias, + tp_mode="rowwise", + layer_type=_lt, + ) + if self.add_all_reduce: + down = torch.ops.auto_deploy.all_reduce(down, layer_type=_lt) + return down + + +class Qwen3_5MoeExpert(nn.Module): + """Single expert with gate, up, and down projections.""" + + def __init__(self, hidden_dim: int, intermediate_dim: int): + super().__init__() + self.gate_proj = nn.Linear(hidden_dim, intermediate_dim, bias=False) + self.up_proj = nn.Linear(hidden_dim, intermediate_dim, bias=False) + self.down_proj = nn.Linear(intermediate_dim, hidden_dim, bias=False) + + +class Qwen3_5MoeTopKRouter(nn.Module): + """Top-K router with softmax normalization.""" + + def __init__(self, config: Qwen3_5MoeTextConfig): + super().__init__() + self.top_k = config.num_experts_per_tok + self.num_experts = config.num_experts + self.hidden_dim = config.hidden_size + self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim)) + + def forward(self, hidden_states: torch.Tensor): + hidden_states = hidden_states.reshape(-1, self.hidden_dim) + router_logits = F.linear(hidden_states, self.weight) # (T, E) + routing_weights = F.softmax(router_logits, dtype=torch.float, dim=-1) + routing_weights, selected_experts = torch.topk( + routing_weights, self.top_k, dim=-1 + ) # (T, top_k) + routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(hidden_states.dtype) + return routing_weights, selected_experts + + +class Qwen3_5MoeSparseMoeBlock(nn.Module): + """MoE block with expert list implementation. + + Implements routed experts by iterating over selected experts and dispatching + tokens accordingly. Each expert is a separate nn.Linear triplet (gate, up, down). + """ + + def __init__(self, config: Qwen3_5MoeTextConfig): + super().__init__() + self.gate = Qwen3_5MoeTopKRouter(config) + self.experts = nn.ModuleList( + [ + Qwen3_5MoeExpert(config.hidden_size, config.moe_intermediate_size) + for _ in range(config.num_experts) + ] + ) + self.shared_expert = Qwen3_5MoeMLP( + config, + intermediate_size=config.shared_expert_intermediate_size, + add_all_reduce=False, + sharding_layer_type="moe", + ) + self.shared_expert_gate = nn.Linear(config.hidden_size, 1, bias=False) + self._register_load_state_dict_pre_hook(self._load_experts_from_fused_checkpoint) + + @staticmethod + def _load_experts_from_fused_checkpoint(state_dict, prefix, *args): + """Load fused MoE expert checkpoint tensors into per-expert ModuleList params. + + Checkpoint format: + - experts.gate_up_proj: [E, 2*I, H] with [gate, up] stacking + - experts.down_proj: [E, H, I] + + Target format: + - experts.{expert_id}.gate_proj.weight: [I, H] + - experts.{expert_id}.up_proj.weight: [I, H] + - experts.{expert_id}.down_proj.weight: [H, I] + """ + gate_up_key = prefix + "experts.gate_up_proj" + down_key = prefix + "experts.down_proj" + + if gate_up_key in state_dict: + fused = state_dict.pop(gate_up_key) + num_experts = fused.shape[0] + intermediate_dim = fused.shape[1] // 2 + gate_weights = fused[:, :intermediate_dim, :] + up_weights = fused[:, intermediate_dim:, :] + + for i in range(num_experts): + state_dict[f"{prefix}experts.{i}.gate_proj.weight"] = gate_weights[i] + state_dict[f"{prefix}experts.{i}.up_proj.weight"] = up_weights[i] + + if down_key in state_dict: + fused = state_dict.pop(down_key) + num_experts = fused.shape[0] + for i in range(num_experts): + state_dict[f"{prefix}experts.{i}.down_proj.weight"] = fused[i] + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states_flat = hidden_states.view(-1, hidden_dim) + + # Router + routing_weights, selected_experts = self.gate(hidden_states_flat) + + # Routed experts via torch_moe (sharded by apply_sharding_hints) + w1_weights = [self.experts[i].gate_proj.weight for i in range(len(self.experts))] + w2_weights = [self.experts[i].down_proj.weight for i in range(len(self.experts))] + w3_weights = [self.experts[i].up_proj.weight for i in range(len(self.experts))] + + expert_output = torch.ops.auto_deploy.torch_moe( + hidden_states_flat, + selected_experts, + routing_weights, + w1_weights, + w2_weights, + w3_weights, + is_gated_mlp=True, + layer_type="moe", + ) + + # Shared expert with sigmoid gating (all_reduce deferred) + shared_expert_output = self.shared_expert(hidden_states_flat) + shared_expert_output = ( + F.sigmoid(self.shared_expert_gate(hidden_states_flat)) * shared_expert_output + ) + + # Merge routed + shared, then single all_reduce + expert_output = expert_output + shared_expert_output + expert_output = torch.ops.auto_deploy.all_reduce(expert_output, layer_type="moe") + + expert_output = expert_output.reshape(batch_size, sequence_length, hidden_dim) + return expert_output + + +# ============================================================================= +# Decoder Layer +# ============================================================================= + + +class Qwen3_5MoeDecoderLayer(nn.Module): + """Single decoder layer: token mixer (linear_attention or full_attention) + MoE.""" + + def __init__(self, config: Qwen3_5MoeTextConfig, layer_idx: int): + super().__init__() + self.layer_type = config.layer_types[layer_idx] + + if self.layer_type == "linear_attention": + self.linear_attn = Qwen3_5MoeGatedDeltaNet(config, layer_idx) + elif self.layer_type == "full_attention": + self.self_attn = Qwen3_5MoeAttention(config, layer_idx) + else: + raise ValueError(f"Unknown layer type: {self.layer_type}") + + self.mlp = Qwen3_5MoeSparseMoeBlock(config) + self.input_layernorm = Qwen3_5MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen3_5MoeRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + ) -> torch.Tensor: + # Token mixer + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + if self.layer_type == "linear_attention": + hidden_states = self.linear_attn(hidden_states) + elif self.layer_type == "full_attention": + hidden_states = self.self_attn(hidden_states, position_embeddings=position_embeddings) + + hidden_states = residual + hidden_states + + # Channel mixer (MoE) + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +# ============================================================================= +# Model +# ============================================================================= + + +class Qwen3_5MoePreTrainedModel(PreTrainedModel): + """Base class for Qwen3.5 MoE pretrained models.""" + + config_class = Qwen3_5MoeTextConfig + base_model_prefix = "model" + _no_split_modules = ["Qwen3_5MoeDecoderLayer"] + supports_gradient_checkpointing = True + _is_stateful = True + + def _init_weights(self, module): + # Delegate nn.Linear / nn.Embedding / nn.Conv* to the base class, which + # safely resolves initializer_range via hasattr + get_text_config() fallback. + super()._init_weights(module) + std = getattr(self.config, "initializer_range", 0.02) + if isinstance(module, Qwen3_5MoeRMSNorm): + module.weight.data.fill_(1.0) + elif isinstance(module, Qwen3_5MoeRMSNormGated): + module.weight.data.fill_(1.0) + elif isinstance(module, Qwen3_5MoeGatedDeltaNet): + module.dt_bias.data.fill_(1.0) + module.A_log.data.copy_(torch.empty_like(module.A_log).uniform_(0, 16).log_()) + elif isinstance(module, Qwen3_5MoeTopKRouter): + module.weight.data.normal_(mean=0.0, std=std) + + +@dataclass +class Qwen3_5MoeOutput(ModelOutput): + """Output of the Qwen3.5 MoE text model.""" + + last_hidden_state: Optional[torch.FloatTensor] = None + + +@dataclass +class Qwen3_5MoeCausalLMOutput(ModelOutput): + """Output of the Qwen3.5 MoE causal language model.""" + + logits: Optional[torch.FloatTensor] = None + last_hidden_state: Optional[torch.FloatTensor] = None + + +class Qwen3_5MoeTextModel(Qwen3_5MoePreTrainedModel): + """Qwen3.5 MoE text model (embed + decoder layers + final norm + lm_head). + + lm_head is included so that the exported GraphModule contains it directly, + allowing sharding and gather_logits_before_lm_head transforms to see it. + """ + + def __init__(self, config: Qwen3_5MoeTextConfig): + super().__init__(config) + pad_token_id = getattr(config, "pad_token_id", None) + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, padding_idx=pad_token_id + ) + self.layers = nn.ModuleList( + [ + Qwen3_5MoeDecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = Qwen3_5MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen3_5MoeTextRotaryEmbedding(config=config) + self.lm_head = None + + # Initialize weights and apply final processing + self.post_init() + + def set_lm_head(self, lm_head: nn.Module): + """Set the lm_head from the parent model.""" + self.lm_head = lm_head + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, new_embeddings): + self.embed_tokens = new_embeddings + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + rope_cos: Optional[torch.Tensor] = None, + rope_sin: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[Tuple, Qwen3_5MoeOutput]: + """Forward pass. + + There are three ways to provide position information (checked in order): + + 1. ``rope_cos`` + ``rope_sin``: separate tensors, each ``(B, S, rotary_dim)``. + Export-friendly -- each is a proper graph input with its own dynamic shape. + 2. ``position_embeddings``: pre-computed ``(cos, sin)`` tuple. Convenient for + the multimodal wrapper calling at the plain-PyTorch level. + 3. ``position_ids`` (or ``None``): standard 2D ``(B, S)`` or 3D ``(3, B, S)`` + position IDs. The internal ``rotary_emb`` computes cos/sin. This is the + default text-only path. + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # Resolve position embeddings from one of the three input modes. + if rope_cos is not None and rope_sin is not None: + position_embeddings = (rope_cos, rope_sin) + elif position_embeddings is None: + if position_ids is None: + seq_len = inputs_embeds.shape[1] + position_ids = torch.arange(seq_len, device=inputs_embeds.device) + position_ids = position_ids.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) + elif position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + position_embeddings = self.rotary_emb(inputs_embeds, position_ids) + + hidden_states = inputs_embeds + for decoder_layer in self.layers: + hidden_states = decoder_layer(hidden_states, position_embeddings=position_embeddings) + + hidden_states = self.norm(hidden_states) + assert self.lm_head is not None, ( + "lm_head not set — call set_lm_head() from the parent model before forward()" + ) + logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float() + return Qwen3_5MoeCausalLMOutput(logits=logits, last_hidden_state=hidden_states) + + +class Qwen3_5MoeForCausalLM(Qwen3_5MoePreTrainedModel, GenerationMixin): + """Qwen3.5 MoE causal language model (text model + lm_head).""" + + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: Qwen3_5MoeTextConfig, **kwargs): + super().__init__(config) + self.model = Qwen3_5MoeTextModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.model.set_lm_head(self.lm_head) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + return self.model.set_input_embeddings(new_embeddings) + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + self.model.set_lm_head(new_embeddings) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + rope_cos: Optional[torch.Tensor] = None, + rope_sin: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[Tuple, Qwen3_5MoeCausalLMOutput]: + outputs = self.model( + input_ids, + inputs_embeds=inputs_embeds, + position_ids=position_ids, + position_embeddings=position_embeddings, + rope_cos=rope_cos, + rope_sin=rope_sin, + ) + logits = outputs.logits + return Qwen3_5MoeCausalLMOutput(logits=logits) + + +# ============================================================================= +# Vision Configuration +# ============================================================================= + + +class Qwen3_5MoeVisionConfig(PretrainedConfig): + """Config class for the Qwen3.5 MoE vision tower. + + Mirrors the upstream ``Qwen3_5MoeVisionConfig``. + """ + + model_type = "qwen3_5_moe_vision" + + def __init__( + self, + depth: int = 27, + hidden_size: int = 1152, + hidden_act: str = "gelu_pytorch_tanh", + intermediate_size: int = 4304, + num_heads: int = 16, + in_channels: int = 3, + patch_size: int = 16, + spatial_merge_size: int = 2, + temporal_patch_size: int = 2, + out_hidden_size: int = 3584, + num_position_embeddings: int = 2304, + initializer_range: float = 0.02, + **kwargs, + ): + super().__init__(**kwargs) + self.depth = depth + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.num_heads = num_heads + self.in_channels = in_channels + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + self.out_hidden_size = out_hidden_size + self.num_position_embeddings = num_position_embeddings + self.initializer_range = initializer_range + + +class Qwen3_5MoeConfig(PretrainedConfig): + """Composite config containing both text and vision configs. + + Mirrors the upstream ``Qwen3_5MoeConfig``. + """ + + model_type = "qwen3_5_moe" + + def __init__( + self, + text_config=None, + vision_config=None, + image_token_id: int = 248056, + video_token_id: int = 248057, + vision_start_token_id: int = 248053, + vision_end_token_id: int = 248054, + tie_word_embeddings: bool = False, + **kwargs, + ): + if isinstance(vision_config, dict): + self.vision_config = Qwen3_5MoeVisionConfig(**vision_config) + elif vision_config is None: + self.vision_config = Qwen3_5MoeVisionConfig() + else: + self.vision_config = vision_config + + if isinstance(text_config, dict): + self.text_config = Qwen3_5MoeTextConfig(**text_config) + elif text_config is None: + self.text_config = Qwen3_5MoeTextConfig() + else: + self.text_config = text_config + + self.image_token_id = image_token_id + self.video_token_id = video_token_id + self.vision_start_token_id = vision_start_token_id + self.vision_end_token_id = vision_end_token_id + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + + +# ============================================================================= +# Vision Tower Components (plain PyTorch -- NOT exported) +# ============================================================================= + + +class Qwen3_5MoeVisionRotaryEmbedding(nn.Module): + """Simple rotary embedding for the vision tower (not mRoPE).""" + + def __init__(self, dim: int, theta: float = 10000.0): + super().__init__() + self.dim = dim + self.theta = theta + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + return freqs + + +def apply_rotary_pos_emb_vision( + q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + """Apply RoPE to vision Q/K tensors. Layout: (seq, heads, dim).""" + orig_q_dtype, orig_k_dtype = q.dtype, k.dtype + q, k = q.float(), k.float() + cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float() + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed.to(orig_q_dtype), k_embed.to(orig_k_dtype) + + +class Qwen3_5MoeVisionPatchEmbed(nn.Module): + """3D convolution patch embedding for images/videos.""" + + def __init__(self, config: Qwen3_5MoeVisionConfig): + super().__init__() + self.patch_size = config.patch_size + self.temporal_patch_size = config.temporal_patch_size + self.in_channels = config.in_channels + self.embed_dim = config.hidden_size + + kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size] + self.proj = nn.Conv3d( + self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + target_dtype = self.proj.weight.dtype + hidden_states = hidden_states.view( + -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size + ) + hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) + return hidden_states + + +class Qwen3_5MoeVisionMLP(nn.Module): + """Feed-forward network for vision blocks.""" + + def __init__(self, config: Qwen3_5MoeVisionConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.linear_fc1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=True) + self.linear_fc2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=True) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state))) + + +class Qwen3_5MoeVisionAttention(nn.Module): + """Bidirectional attention for vision tokens with cu_seqlens support. + + Uses either: + - Eager path: splits by sequence lengths, runs attention per chunk. + - (Future) Flash Attention: single call with cu_seqlens. + + Always non-causal (is_causal=False). + """ + + def __init__(self, config: Qwen3_5MoeVisionConfig): + super().__init__() + self.dim = config.hidden_size + self.num_heads = config.num_heads + self.head_dim = self.dim // self.num_heads + self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True) + self.proj = nn.Linear(self.dim, self.dim) + self.scaling = self.head_dim**-0.5 + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + query_states, key_states, value_states = ( + self.qkv(hidden_states) + .reshape(seq_length, 3, self.num_heads, -1) + .permute(1, 0, 2, 3) + .unbind(0) + ) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin) + + # Layout: (1, num_heads, seq_len, head_dim) per chunk + query_states = query_states.transpose(0, 1).unsqueeze(0) # (1, H, S, D) + key_states = key_states.transpose(0, 1).unsqueeze(0) + value_states = value_states.transpose(0, 1).unsqueeze(0) + + lengths = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + q_splits = torch.split(query_states, lengths, dim=2) + k_splits = torch.split(key_states, lengths, dim=2) + v_splits = torch.split(value_states, lengths, dim=2) + + attn_outputs = [] + for q, k, v in zip(q_splits, k_splits, v_splits): + attn_outputs.append(F.scaled_dot_product_attention(q, k, v, is_causal=False)) + + attn_output = torch.cat(attn_outputs, dim=2) # (1, H, total_S, D) + attn_output = attn_output.squeeze(0).transpose(0, 1) # (S, H, D) + attn_output = attn_output.reshape(seq_length, -1).contiguous() + attn_output = self.proj(attn_output) + return attn_output + + +class Qwen3_5MoeVisionBlock(nn.Module): + """Vision transformer block: LayerNorm -> Attention -> Residual -> LayerNorm -> MLP -> Residual.""" + + def __init__(self, config: Qwen3_5MoeVisionConfig): + super().__init__() + self.norm1 = nn.LayerNorm(config.hidden_size, eps=1e-6) + self.norm2 = nn.LayerNorm(config.hidden_size, eps=1e-6) + self.attn = Qwen3_5MoeVisionAttention(config=config) + self.mlp = Qwen3_5MoeVisionMLP(config=config) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + ) -> torch.Tensor: + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + position_embeddings=position_embeddings, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +class Qwen3_5MoeVisionPatchMerger(nn.Module): + """Merges spatial_merge_size^2 patches into one token and projects to LLM hidden size.""" + + def __init__(self, config: Qwen3_5MoeVisionConfig): + super().__init__() + self.hidden_size = config.hidden_size * (config.spatial_merge_size**2) + self.norm = nn.LayerNorm(config.hidden_size, eps=1e-6) + self.linear_fc1 = nn.Linear(self.hidden_size, self.hidden_size) + self.act_fn = nn.GELU() + self.linear_fc2 = nn.Linear(self.hidden_size, config.out_hidden_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.norm(x).view(-1, self.hidden_size) + x = self.linear_fc2(self.act_fn(self.linear_fc1(x))) + return x + + +class Qwen3_5MoeVisionModel(nn.Module): + """Complete vision tower: PatchEmbed + PositionEmbed + VisionBlocks + PatchMerger. + + This module is NOT exported -- it runs in plain PyTorch. + """ + + def __init__(self, config: Qwen3_5MoeVisionConfig): + super().__init__() + self.config = config + self.spatial_merge_size = config.spatial_merge_size + self.patch_size = config.patch_size + self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size + self.dtype = None # set after loading weights + + self.patch_embed = Qwen3_5MoeVisionPatchEmbed(config=config) + self.pos_embed = nn.Embedding(config.num_position_embeddings, config.hidden_size) + self.num_grid_per_side = int(config.num_position_embeddings**0.5) + + head_dim = config.hidden_size // config.num_heads + self.rotary_pos_emb = Qwen3_5MoeVisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList([Qwen3_5MoeVisionBlock(config) for _ in range(config.depth)]) + self.merger = Qwen3_5MoeVisionPatchMerger(config=config) + + def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + """Compute rotary position embeddings for vision tokens.""" + merge_size = self.spatial_merge_size + max_hw = int(grid_thw[:, 1:].max().item()) + freq_table = self.rotary_pos_emb(max_hw) # (max_hw, dim//2) + device = freq_table.device + + total_tokens = int(torch.prod(grid_thw, dim=1).sum().item()) + pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device) + + offset = 0 + for num_frames, height, width in grid_thw: + merged_h, merged_w = height // merge_size, width // merge_size + + block_rows = torch.arange(merged_h, device=device) + block_cols = torch.arange(merged_w, device=device) + intra_row = torch.arange(merge_size, device=device) + intra_col = torch.arange(merge_size, device=device) + + row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None] + col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :] + + row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) + col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) + + coords = torch.stack((row_idx, col_idx), dim=-1) + if num_frames > 1: + coords = coords.repeat(num_frames, 1) + + num_tokens = coords.shape[0] + pos_ids[offset : offset + num_tokens] = coords + offset += num_tokens + + embeddings = freq_table[pos_ids] + embeddings = embeddings.flatten(1) + return embeddings + + def fast_pos_embed_interpolate(self, grid_thw: torch.Tensor) -> torch.Tensor: + """Bilinear interpolation of learned positional embeddings for variable image sizes.""" + grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2] + device = self.pos_embed.weight.device + + idx_list: List[List[int]] = [[] for _ in range(4)] + weight_list: List[List[float]] = [[] for _ in range(4)] + + for t, h, w in zip(grid_ts, grid_hs, grid_ws): + h_idxs = torch.linspace(0, self.num_grid_per_side - 1, int(h.item())) + w_idxs = torch.linspace(0, self.num_grid_per_side - 1, int(w.item())) + + h_idxs_floor = h_idxs.int() + w_idxs_floor = w_idxs.int() + h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + + dh = h_idxs - h_idxs_floor + dw = w_idxs - w_idxs_floor + + base_h = h_idxs_floor * self.num_grid_per_side + base_h_ceil = h_idxs_ceil * self.num_grid_per_side + + indices = [ + (base_h[None].T + w_idxs_floor[None]).flatten(), + (base_h[None].T + w_idxs_ceil[None]).flatten(), + (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), + (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), + ] + weights = [ + ((1 - dh)[None].T * (1 - dw)[None]).flatten(), + ((1 - dh)[None].T * dw[None]).flatten(), + (dh[None].T * (1 - dw)[None]).flatten(), + (dh[None].T * dw[None]).flatten(), + ] + for i in range(4): + idx_list[i].extend(indices[i].tolist()) + weight_list[i].extend(weights[i].tolist()) + + idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device) + weight_tensor = torch.tensor(weight_list, dtype=self.pos_embed.weight.dtype, device=device) + pos_embeds = self.pos_embed(idx_tensor).to(device) * weight_tensor[:, :, None] + patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] + + patch_pos_embeds = patch_pos_embeds.split( + [int(h.item()) * int(w.item()) for h, w in zip(grid_hs, grid_ws)] + ) + + merge_size = self.config.spatial_merge_size + patch_pos_embeds_permute = [] + for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): + t, h, w = int(t.item()), int(h.item()), int(w.item()) + pos_embed = pos_embed.repeat(t, 1) + pos_embed = ( + pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1) + .permute(0, 1, 3, 2, 4, 5) + .flatten(0, 4) + ) + patch_pos_embeds_permute.append(pos_embed) + + return torch.cat(patch_pos_embeds_permute) + + def forward( + self, hidden_states: torch.Tensor, grid_thw: torch.Tensor + ) -> BaseModelOutputWithPooling: + """Run the vision tower. + + Args: + hidden_states: Raw pixel values reshaped for patch embedding. + grid_thw: Shape ``(num_images_or_videos, 3)`` -- temporal, height, width. + + Returns: + ``BaseModelOutputWithPooling`` with ``pooler_output`` containing merged features. + """ + hidden_states = self.patch_embed(hidden_states) + + pos_embeds = self.fast_pos_embed_interpolate(grid_thw) + hidden_states = hidden_states + pos_embeds + + rotary_pos_emb = self.rot_pos_emb(grid_thw) + + seq_len, _ = hidden_states.size() + hidden_states = hidden_states.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + + cu_seqlens = torch.repeat_interleave( + grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] + ).cumsum(dim=0, dtype=torch.int32) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + for blk in self.blocks: + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens, + position_embeddings=position_embeddings, + ) + + merged_hidden_states = self.merger(hidden_states) + + return BaseModelOutputWithPooling( + last_hidden_state=hidden_states, + pooler_output=merged_hidden_states, + ) + + +# ============================================================================= +# Multimodal Wrapper (plain PyTorch -- NOT exported) +# ============================================================================= + + +@dataclass +class Qwen3_5MoeConditionalOutput(ModelOutput): + """Output of the Qwen3.5 MoE conditional generation model.""" + + logits: Optional[torch.FloatTensor] = None + + +def compute_mrope_positions( + input_ids: torch.LongTensor, + image_grid_thw: Optional[torch.LongTensor], + video_grid_thw: Optional[torch.LongTensor], + image_token_id: int, + video_token_id: int, + vision_start_token_id: int, + spatial_merge_size: int, + attention_mask: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute 3D mRoPE position IDs for multimodal sequences. + + Standalone function usable by both the model forward and the input processor. + For each sample in the batch, scans for vision placeholder tokens and assigns + spatial (T, H, W) positions to vision tokens while text tokens get sequential + positions. + + Args: + input_ids: Token IDs, shape ``(B, S)``. + image_grid_thw: Grid dimensions ``(N_images, 3)`` with ``(T, H, W)`` per image. + video_grid_thw: Grid dimensions ``(N_videos, 3)`` with ``(T, H, W)`` per video. + image_token_id: Token ID for image placeholders. + video_token_id: Token ID for video placeholders. + vision_start_token_id: Token ID marking the start of a vision segment. + spatial_merge_size: Factor by which the vision patch merger reduces spatial dims. + attention_mask: Optional mask, shape ``(B, S)``. + + Returns: + ``(position_ids, mrope_position_deltas)`` where ``position_ids`` has + shape ``(3, B, S)`` and ``mrope_position_deltas`` has shape ``(B, 1)``. + """ + if video_grid_thw is not None: + video_grid_thw = torch.repeat_interleave(video_grid_thw, video_grid_thw[:, 0], dim=0) + video_grid_thw[:, 0] = 1 + + mrope_position_deltas = [] + + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_index, video_index = 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + + for i, ids in enumerate(total_input_ids): + ids = ids[attention_mask[i] == 1] + vision_start_indices = torch.argwhere(ids == vision_start_token_id).squeeze(1) + vision_tokens = ids[vision_start_indices + 1] + image_nums = int((vision_tokens == image_token_id).sum().item()) + video_nums = int((vision_tokens == video_token_id).sum().item()) + input_tokens = ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + + for _ in range(image_nums + video_nums): + ed_image = ( + input_tokens.index(image_token_id, st) + if image_token_id in input_tokens and remain_images > 0 + else len(input_tokens) + 1 + ) + ed_video = ( + input_tokens.index(video_token_id, st) + if video_token_id in input_tokens and remain_videos > 0 + else len(input_tokens) + 1 + ) + + if ed_image < ed_video: + t, h, w = image_grid_thw[image_index].tolist() + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = video_grid_thw[video_index].tolist() + video_index += 1 + remain_videos -= 1 + ed = ed_video + + llm_grid_t = int(t) + llm_grid_h = int(h) // spatial_merge_size + llm_grid_w = int(w) // spatial_merge_size + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + t_index = ( + torch.arange(llm_grid_t) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + .flatten() + ) + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + text_len + st_idx + ) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to( + device=position_ids.device, dtype=position_ids.dtype + ) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + + mrope_position_deltas = torch.tensor( + mrope_position_deltas, device=input_ids.device + ).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], device=input_ids.device, dtype=input_ids.dtype + ) + + return position_ids, mrope_position_deltas + + +def _normalize_video_grid_for_mrope( + video_grid_thw: Optional[torch.Tensor], +) -> Optional[torch.Tensor]: + if video_grid_thw is None: + return None + video_grid_thw = torch.repeat_interleave(video_grid_thw, video_grid_thw[:, 0], dim=0) + video_grid_thw = video_grid_thw.clone() + video_grid_thw[:, 0] = 1 + return video_grid_thw + + +def _extract_mm_item_types_from_input_ids( + input_ids: torch.Tensor, + image_token_id: int, + video_token_id: int, + vision_start_token_id: int, +) -> List[int]: + """Return multimodal item types in prompt order for a single request.""" + flat_ids = input_ids.reshape(-1).tolist() + item_types: List[int] = [] + for idx in range(len(flat_ids) - 1): + if flat_ids[idx] != vision_start_token_id: + continue + next_token = flat_ids[idx + 1] + if next_token == image_token_id: + item_types.append(0) + elif next_token == video_token_id: + item_types.append(1) + return item_types + + +def _is_qwen_video_frame(value: Any) -> bool: + return isinstance(value, (Image.Image, torch.Tensor)) + + +def _normalize_qwen_image_items(images: Any) -> list[Any]: + if images is None: + return [] + if isinstance(images, list): + return images + return [images] + + +def _normalize_qwen_video_items(videos: Any) -> list[Any]: + if videos is None: + return [] + if isinstance(videos, VideoData): + return [videos] + if isinstance(videos, list): + if not videos: + return [] + if all(_is_qwen_video_frame(frame) for frame in videos): + return [videos] + normalized_items = [] + for item in videos: + if isinstance(item, VideoData): + normalized_items.append(item) + elif ( + isinstance(item, list) + and item + and all(_is_qwen_video_frame(frame) for frame in item) + ): + normalized_items.append(item) + else: + normalized_items.append(item) + return normalized_items + return [videos] + + +def _get_qwen_video_num_spans(video: Any) -> int: + if isinstance(video, VideoData): + return len(video.frames) + if isinstance(video, list): + if not video: + return 0 + if all(_is_qwen_video_frame(frame) for frame in video): + return len(video) + shape = getattr(video, "shape", None) + if shape is not None and len(shape) >= 4: + return int(shape[0]) + return 1 + + +def _compute_mm_item_special_counts( + mm_token_lengths: torch.Tensor, + mm_special_offsets_cu_seqlen: torch.Tensor, + mm_special_offsets: torch.Tensor, + req_idx: int, +) -> List[int]: + item_lengths = mm_token_lengths.tolist() + special_start = int(mm_special_offsets_cu_seqlen[req_idx].item()) + special_end = int(mm_special_offsets_cu_seqlen[req_idx + 1].item()) + special_offsets = mm_special_offsets[special_start:special_end].tolist() + counts: List[int] = [] + mm_offset = 0 + for item_len in item_lengths: + item_end = mm_offset + int(item_len) + num_special = sum(1 for off in special_offsets if mm_offset <= int(off) < item_end) + counts.append(num_special) + mm_offset = item_end + return counts + + +def _compute_request_mrope_delta( + mm_item_types: torch.Tensor, + mm_token_lengths: torch.Tensor, + special_counts: Sequence[int], + image_grid_thw: Optional[torch.Tensor], + video_grid_thw: Optional[torch.Tensor], + spatial_merge_size: int, +) -> int: + image_idx = 0 + video_idx = 0 + total_delta = 0 + for item_type, item_len, num_special in zip( + mm_item_types.tolist(), mm_token_lengths.tolist(), special_counts + ): + num_placeholders = int(item_len) - int(num_special) + if item_type == 0: + if image_grid_thw is None: + raise ValueError("Expected image_grid_thw for image multimodal item") + t, h, w = [int(v) for v in image_grid_thw[image_idx].tolist()] + image_idx += 1 + else: + if video_grid_thw is None: + raise ValueError("Expected video_grid_thw for video multimodal item") + t, h, w = [int(v) for v in video_grid_thw[video_idx].tolist()] + video_idx += 1 + llm_grid_t = int(t) + llm_grid_h = int(h) // spatial_merge_size + llm_grid_w = int(w) // spatial_merge_size + total_delta += max(llm_grid_t, llm_grid_h, llm_grid_w) - num_placeholders + return total_delta + + +@torch.library.custom_op("auto_deploy::qwen3_mrope_delta", mutates_args=()) +def qwen3_mrope_delta( + batch_info_host: torch.Tensor, + mm_item_cu_seqlen: torch.Tensor, + mm_item_types: torch.Tensor, + mm_token_lengths: torch.Tensor, + mm_special_offsets_cu_seqlen: torch.Tensor, + mm_special_offsets: torch.Tensor, + image_grid_thw: Optional[torch.Tensor], + video_grid_thw: Optional[torch.Tensor], + spatial_merge_size: int, +) -> torch.Tensor: + num_prefill, _, num_decode = BatchInfo(batch_info_host).get_absorbed_info() + num_seq = num_prefill + num_decode + device = mm_item_cu_seqlen.device + out = torch.zeros((num_seq, 1), dtype=torch.int32, device=device) + video_grid_norm = _normalize_video_grid_for_mrope(video_grid_thw) + img_idx = 0 + vid_idx = 0 + for req_idx in range(num_prefill): + item_start = int(mm_item_cu_seqlen[req_idx].item()) + item_end = int(mm_item_cu_seqlen[req_idx + 1].item()) + req_item_types = mm_item_types[item_start:item_end] + req_item_lengths = mm_token_lengths[item_start:item_end] + if req_item_lengths.numel() == 0: + continue + num_images = int((req_item_types == 0).sum().item()) + num_videos = int((req_item_types == 1).sum().item()) + req_image_grid = image_grid_thw[img_idx : img_idx + num_images] if num_images > 0 else None + req_video_grid = video_grid_norm[vid_idx : vid_idx + num_videos] if num_videos > 0 else None + special_counts = _compute_mm_item_special_counts( + req_item_lengths, + mm_special_offsets_cu_seqlen, + mm_special_offsets, + req_idx, + ) + out[req_idx, 0] = _compute_request_mrope_delta( + req_item_types, + req_item_lengths, + special_counts, + req_image_grid, + req_video_grid, + spatial_merge_size, + ) + img_idx += num_images + vid_idx += num_videos + return out + + +@qwen3_mrope_delta.register_fake +def qwen3_mrope_delta_fake( + batch_info_host: torch.Tensor, + mm_item_cu_seqlen: torch.Tensor, + mm_item_types: torch.Tensor, + mm_token_lengths: torch.Tensor, + mm_special_offsets_cu_seqlen: torch.Tensor, + mm_special_offsets: torch.Tensor, + image_grid_thw: Optional[torch.Tensor], + video_grid_thw: Optional[torch.Tensor], + spatial_merge_size: int, +) -> torch.Tensor: + num_prefill, _, num_decode = BatchInfo(batch_info_host).get_absorbed_info() + num_seq = num_prefill + num_decode + return torch.zeros((num_seq, 1), dtype=torch.int32, device=batch_info_host.device) + + +@torch.library.custom_op( + "auto_deploy::qwen3_mrope_delta_with_cache", mutates_args=("mrope_delta_cache",) +) +def qwen3_mrope_delta_with_cache( + batch_info_host: torch.Tensor, + slot_idx: torch.Tensor, + mm_item_cu_seqlen: Optional[torch.Tensor], + mm_item_types: Optional[torch.Tensor], + mm_token_lengths: Optional[torch.Tensor], + mm_special_offsets_cu_seqlen: Optional[torch.Tensor], + mm_special_offsets: Optional[torch.Tensor], + image_grid_thw: Optional[torch.Tensor], + video_grid_thw: Optional[torch.Tensor], + mrope_delta_cache: torch.Tensor, + spatial_merge_size: int, +) -> torch.Tensor: + num_prefill, _, num_decode = BatchInfo(batch_info_host).get_absorbed_info() + num_seq = num_prefill + num_decode + out = torch.zeros((num_seq, 1), dtype=torch.int32, device=mrope_delta_cache.device) + video_grid_norm = _normalize_video_grid_for_mrope(video_grid_thw) + if num_prefill > 0: + has_mm_metadata = all( + arg is not None + for arg in ( + mm_item_cu_seqlen, + mm_item_types, + mm_token_lengths, + mm_special_offsets_cu_seqlen, + mm_special_offsets, + ) + ) + if has_mm_metadata: + img_idx = 0 + vid_idx = 0 + for req_idx in range(num_prefill): + item_start = int(mm_item_cu_seqlen[req_idx].item()) + item_end = int(mm_item_cu_seqlen[req_idx + 1].item()) + req_item_types = mm_item_types[item_start:item_end] + req_item_lengths = mm_token_lengths[item_start:item_end] + if req_item_lengths.numel() == 0: + continue + num_images = int((req_item_types == 0).sum().item()) + num_videos = int((req_item_types == 1).sum().item()) + req_image_grid = ( + image_grid_thw[img_idx : img_idx + num_images] if num_images > 0 else None + ) + req_video_grid = ( + video_grid_norm[vid_idx : vid_idx + num_videos] if num_videos > 0 else None + ) + special_counts = _compute_mm_item_special_counts( + req_item_lengths, + mm_special_offsets_cu_seqlen, + mm_special_offsets, + req_idx, + ) + out[req_idx, 0] = _compute_request_mrope_delta( + req_item_types, + req_item_lengths, + special_counts, + req_image_grid, + req_video_grid, + spatial_merge_size, + ) + img_idx += num_images + vid_idx += num_videos + mrope_delta_cache.index_copy_( + 0, + slot_idx[:num_prefill].to(torch.long), + out[:num_prefill].to(mrope_delta_cache.dtype), + ) + if num_decode > 0: + out[num_prefill:num_seq] = mrope_delta_cache[ + slot_idx[num_prefill:num_seq].to(torch.long) + ].to(torch.int32) + return out + + +@qwen3_mrope_delta_with_cache.register_fake +def qwen3_mrope_delta_with_cache_fake( + batch_info_host: torch.Tensor, + slot_idx: torch.Tensor, + mm_item_cu_seqlen: Optional[torch.Tensor], + mm_item_types: Optional[torch.Tensor], + mm_token_lengths: Optional[torch.Tensor], + mm_special_offsets_cu_seqlen: Optional[torch.Tensor], + mm_special_offsets: Optional[torch.Tensor], + image_grid_thw: Optional[torch.Tensor], + video_grid_thw: Optional[torch.Tensor], + mrope_delta_cache: torch.Tensor, + spatial_merge_size: int, +) -> torch.Tensor: + num_prefill, _, num_decode = BatchInfo(batch_info_host).get_absorbed_info() + num_seq = num_prefill + num_decode + return torch.zeros((num_seq, 1), dtype=torch.int32, device=slot_idx.device) + + +class Qwen3_5MoeModel(nn.Module): + """Multimodal wrapper: vision tower + embedding merge + mRoPE + language model. + + This module is NOT exported. It orchestrates the vision pipeline in plain + PyTorch and calls the (potentially exported) language model with 3D + ``position_ids (3, B, S)`` so that the text model's internal ``rotary_emb`` + computes the correct mRoPE cos/sin. + """ + + def __init__(self, config: Qwen3_5MoeConfig): + super().__init__() + self.config = config + self.visual = Qwen3_5MoeVisionModel(config.vision_config) + self.language_model = Qwen3_5MoeTextModel(config.text_config) + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def get_rope_index( + self, + input_ids: torch.LongTensor, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute 3D mRoPE position IDs. Delegates to ``compute_mrope_positions``.""" + return compute_mrope_positions( + input_ids=input_ids, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + image_token_id=self.config.image_token_id, + video_token_id=self.config.video_token_id, + vision_start_token_id=self.config.vision_start_token_id, + spatial_merge_size=self.config.vision_config.spatial_merge_size, + attention_mask=attention_mask, + ) + + def get_image_features( + self, pixel_values: torch.Tensor, image_grid_thw: torch.LongTensor + ) -> List[torch.Tensor]: + """Run vision tower on images and split by grid dimensions.""" + vision_output: BaseModelOutputWithPooling = self.visual( + pixel_values, grid_thw=image_grid_thw + ) + image_embeds = vision_output.pooler_output + split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() + return list(torch.split(image_embeds, split_sizes)) + + def get_video_features( + self, pixel_values_videos: torch.Tensor, video_grid_thw: torch.LongTensor + ) -> List[torch.Tensor]: + """Run vision tower on videos (same as images).""" + return self.get_image_features(pixel_values_videos, video_grid_thw) + + def get_placeholder_mask( + self, + input_ids: torch.LongTensor, + inputs_embeds: torch.FloatTensor, + image_features: Optional[torch.Tensor] = None, + video_features: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Find image/video placeholder token positions in the embedding sequence.""" + special_image_mask = ( + (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) + ) + special_video_mask = ( + (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds) + ) + return special_image_mask, special_video_mask + + def _select_request_chunk_multimodal_embeds( + self, + req_input_pos: int, + req_seq_len: int, + req_mm_item_types: Sequence[int], + req_mm_positions: Sequence[int], + req_mm_lengths: Sequence[int], + req_special_offsets: Sequence[int], + image_embeds_list: Optional[Sequence[torch.Tensor]], + video_embeds_list: Optional[Sequence[torch.Tensor]], + ) -> torch.Tensor: + chunk_end = req_input_pos + req_seq_len + mm_cumulative_offset = 0 + img_idx = 0 + vid_idx = 0 + chunks: list[torch.Tensor] = [] + hidden_size = self.config.text_config.hidden_size + special_offsets_set = set(int(x) for x in req_special_offsets) + + for item_type, mm_start, mm_len in zip(req_mm_item_types, req_mm_positions, req_mm_lengths): + item_mm_offset = mm_cumulative_offset + item_mm_len = int(mm_len) + item_abs_start = int(mm_start) + item_abs_end = item_abs_start + item_mm_len + overlap_start = max(req_input_pos, item_abs_start) + overlap_end = min(chunk_end, item_abs_end) + + if item_type == 0: + if image_embeds_list is None: + raise ValueError("Missing image embeddings for image multimodal item") + item_embeds = image_embeds_list[img_idx] + img_idx += 1 + elif item_type == 1: + if video_embeds_list is None: + raise ValueError("Missing video embeddings for video multimodal item") + item_embeds = video_embeds_list[vid_idx] + vid_idx += 1 + else: + raise ValueError(f"Unsupported multimodal item type: {item_type}") + + local_to_feature_idx: list[Optional[int]] = [] + feature_idx = 0 + for rel in range(item_mm_len): + if item_mm_offset + rel in special_offsets_set: + local_to_feature_idx.append(None) + else: + local_to_feature_idx.append(feature_idx) + feature_idx += 1 + + if feature_idx != item_embeds.shape[0]: + raise ValueError( + "Multimodal embedding length mismatch for Qwen3.5 item: " + f"type={item_type}, expected={feature_idx}, actual={item_embeds.shape[0]}, " + f"mm_len={item_mm_len}, item_start={item_abs_start}, " + f"special_offsets={sorted(special_offsets_set)}" + ) + + if overlap_start < overlap_end: + selected_indices = [ + local_to_feature_idx[rel] + for rel in range(overlap_start - item_abs_start, overlap_end - item_abs_start) + if local_to_feature_idx[rel] is not None + ] + if selected_indices: + chunks.append(item_embeds[selected_indices]) + + mm_cumulative_offset += item_mm_len + + if chunks: + return torch.cat(chunks, dim=0) + + device = None + dtype = None + if image_embeds_list: + device = image_embeds_list[0].device + dtype = image_embeds_list[0].dtype + elif video_embeds_list: + device = video_embeds_list[0].device + dtype = video_embeds_list[0].dtype + if device is None or dtype is None: + raise ValueError( + "Cannot build empty multimodal chunk without image or video embeddings" + ) + return torch.empty(0, hidden_size, device=device, dtype=dtype) + + def _expand_video_embeds_by_span( + self, + video_embeds_list: Optional[Sequence[torch.Tensor]], + video_grid_thw: Optional[torch.Tensor], + ) -> Optional[List[torch.Tensor]]: + if video_embeds_list is None or video_grid_thw is None: + return None + + merge = self.config.vision_config.spatial_merge_size + video_span_embeds: List[torch.Tensor] = [] + for video_embeds, grid in zip(video_embeds_list, video_grid_thw): + t, h, w = [int(v) for v in grid.tolist()] + frame_tokens = (int(h) // merge) * (int(w) // merge) + expected = int(t) * frame_tokens + if video_embeds.shape[0] != expected: + raise ValueError( + "Video embedding length mismatch in Qwen3.5 VLM forward: " + f"expected={expected}, actual={video_embeds.shape[0]}, grid={tuple(grid.tolist())}" + ) + video_span_embeds.extend(list(torch.split(video_embeds, frame_tokens, dim=0))) + return video_span_embeds + + def _build_chunked_multimodal_embeds( + self, + input_ids: torch.LongTensor, + batch_info: torch.Tensor, + cu_seqlen: torch.Tensor, + input_pos: torch.Tensor, + seq_len: torch.Tensor, + image_embeds_list: Optional[Sequence[torch.Tensor]], + video_span_embeds_list: Optional[Sequence[torch.Tensor]], + mm_item_cu_seqlen: torch.Tensor, + mm_item_types: torch.Tensor, + mm_token_positions: torch.Tensor, + mm_token_lengths: torch.Tensor, + mm_special_offsets_cu_seqlen: Optional[torch.Tensor], + mm_special_offsets: Optional[torch.Tensor], + ) -> torch.Tensor: + num_prefill_seqs = int(batch_info[0].item()) + img_idx = 0 + vid_idx = 0 + chunks: list[torch.Tensor] = [] + + for i in range(num_prefill_seqs): + item_start = int(mm_item_cu_seqlen[i].item()) + item_end = int(mm_item_cu_seqlen[i + 1].item()) + req_mm_item_types = mm_item_types[item_start:item_end].tolist() + req_mm_positions = mm_token_positions[item_start:item_end].tolist() + req_mm_lengths = mm_token_lengths[item_start:item_end].tolist() + + req_special_offsets: list[int] = [] + if mm_special_offsets_cu_seqlen is not None and mm_special_offsets is not None: + special_start = int(mm_special_offsets_cu_seqlen[i].item()) + special_end = int(mm_special_offsets_cu_seqlen[i + 1].item()) + req_special_offsets = mm_special_offsets[special_start:special_end].tolist() + + num_images = sum(item_type == 0 for item_type in req_mm_item_types) + num_videos = sum(item_type == 1 for item_type in req_mm_item_types) + req_image_embeds = ( + image_embeds_list[img_idx : img_idx + num_images] + if image_embeds_list is not None + else None + ) + req_video_embeds = ( + video_span_embeds_list[vid_idx : vid_idx + num_videos] + if video_span_embeds_list is not None + else None + ) + img_idx += num_images + vid_idx += num_videos + + req_chunk_embeds = self._select_request_chunk_multimodal_embeds( + req_input_pos=int(input_pos[i].item()), + req_seq_len=int(seq_len[i].item()), + req_mm_item_types=req_mm_item_types, + req_mm_positions=req_mm_positions, + req_mm_lengths=req_mm_lengths, + req_special_offsets=req_special_offsets, + image_embeds_list=req_image_embeds, + video_embeds_list=req_video_embeds, + ) + chunks.append(req_chunk_embeds) + + if chunks: + return torch.cat(chunks, dim=0) + + hidden_size = self.config.text_config.hidden_size + return torch.empty( + 0, hidden_size, device=input_ids.device, dtype=self.get_input_embeddings().weight.dtype + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + mrope_position_deltas: Optional[torch.Tensor] = None, + batch_info: Optional[torch.Tensor] = None, + **kwargs, + ) -> Qwen3_5MoeOutput: + """Multimodal forward: vision encoding + embedding merge + mRoPE + text model. + + 3D mRoPE positions are computed per request at forward time using + ``cu_seqlen`` to identify request boundaries and ``image_grid_thw`` + to derive spatial positions for multimodal requests. + + Position assembly cases: + + 1. Images present + ``batch_info`` (mixed or prefill-only with images): + iterate prefill requests via ``cu_seqlen``, call + ``compute_mrope_positions`` for multimodal requests and expand 2D + positions to 3D for text-only requests. Decode tokens get + delta-adjusted 3D expansion. + 2. Otherwise (decode-only or text-only prefill without images): expand + ``position_ids + delta`` to 3D where delta defaults to 0. + """ + inputs_embeds = self.get_input_embeddings()(input_ids) + + has_images = pixel_values is not None and image_grid_thw is not None + has_videos = pixel_values_videos is not None and video_grid_thw is not None + + image_embeds_list = None + if has_images: + image_embeds_list = [ + embeds.to(inputs_embeds.device, inputs_embeds.dtype) + for embeds in self.get_image_features(pixel_values, image_grid_thw) + ] + + video_embeds_list = None + if has_videos: + video_embeds_list = [ + embeds.to(inputs_embeds.device, inputs_embeds.dtype) + for embeds in self.get_video_features(pixel_values_videos, video_grid_thw) + ] + video_span_embeds_list = self._expand_video_embeds_by_span( + video_embeds_list, video_grid_thw + ) + + delta = mrope_position_deltas if mrope_position_deltas is not None else 0 + + vision_grid = image_grid_thw if has_images else video_grid_thw if has_videos else None + if batch_info is None: + batch_info = kwargs.get("batch_info_host") + batch_info_host = kwargs.get("batch_info_host", batch_info) + cu_seqlen = kwargs.get("cu_seqlen") + if cu_seqlen is None: + cu_seqlen = kwargs.get("cu_seqlen_host") + seq_len = kwargs.get("seq_len") + if seq_len is None and cu_seqlen is not None: + seq_len = cu_seqlen[1:] - cu_seqlen[:-1] + input_pos = kwargs.get("input_pos") + if input_pos is None: + seq_len_with_cache = kwargs.get("seq_len_with_cache") + if seq_len_with_cache is None: + seq_len_with_cache = kwargs.get("seq_len_with_cache_host") + if seq_len_with_cache is not None and seq_len is not None: + input_pos = seq_len_with_cache.to(seq_len.device) - seq_len + mm_item_cu_seqlen = kwargs.get("mm_item_cu_seqlen") + mm_token_positions = kwargs.get("mm_token_positions") + mm_token_lengths = kwargs.get("mm_token_lengths") + mm_item_types = kwargs.get("mm_item_types") + mm_special_offsets_cu_seqlen = kwargs.get("mm_special_offsets_cu_seqlen") + mm_special_offsets = kwargs.get("mm_special_offsets") + slot_idx = kwargs.get("slot_idx") + mrope_delta_cache = kwargs.get("mrope_delta_cache") + if mrope_delta_cache is None: + for key, value in kwargs.items(): + if key.endswith("_mrope_delta_cache"): + mrope_delta_cache = value + break + has_chunk_mm_layout = ( + mm_item_cu_seqlen is not None + and mm_item_types is not None + and mm_token_positions is not None + and mm_token_lengths is not None + and mm_item_cu_seqlen.numel() > 0 + and int(mm_item_cu_seqlen[-1].item()) > 0 + and mm_token_positions.numel() > 0 + and mm_token_lengths.numel() > 0 + ) + + if has_images or has_videos: + multimodal_mask = ( + ( + (input_ids == self.config.image_token_id) + | (input_ids == self.config.video_token_id) + ) + .unsqueeze(-1) + .expand_as(inputs_embeds) + ) + num_multimodal_tokens = int( + ( + (input_ids == self.config.image_token_id) + | (input_ids == self.config.video_token_id) + ) + .sum() + .item() + ) + if ( + batch_info is not None + and cu_seqlen is not None + and input_pos is not None + and seq_len is not None + and has_chunk_mm_layout + ): + multimodal_embeds = self._build_chunked_multimodal_embeds( + input_ids=input_ids, + batch_info=batch_info, + cu_seqlen=cu_seqlen, + input_pos=input_pos, + seq_len=seq_len, + image_embeds_list=image_embeds_list, + video_span_embeds_list=video_span_embeds_list, + mm_item_cu_seqlen=mm_item_cu_seqlen, + mm_item_types=mm_item_types, + mm_token_positions=mm_token_positions, + mm_token_lengths=mm_token_lengths, + mm_special_offsets_cu_seqlen=mm_special_offsets_cu_seqlen, + mm_special_offsets=mm_special_offsets, + ) + else: + if image_embeds_list is not None: + image_embeds = torch.cat(image_embeds_list, dim=0) + image_mask = ( + (input_ids == self.config.image_token_id) + .unsqueeze(-1) + .expand_as(inputs_embeds) + ) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + if video_embeds_list is not None: + video_embeds = torch.cat(video_embeds_list, dim=0) + video_mask = ( + (input_ids == self.config.video_token_id) + .unsqueeze(-1) + .expand_as(inputs_embeds) + ) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + multimodal_embeds = None + + if ( + multimodal_embeds is not None + and multimodal_embeds.shape[0] != num_multimodal_tokens + ): + raise ValueError( + "Multimodal embedding count mismatch in Qwen3.5 VLM forward: " + f"selected={multimodal_embeds.shape[0]}, placeholders={num_multimodal_tokens}, " + f"input_shape={tuple(input_ids.shape)}" + ) + if multimodal_embeds is not None: + inputs_embeds = inputs_embeds.masked_scatter(multimodal_mask, multimodal_embeds) + if mrope_delta_cache is not None and batch_info_host is not None and slot_idx is not None: + delta = torch.ops.auto_deploy.qwen3_mrope_delta_with_cache( + batch_info_host, + slot_idx, + mm_item_cu_seqlen, + mm_item_types, + mm_token_lengths, + mm_special_offsets_cu_seqlen, + mm_special_offsets, + image_grid_thw, + video_grid_thw, + mrope_delta_cache, + self.config.vision_config.spatial_merge_size, + ).to(input_ids.dtype) + + if ( + vision_grid is not None + and batch_info is not None + and cu_seqlen is not None + and input_pos is not None + and seq_len is not None + and has_chunk_mm_layout + ): + position_ids_3d = self._build_chunked_multimodal_positions( + input_ids, + position_ids, + delta, + batch_info, + cu_seqlen, + input_pos, + seq_len, + image_grid_thw if has_images else None, + video_grid_thw if has_videos else None, + mm_item_cu_seqlen, + mm_item_types, + mm_token_positions, + mm_token_lengths, + mm_special_offsets_cu_seqlen, + mm_special_offsets, + ) + elif vision_grid is not None and batch_info is not None and cu_seqlen is not None: + position_ids_3d = self._build_mixed_positions( + input_ids, + position_ids, + delta, + batch_info, + cu_seqlen, + image_grid_thw if has_images else None, + video_grid_thw if has_videos else None, + ) + elif vision_grid is not None: + position_ids_3d, _ = compute_mrope_positions( + input_ids=input_ids, + image_grid_thw=image_grid_thw if has_images else None, + video_grid_thw=video_grid_thw if has_videos else None, + image_token_id=self.config.image_token_id, + video_token_id=self.config.video_token_id, + vision_start_token_id=self.config.vision_start_token_id, + spatial_merge_size=self.config.vision_config.spatial_merge_size, + ) + else: + if position_ids is None: + raise ValueError("position_ids is required for text-only or decode-only forward") + is_flattened_cached_layout = position_ids.ndim == 1 or ( + position_ids.ndim == 2 and position_ids.shape[0] == 1 + ) + if is_flattened_cached_layout: + flat_position_ids = position_ids.reshape(-1) + token_delta = torch.zeros_like(flat_position_ids) + if ( + torch.is_tensor(delta) + and cu_seqlen is not None + and delta.ndim == 2 + and delta.shape[0] == cu_seqlen.numel() - 1 + ): + seq_lens = (cu_seqlen[1:] - cu_seqlen[:-1]).to(torch.long) + token_delta = torch.repeat_interleave( + delta.squeeze(-1).to(flat_position_ids.device, flat_position_ids.dtype), + seq_lens.to(flat_position_ids.device), + ) + position_ids_3d = (flat_position_ids + token_delta).view(1, 1, -1).expand(3, 1, -1) + else: + position_ids_3d = (position_ids + delta)[None].expand(3, -1, -1) + + for key in ( + "input_pos", + "mm_chunk_flat_start", + "mm_chunk_count", + "mm_item_cu_seqlen", + "mm_item_types", + "mm_token_positions", + "mm_token_lengths", + "mm_special_offsets_cu_seqlen", + "mm_special_offsets", + "mrope_delta_cache", + ): + kwargs.pop(key, None) + for key in list(kwargs.keys()): + if key.endswith("_mrope_delta_cache"): + kwargs.pop(key, None) + + return self.language_model( + inputs_embeds=inputs_embeds, + position_ids=position_ids_3d, + **kwargs, + ) + + def _build_chunked_multimodal_positions( + self, + input_ids: torch.LongTensor, + position_ids: Optional[torch.LongTensor], + delta, + batch_info: torch.Tensor, + cu_seqlen: torch.Tensor, + input_pos: torch.Tensor, + seq_len: torch.Tensor, + image_grid_thw: Optional[torch.LongTensor], + video_grid_thw: Optional[torch.LongTensor], + mm_item_cu_seqlen: torch.Tensor, + mm_item_types: torch.Tensor, + mm_token_positions: torch.Tensor, + mm_token_lengths: torch.Tensor, + mm_special_offsets_cu_seqlen: Optional[torch.Tensor], + mm_special_offsets: Optional[torch.Tensor], + ) -> torch.Tensor: + """Build 3D positions using chunk runtime metadata from the executor. + + This path is for chunked multimodal prefill where ``input_ids`` only contains the current + chunk but full multimodal tensors are still available. It reconstructs the chunk's 3D + mRoPE positions in absolute request coordinates from: + - per-request chunk start/end (`input_pos`, `seq_len`) + - per-request multimodal item layout (`mm_token_positions`, `mm_token_lengths`) + - full multimodal grids (`image_grid_thw` / `video_grid_thw`) + """ + num_prefill_seqs = batch_info[0].item() + num_prefill_tokens = batch_info[1].item() + + img_grid_idx = 0 + vid_grid_idx = 0 + prefill_3d_parts: list[torch.Tensor] = [] + normalized_video_grid_thw = _normalize_video_grid_for_mrope(video_grid_thw) + + for i in range(num_prefill_seqs): + start = cu_seqlen[i].item() + end = cu_seqlen[i + 1].item() + req_input_pos = int(input_pos[i].item()) + req_seq_len = int(seq_len[i].item()) + + item_start = int(mm_item_cu_seqlen[i].item()) + item_end = int(mm_item_cu_seqlen[i + 1].item()) + req_mm_item_types = mm_item_types[item_start:item_end].tolist() + req_mm_positions = mm_token_positions[item_start:item_end].tolist() + req_mm_lengths = mm_token_lengths[item_start:item_end].tolist() + + req_special_offsets: list[int] = [] + if mm_special_offsets_cu_seqlen is not None and mm_special_offsets is not None: + special_start = int(mm_special_offsets_cu_seqlen[i].item()) + special_end = int(mm_special_offsets_cu_seqlen[i + 1].item()) + req_special_offsets = mm_special_offsets[special_start:special_end].tolist() + + has_img = image_grid_thw is not None and len(req_mm_positions) > 0 + has_vid = normalized_video_grid_thw is not None and len(req_mm_positions) > 0 + + if has_img or has_vid: + req_img_grid = None + req_vid_grid = None + num_images = sum(item_type == 0 for item_type in req_mm_item_types) + num_videos = sum(item_type == 1 for item_type in req_mm_item_types) + if has_img: + req_img_grid = image_grid_thw[img_grid_idx : img_grid_idx + num_images] + img_grid_idx += num_images + if has_vid: + req_vid_grid = normalized_video_grid_thw[ + vid_grid_idx : vid_grid_idx + num_videos + ] + vid_grid_idx += num_videos + + pos_3d = self._compute_request_chunk_mrope_positions( + req_input_pos=req_input_pos, + req_seq_len=req_seq_len, + req_mm_item_types=req_mm_item_types, + req_mm_positions=req_mm_positions, + req_mm_lengths=req_mm_lengths, + req_special_offsets=req_special_offsets, + image_grid_thw=req_img_grid, + video_grid_thw=req_vid_grid, + dtype=input_ids.dtype, + device=input_ids.device, + ) + prefill_3d_parts.append(pos_3d) + else: + if position_ids is not None: + req_pos = position_ids[..., start:end] + else: + req_pos = torch.arange( + req_input_pos, + req_input_pos + req_seq_len, + device=input_ids.device, + dtype=input_ids.dtype, + ).unsqueeze(0) + prefill_3d_parts.append(req_pos[None].expand(3, -1, -1)) + + prefill_pos = torch.cat(prefill_3d_parts, dim=-1) + + if num_prefill_tokens < input_ids.shape[-1]: + if position_ids is None: + raise ValueError("position_ids is required when decode tokens are present") + decode_pos_2d = position_ids[..., num_prefill_tokens:] + if isinstance(delta, torch.Tensor): + gen_deltas = delta[num_prefill_seqs:] + decode_adjusted = decode_pos_2d + gen_deltas.T + else: + decode_adjusted = decode_pos_2d + delta + decode_pos_3d = decode_adjusted[None].expand(3, -1, -1) + return torch.cat([prefill_pos, decode_pos_3d], dim=-1) + + return prefill_pos + + def _compute_request_chunk_mrope_positions( + self, + req_input_pos: int, + req_seq_len: int, + req_mm_item_types: Sequence[int], + req_mm_positions: Sequence[int], + req_mm_lengths: Sequence[int], + req_special_offsets: Sequence[int], + image_grid_thw: Optional[torch.Tensor], + video_grid_thw: Optional[torch.Tensor], + dtype: torch.dtype, + device: torch.device, + ) -> torch.Tensor: + """Compute chunk-local 3D mRoPE positions for one request in absolute coordinates.""" + chunk_end = req_input_pos + req_seq_len + out = torch.empty((3, 1, req_seq_len), dtype=dtype, device=device) + special_offsets_set = set(int(x) for x in req_special_offsets) + mm_cumulative_offset = 0 + abs_cursor = 0 + comp_cursor = 0 + img_idx = 0 + vid_idx = 0 + + def fill_text(abs_start: int, abs_end: int, comp_start: int) -> None: + ov_start = max(req_input_pos, abs_start) + ov_end = min(chunk_end, abs_end) + if ov_start >= ov_end: + return + start_pos = comp_start + (ov_start - abs_start) + text_pos = torch.arange( + start_pos, start_pos + (ov_end - ov_start), device=device, dtype=dtype + ) + out[:, 0, ov_start - req_input_pos : ov_end - req_input_pos] = text_pos.unsqueeze( + 0 + ).expand(3, -1) + + def fill_vision(abs_start: int, grid: torch.Tensor, comp_start: int) -> Tuple[int, int]: + t, h, w = [int(v) for v in grid.tolist()] + llm_grid_t = int(t) + llm_grid_h = int(h) // self.config.vision_config.spatial_merge_size + llm_grid_w = int(w) // self.config.vision_config.spatial_merge_size + vision_len = llm_grid_t * llm_grid_h * llm_grid_w + + ov_start = max(req_input_pos, abs_start) + ov_end = min(chunk_end, abs_start + vision_len) + if ov_start < ov_end: + t_index = ( + torch.arange(llm_grid_t, device=device, dtype=dtype) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + .flatten() + ) + h_index = ( + torch.arange(llm_grid_h, device=device, dtype=dtype) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w, device=device, dtype=dtype) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + positions = torch.stack([t_index, h_index, w_index]) + comp_start + local_start = ov_start - abs_start + local_end = ov_end - abs_start + out[:, 0, ov_start - req_input_pos : ov_end - req_input_pos] = positions[ + :, local_start:local_end + ] + + return vision_len, comp_start + max(llm_grid_t, llm_grid_h, llm_grid_w) + + for item_type, mm_start, mm_len in zip(req_mm_item_types, req_mm_positions, req_mm_lengths): + item_mm_offset = mm_cumulative_offset + leading_specials = 0 + while item_mm_offset + leading_specials in special_offsets_set: + leading_specials += 1 + + vision_abs_start = int(mm_start) + leading_specials + fill_text(abs_cursor, vision_abs_start, comp_cursor) + comp_cursor += vision_abs_start - abs_cursor + + if item_type == 0: + if image_grid_thw is None: + raise ValueError("Missing image_grid_thw for image multimodal item") + grid = image_grid_thw[img_idx] + img_idx += 1 + elif item_type == 1: + if video_grid_thw is None: + raise ValueError("Missing video_grid_thw for video multimodal item") + grid = video_grid_thw[vid_idx] + vid_idx += 1 + else: + raise ValueError(f"Unsupported multimodal item type: {item_type}") + + _, next_comp_cursor = fill_vision(vision_abs_start, grid, comp_cursor) + comp_cursor = next_comp_cursor + abs_cursor = int(mm_start) + int(mm_len) + mm_cumulative_offset += int(mm_len) + + fill_text(abs_cursor, chunk_end, comp_cursor) + return out + + def _build_mixed_positions( + self, + input_ids: torch.LongTensor, + position_ids: torch.LongTensor, + delta, + batch_info: torch.Tensor, + cu_seqlen: torch.Tensor, + image_grid_thw: Optional[torch.LongTensor], + video_grid_thw: Optional[torch.LongTensor], + ) -> torch.Tensor: + """Build 3D mRoPE positions for a batch with per-request granularity. + + Iterates over prefill requests using ``cu_seqlen`` boundaries. For + each request that contains vision tokens, calls + ``compute_mrope_positions`` with the matching ``image_grid_thw`` rows. + Text-only prefill requests get trivial 3D expansion. Decode tokens + are delta-adjusted uniformly. + """ + num_prefill_seqs = batch_info[0].item() + num_prefill_tokens = batch_info[1].item() + + img_grid_idx = 0 + vid_grid_idx = 0 + prefill_3d_parts: list = [] + + for i in range(num_prefill_seqs): + start = cu_seqlen[i].item() + end = cu_seqlen[i + 1].item() + req_ids = input_ids[..., start:end] + + has_img = image_grid_thw is not None and (req_ids == self.config.image_token_id).any() + has_vid = video_grid_thw is not None and (req_ids == self.config.video_token_id).any() + + if has_img or has_vid: + req_img_grid = None + req_vid_grid = None + req_item_types = _extract_mm_item_types_from_input_ids( + req_ids, + image_token_id=self.config.image_token_id, + video_token_id=self.config.video_token_id, + vision_start_token_id=self.config.vision_start_token_id, + ) + if has_img: + n_img = sum(item_type == 0 for item_type in req_item_types) + req_img_grid = image_grid_thw[img_grid_idx : img_grid_idx + n_img] + img_grid_idx += n_img + if has_vid: + n_vid = sum(item_type == 1 for item_type in req_item_types) + req_vid_grid = video_grid_thw[vid_grid_idx : vid_grid_idx + n_vid] + vid_grid_idx += n_vid + + pos_3d, _ = compute_mrope_positions( + input_ids=req_ids, + image_grid_thw=req_img_grid, + video_grid_thw=req_vid_grid, + image_token_id=self.config.image_token_id, + video_token_id=self.config.video_token_id, + vision_start_token_id=self.config.vision_start_token_id, + spatial_merge_size=self.config.vision_config.spatial_merge_size, + ) + prefill_3d_parts.append(pos_3d) + else: + req_pos = position_ids[..., start:end] + prefill_3d_parts.append((req_pos + 0)[None].expand(3, -1, -1)) + + prefill_pos = torch.cat(prefill_3d_parts, dim=-1) + + if num_prefill_tokens < input_ids.shape[-1]: + decode_pos_2d = position_ids[..., num_prefill_tokens:] + if isinstance(delta, torch.Tensor): + num_prefill_seqs_t = batch_info[0].item() + gen_deltas = delta[num_prefill_seqs_t:] + decode_adjusted = decode_pos_2d + gen_deltas.T + else: + decode_adjusted = decode_pos_2d + delta + decode_pos_3d = decode_adjusted[None].expand(3, -1, -1) + return torch.cat([prefill_pos, decode_pos_3d], dim=-1) + + return prefill_pos + + +class Qwen3_5MoeForConditionalGeneration(Qwen3_5MoePreTrainedModel): + """Top-level multimodal model: vision + language model + lm_head. + + This wraps ``Qwen3_5MoeModel`` (which contains the vision tower and the + text ``Qwen3_5MoeTextModel`` as ``language_model``) and adds an ``lm_head`` + at the top level -- matching the HF checkpoint weight layout. + """ + + config_class = Qwen3_5MoeConfig + + def __init__(self, config: Qwen3_5MoeConfig, **kwargs): + super().__init__(config) + self.model = Qwen3_5MoeModel(config) + self.lm_head = nn.Linear( + config.text_config.hidden_size, config.text_config.vocab_size, bias=False + ) + # Share lm_head with the text model so it's inside the exported graph + self.model.language_model.set_lm_head(self.lm_head) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.language_model.get_input_embeddings() + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + self.model.language_model.set_lm_head(new_embeddings) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Qwen3_5MoeConditionalOutput: + outputs = self.model( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + **kwargs, + ) + logits = outputs.logits + return Qwen3_5MoeConditionalOutput(logits=logits) + + +# ============================================================================= +# Custom Export Info and Factory +# ============================================================================= + + +class Qwen3_5MoeTextExportInfo(TextModelExportInfo): + """Export info for mRoPE models that receive 3D position_ids ``(3, B, S)``. + + Dim 0 is always 3 (temporal, height, width) and is static; dims 1 and 2 + (batch, sequence) are dynamic. + """ + + def _init_dynamic_shape_lookup(self): + base = super()._init_dynamic_shape_lookup() + batch_size_dyn = Dim.DYNAMIC + seq_len_dyn = Dim.DYNAMIC + base["position_ids"] = {1: batch_size_dyn, 2: seq_len_dyn} + return base + + +class Qwen3_5MoeADInputProcessor: + """Qwen-specific AD input processor that emits exact multimodal spans from tokenized input.""" + + def __init__(self, base_processor): + self.base_processor = base_processor + # Bypass the generic hashing wrapper. We produce multimodal_input directly. + self.multimodal_hashing_supported = False + + def __getattr__(self, name: str): + return getattr(self.base_processor, name) + + @property + def get_num_multimodal_tokens(self): + """Delegate multimodal token counting to the wrapped Qwen HF processor.""" + if hasattr(self.processor, "_get_num_multimodal_tokens"): + return self.processor._get_num_multimodal_tokens + raise NotImplementedError( + f"get_num_multimodal_tokens not implemented for {self.__class__.__name__}. " + "Please ensure the processor exposes _get_num_multimodal_tokens." + ) + + def get_num_tokens_per_image(self, *, image: Image.Image, **kwargs) -> int: + image_size = (image.height, image.width) + return self.get_num_multimodal_tokens([image_size], **kwargs)["num_image_tokens"][0] + + def get_num_tokens_per_video(self, *, video: List[Image.Image], **kwargs) -> int: + video_size = (len(video), video[0].height, video[0].width) + num_video_tokens = self.get_num_multimodal_tokens(video_sizes=[video_size], **kwargs).get( + "num_video_tokens" + ) + if num_video_tokens is None: + raise NotImplementedError("Underlying processor does not expose num_video_tokens.") + return num_video_tokens[0] + + def get_vocab_size(self) -> Optional[int]: + """Return the tokenizer vocabulary size for Qwen multimodal hashing helpers.""" + if self.tokenizer is not None and hasattr(self.tokenizer, "vocab_size"): + return int(self.tokenizer.vocab_size) + wrapped_tokenizer = getattr(self.tokenizer, "tokenizer", None) + if wrapped_tokenizer is not None and hasattr(wrapped_tokenizer, "vocab_size"): + return int(wrapped_tokenizer.vocab_size) + processor_tokenizer = getattr(self.processor, "tokenizer", None) + if processor_tokenizer is not None and hasattr(processor_tokenizer, "vocab_size"): + return int(processor_tokenizer.vocab_size) + return None + + def get_mm_token_ids(self) -> Optional[torch.Tensor]: + if hasattr(self.processor, "mm_token_ids"): + return self.processor.mm_token_ids + sources = [ + self.processor, + getattr(self.processor, "tokenizer", None), + self.tokenizer, + getattr(self.tokenizer, "tokenizer", None), + ] + token_ids = [] + for source in sources: + if source is None: + continue + for attr in ("image_token_id", "video_token_id"): + value = getattr(source, attr, None) + if value is not None: + token_ids.append(int(value)) + if token_ids: + return torch.tensor(sorted(set(token_ids)), dtype=torch.int32) + return None + + def get_mm_special_token_ids(self) -> Optional[torch.Tensor]: + if hasattr(self.processor, "mm_special_token_ids"): + return self.processor.mm_special_token_ids + sources = [ + self.processor, + getattr(self.processor, "tokenizer", None), + self.tokenizer, + getattr(self.tokenizer, "tokenizer", None), + ] + token_ids = [] + for source in sources: + if source is None: + continue + for attr in ("vision_start_token_id", "vision_end_token_id"): + value = getattr(source, attr, None) + if value is not None: + token_ids.append(int(value)) + if token_ids: + return torch.tensor(sorted(set(token_ids)), dtype=torch.int32) + return None + + def _build_multimodal_input( + self, + token_ids: List[int], + inputs: Dict[str, Any], + ) -> Optional[Tuple[MultimodalInput, List[int], List[int]]]: + mm_data = inputs.get("multi_modal_data") + if not mm_data or not any(k in mm_data for k in ("image", "video")): + return None + + image_token_id = int(self.processor.image_token_id) + video_token_id = int(self.processor.video_token_id) + vision_start_token_id = int(self.processor.vision_start_token_id) + ids = token_ids + + starts: List[int] = [] + lengths: List[int] = [] + special_offsets: List[int] = [] + item_types: List[int] = [] + mm_union_offset = 0 + i = 0 + while i < len(ids): + if ids[i] != vision_start_token_id: + i += 1 + continue + + if i + 1 >= len(ids): + i += 1 + continue + + if ids[i + 1] == image_token_id: + item_token_id = image_token_id + item_type = 0 + elif ids[i + 1] == video_token_id: + item_token_id = video_token_id + item_type = 1 + else: + i += 1 + continue + + j = i + 1 + while j < len(ids) and ids[j] == item_token_id: + j += 1 + if j == i + 1: + i += 1 + continue + + starts.append(i) + lengths.append(j - i) + special_offsets.append(mm_union_offset) + item_types.append(item_type) + mm_union_offset += j - i + i = j + + image_items = _normalize_qwen_image_items(mm_data.get("image")) + video_items = _normalize_qwen_video_items(mm_data.get("video")) + + num_video_spans = sum(item_type == 1 for item_type in item_types) + video_span_counts = [_get_qwen_video_num_spans(video) for video in video_items] + if num_video_spans != sum(video_span_counts): + raise ValueError( + "Mismatch between Qwen video prompt spans and video inputs: " + f"spans={num_video_spans}, expected_from_videos={sum(video_span_counts)}" + ) + + if len(starts) != len(image_items) + num_video_spans: + raise ValueError( + "Mismatch between multimodal prompt spans and multimodal items: " + f"spans={len(starts)}, images={len(image_items)}, video_spans={num_video_spans}" + ) + + mm_uuids = inputs.get("multi_modal_uuids", None) + mm_hash_inputs = {} + if image_items: + mm_hash_inputs["image"] = image_items + if video_items: + mm_hash_inputs["video"] = video_items + mm_hashes, _ = apply_mm_hashes(mm_hash_inputs, mm_uuids) + + image_hashes = [hexdigest_to_int32(h) for h in mm_hashes.get("image", [])] + video_hashes = [hexdigest_to_int32(h) for h in mm_hashes.get("video", [])] + image_uuids = list((mm_uuids or {}).get("image", [None] * len(image_items))) + video_uuids = list((mm_uuids or {}).get("video", [None] * len(video_items))) + image_idx = 0 + video_idx = 0 + remaining_video_spans = video_span_counts[0] if video_span_counts else 0 + mm_hashes_flat: List[List[int]] = [] + mm_uuid_list: List[Optional[str]] = [] + for item_type in item_types: + if item_type == 0: + mm_hashes_flat.append(image_hashes[image_idx]) + mm_uuid_list.append(image_uuids[image_idx]) + image_idx += 1 + else: + if video_idx >= len(video_hashes): + raise ValueError("Video span count exceeded available video items") + mm_hashes_flat.append(video_hashes[video_idx]) + mm_uuid_list.append(video_uuids[video_idx]) + remaining_video_spans -= 1 + if remaining_video_spans == 0: + video_idx += 1 + remaining_video_spans = ( + video_span_counts[video_idx] if video_idx < len(video_span_counts) else 0 + ) + return ( + MultimodalInput.from_components( + mm_hashes_flat, + starts, + lengths, + mm_uuid_list if mm_uuids is not None else None, + ), + special_offsets, + item_types, + ) + + def __call__(self, inputs, sampling_params): + token_ids, extra_processed_inputs = self.base_processor(inputs, sampling_params) + if "multi_modal_data" not in inputs: + return token_ids, extra_processed_inputs + + built = self._build_multimodal_input(token_ids, inputs) + if built is None: + return token_ids, extra_processed_inputs + + multimodal_input, special_offsets, item_types = built + if extra_processed_inputs is None: + extra_processed_inputs = {} + extra_processed_inputs["multimodal_input"] = multimodal_input + multimodal_data = extra_processed_inputs.get("multimodal_data", {}) + multimodal_data["layout_metadata"] = { + "special_token_offsets": torch.tensor(special_offsets, dtype=torch.int32), + "item_types": torch.tensor(item_types, dtype=torch.int32), + } + extra_processed_inputs["multimodal_data"] = multimodal_data + return token_ids, extra_processed_inputs + + +@ModelFactoryRegistry.register("Qwen3_5MoeForConditionalGeneration") +class Qwen3_5MoeFactory(AutoModelForImageTextToTextFactory): + """Factory for Qwen3.5 MoE that uses 3D mRoPE position_ids export info.""" + + def get_export_infos(self, model: nn.Module): + return [Qwen3_5MoeTextExportInfo.from_autoinferred(model)] + + def init_input_processor(self, base): + return Qwen3_5MoeADInputProcessor(base) + + +# ============================================================================= +# Registration +# ============================================================================= + +AutoConfig.register("qwen3_5_moe", Qwen3_5MoeConfig) +AutoConfig.register("qwen3_5_moe_text", Qwen3_5MoeTextConfig) + +AutoModelForCausalLMFactory.register_custom_model_cls("Qwen3_5MoeTextConfig", Qwen3_5MoeForCausalLM) +Qwen3_5MoeFactory.register_custom_model_cls("Qwen3_5MoeConfig", Qwen3_5MoeForConditionalGeneration) diff --git a/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_qwen3_ir.py b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_qwen3_ir.py new file mode 100644 index 000000000000..d1ce41e5b663 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_qwen3_ir.py @@ -0,0 +1,456 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Qwen3 model with explicit sharding hint ops. + +This is a rewrite of modeling_qwen3.py where all sharding-enabled operations use +AutoDeploy custom ops with sharding hint kwargs. The graph produced by this +model is a complete, self-contained specification of how this model should be +sharded. The ``apply_sharding_hints`` transform reads the hints together with a +runtime ``DistConfig`` to apply deterministic, node-local sharding. + +Shardable custom ops used: + - torch.ops.auto_deploy.torch_linear_simple (tp_mode, tp_min_local_shape, layer_type) + - torch.ops.auto_deploy.view (tp_scaled_dim, layer_type) + - torch.ops.auto_deploy.all_reduce (identity / dist.all_reduce, layer_type) +""" + +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +from torch import nn +from transformers.activations import ACT2FN +from transformers.generation import GenerationMixin +from transformers.modeling_utils import PreTrainedModel +from transformers.models.qwen3.configuration_qwen3 import Qwen3Config +from transformers.utils import ModelOutput + +import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401 -- register all ops +from tensorrt_llm._torch.auto_deploy.models.hf import AutoModelForCausalLMFactory + + +class Qwen3RMSNorm(nn.Module): + """RMS Normalization for Qwen3 using AutoDeploy torch_rmsnorm reference op.""" + + def __init__(self, hidden_size: int, eps: float = 1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return torch.ops.auto_deploy.torch_rmsnorm( + hidden_states, self.weight, self.variance_epsilon + ) + + +class Qwen3RotaryEmbedding(nn.Module): + """Rotary Position Embedding for Qwen3. + + Simplified version that precomputes and caches cos/sin values. + Returns full cached values (not sliced by seq_len) to enable export. + + Uses _ad_ prefix for buffer names to work with AutoDeploy's lift_to_meta. + """ + + def __init__( + self, + dim: int, + max_position_embeddings: int = 32768, + base: float = 10000.0, + ): + super().__init__() + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + self._set_cos_sin_cache(max_position_embeddings) + + def _set_cos_sin_cache(self, seq_len: int): + self.max_seq_len_cached = seq_len + t = torch.arange(seq_len, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("_ad_cos_cached", emb.cos(), persistent=False) + self.register_buffer("_ad_sin_cached", emb.sin(), persistent=False) + + def forward( + self, x: torch.Tensor, position_ids: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + cos = self._ad_cos_cached.to(dtype=x.dtype, device=x.device) + sin = self._ad_sin_cached.to(dtype=x.dtype, device=x.device) + return cos[position_ids], sin[position_ids] + + +class Qwen3MLP(nn.Module): + """MLP layer for Qwen3 (SwiGLU) with sharding hints. + + Sharding strategy: + gate_proj -> colwise + up_proj -> colwise + down_proj -> rowwise + all_reduce + """ + + def __init__(self, config: Qwen3Config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate = torch.ops.auto_deploy.torch_linear_simple( + x, + self.gate_proj.weight, + self.gate_proj.bias, + tp_mode="colwise", + layer_type="mlp", + ) + up = torch.ops.auto_deploy.torch_linear_simple( + x, + self.up_proj.weight, + self.up_proj.bias, + tp_mode="colwise", + layer_type="mlp", + ) + down = torch.ops.auto_deploy.torch_linear_simple( + self.act_fn(gate) * up, + self.down_proj.weight, + self.down_proj.bias, + tp_mode="rowwise", + layer_type="mlp", + ) + down = torch.ops.auto_deploy.all_reduce(down, layer_type="mlp") + return down + + +class Qwen3Attention(nn.Module): + """Grouped Query Attention for Qwen3 with per-head Q/K normalization and sharding hints. + + Sharding strategy: + q_proj -> colwise (+ tp_min_local_shape for GQA) + k_proj -> colwise (+ tp_min_local_shape for GQA) + v_proj -> colwise (+ tp_min_local_shape for GQA) + view -> tp_scaled_dim=2 (head count dimension) + o_proj -> rowwise + all_reduce + """ + + def __init__(self, config: Qwen3Config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self.scaling = self.head_dim ** (-0.5) + + self.q_proj = nn.Linear( + self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + self.hidden_size, self.num_kv_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + self.hidden_size, self.num_kv_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias + ) + + self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + ) -> torch.Tensor: + bsz, q_len, _ = hidden_states.size() + + q = torch.ops.auto_deploy.torch_linear_simple( + hidden_states, + self.q_proj.weight, + self.q_proj.bias, + tp_mode="colwise", + tp_min_local_shape=self.head_dim, + layer_type="mha", + ) + k = torch.ops.auto_deploy.torch_linear_simple( + hidden_states, + self.k_proj.weight, + self.k_proj.bias, + tp_mode="colwise", + tp_min_local_shape=self.head_dim, + layer_type="mha", + ) + v = torch.ops.auto_deploy.torch_linear_simple( + hidden_states, + self.v_proj.weight, + self.v_proj.bias, + tp_mode="colwise", + tp_min_local_shape=self.head_dim, + layer_type="mha", + ) + + q = torch.ops.auto_deploy.view( + q, + [bsz, q_len, self.num_heads, self.head_dim], + tp_scaled_dim=2, + layer_type="mha", + ) + k = torch.ops.auto_deploy.view( + k, + [bsz, q_len, self.num_kv_heads, self.head_dim], + tp_scaled_dim=2, + layer_type="mha", + ) + v = torch.ops.auto_deploy.view( + v, + [bsz, q_len, self.num_kv_heads, self.head_dim], + tp_scaled_dim=2, + layer_type="mha", + ) + + q = self.q_norm(q) + k = self.k_norm(k) + + cos, sin = position_embeddings + + q, k = torch.ops.auto_deploy.torch_rope_with_explicit_cos_sin( + q, + k, + cos, + sin, + 2, + ) + + attn_output = torch.ops.auto_deploy.torch_attention( + q, + k, + v, + None, + 0.0, + True, + self.scaling, + None, + None, + None, + "bsnd", + ) + + attn_output = torch.ops.auto_deploy.view( + attn_output, + [bsz, q_len, self.num_heads * self.head_dim], + tp_scaled_dim=2, + layer_type="mha", + ) + + attn_output = torch.ops.auto_deploy.torch_linear_simple( + attn_output, + self.o_proj.weight, + self.o_proj.bias, + tp_mode="rowwise", + layer_type="mha", + ) + attn_output = torch.ops.auto_deploy.all_reduce(attn_output, layer_type="mha") + + return attn_output + + +class Qwen3DecoderLayer(nn.Module): + """Transformer decoder layer for Qwen3.""" + + def __init__(self, config: Qwen3Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Qwen3Attention(config, layer_idx=layer_idx) + self.mlp = Qwen3MLP(config) + self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn(hidden_states, position_embeddings) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +@dataclass +class Qwen3Output(ModelOutput): + """Output for Qwen3Model.""" + + last_hidden_state: Optional[torch.FloatTensor] = None + + +@dataclass +class Qwen3CausalLMOutput(ModelOutput): + """Output for Qwen3ForCausalLM.""" + + logits: Optional[torch.FloatTensor] = None + + +class Qwen3PreTrainedModel(PreTrainedModel): + """Base class for Qwen3 models.""" + + config_class = Qwen3Config + base_model_prefix = "model" + _no_split_modules = ["Qwen3DecoderLayer"] + supports_gradient_checkpointing = False + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class Qwen3Model(Qwen3PreTrainedModel): + """Qwen3 transformer decoder model.""" + + def __init__(self, config: Qwen3Config): + super().__init__(config) + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Qwen3DecoderLayer(config, layer_idx=idx) for idx in range(config.num_hidden_layers)] + ) + self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.rotary_emb = Qwen3RotaryEmbedding( + head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + ) + + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + **kwargs, + ) -> Qwen3Output: + if input_ids is not None and inputs_embeds is not None: + raise ValueError("Cannot specify both input_ids and inputs_embeds") + elif input_ids is None and inputs_embeds is None: + raise ValueError("Must specify either input_ids or inputs_embeds") + + assert position_ids is not None, "position_ids must be provided for AD export" + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + inputs_embeds = inputs_embeds.to(self.norm.weight.dtype) + + position_embeddings = self.rotary_emb(inputs_embeds, position_ids) + + hidden_states = inputs_embeds + + for decoder_layer in self.layers: + hidden_states = decoder_layer(hidden_states, position_embeddings) + + hidden_states = self.norm(hidden_states) + + return Qwen3Output(last_hidden_state=hidden_states) + + +class Qwen3ForCausalLM(Qwen3PreTrainedModel, GenerationMixin): + """Qwen3 model with language modeling head.""" + + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config, **kwargs): + super().__init__(config) + self.model = Qwen3Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_decoder(self): + return self.model + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + **kwargs, + ) -> Qwen3CausalLMOutput: + assert position_ids is not None, "position_ids must be provided for AD export" + outputs = self.model( + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + logits = self.lm_head(hidden_states).float() + + return Qwen3CausalLMOutput(logits=logits) + + +AutoModelForCausalLMFactory.register_custom_model_cls("Qwen3Config", Qwen3ForCausalLM) diff --git a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py index ea336dab0f41..e020bb6a48ac 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py @@ -72,6 +72,7 @@ from ..llm_args import LlmArgs from ..transform.optimizer import InferenceOptimizer from ..utils._graph import get_input_embeddings, get_lm_head_weights +from ..utils.dist_config import DistConfig from ..utils.logger import ad_logger from .interface import CachedSequenceInterface, GetInferenceModel @@ -356,7 +357,7 @@ def _call_func(): can_pad = self.padding_dummy_request is not None # in attention DP mode, we check all ranks - if self.enable_attention_dp and self.mapping.tp_size > 1: + if self.enable_attention_dp and self.dist_config.tp_size > 1: assert self.dist is not None, "Distributed object is required for attention DP mode" all_rank_info = self.dist.tp_allgather([can_run_cuda_graph, can_pad, batch_size]) else: @@ -427,6 +428,8 @@ def _device(self) -> DeviceLikeType: def build_from_config( cls, ad_config: LlmArgs, + dist_config: Optional[DistConfig] = None, + # deprecation: Mapping will soon be replaced entirely by DistConfig mapping: Optional[Mapping] = None, dist: Optional[Distributed] = None, ): @@ -457,12 +460,9 @@ def build_from_config( enable_iter_perf_stats=ad_config.enable_iter_perf_stats, enable_iter_req_stats=ad_config.enable_iter_req_stats, ) - # TODO (lucaslie): consider how we move args around InferenceOptimizer.__init__, - # ADEngine.__init__, and ADEngine.build_from_config. Seems a bit unnatural atm. - # construct inference optimizer build_and_optimize = InferenceOptimizer( - factory=factory, config=ad_config.transforms, mapping=mapping + factory=factory, config=ad_config.transforms, dist_config=dist_config ) # construct engine @@ -470,6 +470,7 @@ def build_from_config( build_and_optimize, cache_seq_interface, ad_config=ad_config, + dist_config=dist_config, mapping=mapping, dist=dist, reporting_info=reporting_info, @@ -481,6 +482,7 @@ def __init__( get_inference_model: GetInferenceModel, cache_seq_interface: CachedSequenceInterface, ad_config: Optional[LlmArgs] = None, + dist_config: Optional[DistConfig] = None, mapping: Optional[Mapping] = None, dist: Optional[Distributed] = None, reporting_info: ReportingInfo = ReportingInfo(), @@ -491,7 +493,8 @@ def __init__( get_inference_model: Callable that builds the inference model. cache_seq_interface: The CachedSequenceInterface containing sequence and cache config. ad_config: Optional LLM configuration. - mapping: Optional distributed mapping configuration. + dist_config: DistConfig (single source of truth for distributed config within AD). + mapping: Mapping for external TRT-LLM APIs (KV cache, sampler, etc.). reporting_info: Reporting configuration for logging. """ # NOTE (lucaslie): create a fake Namespace to satisfy PyExecutor requirements... @@ -511,7 +514,7 @@ def __init__( self.iter_states = {} # NOTE (lucaslie): not a declared base member in the base class; required by PyExecutor... - self.enable_attention_dp = mapping.enable_attention_dp if mapping else False + self.enable_attention_dp = dist_config.enable_attention_dp if dist_config else False if ad_config is not None: self.max_beam_width = ad_config.max_beam_width @@ -559,6 +562,7 @@ def __init__( self.padding_dummy_request: Optional[LlmRequest] = None # Reuse _execute_logit_post_processors from PyTorchModelEngine + self.dist_config = dist_config self.mapping = mapping self.dist = dist self._execute_logit_post_processors = types.MethodType( @@ -942,7 +946,7 @@ def forward( ): spec_resource_manager.capture_hidden_states(self.cache_seq_interface) - if self.mapping is not None: + if self.dist_config is not None: self._execute_logit_post_processors(scheduled_requests, outputs) return outputs @@ -1144,8 +1148,10 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer world_size = mpi_world_size() rank = mpi_rank() - # Initialize Mapping from config - dist_mapping = ad_config.init_mapping_from_config(rank, world_size) + # DistConfig is the single source of truth within AutoDeploy. + # Mapping is derived only for external TRT-LLM APIs that still require it. + dc = ad_config.init_dist_config(rank, world_size) + dist_mapping = dc.to_mapping() dist = Distributed.get(dist_mapping) ad_logger.set_rank(rank) @@ -1153,7 +1159,7 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer port = dist.broadcast(get_free_port()) # use MPI broadcast to pick a free port initialize_or_skip(rank, world_size, port) - ad_logger.info(f"{dist_mapping=}, {dist=}, {port=}") + ad_logger.info(f"dist_config={dc}, {dist=}, {port=}") # Setup AutoTuner with distributed state for allreduce autotuning AutoTuner.get().setup_distributed_state(dist_mapping) @@ -1161,7 +1167,7 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer # some config assert ad_config.max_beam_width <= 1, "_autodeploy + beam_search is not supported" - max_num_sequences = ad_config.max_batch_size * dist_mapping.pp_size + max_num_sequences = ad_config.max_batch_size * dc.pp_size # some derivative properties max_draft_len = ( 0 if ad_config.speculative_config is None else ad_config.speculative_config.max_draft_len @@ -1173,7 +1179,9 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer ) # initialize model engine - engine = ADEngine.build_from_config(ad_config=ad_config, mapping=dist_mapping, dist=dist) + engine = ADEngine.build_from_config( + ad_config=ad_config, dist_config=dc, mapping=dist_mapping, dist=dist + ) spec_config = ad_config.speculative_config @@ -1279,7 +1287,7 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer guided_decoder = None if ( (guided_decoding_backend := ad_config.guided_decoding_backend) is not None - ) and dist_mapping.is_last_pp_rank(): + ) and dc.pp_rank == dc.pp_size - 1: if vocab_size_padded is None: raise RuntimeError( "Could not determine the vocabulary size. Required for guided decoding." diff --git a/tensorrt_llm/_torch/auto_deploy/transform/interface.py b/tensorrt_llm/_torch/auto_deploy/transform/interface.py index 6289a86ee987..ad2079fab2f0 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/interface.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/interface.py @@ -28,6 +28,7 @@ run_shape_prop, ) from ..utils.cuda_mem_tracker import get_mem_info +from ..utils.dist_config import DistConfig # kept in utils to avoid circular imports from ..utils.graph_writer import graph_writer from ..utils.logger import ad_logger from .graph_module_visualizer import to_dot @@ -126,7 +127,7 @@ class SharedConfig(BaseModel): } local_rank: int = Field(default=0) world_size: int = Field(default=1) - mapping: Any = Field(default=None) # Mapping object from ad_executor + dist_config: Optional[DistConfig] = Field(default=None) class TransformConfig(BaseModel): diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/adapt_to_edgellm.py b/tensorrt_llm/_torch/auto_deploy/transform/library/adapt_to_edgellm.py index 75149b2c591a..1b580b33e70c 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/adapt_to_edgellm.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/adapt_to_edgellm.py @@ -22,7 +22,7 @@ from ...shim.interface import CachedSequenceInterface from ...utils._graph import run_shape_prop from ...utils.logger import ad_logger -from ...utils.node_utils import is_op, sync_weight_meta_dtype +from ...utils.node_utils import is_any_view_op, is_op, sync_weight_meta_dtype from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry @@ -101,7 +101,7 @@ def _insert_cast_after_attn_reshape(self, gm: GraphModule) -> int: getitem_0_node = user # Find reshape nodes that use this getitem[0] for reshape_user in list(getitem_0_node.users): - if is_op(reshape_user, torch.ops.aten.reshape.default): + if is_any_view_op(reshape_user): reshape_node = reshape_user # Insert cast (to float16) after reshape with graph.inserting_after(reshape_node): diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/attention.py b/tensorrt_llm/_torch/auto_deploy/transform/library/attention.py index 91729d497b94..a57965528813 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/attention.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/attention.py @@ -295,6 +295,7 @@ def register_repeat_kv(patterns: ADPatternMatcherPass): op_ignore_types={ torch.ops.aten.reshape.default: (int,), torch.ops.aten.expand.default: (int,), + torch.ops.auto_deploy.view.default: (int,), }, scalar_workaround={"n_rep": dummy_args[1]}, ) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/collectives.py b/tensorrt_llm/_torch/auto_deploy/transform/library/collectives.py index 122a5464d424..dd5bab176b53 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/collectives.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/collectives.py @@ -13,6 +13,7 @@ from ...models.factory import ModelFactory from ...shim.interface import CachedSequenceInterface +from ...utils.logger import ad_logger from ...utils.pattern_matcher import ADPatternMatcherPass, register_ad_pattern from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry @@ -129,8 +130,16 @@ def _apply( # Instantiate Pattern Functions # ============================================================================ - # Get the allreduce strategy from shared_config - strategy = gm._sharding_transform_container.config.allreduce_strategy.name + if shared_config.dist_config is not None: + strategy = shared_config.dist_config.allreduce_strategy + elif hasattr(gm, "_sharding_transform_container"): + strategy = gm._sharding_transform_container.config.allreduce_strategy.name + else: + ad_logger.warning("No dist config found, skipping allreduce-residual-rmsnorm fusion") + return gm, TransformInfo( + skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True + ) + ad_logger.info(f"allreduce strategy selected = {strategy!r}") # TRT-LLM backend (MPI mode) - two patterns for different addition orders _allreduce_residual_rmsnorm_pattern_trtllm = _make_allreduce_residual_rmsnorm_pattern( diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_rmsnorm_quant_fp8.py b/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_rmsnorm_quant_fp8.py index 886a5db527d6..435e80579da0 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_rmsnorm_quant_fp8.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_rmsnorm_quant_fp8.py @@ -29,6 +29,7 @@ collect_terminal_users_through_passthrough, extract_op_args, extract_output_tuple, + is_any_view_op, is_op, is_trivial_passthrough_user, ) @@ -91,7 +92,7 @@ def _collect_grouped_fp8_linear_users( def _is_view_like(node: Node) -> bool: - return is_op(node, torch.ops.aten.view.default) or is_op(node, torch.ops.aten.reshape.default) + return is_any_view_op(node) def _unwrap_post_norm_nodes(node: Node) -> Tuple[Node, list[Node]]: diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py b/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py index c2b4842b56c0..a3a9a3b279d7 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py @@ -29,7 +29,13 @@ from ...utils.cuda_mem_tracker import cuda_memory_tracker from ...utils.logger import ad_logger from ...utils.module import get_submodule_of_param -from ...utils.node_utils import bfs, extract_op_args, identify_regions_between_residuals, is_op +from ...utils.node_utils import ( + bfs, + extract_op_args, + identify_regions_between_residuals, + is_any_view_op, + is_op, +) from ..interface import ( BaseTransform, SharedConfig, @@ -1126,7 +1132,7 @@ def _find_output_and_routing_flavor(final_bmm: Node) -> Optional[Tuple[Node, boo # Llama4 pattern: bmm -> view([-1, hidden]) -> reshape([num_experts, -1, hidden]) -> sum(dim=0) output_view = None for user in final_bmm.users: - if is_op(user, torch.ops.aten.view): + if is_any_view_op(user): output_view = user break @@ -1136,7 +1142,7 @@ def _find_output_and_routing_flavor(final_bmm: Node) -> Optional[Tuple[Node, boo # Find reshape after view reshape_node = None for user in output_view.users: - if is_op(user, torch.ops.aten.reshape): + if is_any_view_op(user): reshape_node = user break @@ -1267,7 +1273,7 @@ def _match_bmm_moe_pattern( # Step 3: Get batched input and trace back to original input and routing batched_input = first_bmm.args[0] - if not isinstance(batched_input, Node) or not is_op(batched_input, torch.ops.aten.view): + if not isinstance(batched_input, Node) or not is_any_view_op(batched_input): continue result = MatchBmmMoePattern._find_input_and_routing(batched_input) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py b/tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py index 77dee1e06c22..ce17fc4b361b 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py @@ -413,6 +413,15 @@ def _apply( grouped_nodes: Dict[tuple, List[Node]] = defaultdict(list) for node in gm.graph.nodes: if (is_linear_op(node) or is_fake_quantized_linear_op(node)) and node.args[2] is None: + # Skip linears with a unit dimension (e.g., [1, H] scalar gates). + # A weight with dim=1 is effectively a lower-order tensor and + # should not be fused with proper matrix projections. + try: + w = gm.get_parameter(extract_weight_name(node)) + if any(d == 1 for d in w.shape): + continue + except (AttributeError, KeyError): + pass grouped_nodes[(node.args[0], _get_op_key(node))].append(node) idx = -1 diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/mxfp4_moe.py b/tensorrt_llm/_torch/auto_deploy/transform/library/mxfp4_moe.py index 4113732111ab..71a24bd78577 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/mxfp4_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/mxfp4_moe.py @@ -94,6 +94,7 @@ def _apply( op_ignore_types = { torch.ops.aten.view.default: (int,), torch.ops.aten.reshape.default: (int,), + torch.ops.auto_deploy.view.default: (int,), torch.ops.aten.repeat.default: (int,), torch.ops.aten.slice.Tensor: (int,), torch.ops.aten.unsqueeze.default: (int,), diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py b/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py index 633d0500e9f6..ee0520b9c2bf 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py @@ -19,6 +19,7 @@ from ...utils.logger import ad_logger from ...utils.node_utils import ( WeightBiasInfoCache, + extract_op_args, extract_weight_nodes, get_quantization_params_from_linear_node, is_bmm_op, @@ -212,8 +213,20 @@ def _insert_quantized_linear( custom_args = self.build_custom_args_for_linear(scales) + # Extract sharding hints by name so we don't depend on positional layout. + [tp_mode, output_sizes, tp_min_local_shape, layer_type] = extract_op_args( + node, "tp_mode", "output_sizes", "tp_min_local_shape", "layer_type" + ) + [inp, weight, bias] = extract_op_args(node, "input", "weight", "bias") node.target = self.target_op() - node.args = (*node.args, *custom_args) + node.args = (inp, weight, bias, *custom_args) + node.kwargs = { + **node.kwargs, + "tp_mode": tp_mode, + "output_sizes": output_sizes, + "tp_min_local_shape": tp_min_local_shape, + "layer_type": layer_type, + } def _insert_quantized_bmm( self, diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py b/tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py index 92b2da8b2e5f..544e85bb0f2a 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py @@ -182,6 +182,7 @@ def make_dummy_args_gated(group_size: int, eps: float) -> list: op_ignore_types = { torch.ops.aten.reshape.default: (int, list, tuple), torch.ops.aten.view.default: (int, list, tuple), + torch.ops.auto_deploy.view.default: (int, list, tuple), torch.ops.aten.mean.dim: (list, tuple), torch.ops.aten.to.dtype: (torch.dtype,), } diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/rope.py b/tensorrt_llm/_torch/auto_deploy/transform/library/rope.py index bf7075acbc57..bc91131327fb 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/rope.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/rope.py @@ -55,7 +55,7 @@ def apply_rotary_emb( from ...models.factory import ModelFactory from ...shim.interface import CachedSequenceInterface -from ...utils.node_utils import extract_op_args, extract_output_tuple, is_op +from ...utils.node_utils import extract_op_args, extract_output_tuple, is_any_view_op, is_op from ...utils.pattern_matcher import ADPatternMatcherPass, Match, register_ad_pattern from ..interface import ( BaseTransform, @@ -192,6 +192,7 @@ def _apply( torch.ops.aten.slice.Tensor: (int,), torch.ops.aten.reshape.default: (int,), torch.ops.aten.view.default: (int,), + torch.ops.auto_deploy.view.default: (int,), }, scalar_workaround={"unsqueeze_dim": 1}, ) @@ -202,6 +203,7 @@ def _apply( dummy_args=dummy_complex, op_ignore_types={ torch.ops.aten.reshape.default: (int,), + torch.ops.auto_deploy.view.default: (int,), }, scalar_workaround={"unsqueeze_dim": 1}, ) @@ -212,6 +214,7 @@ def _apply( dummy_args=dummy_complex_2, op_ignore_types={ torch.ops.aten.reshape.default: (int,), + torch.ops.auto_deploy.view.default: (int,), }, scalar_workaround={"unsqueeze_dim": 1}, ) @@ -1012,13 +1015,13 @@ def _match_input_interleave_pattern(node: Node) -> Optional[Dict[str, Node]]: Returns: {"interleaved": raw_node} if matched, else None. """ - if not is_op(node, torch.ops.aten.reshape): + if not is_any_view_op(node): return None transpose_node = node.args[0] if not is_op(transpose_node, torch.ops.aten.transpose): return None view_node = transpose_node.args[0] - if not is_op(view_node, torch.ops.aten.view): + if not is_any_view_op(view_node): return None raw_node = view_node.args[0] if not isinstance(raw_node, Node): diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py index ba9b1a325c4b..a8ed3424e37e 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py @@ -1,5 +1,11 @@ """Transformations to support graph sharding. +.. deprecated:: + The heuristic-based sharding infrastructure in this module (``detect_sharding``, + ``sharding_transform_executor``, and all ``ShardingInfo`` classes) is being replaced + by the hint-driven IR sharding system in ``sharding_ir.py``. New development should + target ``sharding_ir.py``; this module will be removed once the transition is complete. + Our sharding algorithm for tensor parallelism (TP) is based on the following steps: 1. Initialize/construct unsharded model. Ideally, this should be done on device="meta" to avoid @@ -28,8 +34,7 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator from torch.fx import GraphModule, Node -from tensorrt_llm._torch.auto_deploy.utils.mapping_utils import print_grid, serialize_mapping -from tensorrt_llm.mapping import Mapping +from tensorrt_llm._torch.auto_deploy.utils.dist_config import DistConfig from .....functional import AllReduceStrategy from ...custom_ops.distributed.trtllm_dist import is_trtllm_op_available @@ -52,7 +57,9 @@ is_any_delta_op, is_any_lin_op, is_any_moe_op, + is_any_split_op, is_any_ssm_op, + is_any_view_op, is_fake_quantized_linear_op, is_op, is_weight_node, @@ -244,43 +251,36 @@ class ShardingTransformConfig(TransformConfig): description="When True, skip TP sharding as attention data parallelism is enabled.", ) + shard_layers: Optional[List[str]] = Field( + default=None, + description="When set, only shard nodes whose layer_type hint is in this list. " + "Nodes with layer_type='unknown' or missing are NOT sharded. " + "When None (default), all enable_sharding nodes are processed regardless of layer_type.", + ) + dist_mapping: dict[str, int] = Field(default_factory=dict) - mapping: Mapping = Field(default_factory=Mapping) + mapping: Any = Field(default=None) # Legacy: tensorrt_llm.mapping.Mapping (kept for compat) + dist_config: DistConfig = Field(default_factory=DistConfig) def _init_mapping(self): - """Initialize Mapping from dist_mapping config. + """Initialize DistConfig from dist_mapping config. NOTE: This method is now primarily a fallback. The preferred flow is: - 1. Mapping is initialized in ad_executor.py from config.transforms['detect_sharding']['dist_mapping'] - 2. Passed through SharedConfig.mapping to the sharding transform - 3. Only if SharedConfig.mapping is None, this fallback is used - - This ensures Mapping is created once with the correct configuration from YAML, - rather than being recreated in multiple places. + 1. DistConfig is constructed in ad_executor.py from the Mapping object + 2. Passed through SharedConfig.dist_config to the sharding transform + 3. Only if SharedConfig.dist_config is None, this fallback is used """ - # by default, we use 1D parallelism (TP-only for token mixers and FFN, EP-only for MoE) - try: - self.mapping = Mapping( - world_size=self.world_size, - rank=self.rank, - tp_size=self.dist_mapping.get("tp", self.world_size), - moe_tp_size=self.dist_mapping.get("moe_tp", 1), - moe_ep_size=self.dist_mapping.get("moe_ep", self.world_size), - moe_cluster_size=self.dist_mapping.get("moe_cluster", 1), - enable_attention_dp=self.enable_attention_dp, - ) - except ValueError as e: - ad_logger.warning(f"Invalid parallel grid config: {e}") - ad_logger.warning("Defaulting to TP-only sharding (EP only for MoE)") - self.mapping = Mapping( - world_size=self.world_size, - rank=self.rank, - tp_size=self.world_size, - moe_tp_size=1, - moe_ep_size=self.world_size, - moe_cluster_size=1, - ) + self.dist_config = DistConfig( + world_size=self.world_size, + rank=self.rank, + tp_size=self.dist_mapping.get("tp", self.world_size), + moe_tp_size=self.dist_mapping.get("moe_tp", 1), + moe_ep_size=self.dist_mapping.get("moe_ep", self.world_size), + moe_cluster_size=self.dist_mapping.get("moe_cluster", 1), + enable_attention_dp=self.enable_attention_dp, + allreduce_strategy=self.allreduce_strategy.name, + ) def validate_config(self, sources: Union[ShardingSource, List[ShardingSource]] = None) -> bool: if sources is None: @@ -1077,7 +1077,7 @@ class Sharding(BaseTransform): The transformation is based on the following steps: - 1. Identify boundary nodes between residual nodes to identify shardable regions. + 1. Identify boundary nodes between residual nodes to identify enable_sharding regions. 2. Identify the GEMM nodes that can be sharded 3. Trace through the subgraph using DFS/BFS between each pair of boundary nodes 4. Account for each node in the trace to ensure the op is correct even after sharding. This is @@ -1116,11 +1116,9 @@ def _apply( config.rank = local_rank config.world_size = world_size - # Use Mapping from shared_config (initialized in ad_executor) if available - if shared_config.mapping is not None: - config.mapping = shared_config.mapping + if shared_config.dist_config is not None: + config.dist_config = shared_config.dist_config else: - # Fallback to creating mapping from dist_mapping config if not provided config._init_mapping() if world_size > 1: @@ -1149,7 +1147,7 @@ def _apply( ) info = TransformInfo(skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True) - ad_logger.info(print_grid(config.mapping)) + ad_logger.info(config.dist_config.print_grid()) with WeightBiasInfoCache(): # ============================= # ======== EP sharding ======== @@ -1167,7 +1165,7 @@ def _apply( # ======== TP sharding ======== if ShardingDim.TP not in config.sharding_dims: return gm, info - if config.mapping.enable_attention_dp: + if config.dist_config.enable_attention_dp: # only MoE all-to-all sharding is supported in attention DP mode # we already enforced 1D sharding (TP=1, EP=world_size) in init_mapping ad_logger.info( @@ -1455,7 +1453,7 @@ def _validate_sharded_shapes( next_lin_node, _ = bfs(node, is_any_lin_op, include_root=False) nodes_to_validate = subgraph( [node], - include=lambda n: is_op(n, [torch.ops.aten.view, torch.ops.aten.reshape]), + include=is_any_view_op, boundary_condition=is_any_lin_op, ) for view_node in nodes_to_validate: @@ -1482,7 +1480,7 @@ def _validate_sharded_shapes( split_nodes = subgraph( [node], [next_lin_node], - include=lambda n: is_op(n, [torch.ops.aten.split, torch.ops.aten.split_with_sizes]), + include=is_any_split_op, ) for split_node in split_nodes: orig_sizes = split_node.args[1] @@ -1494,15 +1492,25 @@ def _validate_sharded_shapes( TP_SHARDING_RULES = [ + # Standard FP8 (per-tensor scales, replicated) (lambda n: is_op(n, torch.ops.auto_deploy.torch_fake_quant_fp8_linear), FP8WeightShardingInfo), + (lambda n: is_op(n, torch.ops.auto_deploy.torch_quant_fp8_linear), FP8WeightShardingInfo), + (lambda n: is_op(n, torch.ops.auto_deploy.trtllm_quant_fp8_linear), FP8WeightShardingInfo), + # FP4 (per-block scales) ( lambda n: is_op(n, torch.ops.auto_deploy.torch_fake_quant_nvfp4_linear), FP4WeightShardingInfo, ), + (lambda n: is_op(n, torch.ops.auto_deploy.torch_quant_nvfp4_linear), FP4WeightShardingInfo), + # Fine-grained FP8 (per-block scales, need shard + load hook) ( lambda n: is_op(n, torch.ops.auto_deploy.torch_fake_quant_finegrained_fp8_linear), FineGrainedFP8WeightShardingInfo, ), + ( + lambda n: is_op(n, torch.ops.auto_deploy.trtllm_finegrained_fp8_linear), + FineGrainedFP8WeightShardingInfo, + ), ] @@ -1511,8 +1519,8 @@ def _resolve_tp_cls_from_node(node: Node): try: if pred(node): return cls - except Exception: - pass + except AttributeError: + pass # Op not registered yet return WeightShardingInfo @@ -1532,12 +1540,17 @@ def _split_tensor_for_tp( When world_size exceeds the maximum number of even splits (e.g. GQA with num_kv_heads < world_size), multiple ranks share the same shard. + + TODO: support num_units % world_size != 0 via GCD-based partial replication. + When num_heads doesn't divide by world_size (e.g. 28 Q heads at tp_size=8), + use effective_splits = gcd(num_heads, world_size) to split at head + boundaries and replicate each shard across world_size // effective_splits + ranks. To compensate the duplication in all_reduce, scale rowwise weights + by 1 / replication_factor during sharding (baked into the weight tensor, + no changes to all_reduce or the graph needed). """ max_split_size = t.shape[dim] // min_local_shape if world_size > max_split_size: - # TODO: support remainder case (world_size % max_split_size != 0). - # Currently the downstream view/split/slice fixups in _process_column_sharding - # assume even division by world_size, so uneven grouping would produce wrong shapes. assert world_size % max_split_size == 0, ( f"world_size ({world_size}) must be divisible by max_split_size ({max_split_size}). " f"GQA with num_kv_heads not dividing world_size is not supported." @@ -1548,6 +1561,14 @@ def _split_tensor_for_tp( f"Splitting tensor to {num_groups} chunks" ) return torch.tensor_split(t, max_split_size, dim=dim)[rank // num_groups] + + assert max_split_size % world_size == 0, ( + f"Number of units ({max_split_size}, dim {dim} size {t.shape[dim]} / " + f"min_local_shape {min_local_shape}) must be divisible by world_size " + f"({world_size}). For attention heads, use a world_size that divides " + f"num_heads evenly (e.g. for {max_split_size} heads, try world_size in " + f"{[d for d in range(2, max_split_size + 1) if max_split_size % d == 0]})." + ) return torch.tensor_split(t, world_size, dim=dim)[rank] @@ -1887,10 +1908,10 @@ def _insert_sharded_moe( # ===================================================================================== # DISTRIBUTED GRID CONFIGURATION # ===================================================================================== - ep_size = config.mapping.moe_ep_size - ep_rank = config.mapping.moe_ep_rank - tp_size = config.mapping.moe_tp_size - tp_rank = config.mapping.moe_tp_rank + ep_size = config.dist_config.moe_ep_size + ep_rank = config.dist_config.moe_ep_rank + tp_size = config.dist_config.moe_tp_size + tp_rank = config.dist_config.moe_tp_rank # All-to-all is used when: # 1. Attention uses data parallelism (tokens distributed across ranks) # 2. AND we have EP > 1 (experts distributed across ranks) @@ -2041,7 +2062,7 @@ def get_partition(lst, world_size, rank): # Serialize Mapping for all-to-all dispatch/combine # (Will be used inside the op to determine enable_alltoall and workspace size) - mapping_config = serialize_mapping(config.mapping) + mapping_config = config.dist_config.serialize() # Write back weight/scale list updates (applied above) and inject mapping args. # set_op_args uses the op schema to place values into kwargs or the correct @@ -2228,9 +2249,7 @@ def _process_ssm_sharding( # in_proj and conv1d are fused, followed up by split nodes. Infer split sizes: assert len(entry_node.users) == 1, "Expecting exactly one user for the entry node" split_node_0 = list(entry_node.users)[0] - assert is_op(split_node_0, [torch.ops.aten.split_with_sizes]), ( - "Expecting split_with_sizes node for the entry node" - ) + assert is_any_split_op(split_node_0), "Expecting split_with_sizes node for the entry node" split_sizes_0 = split_node_0.args[1] # extract the single conv1d node conv1d_nodes = [ @@ -2242,9 +2261,7 @@ def _process_ssm_sharding( silu_node_1 = list(conv1d_node.users)[0] assert len(silu_node_1.users) == 1, "Expecting exactly one user for the silu node" split_node_1 = list(silu_node_1.users)[0] - assert is_op(split_node_1, [torch.ops.aten.split_with_sizes]), ( - "Expecting split_with_sizes node for the split node" - ) + assert is_any_split_op(split_node_1), "Expecting split_with_sizes node for the split node" split_sizes_1 = split_node_1.args[1] assert split_sizes_0[1] == sum(split_sizes_1) fused_weight_dims = { @@ -2343,9 +2360,7 @@ def _process_ssm_sharding( # ############################################################## # ############## update the view and reshape nodes ############# # ############################################################## - nodes_to_validate = [ - n for n in subgraph_nodes if is_op(n, [torch.ops.aten.view, torch.ops.aten.reshape]) - ] + nodes_to_validate = [n for n in subgraph_nodes if is_any_view_op(n)] for view_node in nodes_to_validate: if len(view_node.args) < 2: continue @@ -2424,8 +2439,7 @@ def _process_delta_sharding( # Find split([key_dim, key_dim, value_dim]) after conv1d (produces 3 outputs: q, k, v) split_node_after_conv, depth = bfs( conv1d_node, - lambda n: is_op(n, [torch.ops.aten.split_with_sizes, torch.ops.aten.split]) - and len(list(n.users)) >= 3, + lambda n: is_any_split_op(n) and len(list(n.users)) >= 3, ) # Extract conv split sizes early (needed for fused_weight_dims on unfused in_proj_qkv) @@ -2433,10 +2447,6 @@ def _process_delta_sharding( if split_node_after_conv is not None and len(split_node_after_conv.args) > 1: conv_split_sizes_original = tuple(split_node_after_conv.args[1]) [conv_groups] = extract_op_args(conv1d_node, "groups") - assert sum(conv_split_sizes_original) == conv_groups, ( - f"Split sizes {conv_split_sizes_original} (sum={sum(conv_split_sizes_original)}) " - f"do not match conv1d groups {conv_groups}" - ) # ############################################################## # ############## shard the opening nodes (column) ############## @@ -2554,9 +2564,7 @@ def _process_delta_sharding( # ############## update the view and reshape nodes ############# # ############################################################## # Shard dim 2 (head count) in view/reshape nodes with concrete num_heads values - nodes_to_validate = [ - n for n in subgraph_nodes if is_op(n, [torch.ops.aten.view, torch.ops.aten.reshape]) - ] + nodes_to_validate = [n for n in subgraph_nodes if is_any_view_op(n)] for view_node in nodes_to_validate: if len(view_node.args) < 2: continue @@ -2706,7 +2714,7 @@ def _process_mla_sharding( # attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) # attn_output = self.o_proj(attn_output) candidate_reshape = layer_subgraph.terminating_node.args[0] - if is_op(candidate_reshape, [torch.ops.aten.reshape]): + if is_any_view_op(candidate_reshape): # reshape args are (attn_output, [bsz, q_len, num_heads * v_head_dim]) # set 3rd arg (num_heads * v_head_dim) to -1 reshape_args = list(candidate_reshape.args) @@ -2749,9 +2757,7 @@ def _determine_fused_weight_dims( if len(linear_nodes) == 1: linear_node = linear_nodes[0] # check if there are split nodes in the subgraph. They may indicate fused weights (e.g., QKV) - linear_split_users = list( - filtered_nodes(linear_node.users, ops=torch.ops.aten.split_with_sizes) - ) + linear_split_users = list(filtered_nodes(linear_node.users, target=is_any_split_op)) linear_slice_users = list(filtered_nodes(linear_node.users, ops=torch.ops.aten.slice)) linear_chunk_users = list(filtered_nodes(linear_node.users, ops=torch.ops.aten.chunk)) if len(linear_split_users) > 0: @@ -2807,6 +2813,7 @@ def _find_upstream_qk_proj(node: Node, gm: GraphModule) -> Optional[str]: passthrough_ops = [ torch.ops.aten.view, torch.ops.aten.reshape, + torch.ops.auto_deploy.view, torch.ops.aten.contiguous, torch.ops.aten.clone, torch.ops.aten.to, @@ -2885,7 +2892,7 @@ def _shard_qk_norm( Returns: Number of nodes added for sharding """ - if layer_subgraph.layer_type != LayerType.ATTENTION or layer_subgraph.terminating_node is None: + if layer_subgraph.layer_type != LayerType.MHA or layer_subgraph.terminating_node is None: return 0 config = transform_container.config @@ -3014,9 +3021,7 @@ def _process_column_sharding( ad_logger.debug("No nodes were added for column sharding. Skipping.") return 0 - nodes_to_validate = [ - n for n in subgraph_nodes if is_op(n, [torch.ops.aten.view, torch.ops.aten.reshape]) - ] + nodes_to_validate = [n for n in subgraph_nodes if is_any_view_op(n)] for view_node in nodes_to_validate: if len(view_node.args) < 2: continue @@ -3038,7 +3043,7 @@ def _process_column_sharding( # fused weight may either be processed by several slice nodes or a single split node linear_node = linear_nodes[0] - split_nodes = list(filtered_nodes(linear_node.users, ops=[torch.ops.aten.split_with_sizes])) + split_nodes = list(filtered_nodes(linear_node.users, target=is_any_split_op)) slice_nodes = list(filtered_nodes(linear_node.users, ops=[torch.ops.aten.slice])) if len(split_nodes) > 0: user = split_nodes[0] @@ -3123,7 +3128,7 @@ def detect_sharding_from_config( # use layer_subgraphs to determine the layer_type # and check the validity of the sharding transform layer_subgraphs, unprocessed_linear_nodes = get_all_layer_subgraphs( - gm, linear_nodes=_get_config_layer_linear_nodes(gm, tp_plan) + gm, linear_nodes=linear_nodes ) for lin_node in linear_nodes: @@ -3169,7 +3174,7 @@ def detect_sharding_from_config( layer_type=layer_type, ) ): - if layer_type == LayerType.ATTENTION: + if layer_type == LayerType.MHA: num_attention_shards += 1 num_row_col_shards += 1 elif config == "mamba": @@ -3300,7 +3305,7 @@ def detect_column_row_shard( The transformation is based on the following steps: - 1. Identify boundary nodes between residual nodes to identify shardable regions. + 1. Identify boundary nodes between residual nodes to identify enable_sharding regions. 2. Identify the GEMM nodes that can be sharded 3. Trace through the subgraph using DFS/BFS between each pair of boundary nodes 4. Account for each node in the trace to ensure the op is correct even after sharding. This is @@ -3369,7 +3374,7 @@ def detect_column_row_shard( ) continue - if layer.layer_type == LayerType.ATTENTION: + if layer.layer_type == LayerType.MHA: head_dim = layer.min_local_shape # if the QKV projection is fused, check if num_kv_heads is divisible by world_size if len(opening) == 1: @@ -3409,7 +3414,7 @@ def detect_column_row_shard( ) ): num_column_row_shards += 1 - if layer.layer_type == LayerType.ATTENTION: + if layer.layer_type == LayerType.MHA: num_mha_shards += 1 # simple shard remaining linear nodes diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding_ir.py b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding_ir.py new file mode 100644 index 000000000000..f0ba93b1fc00 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding_ir.py @@ -0,0 +1,1123 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Hint-driven IR sharding transform for AutoDeploy. + +This module implements the ``apply_sharding_hints`` and ``strip_sharding_hints`` +transforms, which apply deterministic, node-local sharding based on explicit +hint kwargs on custom ops and a runtime ``DistConfig``. + +This is the replacement for the legacy heuristic-based sharding pipeline in +``sharding.py``. See the design documents in ``sharding_architecture_documents/`` +for background. +""" + +import operator +from abc import ABC, abstractmethod +from functools import partial +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch +from pydantic import Field, field_validator +from torch._ops import OpOverload, OpOverloadPacket +from torch.fx import GraphModule, Node + +from tensorrt_llm._torch.auto_deploy.utils.dist_config import DistConfig + +from .....functional import AllReduceStrategy +from ...models.factory import ModelFactory +from ...shim.interface import CachedSequenceInterface +from ...utils._graph import del_attr_by_name, eliminate_dead_code +from ...utils.logger import ad_logger +from ...utils.node_utils import ( + WeightBiasInfoCache, + WeightNode, + _get_op_schema, + extract_op_args, + extract_weight_nodes, + invalidate_weight_node_cache, + is_any_lin_op, + set_op_args, + shape, +) +from ..interface import ( + BaseTransform, + SharedConfig, + TransformConfig, + TransformInfo, + TransformRegistry, +) + +# NOTE: sharding.py module will be deprecated in the future. The following +# imports will move into sharding_ir.py when legacy sharding is removed. +from .sharding import ( + SplitDimension, + _get_dist_ops, + _load_hook, + _shard_fp4_weight_scale, + shard_weight_tensor, + validate_allreduce_strategy, +) + + +def _split_fp8_block_scale( + scale: torch.Tensor, dim: int, rank: int, world_size: int +) -> torch.Tensor: + """Split a finegrained FP8 per-block scale tensor along *dim*. + + Handles the edge case where ``scale.shape[dim] < world_size`` (e.g., a + 2-row scale shared across 8 GPUs) by grouping ranks that share a row. + """ + scale_dim = scale.shape[dim] + if scale_dim >= world_size: + return torch.tensor_split(scale, world_size, dim=dim)[rank] + group = rank // (world_size // scale_dim) + return torch.tensor_split(scale, scale_dim, dim=dim)[group] + + +def _shard_scale_and_hook( + gm: GraphModule, + sn: WeightNode, + sharded_scale: torch.Tensor, + f_split, +) -> None: + """Register a sharded scale buffer and its corresponding load hook.""" + buf_name = sn.node_key.rsplit(".", 1)[-1] + sn.submod.register_buffer(buf_name, sharded_scale) + gm._register_load_state_dict_pre_hook( + partial(_load_hook, f_split=f_split, param_key=sn.node_key, param_shape=sharded_scale.shape) + ) + + +_SHARDING_HINT_NAMES = frozenset( + { + "tp_mode", + "output_sizes", + "tp_min_local_shape", + "layer_type", + "enable_sharding", + "tp_scaled_dim", + } +) + +# ============================================================================= +# ShardableNode abstract base class +# ============================================================================= + + +class ShardableNode(ABC): + """Base class for graph nodes that carry sharding hints. + + Each specialized subclass encapsulates the sharding logic for one category of + custom op. Subclasses self-register via the ``@ShardableNode.register`` + decorator, and ``from_node`` dispatches an FX node to the correct subclass. + """ + + _REGISTRY: Dict[OpOverload, Type["ShardableNode"]] = {} + + def __init__(self, node: Node): + self.node = node + + @classmethod + def register(cls, *op_targets): + """Class decorator that registers a ShardableNode subclass for the given op targets.""" + + def decorator(subcls): + for target in op_targets: + if isinstance(target, OpOverloadPacket): + for overload_name in target.overloads(): + cls._REGISTRY[getattr(target, overload_name)] = subcls + else: + cls._REGISTRY[target] = subcls + return subcls + + return decorator + + @classmethod + def _resolve(cls, target) -> Optional[Type["ShardableNode"]]: + """Look up the registered subclass for an op target. + + Handles both ``OpOverload`` (e.g. ``torch_moe.default``) and + ``OpOverloadPacket`` (e.g. ``torch_moe``) since FX nodes may + store either form as their target. + """ + subcls = cls._REGISTRY.get(target) + if subcls is None and isinstance(target, OpOverloadPacket): + subcls = cls._REGISTRY.get(getattr(target, "default", None)) + return subcls + + @staticmethod + def from_node(node: Node) -> Optional["ShardableNode"]: + """Return a ShardableNode for *node*, or ``None`` if the op is not enable_sharding.""" + if not isinstance(node, Node) or node.op != "call_function": + return None + subcls = ShardableNode._resolve(node.target) + if subcls is None: + return None + return subcls(node) + + @abstractmethod + def apply(self, gm: GraphModule, dc: DistConfig, max_num_tokens: int = 0) -> int: + """Apply sharding to this node. Returns 1 if modified, 0 otherwise.""" + ... + + @classmethod + def strip_hints(cls, node: Node) -> bool: + """Strip sharding hint args/kwargs from *node*. Returns ``True`` if modified. + + Dispatches to the registered subclass's ``_strip_node_hints``. The base + implementation uses schema introspection to strip trailing hint args and + reset non-trailing ones to defaults. Subclasses that represent pure + placeholder ops (view, split_with_sizes, all_reduce) override this to + lower the node to a zero-copy aten equivalent instead. + """ + if node.op != "call_function": + return False + subcls = cls._resolve(node.target) + if subcls is None: + return False + return subcls._strip_node_hints(node) + + @classmethod + def _strip_node_hints(cls, node: Node) -> bool: + """Default: strip hint args/kwargs by schema introspection. + + Trailing hint args are popped entirely (so downstream transforms see + the canonical short args tuple). Non-trailing hints are reset to their + schema defaults via ``set_op_args``. Hint kwargs are removed. + """ + schema = _get_op_schema(node) + hint_defaults: Dict[str, Any] = {} + hint_positions: set = set() + for i, a in enumerate(schema.arguments): + if a.name in _SHARDING_HINT_NAMES: + hint_defaults[a.name] = a.default_value if a.has_default_value else None + hint_positions.add(i) + + has_hint_kwargs = bool(node.kwargs and (_SHARDING_HINT_NAMES & node.kwargs.keys())) + if not hint_defaults and not has_hint_kwargs: + return False + + modified = False + + # Pop trailing hint args to keep the args tuple minimal + args = list(node.args) + while args and (len(args) - 1) in hint_positions: + args.pop() + modified = True + if modified: + node.args = tuple(args) + + # Reset any non-trailing hints still in the args to their defaults + non_trailing = { + schema.arguments[i].name: hint_defaults[schema.arguments[i].name] + for i in hint_positions + if i < len(node.args) + } + if non_trailing: + set_op_args(node, **non_trailing) + modified = True + + # Strip hint kwargs + if has_hint_kwargs: + node.kwargs = {k: v for k, v in node.kwargs.items() if k not in _SHARDING_HINT_NAMES} + modified = True + + return modified + + +# ============================================================================= +# Specialized ShardableNode subclasses +# ============================================================================= + + +@ShardableNode.register( + torch.ops.auto_deploy.torch_linear_simple, + torch.ops.auto_deploy.torch_fake_quant_fp8_linear, + torch.ops.auto_deploy.torch_quant_fp8_linear, + torch.ops.auto_deploy.trtllm_quant_fp8_linear, + torch.ops.auto_deploy.torch_fake_quant_int4_linear, + torch.ops.auto_deploy.torch_fake_quant_int4_gptq_linear, +) +class LinearShardableNode(ShardableNode): + """Linear ops: weight + bias sharding. No quantized scale handling. + + Covers BF16, standard FP8 (per-tensor scales via List args), and INT4. + Quantization variants with per-block scale buffers use subclasses that + override ``_shard_scales``. + """ + + def apply(self, gm: GraphModule, dc: DistConfig, max_num_tokens: int = 0) -> int: + [tp_mode, output_sizes, tp_min_local_shape] = extract_op_args( + self.node, "tp_mode", "output_sizes", "tp_min_local_shape" + ) + if tp_mode == "none": + return 0 + + split_dim = SplitDimension.COLUMN if tp_mode == "colwise" else SplitDimension.ROW + fused = tuple(output_sizes) if output_sizes else None + min_shape = tp_min_local_shape if tp_min_local_shape else 1 + + weight_nodes = extract_weight_nodes(self.node) + + for wn in weight_nodes.weights: + shard_weight_tensor( + gm=gm, + weight_tensor=wn.tensor, + param_key=wn.node_key, + dim=split_dim, + rank=dc.tp_rank, + world_size=dc.tp_size, + min_local_shape=min_shape, + fused_weight_dims=fused, + ) + + self._shard_scales(gm, dc, weight_nodes, split_dim, min_shape, fused) + + for bn in weight_nodes.biases: + if split_dim == SplitDimension.COLUMN: + shard_weight_tensor( + gm=gm, + weight_tensor=bn.tensor, + param_key=bn.node_key, + dim=SplitDimension.COLUMN, + rank=dc.tp_rank, + world_size=dc.tp_size, + fused_weight_dims=fused, + ) + + ad_logger.debug(f" sharded linear tp_mode={tp_mode}") + return 1 + + def _shard_scales(self, gm, dc, weight_nodes, dim, min_shape=1, fused=None): + """Override in quantization subclasses to shard per-block scale buffers.""" + pass + + +@ShardableNode.register( + torch.ops.auto_deploy.torch_fake_quant_finegrained_fp8_linear, + torch.ops.auto_deploy.trtllm_finegrained_fp8_linear, +) +class FineGrainedFP8LinearShardableNode(LinearShardableNode): + """FineGrained FP8 linear: shards per-block ``weight_scale_inv`` buffers.""" + + def _shard_scales(self, gm, dc, weight_nodes, dim, min_shape=1, fused=None): + for sn in weight_nodes.scales: + f_split = partial( + _split_fp8_block_scale, dim=dim, rank=dc.tp_rank, world_size=dc.tp_size + ) + sharded = f_split(sn.tensor) + _shard_scale_and_hook(gm, sn, sharded, f_split) + + +@ShardableNode.register( + torch.ops.auto_deploy.torch_fake_quant_nvfp4_linear, + torch.ops.auto_deploy.torch_quant_nvfp4_linear, +) +class FP4LinearShardableNode(LinearShardableNode): + """NVFP4 linear: shards cutlass-format ``weight_scale`` buffers.""" + + def _shard_scales(self, gm, dc, weight_nodes, dim, min_shape=1, fused=None): + weight_shape = weight_nodes.weights[0].tensor.shape if weight_nodes.weights else None + if weight_shape is None: + return + for sn in weight_nodes.scales: + f_split = partial( + _shard_fp4_weight_scale, + original_uint8_weight_shape=weight_shape, + dim=dim, + rank=dc.tp_rank, + world_size=dc.tp_size, + min_local_shape=min_shape, + fused_weight_dims=fused, + ) + sharded = f_split(sn.tensor) + _shard_scale_and_hook(gm, sn, sharded, f_split) + + +@ShardableNode.register(torch.ops.auto_deploy.view) +class ViewShardableNode(ShardableNode): + """View op: replace shape[tp_scaled_dim] with -1 so PyTorch infers the sharded size.""" + + @classmethod + def _strip_node_hints(cls, node: Node) -> bool: + """Lower to aten.reshape, keeping only (input, shape).""" + node.target = torch.ops.aten.reshape.default + node.args = (node.args[0], node.args[1]) + node.kwargs = {} + return True + + def apply(self, gm: GraphModule, dc: DistConfig, max_num_tokens: int = 0) -> int: + [tp_scaled_dim, view_shape] = extract_op_args(self.node, "tp_scaled_dim", "shape") + if tp_scaled_dim == -1: + return 0 + + view_shape = list(view_shape) + if tp_scaled_dim < 0: + tp_scaled_dim = len(view_shape) + tp_scaled_dim + if tp_scaled_dim < len(view_shape) and isinstance(view_shape[tp_scaled_dim], int): + view_shape[tp_scaled_dim] = -1 + set_op_args(self.node, shape=view_shape) + ad_logger.debug(f" updated view shape at dim {tp_scaled_dim} to -1 (inferred)") + return 1 + return 0 + + +@ShardableNode.register(torch.ops.auto_deploy.split_with_sizes) +class SplitShardableNode(ShardableNode): + """split_with_sizes op: divide all split sizes by tp_size.""" + + @classmethod + def _strip_node_hints(cls, node: Node) -> bool: + """Lower to aten.split_with_sizes, keeping only (input, sizes, dim).""" + dim = node.args[2] if len(node.args) > 2 else -1 + node.target = torch.ops.aten.split_with_sizes.default + node.args = (node.args[0], node.args[1], dim) + node.kwargs = {} + return True + + def apply(self, gm: GraphModule, dc: DistConfig, max_num_tokens: int = 0) -> int: + [enable_sharding, split_sizes] = extract_op_args( + self.node, "enable_sharding", "split_sizes" + ) + if not enable_sharding: + return 0 + + split_sizes = list(split_sizes) + for s in split_sizes: + assert s % dc.tp_size == 0, ( + f"split_with_sizes: size {s} is not divisible by tp_size={dc.tp_size}. " + f"Full split_sizes={split_sizes}. Ensure the model dimensions are " + f"compatible with the tensor parallel degree." + ) + scaled = [s // dc.tp_size for s in split_sizes] + set_op_args(self.node, split_sizes=scaled) + ad_logger.debug(f" updated split_with_sizes: {split_sizes} -> {scaled}") + return 1 + + +@ShardableNode.register(torch.ops.auto_deploy.all_reduce) +class AllReduceShardableNode(ShardableNode): + """all_reduce placeholder: replace with real dist.all_reduce or identity.""" + + @classmethod + def _strip_node_hints(cls, node: Node) -> bool: + """Remove the all_reduce placeholder entirely (passthrough to input).""" + node.replace_all_uses_with(node.args[0]) + node.graph.erase_node(node) + return True + + def apply(self, gm: GraphModule, dc: DistConfig, max_num_tokens: int = 0) -> int: + if dc.tp_size <= 1: + return 0 + + _, all_reduce_op = _get_dist_ops("auto") + [x] = extract_op_args(self.node, "x") + self.node.target = all_reduce_op + self.node.args = (x, dc.allreduce_strategy) + ad_logger.debug(f" inserted real all_reduce ({all_reduce_op.__name__})") + return 1 + + +@ShardableNode.register(torch.ops.auto_deploy.torch_causal_conv1d) +class Conv1dShardableNode(ShardableNode): + """Conv1d op: shard weight/bias with fused dims, update groups.""" + + def apply(self, gm: GraphModule, dc: DistConfig, max_num_tokens: int = 0) -> int: + [enable_sharding, output_sizes] = extract_op_args( + self.node, "enable_sharding", "output_sizes" + ) + if not enable_sharding: + return 0 + + fused = list(output_sizes) if output_sizes else None + weight_nodes = extract_weight_nodes(self.node) + + for wn in weight_nodes.weights: + shard_weight_tensor( + gm=gm, + weight_tensor=wn.tensor, + param_key=wn.node_key, + dim=0, + rank=dc.tp_rank, + world_size=dc.tp_size, + fused_weight_dims=fused, + ) + + for bn in weight_nodes.biases: + shard_weight_tensor( + gm=gm, + weight_tensor=bn.tensor, + param_key=bn.node_key, + dim=0, + rank=dc.tp_rank, + world_size=dc.tp_size, + fused_weight_dims=fused, + ) + + # No quantized conv1d variants exist; scales not handled. + # If a quantized variant is added, add _shard_scales() like LinearShardableNode. + + [groups] = extract_op_args(self.node, "groups") + assert groups % dc.tp_size == 0, ( + f"conv1d groups ({groups}) must be divisible by tp_size ({dc.tp_size})" + ) + set_op_args(self.node, groups=groups // dc.tp_size) + ad_logger.debug(f" sharded conv1d, groups {groups} -> {groups // dc.tp_size}") + return 1 + + +@ShardableNode.register( + torch.ops.auto_deploy.torch_ssm, + torch.ops.auto_deploy.torch_gated_delta_rule, + torch.ops.auto_deploy.torch_mla, +) +class WeightedParamShardableNode(ShardableNode): + """Ops whose weight parameters are sharded along dim 0 (head dimension). + + Covers SSM (A, D, dt_bias), GatedDeltaNet (A_log, dt_bias), and MLA + (kv_b_proj). All share identical sharding logic: when ``enable_sharding`` + is ``True``, every discovered weight parameter is split along dim 0. + """ + + def apply(self, gm: GraphModule, dc: DistConfig, max_num_tokens: int = 0) -> int: + [enable_sharding] = extract_op_args(self.node, "enable_sharding") + if not enable_sharding: + return 0 + + weight_nodes = extract_weight_nodes(self.node) + + # SSM/GDN/MLA ops have only weight parameters (A, D, dt_bias, kv_b_proj); + # no biases or quantized scales. Assert this assumption explicitly. + assert not weight_nodes.biases, ( + f"Unexpected biases on {self.node.target}: {weight_nodes.biases}" + ) + assert not weight_nodes.scales, ( + f"Unexpected scales on {self.node.target}: {weight_nodes.scales}" + ) + + count = 0 + for wn in weight_nodes.weights: + shard_weight_tensor( + gm=gm, + weight_tensor=wn.tensor, + param_key=wn.node_key, + dim=0, + rank=dc.tp_rank, + world_size=dc.tp_size, + ) + count += 1 + + ad_logger.debug(f" sharded weighted params ({count} tensors)") + return 1 if count > 0 else 0 + + +@ShardableNode.register( + torch.ops.auto_deploy.torch_rmsnorm_gated, + torch.ops.auto_deploy.triton_rmsnorm_gated, +) +class NormShardableNode(ShardableNode): + """Gated RMSNorm op: shard weight parameter.""" + + def apply(self, gm: GraphModule, dc: DistConfig, max_num_tokens: int = 0) -> int: + [tp_mode] = extract_op_args(self.node, "tp_mode") + if tp_mode == "none": + return 0 + + weight_nodes = extract_weight_nodes(self.node) + count = 0 + for wn in weight_nodes.weights: + shard_weight_tensor( + gm=gm, + weight_tensor=wn.tensor, + param_key=wn.node_key, + dim=0, + rank=dc.tp_rank, + world_size=dc.tp_size, + ) + count += 1 + + ad_logger.debug(f" sharded norm ({count} tensors)") + return 1 if count > 0 else 0 + + +@ShardableNode.register(torch.ops.auto_deploy.torch_swiglu_mlp) +class SwiGLUShardableNode(ShardableNode): + """SwiGLU MLP ops: shard gate/up colwise, down rowwise. + + Handles the intermediate (unfused) SwiGLU representation produced by + ``match_swiglu_pattern``. The fused variant (``fused_swiglu_mlp``) is + created AFTER sharding in ``post_load_fusion`` and needs no handling. + + Quantized variants with per-block scale buffers use subclasses that + override ``_shard_scales``. + """ + + @staticmethod + def _dim_for_key(node_key: str) -> int: + """Determine split dimension from the module path: ``down`` → ROW, else COLUMN.""" + return SplitDimension.ROW if "down" in node_key else SplitDimension.COLUMN + + def apply(self, gm: GraphModule, dc: DistConfig, max_num_tokens: int = 0) -> int: + weight_nodes = extract_weight_nodes(self.node) + if not weight_nodes.weights: + return 0 + + for wn in weight_nodes.weights: + shard_weight_tensor( + gm=gm, + weight_tensor=wn.tensor, + param_key=wn.node_key, + dim=self._dim_for_key(wn.node_key), + rank=dc.tp_rank, + world_size=dc.tp_size, + ) + + for bn in weight_nodes.biases: + if self._dim_for_key(bn.node_key) == SplitDimension.COLUMN: + shard_weight_tensor( + gm=gm, + weight_tensor=bn.tensor, + param_key=bn.node_key, + dim=SplitDimension.COLUMN, + rank=dc.tp_rank, + world_size=dc.tp_size, + ) + + self._shard_scales(gm, dc, weight_nodes) + + ad_logger.debug( + f" sharded SwiGLU MLP ({len(weight_nodes.weights)} weights, " + f"{len(weight_nodes.scales)} scales)" + ) + return 1 + + def _shard_scales(self, gm, dc, weight_nodes): + """Override in quantization subclasses to shard per-block scale buffers.""" + pass + + +@ShardableNode.register(torch.ops.auto_deploy.torch_finegrained_fp8_swiglu_mlp) +class FineGrainedFP8SwiGLUShardableNode(SwiGLUShardableNode): + """FineGrained FP8 SwiGLU: shards per-block ``weight_scale_inv`` buffers.""" + + def _shard_scales(self, gm, dc, weight_nodes): + for sn in weight_nodes.scales: + dim = self._dim_for_key(sn.node_key) + f_split = partial( + _split_fp8_block_scale, dim=dim, rank=dc.tp_rank, world_size=dc.tp_size + ) + _shard_scale_and_hook(gm, sn, f_split(sn.tensor), f_split) + + +@ShardableNode.register(torch.ops.auto_deploy.torch_nvfp4_swiglu_mlp) +class FP4SwiGLUShardableNode(SwiGLUShardableNode): + """NVFP4 SwiGLU: shards cutlass-format ``weight_scale`` buffers.""" + + def _shard_scales(self, gm, dc, weight_nodes): + weight_shape = weight_nodes.weights[0].tensor.shape if weight_nodes.weights else None + if weight_shape is None: + return + for sn in weight_nodes.scales: + dim = self._dim_for_key(sn.node_key) + f_split = partial( + _shard_fp4_weight_scale, + original_uint8_weight_shape=weight_shape, + dim=dim, + rank=dc.tp_rank, + world_size=dc.tp_size, + min_local_shape=1, + fused_weight_dims=None, + ) + _shard_scale_and_hook(gm, sn, f_split(sn.tensor), f_split) + + +@ShardableNode.register( + torch.ops.auto_deploy.torch_moe, + torch.ops.auto_deploy.torch_quant_fp8_moe, + torch.ops.auto_deploy.torch_quant_nvfp4_moe, + torch.ops.auto_deploy.torch_quant_finegrained_fp8_moe, +) +class MoEShardableNode(ShardableNode): + """List-based MoE ops: EP weight partitioning, expert ID localization, mapping injection. + + Handles ops where args[3:6] are ``List[torch.Tensor]`` (per-expert weight + lists). Stacked-tensor MoE variants (``torch_moe_fused``, + ``torch_moe_dense_mlp``, ``triton_mxfp4_moe``) are NOT registered here -- + see ``StackedMoEShardableNode`` for stacked-tensor variants, and the other two are either + converted to list-based ``torch_moe`` before sharding or left replicated. + """ + + def apply(self, gm: GraphModule, dc: DistConfig, max_num_tokens: int = 0) -> int: + ep_size = dc.moe_ep_size + ep_rank = dc.moe_ep_rank + tp_size = dc.moe_tp_size + tp_rank = dc.moe_tp_rank + enable_alltoall = dc.enable_attention_dp and ep_size > 1 + + if ep_size <= 1 and tp_size <= 1: + return 0 + + [selected_experts, routing_weights, w1_weight, w2_weight, w3_weight] = extract_op_args( + self.node, "selected_experts", "routing_weights", "w1_weight", "w2_weight", "w3_weight" + ) + num_experts = len(w1_weight) + assert num_experts % ep_size == 0, ( + f"num_experts ({num_experts}) must be divisible by ep_size ({ep_size})" + ) + experts_per_rank = num_experts // ep_size + + def get_partition(lst, world_size, rank): + n = len(lst) + per_part = n // world_size + start = rank * per_part + end_idx = n if (rank == world_size - 1) else start + per_part + return lst[start:end_idx], lst[:start] + lst[end_idx:] + + w1_sharded, w1_removed = get_partition(w1_weight, ep_size, ep_rank) + w2_sharded, w2_removed = get_partition(w2_weight, ep_size, ep_rank) + w3_sharded, w3_removed = get_partition(w3_weight, ep_size, ep_rank) + nodes_to_remove = w1_removed + w2_removed + w3_removed + + if tp_size > 1: + for w in w1_sharded + w3_sharded: + shard_weight_tensor( + gm=gm, + weight_tensor=gm.get_parameter(w.target), + param_key=w.target, + dim=SplitDimension.COLUMN, + rank=tp_rank, + world_size=tp_size, + ) + for w in w2_sharded: + shard_weight_tensor( + gm=gm, + weight_tensor=gm.get_parameter(w.target), + param_key=w.target, + dim=SplitDimension.ROW, + rank=tp_rank, + world_size=tp_size, + ) + + set_op_args(self.node, w1_weight=w1_sharded, w2_weight=w2_sharded, w3_weight=w3_sharded) + + # Shard scale lists (quantized MoE ops have per-expert scale lists). + # Unlike Linear/SwiGLU where scales are single buffer tensors handled by + # _shard_scales(), MoE scales are List[Tensor] (one per expert) -- the same + # structure as weights. They must be EP-partitioned identically to weights. + # We use positional args[6:] because scale arg names vary across quantized + # op variants (w1_weight_scale, input_scale, etc.). + args = list(self.node.args) + for i in range(6, len(args)): + if isinstance(args[i], (list, tuple)) and len(args[i]) == num_experts: + sharded, removed = get_partition(list(args[i]), ep_size, ep_rank) + args[i] = sharded + nodes_to_remove.extend(removed) + self.node.args = tuple(args) + + if enable_alltoall: + # mapping and max_num_tokens are needed downstream for MoE all-to-all dispatcher + mapping_config = dc.serialize() + set_op_args(self.node, mapping_config=mapping_config, max_num_tokens=max_num_tokens) + else: + # with pure EP/TP parallelism, global expert indices must be localized + self._localize_expert_indices( + gm, selected_experts, routing_weights, experts_per_rank, ep_rank, ep_size + ) + + ad_logger.debug( + f" sharded MoE: {num_experts} experts, ep={ep_size}, ep_rank={ep_rank}, " + f"tp={tp_size}, tp_rank={tp_rank}, alltoall={enable_alltoall}, " + f"local_experts={len(w1_sharded)}, mapping_config_keys=" + f"[ep={dc.moe_ep_size},tp={dc.moe_tp_size},attn_dp={dc.enable_attention_dp}]" + ) + self._pending_dead_nodes = nodes_to_remove + return 1 + + def _localize_expert_indices( + self, + gm: GraphModule, + selected_experts: Node, + routing_weights: Node, + experts_per_rank: int, + ep_rank: int, + ep_size: int, + ) -> None: + """Remap global expert indices to EP-local indices and mask routing weights. + + Inserts graph nodes that (1) subtract the rank offset from + selected_experts to get local indices, and (2) zero out routing + weights for experts not assigned to this rank. + """ + with gm.graph.inserting_before(self.node): + lower = experts_per_rank * ep_rank + selected_experts_local = gm.graph.create_node( + "call_function", operator.sub, args=(selected_experts, lower), kwargs={} + ) + div_node = gm.graph.create_node( + "call_function", + operator.floordiv, + args=(selected_experts, experts_per_rank), + kwargs={}, + ) + comp_op = torch.ge if ep_rank == ep_size - 1 else torch.eq + rank_mask = gm.graph.create_node( + "call_function", comp_op, args=(div_node, ep_rank), kwargs={} + ) + routing_weights_local = gm.graph.create_node( + "call_function", operator.mul, args=(routing_weights, rank_mask), kwargs={} + ) + set_op_args( + self.node, + selected_experts=selected_experts_local, + routing_weights=routing_weights_local, + ) + + +@ShardableNode.register(torch.ops.auto_deploy.triton_mxfp4_moe) +class StackedMoEShardableNode(ShardableNode): + """Stacked-tensor MoE EP sharding: slice along the expert dimension and rewrite. + + Unlike :class:`MoEShardableNode` which handles list-based expert weights + (``List[Tensor]``), this class handles MoE ops where expert weights are + stacked into 3-D tensors (``Tensor[num_experts, ...]``). Sharding slices + along dim 0 to select the local expert partition. + + Currently the only registered variant is ``triton_mxfp4_moe`` (MXFP4 + quantized), but the approach generalises to any stacked-tensor MoE op. + The op is rewritten to ``triton_mxfp4_moe_ep`` with an explicit + ``all_reduce`` after the node. + """ + + _IDX_GATE_UP_BLOCKS = 4 + _IDX_GATE_UP_BIAS = 5 + _IDX_GATE_UP_SCALES = 6 + _IDX_DOWN_BLOCKS = 9 + _IDX_DOWN_BIAS = 10 + _IDX_DOWN_SCALES = 11 + + def apply(self, gm: GraphModule, dc: DistConfig, max_num_tokens: int = 0) -> int: + ep_size = dc.moe_ep_size + ep_rank = dc.moe_ep_rank + + if ep_size <= 1: + return 0 + + expert_shape = shape(self.node.args[self._IDX_GATE_UP_BLOCKS]) + assert expert_shape is not None, ( + f"Cannot determine num_experts: gate_up_blocks arg has no shape metadata " + f"(node: {self.node.name})" + ) + num_experts = expert_shape[0] + base = num_experts // ep_size + lo = base * ep_rank + hi = num_experts if ep_rank == ep_size - 1 else base * (ep_rank + 1) + + args = list(self.node.args) + for idx in ( + self._IDX_GATE_UP_BLOCKS, + self._IDX_GATE_UP_BIAS, + self._IDX_GATE_UP_SCALES, + self._IDX_DOWN_BLOCKS, + self._IDX_DOWN_BIAS, + self._IDX_DOWN_SCALES, + ): + with gm.graph.inserting_after(args[idx]): + args[idx] = gm.graph.call_function( + torch.ops.aten.slice.Tensor, + args=(args[idx], 0, lo, hi, 1), + ) + + self.node.target = torch.ops.auto_deploy.triton_mxfp4_moe_ep.default + self.node.args = tuple(args) + (int(ep_size), int(ep_rank)) + + _, all_reduce_op = _get_dist_ops("auto") + with gm.graph.inserting_after(self.node): + red = gm.graph.call_function( + all_reduce_op, + args=(self.node, dc.allreduce_strategy), + ) + self.node.replace_all_uses_with(red) + red.replace_input_with(red, self.node) + + ad_logger.debug( + f" sharded MXFP4 MoE: {num_experts} experts, ep={ep_size}, rank slice [{lo}:{hi}]" + ) + return 1 + + +# ============================================================================= +# IR sharding config +# ============================================================================= + + +class IRShardingConfig(TransformConfig): + """Minimal configuration for the hint-driven IR sharding transform. + + This replaces the legacy ``ShardingTransformConfig`` for + ``ApplyShardingHints``, carrying only the fields that the IR path actually + reads. When the legacy sharding path is removed, this is the only sharding + config class. + """ + + allreduce_strategy: AllReduceStrategy = Field( + default=AllReduceStrategy.AUTO, + description="AllReduce strategy for distributed operations.", + ) + simple_shard_only: bool = Field(default=False) + shard_layers: Optional[List[str]] = Field( + default=None, + description="When set, only shard nodes whose layer_type hint is in this list.", + ) + enable_attention_dp: bool = Field(default=False) + dist_mapping: dict[str, int] = Field(default_factory=dict) + dist_config: DistConfig = Field(default_factory=DistConfig) + + @field_validator("allreduce_strategy", mode="before") + @classmethod + def _validate_allreduce_strategy(cls, v): + return validate_allreduce_strategy(v) + + def _init_dist_config(self, rank: int, world_size: int): + """Initialize DistConfig from dist_mapping config (fallback path). + + Called when ``shared_config.dist_config`` is None (e.g. test suites + that construct ``InferenceOptimizer`` without a ``Mapping``). + ``rank`` and ``world_size`` come from ``shared_config``. + """ + self.dist_config = DistConfig( + world_size=world_size, + rank=rank, + tp_size=self.dist_mapping.get("tp", world_size), + moe_tp_size=self.dist_mapping.get("moe_tp", 1), + moe_ep_size=self.dist_mapping.get("moe_ep", world_size), + moe_cluster_size=self.dist_mapping.get("moe_cluster", 1), + enable_attention_dp=self.enable_attention_dp, + allreduce_strategy=self.allreduce_strategy.name, + ) + + +# ============================================================================= +# Standalone helpers +# ============================================================================= + + +def _log_sharding_prelude(dc: DistConfig) -> None: + """Log the sharding configuration before apply_sharding_hints runs.""" + skip = " (skipping)" if dc.tp_size <= 1 else "" + ad_logger.info( + f"apply_sharding_hints{skip}: tp_size={dc.tp_size}, tp_rank={dc.tp_rank}, " + f"moe grid: [ep x tp] = [{dc.moe_ep_size} x {dc.moe_tp_size}], " + f"strategy={dc.allreduce_strategy}" + ) + + +def _log_sharding_result( + dc: DistConfig, + num_updates: int, + num_skipped: int = 0, + *, + shard_layers: Optional[List[str]] = None, +) -> None: + """Log the sharding result after apply_sharding_hints completes.""" + mode = "attention_dp" if dc.enable_attention_dp else "TP + EP" + parts = [f"apply_sharding_hints ({mode}): {num_updates} nodes processed"] + if num_skipped: + parts.append(f"{num_skipped} skipped (shard_layers={shard_layers})") + ad_logger.info(", ".join(parts)) + + +def _apply_simple_shard(gm: GraphModule, dc: DistConfig) -> int: + """Simple shard fallback: column-split every linear weight, bias, and scale, then all_gather. + + Uses the polymorphic ``_shard_scales`` method on the resolved + ``ShardableNode`` subclass for format-aware scale handling (FP4 block + alignment, fine-grained FP8 per-block splits). + """ + num_updates = 0 + for node in list(gm.graph.nodes): + if not is_any_lin_op(node): + continue + weight_nodes = extract_weight_nodes(node) + if not weight_nodes.weights: + continue + for wn in weight_nodes.weights: + shard_weight_tensor( + gm=gm, + weight_tensor=wn.tensor, + param_key=wn.node_key, + dim=SplitDimension.COLUMN, + rank=dc.tp_rank, + world_size=dc.tp_size, + ) + for bn in weight_nodes.biases: + shard_weight_tensor( + gm=gm, + weight_tensor=bn.tensor, + param_key=bn.node_key, + dim=SplitDimension.COLUMN, + rank=dc.tp_rank, + world_size=dc.tp_size, + ) + enable_sharding = ShardableNode.from_node(node) + if isinstance(enable_sharding, LinearShardableNode): + enable_sharding._shard_scales( + gm, dc, weight_nodes, dim=SplitDimension.COLUMN, min_shape=1, fused=None + ) + with gm.graph.inserting_after(node): + gather_node = gm.graph.call_function( + torch.ops.auto_deploy.torch_dist_all_gather.default, + args=(node, -1), + ) + node.replace_all_uses_with(gather_node) + gather_node.replace_input_with(gather_node, node) + num_updates += 1 + return num_updates + + +# ============================================================================= +# Transform classes +# ============================================================================= + + +@TransformRegistry.register("strip_sharding_hints") +class StripShardingHints(BaseTransform): + """Strip sharding hints and lower placeholder ops to zero-copy aten equivalents. + + Placeholder ops (``auto_deploy.view``, ``split_with_sizes``, ``all_reduce``) + are replaced with native aten ops to eliminate the ``.clone()`` overhead + required by PyTorch's custom op framework. Other enable_sharding ops that have no + aten equivalent get their hint kwargs stripped so downstream transforms see + canonical op signatures. + """ + + @classmethod + def get_config_class(cls) -> Type[TransformConfig]: + return TransformConfig + + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + count = 0 + for node in list(gm.graph.nodes): + if ShardableNode.strip_hints(node): + count += 1 + if count: + gm.graph.lint() + gm.recompile() + return gm, TransformInfo( + skipped=(count == 0), + num_matches=count, + is_clean=(count == 0), + has_valid_shapes=True, + ) + + +@TransformRegistry.register("apply_sharding_hints") +class ApplyShardingHints(BaseTransform): + """Deterministic, node-local sharding transform driven by hint kwargs. + + Iterates graph nodes and applies sharding based on explicit hint arguments + (tp_mode, tp_scaled_dim, tp_scale_sizes, etc.) together with the runtime + DistConfig. No cross-node propagation, no topology inference. + """ + + config: IRShardingConfig + + @classmethod + def get_config_class(cls) -> Type[TransformConfig]: + return IRShardingConfig + + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + """ + Apply node-local sharding based on hint kwargs and runtime DistConfig. + + Skips when world_size < 2. Supports shard_layers filtering and simple_shard_only mode. + """ + invalidate_weight_node_cache(gm) + + if shared_config.dist_config is not None: + # Intentional alias: single shared DistConfig across all transforms + # so mutations (e.g., allreduce_strategy) propagate to downstream fusions. + self.config.dist_config = shared_config.dist_config + else: + self.config._init_dist_config(shared_config.local_rank, shared_config.world_size) + + dc = self.config.dist_config + dc.allreduce_strategy = self.config.allreduce_strategy.name + _log_sharding_prelude(dc) + + if shared_config.world_size < 2: + return gm, TransformInfo( + skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True + ) + + max_num_tokens = cm.info.max_num_tokens if (cm and cm.info) else 0 + + num_updates = 0 + if self.config.simple_shard_only: + num_updates = _apply_simple_shard(gm, dc) + _log_sharding_result(dc, num_updates) + else: + shard_layers = self.config.shard_layers + num_skipped = 0 + all_dead_nodes = [] + + with WeightBiasInfoCache(): + for node in list(gm.graph.nodes): + shardable_node = ShardableNode.from_node(node) + if shardable_node is None: + continue + if dc.enable_attention_dp and not isinstance( + shardable_node, (MoEShardableNode, StackedMoEShardableNode) + ): + continue + if shard_layers is not None: + [lt] = extract_op_args(node, "layer_type") + if lt is not None and lt not in shard_layers: + num_skipped += 1 + continue + + num_updates += shardable_node.apply(gm, dc, max_num_tokens) + + if hasattr(shardable_node, "_pending_dead_nodes"): + all_dead_nodes.extend(shardable_node._pending_dead_nodes) + + if all_dead_nodes: + eliminate_dead_code(gm) + for dead_node in all_dead_nodes: + try: + del_attr_by_name(gm, dead_node.target) + except AttributeError: + pass + + _log_sharding_result(dc, num_updates, num_skipped, shard_layers=shard_layers) + + return gm, TransformInfo( + skipped=False, + num_matches=num_updates, + is_clean=(num_updates == 0), + has_valid_shapes=True, + ) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/short_reshape_attention_output.py b/tensorrt_llm/_torch/auto_deploy/transform/library/short_reshape_attention_output.py index 10bc5d369950..57d560e8592a 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/short_reshape_attention_output.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/short_reshape_attention_output.py @@ -98,8 +98,11 @@ def _find_reshape_attention_output(self, gm: GraphModule) -> List[Node]: matched patterns. """ reshape_linear_pairs = [] - reshape_nodes = gm.graph.find_nodes( - op="call_function", target=torch.ops.aten.reshape.default + reshape_nodes = list( + gm.graph.find_nodes(op="call_function", target=torch.ops.aten.reshape.default) + ) + reshape_nodes.extend( + gm.graph.find_nodes(op="call_function", target=torch.ops.auto_deploy.view.default) ) for node in reshape_nodes: diff --git a/tensorrt_llm/_torch/auto_deploy/transform/optimizer.py b/tensorrt_llm/_torch/auto_deploy/transform/optimizer.py index 0bbcf4dca1c7..21ffbe3e3d58 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/optimizer.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/optimizer.py @@ -13,6 +13,7 @@ from ..shim.interface import CachedSequenceInterface from ..utils.logger import ad_logger from .interface import ( + DistConfig, InferenceOptimizerConfig, SharedConfig, Stages, @@ -23,15 +24,23 @@ class InferenceOptimizer: - def __init__(self, factory: ModelFactory, config: InferenceOptimizerConfig, mapping=None): + def __init__( + self, + factory: ModelFactory, + config: InferenceOptimizerConfig, + dist_config: Optional[DistConfig] = None, + ): self.factory = factory self.config = self._clean_config(config) if not dist.is_initialized(): local_rank, world_size = 0, 1 else: local_rank, world_size = dist_ad.get_rank_world_size() + self.shared_config = SharedConfig( - local_rank=local_rank, world_size=world_size, mapping=mapping + local_rank=local_rank, + world_size=world_size, + dist_config=dist_config, ) def _clean_config(self, config: InferenceOptimizerConfig) -> StrictInferenceOptimizerConfig: diff --git a/tensorrt_llm/_torch/auto_deploy/utils/dist_config.py b/tensorrt_llm/_torch/auto_deploy/utils/dist_config.py new file mode 100644 index 000000000000..4eed94c79f4e --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/utils/dist_config.py @@ -0,0 +1,152 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Self-contained distributed configuration for AutoDeploy sharding. + +``DistConfig`` replaces the dependency on ``tensorrt_llm.mapping.Mapping`` +within AutoDeploy. It carries the minimal set of parallelism parameters +needed by the sharding transforms and custom ops, plus serialization +support for graph-level metadata (e.g., MoE all-to-all dispatch). +""" + +import json +from typing import Any + +from pydantic import BaseModel, Field, model_validator + + +class DistConfig(BaseModel): + """Distributed parallelism configuration for AutoDeploy.""" + + model_config = {"extra": "allow"} + + world_size: int = Field(default=1, ge=1) + rank: int = Field(default=0, ge=0) + tp_size: int = Field(default=1, ge=1) + pp_size: int = Field(default=1, ge=1) + moe_tp_size: int = Field(default=1, ge=1) + moe_ep_size: int = Field(default=1, ge=1) + moe_cluster_size: int = Field(default=1, ge=1) + enable_attention_dp: bool = Field(default=False) + allreduce_strategy: str = Field(default="NCCL") + + @model_validator(mode="after") + def _validate_grid(self) -> "DistConfig": + if self.rank >= self.world_size: + raise ValueError(f"rank ({self.rank}) must be < world_size ({self.world_size})") + if self.tp_size > self.world_size: + raise ValueError(f"tp_size ({self.tp_size}) must be <= world_size ({self.world_size})") + moe_grid = self.moe_tp_size * self.moe_ep_size * self.moe_cluster_size + if moe_grid != self.tp_size: + raise ValueError( + f"moe_tp_size * moe_ep_size * moe_cluster_size ({moe_grid}) " + f"must equal tp_size ({self.tp_size})" + ) + return self + + @property + def tp_rank(self) -> int: + """Local rank within tensor parallelism (0 .. tp_size - 1).""" + return self.rank % self.tp_size + + @property + def pp_rank(self) -> int: + """Pipeline-parallel stage index for this process.""" + return self.rank // self.tp_size + + @property + def moe_tp_rank(self) -> int: + """MoE tensor-parallel rank within the MoE TP subgroup.""" + return self.tp_rank // (self.moe_ep_size * self.moe_cluster_size) + + @property + def moe_ep_rank(self) -> int: + """Expert-parallel rank derived from the tensor-parallel rank.""" + return self.tp_rank % self.moe_ep_size + + @property + def moe_cluster_rank(self) -> int: + """MoE cluster index derived from the tensor-parallel rank.""" + return self.tp_rank % self.moe_cluster_size + + def to_dict(self) -> dict: + """Return a plain dict of serializable DistConfig fields.""" + return { + "world_size": self.world_size, + "rank": self.rank, + "tp_size": self.tp_size, + "pp_size": self.pp_size, + "moe_tp_size": self.moe_tp_size, + "moe_ep_size": self.moe_ep_size, + "moe_cluster_size": self.moe_cluster_size, + "enable_attention_dp": self.enable_attention_dp, + "allreduce_strategy": self.allreduce_strategy, + } + + @classmethod + def from_dict(cls, d: dict) -> "DistConfig": + """Construct from a dict, ignoring keys that are not DistConfig fields.""" + known = {f for f in cls.model_fields} + filtered = {k: v for k, v in d.items() if k in known} + return cls(**filtered) + + def serialize(self) -> str: + """JSON string for this config (via ``to_dict``).""" + return json.dumps(self.to_dict()) + + @classmethod + def deserialize(cls, s: str) -> "DistConfig": + """Parse a JSON string into a ``DistConfig``.""" + return cls.from_dict(json.loads(s)) + + @staticmethod + def from_mapping(mapping: Any) -> "DistConfig": + """Construct from a ``tensorrt_llm.mapping.Mapping`` instance.""" + return DistConfig( + world_size=mapping.world_size, + rank=mapping.rank, + tp_size=mapping.tp_size, + pp_size=mapping.pp_size, + moe_tp_size=mapping.moe_tp_size, + moe_ep_size=mapping.moe_ep_size, + moe_cluster_size=mapping.moe_cluster_size, + enable_attention_dp=mapping.enable_attention_dp, + ) + + def to_mapping(self) -> Any: + """Convert back to a ``tensorrt_llm.mapping.Mapping`` for C++ op interop.""" + from tensorrt_llm.mapping import Mapping # will be deprecated by DistConfig + + return Mapping( + world_size=self.world_size, + rank=self.rank, + tp_size=self.tp_size, + pp_size=self.pp_size, + moe_tp_size=self.moe_tp_size, + moe_ep_size=self.moe_ep_size, + moe_cluster_size=self.moe_cluster_size, + enable_attention_dp=self.enable_attention_dp, + ) + + def print_grid(self) -> str: + """Human-readable summary of the TP / MoE parallelism grid.""" + return ( + f"process grid: [TP, MoE_TP, MoE_EP] = " + f"[{self.tp_size}, {self.moe_tp_size}, {self.moe_ep_size}]" + ) + + def print_rank(self) -> str: + """Human-readable summary of this process's rank assignments.""" + return f"rank: [{self.rank}, {self.moe_tp_rank}, {self.moe_ep_rank}]" diff --git a/tensorrt_llm/_torch/auto_deploy/utils/mapping_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/mapping_utils.py deleted file mode 100644 index ed771cceea27..000000000000 --- a/tensorrt_llm/_torch/auto_deploy/utils/mapping_utils.py +++ /dev/null @@ -1,34 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json - -from tensorrt_llm.mapping import Mapping - - -def deserialize_mapping(mapping_config: str) -> Mapping: - return Mapping.from_dict(json.loads(mapping_config)) - - -def serialize_mapping(mapping: Mapping) -> str: - return json.dumps(mapping.to_dict()) - - -def print_grid(mapping: Mapping) -> str: - return f"process grid: [TP, MoE_TP, MoE_EP] = [{mapping.tp_size}, {mapping.moe_tp_size}, {mapping.moe_ep_size}]" - - -def print_rank(mapping: Mapping) -> str: - return f"rank: [{mapping.rank}, {mapping.moe_tp_rank}, {mapping.moe_ep_rank}]" diff --git a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py index a64800782dbf..cb0b37e3f6bf 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py @@ -34,7 +34,7 @@ class LayerType(Enum): """Enum for layer type.""" - ATTENTION = "attention" + MHA = "mha" SSM = "ssm" MLP = "mlp" MOE = "moe" @@ -61,8 +61,9 @@ class WeightNode(BaseModel): class WeightNodes(BaseModel): - weights: list[WeightNode] - biases: list[WeightNode] + weights: list[WeightNode] = [] + biases: list[WeightNode] = [] + scales: list[WeightNode] = [] @dataclass @@ -274,142 +275,93 @@ def set_weight_shape(cls, node: Node, shape: Optional[List[int]]): cls._active_instance._weight_shape_cache[node] = shape -def extract_weight_nodes(node: Node) -> WeightNodes: - """Extracts the list of weight node and optional bias node from the given parametrized node""" - gm = node.graph.owning_module +def get_source_nodes( + node: Union[Node, List[Node]], + allowed_ops: Optional[set] = None, +) -> List[Node]: + """Walk backward through a computation chain and return all source (get_attr) nodes. - # Use cached param_names to avoid repeated expensive named_parameters/named_buffers calls - param_names = WeightBiasInfoCache.get_param_names(gm) + Args: + node: Starting node or list of starting nodes. + allowed_ops: If provided, only traverse through ``call_function`` nodes + whose ``target`` is in this set. Nodes with targets outside the set + act as traversal boundaries (their inputs are NOT explored). This + prevents cross-layer contamination through linear/conv/view ops when + searching for elementwise parameter chains (e.g., A_log -> exp -> neg). + When ``None``, all ``call_function`` nodes are traversed (original + behaviour). + + Warning: + Unconstrained traversal (``allowed_ops=None``) can explore the entire + backward-reachable subgraph from the starting node(s). This can be + expensive on large graphs and can cross layer boundaries through residual + connections. Callers should invoke this on specific parameter-bearing + argument nodes, not on full compute/activation nodes. + + Recommended alternatives: + * Use :func:`extract_weight_nodes` when the goal is specifically to find + weight/bias nodes for a parametrized op. + * Use the ``allowed_ops`` parameter when traversal should stay within a + narrow wrapper chain (e.g., elementwise parameter transforms like + ``exp``/``neg``). + """ + roots = [node] if isinstance(node, Node) else list(node) + result = [] + visited: set[Node] = set() + stack = list(roots) + while stack: + n = stack.pop() + if n in visited: + continue + visited.add(n) + if n.op == "get_attr": + result.append(n) + elif n.op == "call_function": + if allowed_ops is None or n.target in allowed_ops: + stack.extend(n.all_input_nodes) + return result + + +def _make_weight_node(attr_node: Node, gm: GraphModule) -> WeightNode: + """Construct a ``WeightNode`` from a ``get_attr`` FX node.""" + return WeightNode( + node=attr_node, + node_key=attr_node.target, + tensor=get_param_or_buffer(attr_node.target, gm), + submod=gm.get_submodule(attr_node.target.rpartition(".")[0]), + ) - def find_get_attr_node(weight_node: Node) -> Node: - """Recursively traverse inputs of allowed nodes to find a node with 'get_attr' op.""" - # If node is a get_attr node return node - # List of nodes allowed in between a get_attr node and the matmul node - allowed_ops = { - torch.ops.aten.to.dtype, - torch.ops.aten.view.default, - } - if ( - weight_node.op == "get_attr" - and weight_node.target in param_names - and has_shape(weight_node) - and len(shape(weight_node)) > 0 - ): - return weight_node +def extract_weight_nodes(node: Node) -> WeightNodes: + """Return the weight, bias, and scale ``get_attr`` nodes for a compute node. - # If node is not in the list of allowable ops then return None - if weight_node.target not in allowed_ops: - return None + Uses the precomputed forward-traversal mapping from + :func:`_precompute_weight_node_mapping`. The mapping is built lazily on + first call and cached on the ``GraphModule``. - for input_node in weight_node.all_input_nodes: - result = find_get_attr_node(input_node) - if result: - return result - return None + When *node* is itself a ``get_attr`` parameter node, it is classified and + returned directly (edge case for callers that pass a weight node instead of + a compute node). + """ + gm = node.graph.owning_module - if is_op(node, torch.ops.aten.bmm): - # no bias for bmm - weight_node = find_get_attr_node(node.args[1]) - return WeightNodes( - weights=[ - WeightNode( - node=node.args[1], - node_key=weight_node.target, - tensor=get_param_or_buffer(weight_node.target, gm), - submod=gm.get_submodule(weight_node.target.rpartition(".")[0]), - ) - ], - biases=[], - ) - elif is_fake_quantized_linear_op(node): - # For quantized linear ops (FP8, FP4, etc.), only args[1] is the actual shardable - # weight. Scale buffers (input_scale, weight_scale, alpha, ...) are also registered - # as get_attr nodes in the graph and would otherwise be picked up by the generic - # all_input_nodes scan below -- causing shard_weight_tensor to overwrite them as - # nn.Parameters, which then breaks quantization_cb's get_buffer() call. - # The quantization_cb (QuantizationShardingMixin) is responsible for sharding scales. - weight_node = find_get_attr_node(node.args[1]) - if weight_node is None: - return WeightNodes(weights=[], biases=[]) - biases = [] - if len(node.args) > 2 and isinstance(node.args[2], Node): - b = find_get_attr_node(node.args[2]) - if b is not None and b.target.rsplit(".", 1)[-1] == "bias": - biases = [ - WeightNode( - node=node.args[2], - node_key=b.target, - submod=gm.get_submodule(b.target.rpartition(".")[0]), - tensor=get_param_or_buffer(b.target, gm), - ) - ] - return WeightNodes( - weights=[ - WeightNode( - node=node.args[1], - node_key=weight_node.target, - submod=gm.get_submodule(weight_node.target.rpartition(".")[0]), - tensor=get_param_or_buffer(weight_node.target, gm), - ) - ], - biases=biases, - ) - elif is_weight_node(node): - weights = [] - biases = [] - - if node.target.rsplit(".", 1)[-1] == "bias": - biases = [ - WeightNode( - node=node, - node_key=node.target, - tensor=get_param_or_buffer(node.target, gm), - submod=gm.get_submodule(node.target.rpartition(".")[0]), - ) - ] + if is_weight_node(node): + wn = _make_weight_node(node, gm) + cat = _classify_weight_node(node) + if cat == "bias_nodes": + return WeightNodes(biases=[wn]) + elif cat == "scale_nodes": + return WeightNodes(scales=[wn]) else: - weights = [ - WeightNode( - node=node, - node_key=node.target, - tensor=get_param_or_buffer(node.target, gm), - submod=gm.get_submodule(node.target.rpartition(".")[0]), - ) - ] - return WeightNodes( - weights=weights, - biases=biases, - ) - # for other parametrized nodes, we need to find the weight node - else: - all_weight_nodes = [ - attr_node - for n in node.all_input_nodes - if (attr_node := find_get_attr_node(n)) is not None - ] - # separate weight nodes and bias nodes - bias_nodes = [n for n in all_weight_nodes if n.target.rsplit(".", 1)[-1] == "bias"] - weight_nodes = [n for n in all_weight_nodes if n not in bias_nodes] - weight_nodes = [ - WeightNode( - node=n, - node_key=n.target, - submod=gm.get_submodule(n.target.rpartition(".")[0]), - tensor=get_param_or_buffer(n.target, gm), - ) - for n in weight_nodes - ] - bias_nodes = [ - WeightNode( - node=n, - node_key=n.target, - submod=gm.get_submodule(n.target.rpartition(".")[0]), - tensor=get_param_or_buffer(n.target, gm), - ) - for n in bias_nodes - ] - return WeightNodes(weights=weight_nodes, biases=bias_nodes) + return WeightNodes(weights=[wn]) + + _precompute_weight_node_mapping(gm) + + return WeightNodes( + weights=[_make_weight_node(n, gm) for n in node.meta.get("weight_nodes", [])], + biases=[_make_weight_node(n, gm) for n in node.meta.get("bias_nodes", [])], + scales=[_make_weight_node(n, gm) for n in node.meta.get("scale_nodes", [])], + ) def get_weight_node(node: Node) -> Node: @@ -644,6 +596,16 @@ def is_fp4_op(node: Node) -> bool: ) +def is_finegrained_fp8_linear_op(node: Node) -> bool: + return is_op( + node, + [ + torch.ops.auto_deploy.torch_fake_quant_finegrained_fp8_linear, + torch.ops.auto_deploy.trtllm_finegrained_fp8_linear, + ], + ) + + def is_any_moe_op(node: Node) -> bool: return is_op( node, @@ -653,6 +615,8 @@ def is_any_moe_op(node: Node) -> bool: torch.ops.auto_deploy.torch_quant_nvfp4_moe, torch.ops.auto_deploy.torch_quant_finegrained_fp8_moe, torch.ops.auto_deploy.triton_mxfp4_moe, + torch.ops.auto_deploy.torch_moe_fused, + torch.ops.auto_deploy.torch_moe_dense_mlp, ], ) @@ -712,6 +676,30 @@ def is_any_mla_op(node: Node) -> bool: ) +def is_any_view_op(node: Node) -> bool: + """Check if the node is a view/reshape op (aten or auto_deploy variant).""" + return is_op( + node, + [ + torch.ops.aten.view, + torch.ops.aten.reshape, + torch.ops.auto_deploy.view, + ], + ) + + +def is_any_split_op(node: Node) -> bool: + """Check if the node is a split/split_with_sizes op (aten or auto_deploy variant).""" + return is_op( + node, + [ + torch.ops.aten.split, + torch.ops.aten.split_with_sizes, + torch.ops.auto_deploy.split_with_sizes, + ], + ) + + def is_linear_op(node: Node) -> bool: """Check if the node is a linear op. @@ -758,87 +746,99 @@ def is_weight_node(node: Node) -> bool: return node.op == "get_attr" and node.target and has_shape(node) and len(shape(node)) > 0 -# Auxiliary ops that may appear between a weight node and its consumer compute node -_WEIGHT_AUX_OPS = frozenset( - { - torch.ops.aten.to.dtype, - torch.ops.aten.view.default, - } -) +def _classify_weight_node(node: Node) -> str: + """Classify a weight get_attr node by the last segment of its target path. + Returns the metadata key to store this node under on its consumer: + ``"bias_nodes"`` if the attribute name is exactly ``"bias"``, + ``"scale_nodes"`` if it contains ``"scale"``, otherwise ``"weight_nodes"``. + """ + attr_name = node.target.rsplit(".", 1)[-1] + if attr_name in ("bias", "alpha"): + return "bias_nodes" + if "scale" in attr_name: + return "scale_nodes" + return "weight_nodes" -def precompute_weight_node_mapping(gm: GraphModule) -> None: + +def invalidate_weight_node_cache(gm: GraphModule) -> None: + """Clear the cached weight-to-consumer mapping so it is rebuilt on next access. + + Call this at the start of any transform that mutates the graph (adds/removes + nodes, replaces weight tensors) and later needs ``extract_weight_nodes`` to + reflect the mutated state. """ - Pre-compute weight-to-consumer mapping for all weight nodes in the graph. + gm.meta.pop("_weight_mapping_computed", None) + + +def _precompute_weight_node_mapping(gm: GraphModule) -> None: + """Pre-compute weight-to-consumer mapping for all parameter/buffer nodes. - For each weight node (get_attr), finds the consumer compute node by traversing - through auxiliary ops (to.dtype, view.default). Stores the mapping in consumer - node's metadata: - - node.meta["weight_nodes"]: list of weight nodes (non-bias) - - node.meta["bias_nodes"]: list of bias nodes + For each ``get_attr`` node that is a registered parameter or buffer, + traverses forward through **unary ops** (nodes with at most one input) + until reaching a **multi-input consumer** (a node with >1 input nodes, + e.g. a linear or SSM op that combines the weight with activations). + Every node along the chain -- including the terminal consumer -- gets + tagged in its metadata: - This enables O(1) weight node lookup instead of O(depth) backward traversal. - Called automatically on first weight lookup via lazy initialization. + - ``node.meta["weight_nodes"]``: list of weight ``get_attr`` nodes + - ``node.meta["bias_nodes"]``: list of bias ``get_attr`` nodes + - ``node.meta["scale_nodes"]``: list of scale ``get_attr`` nodes - GUARANTEES (verified by assertions for debugging): - - Called exactly once per GraphModule - - No duplicate weight/bias nodes in any consumer's lists - - Each weight node mapped to exactly one consumer + The traversal naturally follows parameter preprocessing chains + (``exp``, ``neg``, ``float()``, ``view``, etc.) without maintaining a + fragile allowlist of passthrough ops. The chain terminates at nodes + with multiple input nodes (the actual consumers), not at nodes with + multiple output users. + + Classification uses the last segment of the parameter path + (``node.target``): exactly ``"bias"`` -> bias, contains ``"scale"`` -> + scale, everything else -> weight. """ - # Early return if already computed if "_weight_mapping_computed" in gm.meta and gm.meta["_weight_mapping_computed"]: return gm.meta["_weight_mapping_computed"] = True + # Clear stale metadata from previous runs before rebuilding for node in gm.graph.nodes: - if not is_weight_node(node): - continue + node.meta.pop("weight_nodes", None) + node.meta.pop("bias_nodes", None) + node.meta.pop("scale_nodes", None) - is_bias = node.target.rsplit(".", 1)[-1] == "bias" + param_names = WeightBiasInfoCache.get_param_names(gm) + + for node in gm.graph.nodes: + if not is_weight_node(node) or node.target not in param_names: + continue - # the weight to user mapping is reflective - the weight node "owns" itself - node.meta["weight_nodes"] = [node] + category = _classify_weight_node(node) - # Find the consumer compute node by traversing through auxiliary ops + # Forward-traverse through unary ops, tagging every node along the + # way (intermediates like exp, neg, to.dtype AND the terminal consumer). + # Stops at multi-input nodes (the actual consumer) or dead ends. current = node - visited = {current} - while True: - # Get users of current node - users = list(current.users.keys()) - if not users: + if category not in current.meta: + current.meta[category] = [] + current.meta[category].append(node) + if len(current.all_input_nodes) > 1: break - - aux_node = None - - for user in users: - if is_bias: - if "bias_nodes" not in user.meta: - user.meta["bias_nodes"] = [] - # ASSERTION: Each weight node should be mapped exactly once - assert node not in user.meta["bias_nodes"], ( - f"Duplicate bias node {node.name} found for consumer {user.name}" - ) - user.meta["bias_nodes"].append(node) - else: - if "weight_nodes" not in user.meta: - user.meta["weight_nodes"] = [] - # ASSERTION: Each weight node should be mapped exactly once - assert node not in user.meta["weight_nodes"], ( - f"Duplicate weight node {node.name} found for consumer {user.name}" - ) - user.meta["weight_nodes"].append(node) - if user.target in _WEIGHT_AUX_OPS: - # This is an auxiliary op, continue traversing - aux_node = user - - if aux_node is not None and aux_node not in visited: - # Continue through auxiliary op - current = aux_node - visited.add(current) - else: - # No more nodes to traverse + if len(current.users) == 0: + ad_logger.debug( + f"Weight node {node.name} has no downstream consumer " + f"(chain ended at {current.name})" + ) break + current = next(iter(current.users)) + + # If the chain could not advance past the get_attr node itself + # (e.g. get_attr has 0 users, or get_attr directly feeds a + # multi-input consumer), tag each direct user as a fallback. + if current == node: + for user in node.users: + if category not in user.meta: + user.meta[category] = [] + user.meta[category].append(node) def get_user_if_pattern_match(node, ops, numusers, user_idx: int = 0): @@ -881,16 +881,22 @@ def identify_regions_between_residuals(gm: GraphModule) -> List[Node]: boundary_nodes = [input_id_node] # find embedding node which we assume to be the first node in a sequence of residual nodes - for n_user in input_id_node.users: - if is_op(n_user, torch.ops.aten.embedding): - break + # NOTE: Qwen's first node is strange: it's name is "inputs_embeds", there is no torch.ops.aten.embedding + # in the graph, and this input_id_node op is "placeholder". Nevertheless, it serves as a proper + # hook for residual identification. + if input_id_node.name == "inputs_embeds": + boundary_nodes.append(input_id_node) else: - # we could not identify any boundary regions via embedding nodes - boundary_nodes.append(output_node) - return boundary_nodes + for n_user in input_id_node.users: + if is_op(n_user, torch.ops.aten.embedding): + break + else: + # we could not identify any boundary regions via embedding nodes + boundary_nodes.append(output_node) + return boundary_nodes - # add embedding node to boundary nodes - boundary_nodes.append(n_user) + # add embedding node to boundary nodes + boundary_nodes.append(n_user) # find residual nodes from here on while True: @@ -955,7 +961,7 @@ def get_all_layer_subgraphs( residuals = identify_regions_between_residuals(gm) # Pre-compute weight-to-consumer mapping for O(1) weight node lookup - precompute_weight_node_mapping(gm) + _precompute_weight_node_mapping(gm) # Cache weight shapes for all linear nodes for lin_node in linear_nodes: @@ -1083,6 +1089,8 @@ def extract_op_args(node: Node, *arg_names): args = list(node.args) kwargs = node.kwargs or {} + _MISSING = object() + def _get(name): if name in kwargs: return kwargs[name] @@ -1091,9 +1099,12 @@ def _get(name): return args[i] if name in defs: return defs[name] + if name not in pos: + return _MISSING raise RuntimeError(f"Could not find a value for '{name}' on op {node.target}") - return [_get(n) for n in arg_names] + result = [_get(n) for n in arg_names] + return [None if v is _MISSING else v for v in result] def set_op_args(node: Node, **name_value_pairs) -> None: @@ -1392,6 +1403,13 @@ def filter_condition(node: Node, dim: int) -> bool: sources=[linear_nodes[start_lin_index]], boundary_condition=lambda n: boundary_condition(n, dim=0), ) + if "layers_39_mlp_shared_expert_gate_torch_linear_simple_389" in [ + n.name for n in forward_subgraph + ]: + forward_subgraph = subgraph( + sources=[linear_nodes[start_lin_index]], + boundary_condition=lambda n: boundary_condition(n, dim=0), + ) lin_nodes_in_subgraph = list( filtered_nodes(forward_subgraph, lambda n: filter_condition(n, dim=0)) ) @@ -1488,7 +1506,7 @@ def classify_layer_type() -> [LayerType, int]: head_size = shape(attention_nodes[0])[-1] if len(intermediate_lin_nodes) > 0: return LayerType.UNKNOWN, 1 - return LayerType.ATTENTION, head_size + return LayerType.MHA, head_size if len(ssm_nodes) == 1: head_size = shape(ssm_nodes[0])[-1] @@ -1532,9 +1550,10 @@ def classify_layer_type() -> [LayerType, int]: min_local_shape=head_size, ) assert linear_nodes[start_lin_index] in opening_linear_nodes, ( - f"Linear node not found in opening linear nodes - " - f"terminating_linear_node:{terminating_linear_node.name}, " + f"Start linear node (index {start_lin_index}) not found in opening linear nodes - " + f"start_linear node: {linear_nodes[start_lin_index].name}, " f"opening_linear_nodes: {[n.name for n in opening_linear_nodes]}" + f"terminating_linear_node:{terminating_linear_node.name}, " ) # return the index of the terminating linear node diff --git a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py index 2a65883fe9ed..0ffb9dc32259 100644 --- a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py +++ b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py @@ -1136,3 +1136,132 @@ def test_autodeploy_from_registry(self, model_name, config_overrides, tasks, task.evaluate(llm, **evaluate_kwargs) except (AssertionError, RuntimeError, ValueError) as e: raise type(e)(f"[{task_cls.__name__}] {e}") from None + + +# ============================================================================= +# IR Sharding Path Tests +# ============================================================================= + +_IR_SHARDING_TRANSFORMS = { + "detect_sharding": { + "enabled": False, + }, + "sharding_transform_executor": { + "enabled": False, + }, + "apply_sharding_hints": { + "enabled": True, + "stage": "sharding", + "run_shape_prop": True, + "allreduce_strategy": "SYMM_MEM", + }, +} + + +class TestNemotronSuperV3_IR(LlmapiAccuracyTestHarness): + """Accuracy tests for Nemotron-Super using the IR sharding path. + + Uses ``apply_sharding_hints`` with sharding-aware IR modeling code + instead of the legacy ``detect_sharding`` + heuristic path. + """ + + MODEL_NAME = "nvidia/Nemotron-Super-V3" + CONFIG_YAML = str( + Path(get_llm_root()) / "examples" / "auto_deploy" / "super_v3.yaml") + MODEL_PATHS = { + "fp8": + hf_id_to_local_model_dir( + "nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-FP8"), + } + + def get_default_sampling_params(self): + eos_id = -1 + return SamplingParams(end_id=eos_id, pad_id=eos_id) + + @pytest.mark.skip_less_device_memory(65000) + @pytest.mark.parametrize("world_size", [4, 8]) + @pytest.mark.parametrize("model_id", ["fp8"]) + def test_ir_accuracy(self, model_id, world_size, monkeypatch): + if get_device_count() < world_size: + pytest.skip(f"Not enough devices for world_size={world_size}") + + monkeypatch.setenv("AD_USE_IR_MODELS", "1") + + model_path = self.MODEL_PATHS[model_id] + transforms = dict(_IR_SHARDING_TRANSFORMS) + transforms["apply_sharding_hints"]["dist_mapping"] = { + "tp": world_size, + "moe_ep": world_size, + } + transforms["insert_cached_ssm_attention"] = {"backend": "triton_ssm"} + kwargs = { + "attn_backend": "flashinfer", + "transforms": transforms, + } + + with AutoDeployLLM(model=model_path, + tokenizer=model_path, + world_size=world_size, + yaml_extra=[self.CONFIG_YAML], + trust_remote_code=True, + **kwargs) as llm: + _set_quant_config(llm, model_id) + + sampling_params = self.get_default_sampling_params() + task = MMLU(self.MODEL_NAME) + task.evaluate(llm, sampling_params=sampling_params) + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + + +class TestQwen3_5_MoE_IR(LlmapiAccuracyTestHarness): + """Accuracy tests for Qwen3.5 MoE using the IR sharding path. + + Uses ``apply_sharding_hints`` with sharding-aware IR modeling code + instead of the legacy ``detect_sharding`` + heuristic path. + """ + + MODEL_NAME = "Qwen/Qwen3.5-35B-A3B" + CONFIG_YAML = str(_AD_CONFIGS_DIR / "qwen3.5_moe_35b.yaml") + EXTRA_EVALUATOR_KWARGS = dict(chat_template_kwargs=dict( + enable_thinking=False)) + + def get_default_sampling_params(self): + eos_id = -1 + return SamplingParams(end_id=eos_id, pad_id=eos_id) + + @pytest.mark.skip_less_device_memory(32000) + @pytest.mark.parametrize("world_size", [4]) + @pytest.mark.parametrize("model_id", ["fp8"]) + def test_ir_accuracy(self, model_id, world_size, monkeypatch): + if get_device_count() < world_size: + pytest.skip(f"Not enough devices for world_size={world_size}") + + monkeypatch.setenv("AD_USE_IR_MODELS", "1") + monkeypatch.setenv("TRTLLM_ACCURACY_NO_REFERENCE", "1") + + model_path = hf_id_to_local_model_dir("Qwen/Qwen3.5-35B-A3B-FP8") + transforms = dict(_IR_SHARDING_TRANSFORMS) + transforms["apply_sharding_hints"]["dist_mapping"] = { + "tp": world_size, + "moe_ep": world_size, + } + kwargs = { + "attn_backend": "flashinfer", + "transforms": transforms, + } + + with AutoDeployLLM(model=model_path, + tokenizer=model_path, + world_size=world_size, + yaml_extra=[self.CONFIG_YAML], + skip_tokenizer_init=False, + trust_remote_code=True, + **kwargs) as llm: + _set_quant_config(llm, model_id) + + sampling_params = self.get_default_sampling_params() + task = MMLU(self.MODEL_NAME) + task.evaluate(llm, sampling_params=sampling_params) + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) diff --git a/tests/unittest/auto_deploy/multigpu/transformations/library/test_apply_sharding_hints.py b/tests/unittest/auto_deploy/multigpu/transformations/library/test_apply_sharding_hints.py new file mode 100644 index 000000000000..c15b4bbea53d --- /dev/null +++ b/tests/unittest/auto_deploy/multigpu/transformations/library/test_apply_sharding_hints.py @@ -0,0 +1,143 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for ``apply_sharding_hints`` (hint-driven TP sharding). + +``test_sharding`` — multi-GPU end-to-end: exports, transforms, and validates + output correctness on real GPUs via ``run_test_transformed_gm``. +``test_apply_hints`` — single-process transform check: verifies graph + rewriting (weight shapes, all_reduce replacement, skip conditions) + without distributed execution. +""" + +import pytest +import torch +import torch.nn as nn +from _dist_test_utils import get_device_counts +from _graph_test_helpers import run_test_transformed_gm + +import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401 +import tensorrt_llm._torch.auto_deploy.distributed.common as dist_common +import tensorrt_llm._torch.auto_deploy.transform.library # noqa: F401 +from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm +from tensorrt_llm._torch.auto_deploy.transform.interface import SharedConfig +from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer +from tensorrt_llm._torch.auto_deploy.utils.dist_config import DistConfig +from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op + +pytestmark = pytest.mark.threadleak(enabled=False) + +FEATURES, HIDDEN = 32, 64 + + +class HintedMLP(nn.Module): + def __init__(self, features=FEATURES, hidden=HIDDEN): + super().__init__() + self.up = nn.Linear(features, hidden, bias=False) + self.down = nn.Linear(hidden, features, bias=False) + + def forward(self, x): + h = torch.ops.auto_deploy.torch_linear_simple(x, self.up.weight, None, tp_mode="colwise") + h = torch.relu(h) + h = torch.ops.auto_deploy.torch_linear_simple(h, self.down.weight, None, tp_mode="rowwise") + h = torch.ops.auto_deploy.all_reduce(h) + return h + + +# --------------------------------------------------------------------------- +# test_sharding — multi-GPU end-to-end (follows test_tp_sharding.py pattern) +# --------------------------------------------------------------------------- + + +def _run_sharding_job(rank: int, world_size: int) -> None: + model = HintedMLP().to(device="cuda", dtype=torch.float16) + x = torch.randn(4, 8, FEATURES, device="cuda", dtype=torch.float16) + + gm = torch_export_to_gm(model, args=(x,), clone=True) + gm_transformed = InferenceOptimizer(None, {"apply_sharding_hints": {"stage": "sharding"}})( + None, gm + ) + + op_ar = torch.ops.auto_deploy.torch_dist_all_reduce + + def check_transformed_graph(gm_mod) -> bool: + has_dist = any(is_op(n, op_ar) for n in gm_mod.graph.nodes) + return has_dist == (world_size > 1) + + run_test_transformed_gm( + model, + x, + gm_transformed, + check_transformed_graph=check_transformed_graph, + _get_expected_num_params=lambda n: n // world_size, + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +@pytest.mark.parametrize("world_size", get_device_counts([2])) +def test_sharding(world_size: int): + """Hint-based colwise/rowwise + all_reduce: end-to-end on real GPUs.""" + dist_common.spawn_multiprocess_job(job=_run_sharding_job, size=world_size) + + +# --------------------------------------------------------------------------- +# test_apply_hints — single-process transform checks (no distributed exec) +# --------------------------------------------------------------------------- + + +def _make_optimizer(world_size: int, rank: int = 0): + opt = InferenceOptimizer( + factory=None, + config={"apply_sharding_hints": {"stage": "sharding"}}, + ) + opt.shared_config = SharedConfig( + local_rank=rank, + world_size=world_size, + dist_config=DistConfig( + world_size=world_size, rank=rank, tp_size=world_size, moe_ep_size=world_size + ), + ) + return opt + + +def _export_hinted_mlp(): + model = HintedMLP().cuda() + x = torch.randn(2, FEATURES, device="cuda") + return torch_export_to_gm(model, args=(x,), clone=True), model, x + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +@pytest.mark.parametrize( + "world_size, expect_skipped, expect_up_shape, expect_down_shape", + [ + (1, True, (HIDDEN, FEATURES), (FEATURES, HIDDEN)), + (2, False, (HIDDEN // 2, FEATURES), (FEATURES, HIDDEN // 2)), + ], +) +def test_apply_hints(world_size, expect_skipped, expect_up_shape, expect_down_shape): + """Verify graph rewriting without distributed execution.""" + gm, _, _ = _export_hinted_mlp() + gm_out = _make_optimizer(world_size)(None, gm) + + info = gm_out.meta["_autodeploy"]["transform_history"]["apply_sharding_hints"] + assert info.skipped is expect_skipped + + assert gm_out.up.weight.shape == expect_up_shape + assert gm_out.down.weight.shape == expect_down_shape + + has_dist_ar = any( + is_op(n, torch.ops.auto_deploy.torch_dist_all_reduce.default) for n in gm_out.graph.nodes + ) + assert has_dist_ar == (not expect_skipped) diff --git a/tests/unittest/auto_deploy/multigpu/transformations/library/test_ep_sharding.py b/tests/unittest/auto_deploy/multigpu/transformations/library/test_ep_sharding.py index c6b1a340169a..4c0c2a3201cf 100644 --- a/tests/unittest/auto_deploy/multigpu/transformations/library/test_ep_sharding.py +++ b/tests/unittest/auto_deploy/multigpu/transformations/library/test_ep_sharding.py @@ -18,7 +18,7 @@ ) from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer from tensorrt_llm._torch.auto_deploy.utils._graph import lint, recompile -from tensorrt_llm._torch.auto_deploy.utils.mapping_utils import deserialize_mapping +from tensorrt_llm._torch.auto_deploy.utils.dist_config import DistConfig from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op from tensorrt_llm.functional import AllReduceStrategy @@ -73,9 +73,9 @@ def transform_check(gm): assert "mapping_config" in moe_node.kwargs, ( f"Mapping config not found in MoE node {moe_node.name}" ) - # deserialize the mapping config string to dict - mapping_config = deserialize_mapping(moe_node.kwargs["mapping_config"]) - return mapping_config.enable_attention_dp + # deserialize the mapping config string to DistConfig + dc = DistConfig.deserialize(moe_node.kwargs["mapping_config"]) + return dc.enable_attention_dp else: def transform_check(gm): diff --git a/tests/unittest/auto_deploy/multigpu/transformations/library/test_tp_sharding.py b/tests/unittest/auto_deploy/multigpu/transformations/library/test_tp_sharding.py index f1667be8e2cb..d92b8072e83b 100644 --- a/tests/unittest/auto_deploy/multigpu/transformations/library/test_tp_sharding.py +++ b/tests/unittest/auto_deploy/multigpu/transformations/library/test_tp_sharding.py @@ -885,7 +885,7 @@ def _run_pattern_detection_job( config=config, dist_op=dist_op, min_local_shape=min_local_shape, - layer_type=LayerType.ATTENTION, + layer_type=LayerType.MHA, ) ) elif model_cls == MLP: diff --git a/tests/unittest/auto_deploy/singlegpu/custom_ops/test_sharding_ops.py b/tests/unittest/auto_deploy/singlegpu/custom_ops/test_sharding_ops.py new file mode 100644 index 000000000000..b275ee182842 --- /dev/null +++ b/tests/unittest/auto_deploy/singlegpu/custom_ops/test_sharding_ops.py @@ -0,0 +1,95 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for sharding hint custom ops in ``sharding_ops.py``.""" + +import pytest +import torch + +import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401 — register ops + +pytestmark = pytest.mark.skipif( + not torch.cuda.is_available(), reason="CUDA required for auto_deploy sharding ops" +) + + +def test_view_matches_reshape(): + x = torch.randn(4, 8, device="cuda") + shape = [32] + out = torch.ops.auto_deploy.view(x, shape) + ref = x.reshape(shape).clone() + torch.testing.assert_close(out, ref) + + +def test_view_tp_scaled_dim_passthrough(): + x = torch.randn(4, 8, device="cuda") + shape = [2, 16] + default = torch.ops.auto_deploy.view(x, shape) + explicit_neg1 = torch.ops.auto_deploy.view(x, shape, tp_scaled_dim=-1) + nonzero_dim = torch.ops.auto_deploy.view(x, shape, tp_scaled_dim=0) + torch.testing.assert_close(default, explicit_neg1) + torch.testing.assert_close(default, nonzero_dim) + + +def test_view_accepts_layer_type(): + x = torch.randn(4, 8, device="cuda") + shape = [32] + out = torch.ops.auto_deploy.view(x, shape, layer_type="mha") + ref = x.reshape(shape).clone() + torch.testing.assert_close(out, ref) + + +def test_split_matches_torch_split(): + x = torch.randn(4, 8, device="cuda") + split_sizes = [3, 5] + dim = -1 + out = torch.ops.auto_deploy.split_with_sizes(x, split_sizes, dim) + ref = list(torch.split(x, split_sizes, dim=dim)) + assert len(out) == len(ref) + for a, b in zip(out, ref): + torch.testing.assert_close(a, b) + + +def test_split_enable_sharding_flag(): + x = torch.randn(4, 8, device="cuda") + split_sizes = [2, 2, 4] + dim = -1 + a = torch.ops.auto_deploy.split_with_sizes(x, split_sizes, dim, enable_sharding=False) + b = torch.ops.auto_deploy.split_with_sizes(x, split_sizes, dim, enable_sharding=True) + assert len(a) == len(b) + for u, v in zip(a, b): + torch.testing.assert_close(u, v) + + +def test_split_accepts_layer_type(): + x = torch.randn(4, 8, device="cuda") + split_sizes = [4, 4] + out = torch.ops.auto_deploy.split_with_sizes(x, split_sizes, dim=-1, layer_type="mlp") + ref = list(torch.split(x, split_sizes, dim=-1)) + assert len(out) == len(ref) + for a, b in zip(out, ref): + torch.testing.assert_close(a, b) + + +def test_all_reduce_is_identity(): + x = torch.randn(4, 8, device="cuda") + out = torch.ops.auto_deploy.all_reduce(x) + torch.testing.assert_close(out, x.clone()) + + +def test_all_reduce_accepts_layer_type(): + x = torch.randn(4, 8, device="cuda") + out = torch.ops.auto_deploy.all_reduce(x, layer_type="mha") + torch.testing.assert_close(out, x.clone()) diff --git a/tests/unittest/auto_deploy/singlegpu/utils/test_dist_config.py b/tests/unittest/auto_deploy/singlegpu/utils/test_dist_config.py new file mode 100644 index 000000000000..ed09413f4435 --- /dev/null +++ b/tests/unittest/auto_deploy/singlegpu/utils/test_dist_config.py @@ -0,0 +1,115 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for ``DistConfig`` (CPU-only).""" + +from tensorrt_llm._torch.auto_deploy.utils.dist_config import DistConfig +from tensorrt_llm.mapping import Mapping + + +def test_defaults(): + cfg = DistConfig() + assert cfg.world_size == 1 + assert cfg.rank == 0 + assert cfg.tp_size == 1 + assert cfg.pp_size == 1 + assert cfg.moe_tp_size == 1 + assert cfg.moe_ep_size == 1 + assert cfg.moe_cluster_size == 1 + assert cfg.enable_attention_dp is False + assert cfg.allreduce_strategy == "NCCL" + + +def test_serialize_deserialize_roundtrip(): + original = DistConfig( + world_size=16, + rank=7, + tp_size=4, + pp_size=2, + moe_tp_size=2, + moe_ep_size=2, + moe_cluster_size=1, + enable_attention_dp=True, + allreduce_strategy="CUSTOM", + ) + restored = DistConfig.deserialize(original.serialize()) + assert restored == original + + +def test_from_dict_ignores_unknown_keys(): + cfg = DistConfig.from_dict( + { + "world_size": 4, + "tp_size": 4, + "moe_ep_size": 4, + "not_a_real_field": "ignore_me", + "another_extra": 123, + } + ) + assert cfg.world_size == 4 + assert cfg.rank == 0 + + +def test_from_mapping_to_mapping_roundtrip(): + m = Mapping( + world_size=8, + rank=5, + tp_size=4, + pp_size=2, + moe_tp_size=2, + moe_ep_size=2, + moe_cluster_size=1, + enable_attention_dp=True, + ) + dist = DistConfig.from_mapping(m) + m2 = dist.to_mapping() + assert m2.world_size == m.world_size + assert m2.rank == m.rank + assert m2.tp_size == m.tp_size + assert m2.pp_size == m.pp_size + assert m2.moe_tp_size == m.moe_tp_size + assert m2.moe_ep_size == m.moe_ep_size + assert m2.moe_cluster_size == m.moe_cluster_size + assert m2.enable_attention_dp == m.enable_attention_dp + + +def test_tp_rank_property(): + assert DistConfig(world_size=8, rank=3, tp_size=4, moe_ep_size=4).tp_rank == 3 + assert DistConfig(world_size=8, rank=5, tp_size=4, moe_ep_size=4).tp_rank == 1 + + +def test_moe_ep_rank_property(): + cfg = DistConfig(world_size=8, rank=3, tp_size=8, moe_ep_size=4, moe_tp_size=2) + assert cfg.tp_rank == 3 + assert cfg.moe_ep_rank == 3 + assert cfg.moe_ep_rank == cfg.tp_rank % cfg.moe_ep_size + + +def test_allreduce_strategy_default(): + assert DistConfig().allreduce_strategy == "NCCL" + + +def test_invalid_rank_raises(): + import pytest + + with pytest.raises(ValueError, match="rank.*must be < world_size"): + DistConfig(world_size=4, rank=5, tp_size=4, moe_ep_size=4) + + +def test_invalid_moe_grid_raises(): + import pytest + + with pytest.raises(ValueError, match="moe_tp_size.*must equal tp_size"): + DistConfig(world_size=8, rank=0, tp_size=8, moe_tp_size=2, moe_ep_size=2) diff --git a/tests/unittest/auto_deploy/singlegpu/utils/test_node_utils_sharding.py b/tests/unittest/auto_deploy/singlegpu/utils/test_node_utils_sharding.py new file mode 100644 index 000000000000..502d9ac6d872 --- /dev/null +++ b/tests/unittest/auto_deploy/singlegpu/utils/test_node_utils_sharding.py @@ -0,0 +1,168 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CPU-only tests for sharding-related FX node predicates in ``node_utils``.""" + +import operator + +import torch +import torch.fx as fx +import torch.nn as nn +from torch.fx import GraphModule + +import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401 — register custom ops +from tensorrt_llm._torch.auto_deploy.transform.library.sharding_ir import ShardableNode +from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_any_split_op, is_any_view_op + + +def _call_function_nodes(gm: GraphModule): + return [n for n in gm.graph.nodes if n.op == "call_function"] + + +def test_is_any_view_op_aten_view(): + class ViewModel(nn.Module): + def forward(self, x): + return x.view(2, 4) + + # ``symbolic_trace`` records ``Tensor.view`` as ``call_method``; ``torch.export`` lowers to + # ``torch.ops.aten.view.default``, which ``is_any_view_op`` matches. + exported = torch.export.export(ViewModel(), (torch.randn(8),)) + gm = exported.module() + assert any(n.target == torch.ops.aten.view.default for n in _call_function_nodes(gm)) + assert any(is_any_view_op(n) for n in _call_function_nodes(gm)), ( + f"Expected aten view in graph, got targets: {[n.target for n in _call_function_nodes(gm)]}" + ) + + +def test_is_any_view_op_auto_deploy(): + graph = fx.Graph() + x = graph.placeholder("x") + out = graph.call_function( + torch.ops.auto_deploy.view.default, + args=(x, [2, 4]), + kwargs={"tp_scaled_dim": -1, "layer_type": "unknown"}, + ) + graph.output(out) + gm = GraphModule(nn.Module(), graph) + view_nodes = [n for n in _call_function_nodes(gm) if is_any_view_op(n)] + assert len(view_nodes) == 1 + assert is_any_view_op(view_nodes[0]) + + +def test_is_any_view_op_negative(): + class AtenLinearOnly(nn.Module): + def __init__(self): + super().__init__() + self.w = nn.Parameter(torch.randn(4, 8)) + self.b = nn.Parameter(torch.randn(4)) + + def forward(self, x): + return torch.ops.aten.linear.default(x, self.w, self.b) + + gm = torch.fx.symbolic_trace(AtenLinearOnly()) + assert not any(is_any_view_op(n) for n in _call_function_nodes(gm)), ( + f"Unexpected view op in linear-only graph: {[n.target for n in _call_function_nodes(gm)]}" + ) + + +def test_is_any_split_op_aten(): + class SplitModel(nn.Module): + def forward(self, x): + a, b = torch.split(x, [2, 2], dim=-1) + return a + b + + exported = torch.export.export(SplitModel(), (torch.randn(2, 4),)) + gm = exported.module() + assert any( + n.target == torch.ops.aten.split_with_sizes.default for n in _call_function_nodes(gm) + ) + assert any(is_any_split_op(n) for n in _call_function_nodes(gm)), ( + f"Expected split op in graph, got: {[n.target for n in _call_function_nodes(gm)]}" + ) + + +def test_is_any_split_op_auto_deploy(): + graph = fx.Graph() + x = graph.placeholder("x") + splits = graph.call_function( + torch.ops.auto_deploy.split_with_sizes.default, + args=(x, [2, 2], -1), + kwargs={"enable_sharding": False, "layer_type": "unknown"}, + ) + first = graph.call_function(operator.getitem, args=(splits, 0)) + graph.output(first) + gm = GraphModule(nn.Module(), graph) + split_nodes = [n for n in _call_function_nodes(gm) if is_any_split_op(n)] + assert len(split_nodes) == 1 + assert is_any_split_op(split_nodes[0]) + + +def _minimal_graph_module_for_enable_sharding_linear(): + class Shell(nn.Module): + def __init__(self): + super().__init__() + self.w = nn.Parameter(torch.randn(4, 8)) + self.b = nn.Parameter(torch.randn(4)) + + root = Shell() + graph = fx.Graph() + x = graph.placeholder("x") + w = graph.get_attr("w") + b = graph.get_attr("b") + lin = graph.call_function( + torch.ops.auto_deploy.torch_linear_simple.default, + args=(x, w, b, "none", None, 1, "unknown"), + kwargs={}, + ) + graph.output(lin) + return GraphModule(root, graph) + + +def test_enable_sharding_node_linear(): + gm = _minimal_graph_module_for_enable_sharding_linear() + lin_nodes = [n for n in _call_function_nodes(gm) if ShardableNode.from_node(n) is not None] + assert len(lin_nodes) == 1 + assert ShardableNode.from_node(lin_nodes[0]) is not None + + +def test_enable_sharding_node_view(): + graph = fx.Graph() + x = graph.placeholder("x") + out = graph.call_function( + torch.ops.auto_deploy.view.default, + args=(x, [2, 4]), + kwargs={"tp_scaled_dim": -1, "layer_type": "unknown"}, + ) + graph.output(out) + gm = GraphModule(nn.Module(), graph) + view_nodes = [n for n in _call_function_nodes(gm) if ShardableNode.from_node(n) is not None] + assert len(view_nodes) == 1 + assert ShardableNode.from_node(view_nodes[0]) is not None + + +def test_enable_sharding_node_none_for_aten(): + class AtenLinearOnly(nn.Module): + def __init__(self): + super().__init__() + self.w = nn.Parameter(torch.randn(4, 8)) + self.b = nn.Parameter(torch.randn(4)) + + def forward(self, x): + return torch.ops.aten.linear.default(x, self.w, self.b) + + gm = torch.fx.symbolic_trace(AtenLinearOnly()) + aten_linear = [n for n in _call_function_nodes(gm) if n.target == torch.ops.aten.linear.default] + assert len(aten_linear) == 1 + assert ShardableNode.from_node(aten_linear[0]) is None From 04915adb8028319f4a73cd74781280109bae4be3 Mon Sep 17 00:00:00 2001 From: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> Date: Mon, 20 Apr 2026 19:32:16 +0800 Subject: [PATCH 7/8] [None] [chore] Update .github/CODEOWNERS (#13213) Signed-off-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> --- .github/CODEOWNERS | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 0ee86ed1b171..60fb64a573c0 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -211,6 +211,7 @@ docs/source/performance/perf-benchmarking.md @NVIDIA/trtllm-bench-reviewers ## TensorRT-LLM LLM Disaggregated /examples/disaggregated @NVIDIA/trt-llm-disagg-devs @NVIDIA/trt-llm-doc-owners +/examples/disaggregated/slurm/benchmark @NVIDIA/trt-llm-disagg-devs @NVIDIA/trtllm-bench-reviewers /tensorrt_llm/disaggregated_params.py @NVIDIA/trt-llm-disagg-devs /tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py @NVIDIA/trt-llm-disagg-devs /cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp @NVIDIA/trt-llm-disagg-devs From 14539f18ec0d7591810c9d83f0f45293e59e20ee Mon Sep 17 00:00:00 2001 From: Guoming Zhang <137257613+nv-guomingz@users.noreply.github.com> Date: Mon, 20 Apr 2026 22:06:27 +0800 Subject: [PATCH 8/8] =?UTF-8?q?[None][test]=20Fix=20DGX=5FB200=20CI=20time?= =?UTF-8?q?out=20by=20splitting=20multimodal=20tests=20an=E2=80=A6=20(#129?= =?UTF-8?q?78)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com> --- tests/integration/defs/.test_durations | 17 ++++++++++++++--- .../integration/test_lists/test-db/l0_b200.yml | 7 ++++++- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/tests/integration/defs/.test_durations b/tests/integration/defs/.test_durations index 08b28bbc939e..6c5f8b97a233 100644 --- a/tests/integration/defs/.test_durations +++ b/tests/integration/defs/.test_durations @@ -790,6 +790,7 @@ "test_unittests.py::test_unittests_v2[unittest/_torch/compilation]": 31.94, "test_unittests.py::test_unittests_v2[unittest/_torch/debugger]": 36.69, "test_unittests.py::test_unittests_v2[unittest/_torch/executor]": 170.86, + "test_unittests.py::test_unittests_v2[unittest/_torch/flashinfer/test_trtllm_flashinfer_symbol_collision.py]": 1004.0, "test_unittests.py::test_unittests_v2[unittest/_torch/misc]": 600.5, "test_unittests.py::test_unittests_v2[unittest/_torch/modeling -k \"modeling_llama\"]": 718.749935634085, "test_unittests.py::test_unittests_v2[unittest/_torch/modeling -k \"modeling_mixtral\"]": 208.1838396479725, @@ -801,11 +802,21 @@ "test_unittests.py::test_unittests_v2[unittest/_torch/modeling -k \"modeling_qwen_moe\"]": 401.2630233000382, "test_unittests.py::test_unittests_v2[unittest/_torch/modeling -k \"modeling_vila\"]": 79.90315388399176, "test_unittests.py::test_unittests_v2[unittest/_torch/modules]": 158.5, + "test_unittests.py::test_unittests_v2[unittest/_torch/modules/moe/test_moe_backend.py::test_moe_backend -k \"CUTEDSL\"]": 68.0, + "test_unittests.py::test_unittests_v2[unittest/_torch/modules/moe/test_moe_backend.py::test_moe_backend -k \"CUTLASS\"]": 1598.0, + "test_unittests.py::test_unittests_v2[unittest/_torch/modules/moe/test_moe_backend.py::test_moe_backend -k \"DEEPGEMM\"]": 22.0, + "test_unittests.py::test_unittests_v2[unittest/_torch/modules/moe/test_moe_backend.py::test_moe_backend -k \"DENSEGEMM\"]": 149.0, + "test_unittests.py::test_unittests_v2[unittest/_torch/modules/moe/test_moe_backend.py::test_moe_backend -k \"TRTLLM\"]": 998.0, "test_unittests.py::test_unittests_v2[unittest/_torch/multi_gpu_modeling -k \"deepseek\"]": 393.0210295501165, - "test_unittests.py::test_unittests_v2[unittest/_torch/multimodal]": 23.54, - "test_unittests.py::test_unittests_v2[unittest/_torch/sampler]": 107.66, + "test_unittests.py::test_unittests_v2[unittest/_torch/multimodal/test_external_embedding.py]": 30.0, + "test_unittests.py::test_unittests_v2[unittest/_torch/multimodal/test_find_num_image_tokens.py]": 240.0, + "test_unittests.py::test_unittests_v2[unittest/_torch/multimodal/test_fuse_input_embeds.py]": 60.0, + "test_unittests.py::test_unittests_v2[unittest/_torch/multimodal/test_mm_encoder_standalone.py]": 1800.0, + "test_unittests.py::test_unittests_v2[unittest/_torch/multimodal/test_multimodal_runtime.py]": 180.0, + "test_unittests.py::test_unittests_v2[unittest/_torch/multimodal/test_share_multiparams.py]": 30.0, + "test_unittests.py::test_unittests_v2[unittest/_torch/sampler]": 1020.0, "test_unittests.py::test_unittests_v2[unittest/_torch/speculative]": 1850.16, - "test_unittests.py::test_unittests_v2[unittest/_torch/thop/parallel]": 311.58, + "test_unittests.py::test_unittests_v2[unittest/_torch/thop/parallel]": 1463.0, "test_unittests.py::test_unittests_v2[unittest/_torch/thop/serial]": 18.96, "test_unittests.py::test_unittests_v2[unittest/api_stability]": 33.137137457728386, "test_unittests.py::test_unittests_v2[unittest/bindings]": 1119.2564616799355, diff --git a/tests/integration/test_lists/test-db/l0_b200.yml b/tests/integration/test_lists/test-db/l0_b200.yml index 3484e91a8e53..3f7e2f7aec9d 100644 --- a/tests/integration/test_lists/test-db/l0_b200.yml +++ b/tests/integration/test_lists/test-db/l0_b200.yml @@ -120,7 +120,12 @@ l0_b200: # ------------- MoE: FlashInfer & TRTLLM symbol collision tests --------------- - unittest/_torch/flashinfer/test_trtllm_flashinfer_symbol_collision.py # --- MoE end - - unittest/_torch/multimodal + - unittest/_torch/multimodal/test_mm_encoder_standalone.py + - unittest/_torch/multimodal/test_multimodal_runtime.py + - unittest/_torch/multimodal/test_find_num_image_tokens.py + - unittest/_torch/multimodal/test_fuse_input_embeds.py + - unittest/_torch/multimodal/test_external_embedding.py + - unittest/_torch/multimodal/test_share_multiparams.py - unittest/_torch/sampler - unittest/_torch/speculative - unittest/_torch/thop/parallel TIMEOUT (90)