Skip to content

Commit 0dc4b4e

Browse files
Fridah-nvh-guo18
andauthored
[NVIDIA#4403][autodeploy] Refactor: Move more transformations to new inf optimizer, Add quantization_source to factory interface (NVIDIA#6760)
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> Co-authored-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
1 parent 7c686ba commit 0dc4b4e

30 files changed

+2189
-1211
lines changed

tensorrt_llm/_torch/auto_deploy/config/default.yaml

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,6 @@ transforms:
1919
stage: post_export
2020
cleanup_input_constraints:
2121
stage: post_export
22-
quantize:
23-
stage: pattern_matcher
24-
quantize_moe:
25-
stage: pattern_matcher
2622
match_repeat_kv:
2723
stage: pattern_matcher
2824
match_eager_attention:
@@ -31,3 +27,35 @@ transforms:
3127
stage: pattern_matcher
3228
match_attention_layout:
3329
stage: pattern_matcher
30+
match_moe_pattern:
31+
stage: pattern_matcher
32+
match_rope_pattern:
33+
stage: pattern_matcher
34+
match_rope_layout:
35+
stage: pattern_matcher
36+
eliminate_redundant_transposes:
37+
stage: pattern_matcher
38+
# TODO (lucaslie): let's move this to perf optimization once TP sharding is improved
39+
# see https://github.com/NVIDIA/TensorRT-LLM/pull/3668#discussion_r2052714528
40+
optimize_rope:
41+
stage: pattern_matcher
42+
quantize_from_config:
43+
stage: pattern_matcher
44+
quantize_from_graph:
45+
stage: pattern_matcher
46+
quantize_moe:
47+
stage: pattern_matcher
48+
# TODO: Infer sharding parameters (tp_size, row/column sharding) from the model config.
49+
detect_column_row_shard:
50+
stage: sharding
51+
simple_shard_only: false
52+
detect_ep_shard:
53+
stage: sharding
54+
detect_dp_bmm_shard:
55+
stage: sharding
56+
# TODO: (hg) need to ensure run_shape_prop after sharding.
57+
sharding_transform_executor:
58+
stage: sharding
59+
run_shape_prop: true
60+
load_weights:
61+
stage: weight_load

tensorrt_llm/_torch/auto_deploy/models/hf.py

Lines changed: 28 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Interface to initialize and load HF models."""
22

3-
import json
43
import os
54
import types
65
from contextlib import contextmanager, nullcontext
@@ -31,6 +30,7 @@
3130
from ..utils._config import deep_merge_dicts
3231
from ..utils.logger import ad_logger
3332
from .factory import ModelFactory, ModelFactoryRegistry
33+
from .quant_config_reader import QuantConfigReader, QuantConfigReaderRegistry
3434

3535

3636
@contextmanager
@@ -84,9 +84,7 @@ def _get_max_position_embeddings_config(self) -> Dict[str, Any]:
8484

8585
def __init__(self, *args, **kwargs):
8686
super().__init__(*args, **kwargs)
87-
88-
self._quant_config: Optional[Dict] = None
89-
87+
self._quant_config_reader: QuantConfigReader | None = None
9088
# Ingest defaults for tokenizer and model kwargs
9189
self.tokenizer_kwargs = deep_merge_dicts(self._tokenizer_defaults, self.tokenizer_kwargs)
9290
self.model_kwargs = deep_merge_dicts(
@@ -156,9 +154,6 @@ def _recursive_update_config(self, config: PretrainedConfig, update_dict: Dict[s
156154

157155
def _build_model(self, device: DeviceLikeType) -> nn.Module:
158156
"""Build the model on the desired device."""
159-
# We only support fp16 to fp4 conversion.
160-
if self._quant_config and self._quant_config.get("quant_algo", None) == "NVFP4":
161-
self.model_kwargs["torch_dtype"] = torch.half
162157

163158
# NOTE (lucaslie): HF doesn't recursively update nested PreTrainedConfig objects. Instead,
164159
# the entire subconfig will be overwritten.
@@ -178,23 +173,27 @@ def _build_model(self, device: DeviceLikeType) -> nn.Module:
178173
model.forward = types.MethodType(self._simple_forward, model)
179174

180175
model.eval()
176+
181177
return model
182178

183179
def get_quant_config(self) -> Dict:
184-
return self._quant_config or {}
180+
"""Returns the quantization config for this model or an empty dict if not quantized."""
181+
if self._quant_config_reader is not None:
182+
return self._quant_config_reader.get_config()
183+
return {}
185184

186185
def get_cache_config(self):
187-
"""Setup cache information based on quantization information."""
188-
if self._quant_config is not None and "kv_cache_quant_algo" in self._quant_config.keys():
189-
kv_cache_format = self._quant_config.get("kv_cache_quant_algo", None)
190-
if kv_cache_format is not None:
191-
assert kv_cache_format == "FP8", (
192-
f"KV cache quantization format {kv_cache_format} is not supported."
193-
)
194-
kv_cache_dtype = torch.float8_e4m3fn if kv_cache_format is not None else None
195-
else:
196-
kv_cache_dtype = None
197-
return CacheConfig(dtype=kv_cache_dtype)
186+
"""Return kv cache dtype configuration."""
187+
if not self._quant_config_reader:
188+
return CacheConfig(dtype=None)
189+
190+
kv_cache_dtype = self._quant_config_reader.get_config().get("kv_cache_dtype")
191+
torch_dtype = torch.float8_e4m3fn if kv_cache_dtype == "float8_e4m3fn" else None
192+
assert torch_dtype in (torch.float8_e4m3fn, None), (
193+
f"Unsupported dtype: {torch_dtype}. Only torch.float8_e4m3fn is supported."
194+
)
195+
196+
return CacheConfig(dtype=torch_dtype)
198197

199198
def init_tokenizer(self) -> Optional[Any]:
200199
"""Initialize the tokenizer—either a custom name or the model's default."""
@@ -325,22 +324,18 @@ def _load_checkpoint(self, model: nn.Module, device: DeviceLikeType):
325324

326325
def _load_quantization_config(self, fetched_dir: str):
327326
"""Load the quantization config from the model directory if not done already."""
328-
if self._quant_config is not None:
327+
if self._quant_config_reader is not None:
328+
return
329+
# TODO: specified by user or auto-detect
330+
reader_cls = QuantConfigReaderRegistry.get("modelopt")
331+
result = reader_cls.from_file(fetched_dir)
332+
if result is None:
329333
return
334+
reader, extra_model_kwargs = result
330335

331-
assert self.model
332-
hf_quant_config_file = os.path.join(fetched_dir, "hf_quant_config.json")
333-
if os.path.exists(hf_quant_config_file):
334-
with open(hf_quant_config_file, "r") as file:
335-
quantization_config = json.load(file)
336-
assert quantization_config.get("producer", {}).get("name", None) == "modelopt", (
337-
"Only support modelopt quantized checkpoint"
338-
)
339-
self._quant_config = quantization_config.get("quantization", {})
340-
341-
# We do not quantize lm_head.
342-
if "exclude_modules" not in self._quant_config:
343-
self._quant_config["exclude_modules"] = ["lm_head"]
336+
if reader is not None:
337+
self._quant_config_reader = reader
338+
self.model_kwargs = deep_merge_dicts(self.model_kwargs, extra_model_kwargs)
344339

345340

346341
@ModelFactoryRegistry.register("AutoModelForImageTextToText")
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
"""
2+
Quantization Config Reader Registry.
3+
4+
This module defines a registry system for parsing quantization configurations
5+
from various sources (e.g., 'modelopt'). It enables extensible support for different
6+
quantization producers by delegating parsing logic to dedicated subclasses.
7+
"""
8+
9+
import json
10+
import os
11+
from abc import ABC, abstractmethod
12+
from typing import Any, Callable, Dict, Optional, Tuple, Type
13+
14+
15+
class QuantConfigReader(ABC):
16+
"""Base class for reading and parsing quantization config."""
17+
18+
def __init__(self):
19+
self._quant_config: Optional[Dict] = {}
20+
21+
def get_config(self) -> Dict:
22+
"""Return the parsed quantization config."""
23+
return self._quant_config
24+
25+
@abstractmethod
26+
def read_config(self, config: Dict) -> Dict:
27+
"""
28+
Parse and normalize a quantization config dictionary.
29+
30+
Args:
31+
config: The raw parsed JSON object.
32+
33+
Returns:
34+
A dictionary of extra model kwargs derived from the quantization config.
35+
Implementations must also populate self._quant_config with the normalized
36+
quantization config.
37+
"""
38+
pass
39+
40+
@classmethod
41+
@abstractmethod
42+
def from_file(cls, file_path: str) -> Optional[Tuple["QuantConfigReader", Dict[str, Any]]]:
43+
"""
44+
Load and parse a quantization config file from disk.
45+
46+
This method is implemented by each reader to handle loading and parsing logic.
47+
48+
Args:
49+
file_path: Path to the quant config JSON file.
50+
51+
Returns:
52+
A (reader, extra_model_kwargs) tuple, or None if the file doesn't exist.
53+
"""
54+
pass
55+
56+
57+
class QuantConfigReaderRegistry:
58+
_registry: Dict[str, Type[QuantConfigReader]] = {}
59+
60+
@classmethod
61+
def register(cls, name: str) -> Callable[[Type[QuantConfigReader]], Type[QuantConfigReader]]:
62+
def inner(reader_cls: Type[QuantConfigReader]) -> Type[QuantConfigReader]:
63+
cls._registry[name] = reader_cls
64+
return reader_cls
65+
66+
return inner
67+
68+
@classmethod
69+
def get(cls, name: str) -> Type[QuantConfigReader]:
70+
if name not in cls._registry:
71+
raise ValueError(f"QuantConfigReader for '{name}' not registered.")
72+
return cls._registry[name]
73+
74+
@classmethod
75+
def has(cls, reader_cls: str) -> bool:
76+
return reader_cls in cls._registry
77+
78+
79+
@QuantConfigReaderRegistry.register("modelopt")
80+
class ModelOPTQuantConfigReader(QuantConfigReader):
81+
def read_config(self, config: Dict) -> Dict:
82+
producer = config.get("producer", {}).get("name")
83+
# sanity check
84+
if producer != "modelopt":
85+
raise ValueError(f"Expected producer 'modelopt', got '{producer}'")
86+
87+
quant_config = config.get("quantization", {})
88+
# Inject default exclusion, add "model.embed_tokens" for "tie_word_embedding:true" case
89+
quant_config.setdefault("exclude_modules", ["lm_head", "model.embed_tokens"])
90+
# Update dtype
91+
if quant_config.get("quant_algo") == "NVFP4":
92+
quant_config["torch_dtype"] = "float16"
93+
94+
# Handle kv cache
95+
kv_algo = quant_config.get("kv_cache_quant_algo")
96+
if kv_algo:
97+
if kv_algo != "FP8":
98+
raise ValueError(f"KV cache quantization format {kv_algo} not supported.")
99+
quant_config["kv_cache_dtype"] = "float8_e4m3fn"
100+
101+
self._quant_config = quant_config
102+
103+
extra_model_kwargs: Dict[str, Any] = {}
104+
if quant_config.get("quant_algo", None) == "NVFP4":
105+
extra_model_kwargs["torch_dtype"] = "float16"
106+
107+
return extra_model_kwargs
108+
109+
@classmethod
110+
def from_file(
111+
cls, ckpt_dir: str
112+
) -> Optional[Tuple["ModelOPTQuantConfigReader", Dict[str, Any]]]:
113+
"""
114+
Load and parse a modelopt-style quantization config from a checkpoint directory.
115+
116+
Args:
117+
ckpt_dir: Path to the root directory containing the checkpoint.
118+
119+
Returns:
120+
An initialized ModelOPTQuantConfigReader instance, or None if the file doesn't exist.
121+
"""
122+
quant_file = os.path.join(ckpt_dir, "hf_quant_config.json")
123+
if not os.path.exists(quant_file):
124+
return None
125+
126+
with open(quant_file, "r") as f:
127+
raw = json.load(f)
128+
reader = cls()
129+
extra_model_kwargs = reader.read_config(raw)
130+
return reader, extra_model_kwargs

tensorrt_llm/_torch/auto_deploy/transform/interface.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from ..shim.interface import CachedSequenceInterface
1616
from ..transformations._graph import canonicalize_graph, lift_to_meta
1717
from ..utils.logger import ad_logger
18+
from ..utils.sharding_utils import ShardingConfig
1819

1920

2021
class TransformError(Exception):
@@ -47,6 +48,14 @@ def __lt__(self, other):
4748
return NotImplemented
4849

4950

51+
class SharedConfig(BaseModel):
52+
"""Global config shared between multiple transforms in the inference optimizer."""
53+
54+
sharding_config: ShardingConfig = Field(default_factory=ShardingConfig)
55+
local_rank: int = Field(default=0)
56+
world_size: int = Field(default=1)
57+
58+
5059
class TransformConfig(BaseModel):
5160
"""A simple configuration class that can be extended by a transform for configurability."""
5261

@@ -190,14 +199,19 @@ def from_kwargs(cls, **kwargs) -> "BaseTransform":
190199

191200
@final
192201
def __call__(
193-
self, gm: GraphModule, cm: CachedSequenceInterface, factory: ModelFactory
202+
self,
203+
gm: GraphModule,
204+
cm: CachedSequenceInterface,
205+
factory: ModelFactory,
206+
shared_config: SharedConfig,
194207
) -> GraphModule:
195208
"""Apply the transform to the graph.
196209
197210
Args:
198211
gm: The graph module to apply the transform to.
199212
cm: The cached sequence interface defining the sequence interface.
200213
factory: The model factory used to build the model.
214+
shared_config: Global info shared between multiple transforms.
201215
202216
Returns:
203217
GraphModule: The transformed graph module.
@@ -232,14 +246,14 @@ def __call__(
232246
# run the transform in a error-handling wrapper if desired
233247
if self.config.skip_on_error:
234248
try:
235-
gm, info = self._apply(gm, cm, factory)
249+
gm, info = self._apply(gm, cm, factory, shared_config)
236250
except Exception as e:
237251
error_msg = f"Transform {t_name} failed"
238252
ad_logger.warning(f"{error_msg}: {e}")
239253
info = TransformInfo(skipped=True, num_matches=0)
240254
else:
241255
# handle this here normally to improve debugging and error message
242-
gm, info = self._apply(gm, cm, factory)
256+
gm, info = self._apply(gm, cm, factory, shared_config)
243257

244258
# we cannot say it's clean if the previous wasn't clean even if this one is
245259
# create new info object with updated cleanup status
@@ -346,7 +360,11 @@ def _run_post_cleanup(self, gm: GraphModule, info: TransformInfo) -> TransformIn
346360

347361
@abstractmethod
348362
def _apply(
349-
self, gm: GraphModule, cm: CachedSequenceInterface, factory: ModelFactory
363+
self,
364+
gm: GraphModule,
365+
cm: CachedSequenceInterface,
366+
factory: ModelFactory,
367+
shared_config: SharedConfig,
350368
) -> Tuple[GraphModule, TransformInfo]:
351369
"""Apply the transform to the graph.
352370

0 commit comments

Comments
 (0)