[#11083][feat] Add hardware-aware MLA defaults to get_model_defaults()#12122
[#11083][feat] Add hardware-aware MLA defaults to get_model_defaults()#12122wojciech-wais wants to merge 2 commits intoNVIDIA:mainfrom
Conversation
…aults() Extend the get_model_defaults() framework to support pretrained_config parameter, enabling model-specific defaults that are both hardware-aware and model-config-aware. Changes: - Add optional pretrained_config parameter to get_model_defaults() base class and all existing overrides (NemotronH, Qwen3Next) - Pass pretrained_config from ModelLoader.load_config_and_apply_defaults() to get_model_defaults() so models can inspect their architecture details - Add get_model_defaults() to DeepseekV3ForCausalLM with SM-version-aware defaults: disable KV cache block reuse and chunked prefill on hardware that doesn't support MLA features (requires SM90/100/103/120) - Add get_model_defaults() to HunYuanDenseV1ForCausalLM with config-aware MLA defaults (only applied when use_mla=True in pretrained config) - Add comprehensive unit tests covering SM-version parametrization, user override preservation, and config-aware behavior These defaults ensure DeepSeek V3 and HunyuanDense MLA models work OOTB on all hardware without manual configuration, while still allowing users to explicitly override settings when needed. Ref: NVIDIA#11083 Signed-off-by: Wojciech Wais <wojciech.wais@gmail.com>
Address review gaps with additional tests: - KimiK25 inherits DeepseekV3 MLA defaults (both SM80 and SM90) - NemotronH and Qwen3Next backward compat with pretrained_config param - HunyuanDense returns empty dict when pretrained_config is None - Partial user override: user sets only enable_block_reuse, default still applied for enable_chunked_prefill - DeepseekV3 returns exactly empty dict on all supported SM versions Signed-off-by: Wojciech Wais <wojciech.wais@gmail.com>
📝 WalkthroughWalkthroughThe changes introduce an optional Changes
Sequence DiagramsequenceDiagram
participant Loader as model_loader
participant AutoModel as AutoModelForCausalLM
participant Config as pretrained_config
participant Defaults as get_model_defaults
participant SMCheck as SM Version Check
Loader->>AutoModel: load_model(llm_args, config)
Loader->>AutoModel: get_model_defaults(llm_args, pretrained_config=config)
AutoModel->>Defaults: execute implementation
Defaults->>SMCheck: get_sm_version()
SMCheck-->>Defaults: SM version (90/100/103/120?)
alt SM version in {90, 100, 103, 120}
Defaults-->>AutoModel: return {}
else SM version not supported
Defaults->>Config: check use_mla or model-specific config
Config-->>Defaults: config attributes
Defaults-->>AutoModel: return {kv_cache_config.enable_block_reuse: False, enable_chunked_prefill: False}
end
AutoModel-->>Loader: model defaults dict
Loader->>Loader: merge defaults with llm_args (user overrides take precedence)
Loader-->>Loader: return configured model
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (2)
tensorrt_llm/_torch/models/modeling_hunyuan_dense.py (1)
11-11: Prefer a module import for_utils.Line 11 imports
get_sm_versiondirectly. This works, but it goes against the repo’s namespace-preserving Python import rule; importing the module and callingmodule.get_sm_version()would match the project style.Suggested cleanup
-from tensorrt_llm._utils import get_sm_version +import tensorrt_llm._utils as trtllm_utils ... - sm_version = get_sm_version() + sm_version = trtllm_utils.get_sm_version()As per coding guidelines, "When importing in Python, always maintain the namespace. Import the module, not individual classes or functions (e.g., use
from package.subpackage import foothenfoo.SomeClass()instead offrom package.subpackage.foo import SomeClass)."🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/models/modeling_hunyuan_dense.py` at line 11, Replace the direct function import with a module import to preserve namespace: change the import of get_sm_version to import tensorrt_llm._utils as _utils (or similar) and update all usages of get_sm_version() in modeling_hunyuan_dense.py to call _utils.get_sm_version() so the code follows the project's namespace-preserving import style.tests/unittest/llmapi/test_llm_args.py (1)
375-398: Consider adding assertion for kv_cache_config in applied dict.The test correctly verifies user overrides take precedence, but only asserts that
enable_chunked_prefillis not in theapplieddict. For completeness, consider also verifying thatkv_cache_config(or itsenable_block_reusesub-field) is not inappliedsince the user also overrode that.♻️ Suggested additional assertion
# User overrides should win assert llm_args.kv_cache_config.enable_block_reuse is True assert llm_args.enable_chunked_prefill is True # Applied dict should be empty since user overrode everything assert "enable_chunked_prefill" not in applied + assert "kv_cache_config" not in applied or \ + "enable_block_reuse" not in applied.get("kv_cache_config", {})🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/llmapi/test_llm_args.py` around lines 375 - 398, Add an assertion to test_deepseekv3_defaults_respect_user_override to verify that the user-provided kv_cache_config was not added to the applied defaults: after computing applied = apply_model_defaults_to_llm_args(llm_args, defaults), assert that either "kv_cache_config" not in applied or that the nested field "kv_cache_config.enable_block_reuse" is not present in applied (use whichever shape apply_model_defaults_to_llm_args returns) to mirror the existing check for "enable_chunked_prefill" and ensure user overrides for enable_block_reuse are respected.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tensorrt_llm/_torch/models/modeling_deepseekv3.py`:
- Around line 1917-1920: Add a TYPE_CHECKING-guarded import for TorchLlmArgs and
use it for the get_model_defaults signature: inside the module add "from typing
import TYPE_CHECKING" and under "if TYPE_CHECKING:" add "from
<module_where_TorchLlmArgs_is_defined> import TorchLlmArgs" (so the runtime
import is avoided but the annotation 'TorchLlmArgs' in get_model_defaults
remains valid). Update any existing import style to match the other model files
(modeling_nemotron_h.py, modeling_qwen3_next.py) so the classmethod
get_model_defaults(llm_args: 'TorchLlmArgs', ...) has a corresponding
TYPE_CHECKING import.
In `@tensorrt_llm/_torch/models/modeling_hunyuan_dense.py`:
- Around line 630-632: The file is missing the TYPE_CHECKING guarded import for
the TorchLlmArgs type used in get_model_defaults; add "from typing import
TYPE_CHECKING" at top and inside an if TYPE_CHECKING: block import TorchLlmArgs
from the module where it is defined (same pattern used in modeling_qwen3_next.py
and modeling_nemotron_h.py) so the annotation 'TorchLlmArgs' resolves without
importing it at runtime; update imports near the other typing guards to
reference the exact symbol name TorchLlmArgs.
In `@tensorrt_llm/_torch/models/modeling_utils.py`:
- Around line 527-529: The get_model_defaults method uses the forward-ref string
'TorchLlmArgs' but doesn't declare it for static type checkers; add "from typing
import TYPE_CHECKING" at top and inside "if TYPE_CHECKING:" import TorchLlmArgs
from its module (the same import pattern used in modeling_nemotron_h.py /
modeling_qwen3_next.py) so the name is defined for linters/type checkers while
avoiding runtime imports.
---
Nitpick comments:
In `@tensorrt_llm/_torch/models/modeling_hunyuan_dense.py`:
- Line 11: Replace the direct function import with a module import to preserve
namespace: change the import of get_sm_version to import tensorrt_llm._utils as
_utils (or similar) and update all usages of get_sm_version() in
modeling_hunyuan_dense.py to call _utils.get_sm_version() so the code follows
the project's namespace-preserving import style.
In `@tests/unittest/llmapi/test_llm_args.py`:
- Around line 375-398: Add an assertion to
test_deepseekv3_defaults_respect_user_override to verify that the user-provided
kv_cache_config was not added to the applied defaults: after computing applied =
apply_model_defaults_to_llm_args(llm_args, defaults), assert that either
"kv_cache_config" not in applied or that the nested field
"kv_cache_config.enable_block_reuse" is not present in applied (use whichever
shape apply_model_defaults_to_llm_args returns) to mirror the existing check for
"enable_chunked_prefill" and ensure user overrides for enable_block_reuse are
respected.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: b3fea2c9-0a82-40b3-8aa8-2b6372c230bc
📒 Files selected for processing (7)
tensorrt_llm/_torch/models/modeling_deepseekv3.pytensorrt_llm/_torch/models/modeling_hunyuan_dense.pytensorrt_llm/_torch/models/modeling_nemotron_h.pytensorrt_llm/_torch/models/modeling_qwen3_next.pytensorrt_llm/_torch/models/modeling_utils.pytensorrt_llm/_torch/pyexecutor/model_loader.pytests/unittest/llmapi/test_llm_args.py
| @classmethod | ||
| def get_model_defaults(cls, | ||
| llm_args: 'TorchLlmArgs', | ||
| pretrained_config=None) -> dict: |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, let's check the current imports in modeling_deepseekv3.py
head -100 tensorrt_llm/_torch/models/modeling_deepseekv3.py | cat -nRepository: NVIDIA/TensorRT-LLM
Length of output: 5100
🏁 Script executed:
# Check if TYPE_CHECKING or TorchLlmArgs is already imported
grep -n "TYPE_CHECKING\|TorchLlmArgs" tensorrt_llm/_torch/models/modeling_deepseekv3.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 120
🏁 Script executed:
# Check the actual method signature around lines 1917-1920
sed -n '1910,1930p' tensorrt_llm/_torch/models/modeling_deepseekv3.py | cat -nRepository: NVIDIA/TensorRT-LLM
Length of output: 1114
🏁 Script executed:
# Verify the pattern in nemotron_h.py
grep -n "TYPE_CHECKING" tensorrt_llm/_torch/models/modeling_nemotron_h.py | head -5Repository: NVIDIA/TensorRT-LLM
Length of output: 119
🏁 Script executed:
# Verify the pattern in qwen3_next.py
grep -n "TYPE_CHECKING" tensorrt_llm/_torch/models/modeling_qwen3_next.py | head -5Repository: NVIDIA/TensorRT-LLM
Length of output: 141
🏁 Script executed:
# Check if TorchLlmArgs import exists with TYPE_CHECKING in the other files
grep -A 3 "TYPE_CHECKING" tensorrt_llm/_torch/models/modeling_nemotron_h.py | head -10Repository: NVIDIA/TensorRT-LLM
Length of output: 208
Import TorchLlmArgs using the TYPE_CHECKING pattern for consistency.
Line 1919 uses 'TorchLlmArgs' in the method signature, but this file does not import it. The other model files in this PR (modeling_nemotron_h.py and modeling_qwen3_next.py) handle the same annotation using a TYPE_CHECKING import, which is the standard pattern in this codebase for avoiding circular import issues while maintaining proper type hints.
Suggested fix
-from typing import Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
+
+if TYPE_CHECKING:
+ from tensorrt_llm.llmapi.llm_args import TorchLlmArgs🧰 Tools
🪛 Ruff (0.15.5)
[warning] 1919-1919: Unused class method argument: llm_args
(ARG003)
[error] 1919-1919: Undefined name TorchLlmArgs
(F821)
[warning] 1920-1920: Unused class method argument: pretrained_config
(ARG003)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/models/modeling_deepseekv3.py` around lines 1917 - 1920,
Add a TYPE_CHECKING-guarded import for TorchLlmArgs and use it for the
get_model_defaults signature: inside the module add "from typing import
TYPE_CHECKING" and under "if TYPE_CHECKING:" add "from
<module_where_TorchLlmArgs_is_defined> import TorchLlmArgs" (so the runtime
import is avoided but the annotation 'TorchLlmArgs' in get_model_defaults
remains valid). Update any existing import style to match the other model files
(modeling_nemotron_h.py, modeling_qwen3_next.py) so the classmethod
get_model_defaults(llm_args: 'TorchLlmArgs', ...) has a corresponding
TYPE_CHECKING import.
| def get_model_defaults(cls, | ||
| llm_args: 'TorchLlmArgs', | ||
| pretrained_config=None) -> dict: |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# Find and examine the file mentioned in the review
fd "modeling_hunyuan_dense.py" -t fRepository: NVIDIA/TensorRT-LLM
Length of output: 115
🏁 Script executed:
# Check the current imports and the method signature at lines 630-632
cd tensorrt_llm/_torch/models && head -n 640 modeling_hunyuan_dense.py | tail -n 50Repository: NVIDIA/TensorRT-LLM
Length of output: 2140
🏁 Script executed:
# Search for all usages of TorchLlmArgs in the file
rg "TorchLlmArgs" tensorrt_llm/_torch/models/modeling_hunyuan_dense.py -nRepository: NVIDIA/TensorRT-LLM
Length of output: 119
🏁 Script executed:
# Check if TorchLlmArgs is already imported anywhere in the file
rg "^from|^import" tensorrt_llm/_torch/models/modeling_hunyuan_dense.py -n | head -n 30Repository: NVIDIA/TensorRT-LLM
Length of output: 1167
🏁 Script executed:
# Look for the reference files (NemotronH/Qwen3Next) to see the pattern used
fd "Qwen3Next\|NemotronH" -t dRepository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
# Search for files that import TorchLlmArgs
rg "from.*import.*TorchLlmArgs|TorchLlmArgs" --type py -l | head -20Repository: NVIDIA/TensorRT-LLM
Length of output: 888
🏁 Script executed:
# Check if the module tensorrt_llm/llmapi/llm_args.py exists and has TorchLlmArgs
fd "llm_args.py" -t fRepository: NVIDIA/TensorRT-LLM
Length of output: 177
🏁 Script executed:
# Look at one file that imports TorchLlmArgs to see the pattern
rg "from.*import.*TorchLlmArgs" --type py -A 2 -B 2 | head -40Repository: NVIDIA/TensorRT-LLM
Length of output: 3173
🏁 Script executed:
# Check the first 30 lines to see the import structure
head -n 30 tensorrt_llm/_torch/models/modeling_hunyuan_dense.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 1290
🏁 Script executed:
# Also check modeling_qwen3_next.py to see exact placement of TYPE_CHECKING import
head -n 30 tensorrt_llm/_torch/models/modeling_qwen3_next.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 1247
Add the missing TorchLlmArgs TYPE_CHECKING import.
Line 631 uses 'TorchLlmArgs' as a type annotation, but the symbol is never imported or defined in this file. Add the TYPE_CHECKING pattern used in other model files like modeling_qwen3_next.py and modeling_nemotron_h.py:
Suggested fix
-from typing import Dict, List, Optional, Union
+from typing import TYPE_CHECKING, Dict, List, Optional, Union
+
+if TYPE_CHECKING:
+ from tensorrt_llm.llmapi.llm_args import TorchLlmArgs🧰 Tools
🪛 Ruff (0.15.5)
[warning] 631-631: Unused class method argument: llm_args
(ARG003)
[error] 631-631: Undefined name TorchLlmArgs
(F821)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/models/modeling_hunyuan_dense.py` around lines 630 - 632,
The file is missing the TYPE_CHECKING guarded import for the TorchLlmArgs type
used in get_model_defaults; add "from typing import TYPE_CHECKING" at top and
inside an if TYPE_CHECKING: block import TorchLlmArgs from the module where it
is defined (same pattern used in modeling_qwen3_next.py and
modeling_nemotron_h.py) so the annotation 'TorchLlmArgs' resolves without
importing it at runtime; update imports near the other typing guards to
reference the exact symbol name TorchLlmArgs.
| def get_model_defaults(cls, | ||
| llm_args: 'TorchLlmArgs', | ||
| pretrained_config=None) -> dict: |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, check the imports at the top of the file
head -30 tensorrt_llm/_torch/models/modeling_utils.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 1154
🏁 Script executed:
# Check the specific lines mentioned (527-529)
sed -n '520,540p' tensorrt_llm/_torch/models/modeling_utils.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 1000
🏁 Script executed:
# Look for TorchLlmArgs imports in the file
rg "TorchLlmArgs" tensorrt_llm/_torch/models/modeling_utils.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 115
🏁 Script executed:
# Check the related model files mentioned in the comment
rg "TYPE_CHECKING" tensorrt_llm/_torch/models/modeling_nemotron_h.py tensorrt_llm/_torch/models/modeling_qwen3_next.py -A 5Repository: NVIDIA/TensorRT-LLM
Length of output: 1508
🏁 Script executed:
# Verify how TorchLlmArgs is imported in these files
rg "from.*TorchLlmArgs|import.*TorchLlmArgs" tensorrt_llm/_torch/models/ -B 2 -A 2Repository: NVIDIA/TensorRT-LLM
Length of output: 787
Add TYPE_CHECKING import to resolve undefined name in type annotation.
Line 528 uses the string annotation 'TorchLlmArgs' without importing it. While the string quotes prevent a runtime error, this will trigger Ruff F821 (undefined name) in strict linting configurations. The related model files (modeling_nemotron_h.py, modeling_qwen3_next.py) already follow the TYPE_CHECKING pattern for this import. Add the same pattern here to keep the shared base API lint-clean:
Suggested fix
-from typing import Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union
+from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union
+
+if TYPE_CHECKING:
+ from tensorrt_llm.llmapi.llm_args import TorchLlmArgs📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def get_model_defaults(cls, | |
| llm_args: 'TorchLlmArgs', | |
| pretrained_config=None) -> dict: | |
| from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union | |
| if TYPE_CHECKING: | |
| from tensorrt_llm.llmapi.llm_args import TorchLlmArgs |
🧰 Tools
🪛 Ruff (0.15.5)
[warning] 528-528: Unused class method argument: llm_args
(ARG003)
[error] 528-528: Undefined name TorchLlmArgs
(F821)
[warning] 529-529: Unused class method argument: pretrained_config
(ARG003)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/models/modeling_utils.py` around lines 527 - 529, The
get_model_defaults method uses the forward-ref string 'TorchLlmArgs' but doesn't
declare it for static type checkers; add "from typing import TYPE_CHECKING" at
top and inside "if TYPE_CHECKING:" import TorchLlmArgs from its module (the same
import pattern used in modeling_nemotron_h.py / modeling_qwen3_next.py) so the
name is defined for linters/type checkers while avoiding runtime imports.
Extend the get_model_defaults() framework to support pretrained_config parameter, enabling model-specific defaults that are both hardware-aware and model-config-aware.
Changes:
These defaults ensure DeepSeek V3 and HunyuanDense MLA models work OOTB on all hardware without manual configuration, while still allowing users to explicitly override settings when needed.
Ref: #11083
Summary by CodeRabbit
Release Notes
New Features
Tests
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.