|
13 | 13 | import json |
14 | 14 | import struct |
15 | 15 |
|
16 | | -from tensorrt_llm._torch.pyexecutor.config_utils import load_pretrained_config |
| 16 | +from tensorrt_llm._torch.pyexecutor.config_utils import ( |
| 17 | + load_pretrained_config, get_qwen3_hybrid_layer_types) |
17 | 18 |
|
18 | 19 |
|
19 | 20 | def parse_safetensors_file_metadata(model_path, filename): |
@@ -113,8 +114,9 @@ def _parse(filename: str) -> None: |
113 | 114 |
|
114 | 115 |
|
115 | 116 | class ModelConfig(BaseModel): |
116 | | - """ Model specific configurations. The parameters are needed in engine |
117 | | - setting calculation. |
| 117 | + """Model specific configurations. |
| 118 | +
|
| 119 | + The parameters are needed in engine setting calculation. |
118 | 120 | """ |
119 | 121 | name: str |
120 | 122 | model_type: str |
@@ -254,3 +256,55 @@ def cache_memory_fraction(self, cache_memory_fraction): |
254 | 256 |
|
255 | 257 | def set_mamba_ssm_cache_dtype(self, mamba_ssm_cache_dtype: str): |
256 | 258 | self.mamba_ssm_cache_dtype = mamba_ssm_cache_dtype |
| 259 | + |
| 260 | + |
| 261 | +class Qwen3HybridConfig(ModelConfig): |
| 262 | + """Config for Qwen3 hybrid models (full-attention + linear-attention layers). |
| 263 | +
|
| 264 | + Maps Qwen3.5 linear-attention parameters to the same cache estimation |
| 265 | + formulas used by NemotronHybridConfig. |
| 266 | + """ |
| 267 | + linear_key_head_dim: int # d_state |
| 268 | + linear_conv_kernel_dim: int # d_conv |
| 269 | + linear_num_value_heads: int # num_heads (mamba_num_heads) |
| 270 | + linear_num_key_heads: int # n_groups |
| 271 | + linear_value_head_dim: int # head_dim (mamba_head_dim) |
| 272 | + num_linear_attention_layers: Optional[int] = Field(default=None) |
| 273 | + mamba_ssm_cache_dtype: Optional[str] = Field(default="auto") |
| 274 | + |
| 275 | + @model_validator(mode="after") |
| 276 | + def set_values_if_none(self): |
| 277 | + """Derive num_attention_layers and num_linear_attention_layers. |
| 278 | +
|
| 279 | + Uses the HF config's layer_types / full_attention_interval. |
| 280 | + """ |
| 281 | + if self.num_linear_attention_layers is None or self.num_attention_layers is None: |
| 282 | + pretrained_config = load_pretrained_config(self.name, |
| 283 | + trust_remote_code=True) |
| 284 | + layer_types = get_qwen3_hybrid_layer_types(pretrained_config) |
| 285 | + if self.num_attention_layers is None: |
| 286 | + self.num_attention_layers = sum(1 for lt in layer_types |
| 287 | + if lt == "full_attention") |
| 288 | + if self.num_linear_attention_layers is None: |
| 289 | + self.num_linear_attention_layers = sum( |
| 290 | + 1 for lt in layer_types if lt == "linear_attention") |
| 291 | + |
| 292 | + super().set_values_if_none() |
| 293 | + return self |
| 294 | + |
| 295 | + def extra_model_cache_in_gb(self, bytes_per_elem, target_seq_len=None): |
| 296 | + d_inner = self.linear_value_head_dim * self.linear_num_value_heads |
| 297 | + conv_dim = d_inner + 2 * self.linear_num_key_heads * self.linear_key_head_dim |
| 298 | + conv_state_elems = conv_dim * (self.linear_conv_kernel_dim - 1) |
| 299 | + ssm_state_elems = (self.linear_num_value_heads * |
| 300 | + self.linear_value_head_dim * |
| 301 | + self.linear_key_head_dim) |
| 302 | + gb_per_cache = bytes_per_elem * self.num_linear_attention_layers * ( |
| 303 | + conv_state_elems + ssm_state_elems) / (1024**3) |
| 304 | + return gb_per_cache |
| 305 | + |
| 306 | + def cache_memory_fraction(self, cache_memory_fraction): |
| 307 | + return cache_memory_fraction**2 |
| 308 | + |
| 309 | + def set_mamba_ssm_cache_dtype(self, mamba_ssm_cache_dtype: str): |
| 310 | + self.mamba_ssm_cache_dtype = mamba_ssm_cache_dtype |
0 commit comments