Skip to content

Commit be1f6f5

Browse files
authored
[https://nvbugs/6095953][fix] Fix cache memory estimation for Qwen3 hybrid models in trtllm-bench (#13268)
Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
1 parent d1731cb commit be1f6f5

4 files changed

Lines changed: 70 additions & 14 deletions

File tree

tensorrt_llm/bench/benchmark/utils/general.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
validate_and_set_kv_cache_quant
1212
from tensorrt_llm.bench.build.build import (get_benchmark_engine_settings,
1313
get_model_config)
14-
from tensorrt_llm.bench.build.dataclasses import NemotronHybridConfig
14+
from tensorrt_llm.bench.build.dataclasses import (NemotronHybridConfig,
15+
Qwen3HybridConfig)
1516
from tensorrt_llm.bench.dataclasses.general import (DatasetMetadata,
1617
InferenceRequest)
1718
from tensorrt_llm.logger import logger
@@ -111,10 +112,9 @@ def get_settings(params: dict, dataset_metadata: DatasetMetadata, model: str,
111112
else:
112113
model_config = get_model_config(model, model_path)
113114

114-
if isinstance(
115-
model_config,
116-
NemotronHybridConfig) and mamba_ssm_cache_dtype not in (None,
117-
"auto"):
115+
if isinstance(model_config,
116+
(NemotronHybridConfig, Qwen3HybridConfig
117+
)) and mamba_ssm_cache_dtype not in (None, "auto"):
118118
model_config.set_mamba_ssm_cache_dtype(mamba_ssm_cache_dtype)
119119

120120
from tensorrt_llm._torch.model_config import ModelConfig

tensorrt_llm/bench/build/build.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import click
66
from click_option_group import AllOptionGroup, optgroup
77

8-
from tensorrt_llm._torch.pyexecutor.config_utils import is_nemotron_hybrid, load_pretrained_config
8+
from tensorrt_llm._torch.pyexecutor.config_utils import is_nemotron_hybrid, is_qwen3_hybrid, load_pretrained_config
99
from tensorrt_llm.bench.dataclasses.general import BenchmarkEnvironment
1010
from tensorrt_llm.bench.utils.data import create_dataset_from_stream, initialize_tokenizer
1111
from tensorrt_llm.bench.utils import VALID_QUANT_ALGOS
@@ -14,7 +14,7 @@
1414
from tensorrt_llm.llmapi.llm_utils import QuantConfig
1515
from tensorrt_llm.logger import logger
1616
from tensorrt_llm.quantization.mode import QuantAlgo
17-
from tensorrt_llm.bench.build.dataclasses import ModelConfig, NemotronHybridConfig
17+
from tensorrt_llm.bench.build.dataclasses import ModelConfig, NemotronHybridConfig, Qwen3HybridConfig
1818
from tensorrt_llm.bench.build.tuning import calc_engine_setting
1919

2020
TUNED_QUANTS = {
@@ -89,6 +89,8 @@ def get_model_config(model_name: str, model_path: Path = None) -> ModelConfig:
8989
trust_remote_code=True)
9090
if is_nemotron_hybrid(pretrained_config):
9191
return NemotronHybridConfig.from_hf(model_name, model_path)
92+
if is_qwen3_hybrid(pretrained_config):
93+
return Qwen3HybridConfig.from_hf(model_name, model_path)
9294
return ModelConfig.from_hf(model_name, model_path)
9395

9496

tensorrt_llm/bench/build/dataclasses.py

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
import json
1414
import struct
1515

16-
from tensorrt_llm._torch.pyexecutor.config_utils import load_pretrained_config
16+
from tensorrt_llm._torch.pyexecutor.config_utils import (
17+
load_pretrained_config, get_qwen3_hybrid_layer_types)
1718

1819

1920
def parse_safetensors_file_metadata(model_path, filename):
@@ -113,8 +114,9 @@ def _parse(filename: str) -> None:
113114

114115

115116
class ModelConfig(BaseModel):
116-
""" Model specific configurations. The parameters are needed in engine
117-
setting calculation.
117+
"""Model specific configurations.
118+
119+
The parameters are needed in engine setting calculation.
118120
"""
119121
name: str
120122
model_type: str
@@ -254,3 +256,55 @@ def cache_memory_fraction(self, cache_memory_fraction):
254256

255257
def set_mamba_ssm_cache_dtype(self, mamba_ssm_cache_dtype: str):
256258
self.mamba_ssm_cache_dtype = mamba_ssm_cache_dtype
259+
260+
261+
class Qwen3HybridConfig(ModelConfig):
262+
"""Config for Qwen3 hybrid models (full-attention + linear-attention layers).
263+
264+
Maps Qwen3.5 linear-attention parameters to the same cache estimation
265+
formulas used by NemotronHybridConfig.
266+
"""
267+
linear_key_head_dim: int # d_state
268+
linear_conv_kernel_dim: int # d_conv
269+
linear_num_value_heads: int # num_heads (mamba_num_heads)
270+
linear_num_key_heads: int # n_groups
271+
linear_value_head_dim: int # head_dim (mamba_head_dim)
272+
num_linear_attention_layers: Optional[int] = Field(default=None)
273+
mamba_ssm_cache_dtype: Optional[str] = Field(default="auto")
274+
275+
@model_validator(mode="after")
276+
def set_values_if_none(self):
277+
"""Derive num_attention_layers and num_linear_attention_layers.
278+
279+
Uses the HF config's layer_types / full_attention_interval.
280+
"""
281+
if self.num_linear_attention_layers is None or self.num_attention_layers is None:
282+
pretrained_config = load_pretrained_config(self.name,
283+
trust_remote_code=True)
284+
layer_types = get_qwen3_hybrid_layer_types(pretrained_config)
285+
if self.num_attention_layers is None:
286+
self.num_attention_layers = sum(1 for lt in layer_types
287+
if lt == "full_attention")
288+
if self.num_linear_attention_layers is None:
289+
self.num_linear_attention_layers = sum(
290+
1 for lt in layer_types if lt == "linear_attention")
291+
292+
super().set_values_if_none()
293+
return self
294+
295+
def extra_model_cache_in_gb(self, bytes_per_elem, target_seq_len=None):
296+
d_inner = self.linear_value_head_dim * self.linear_num_value_heads
297+
conv_dim = d_inner + 2 * self.linear_num_key_heads * self.linear_key_head_dim
298+
conv_state_elems = conv_dim * (self.linear_conv_kernel_dim - 1)
299+
ssm_state_elems = (self.linear_num_value_heads *
300+
self.linear_value_head_dim *
301+
self.linear_key_head_dim)
302+
gb_per_cache = bytes_per_elem * self.num_linear_attention_layers * (
303+
conv_state_elems + ssm_state_elems) / (1024**3)
304+
return gb_per_cache
305+
306+
def cache_memory_fraction(self, cache_memory_fraction):
307+
return cache_memory_fraction**2
308+
309+
def set_mamba_ssm_cache_dtype(self, mamba_ssm_cache_dtype: str):
310+
self.mamba_ssm_cache_dtype = mamba_ssm_cache_dtype

tensorrt_llm/bench/build/tuning.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from tensorrt_llm.llmapi.llm_utils import QuantConfig
77
from tensorrt_llm.logger import logger
88
from tensorrt_llm.quantization.mode import QuantAlgo
9-
from tensorrt_llm.bench.build.dataclasses import ModelConfig, NemotronHybridConfig
9+
from tensorrt_llm.bench.build.dataclasses import ModelConfig, NemotronHybridConfig, Qwen3HybridConfig
1010
from .utils import get_device_memory
1111
import math
1212

@@ -82,7 +82,7 @@ def calc_engine_setting(
8282
kv_cache_gpu_mem_fraction)
8383

8484
bytes_per_elem = BYTES_PER_ELEM.get(QuantAlgo.NO_QUANT)
85-
if isinstance(model_config, NemotronHybridConfig):
85+
if isinstance(model_config, (NemotronHybridConfig, Qwen3HybridConfig)):
8686
mamba_ssm_cache_dtype = model_config.mamba_ssm_cache_dtype
8787
if mamba_ssm_cache_dtype != "auto":
8888
if str_dtype_to_torch(mamba_ssm_cache_dtype) == torch.float32:
@@ -110,8 +110,8 @@ def calc_engine_setting(
110110
target_input_len,
111111
target_output_len,
112112
pp_size,
113-
disable_optimistic_tuning=isinstance(model_config,
114-
NemotronHybridConfig))
113+
disable_optimistic_tuning=isinstance(
114+
model_config, (NemotronHybridConfig, Qwen3HybridConfig)))
115115

116116
# Functional and performance
117117
if total_gpu_memory < engine_size:

0 commit comments

Comments
 (0)