Skip to content

Commit c758197

Browse files
committed
Refine quantization config schema handling
Signed-off-by: Shengliang Xu <shengliangx@nvidia.com>
1 parent d9eccf3 commit c758197

5 files changed

Lines changed: 231 additions & 44 deletions

File tree

modelopt/torch/quantization/algorithms.py

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -104,39 +104,40 @@ def estimate_quant_compression_for_quantizer(quantizer_attr_cfg):
104104

105105

106106
QuantRecipeConfig = str | Mapping[str, Any] | QuantizeConfig | None
107+
QuantizationFormatConfig = QuantRecipeConfig
108+
NamedQuantRecipeConfig = tuple[str | Mapping[str, Any] | QuantizeConfig, str]
107109

108110

109111
class QuantRecipe(CustomHPType):
110112
"""A subclass of QuantizeConfig enabling auto_quantize specific configurations.
111113
112114
Args:
113115
quant_cfg: str, QuantizeConfig, mapping, or None. A mapping is used for custom quantization formats.
114-
name: name for custom quantization formats. Only used if quantization format is a custom
115-
format not available in :mod:`modelopt.torch.quantization.config`.
116+
name: Required display/search name when ``quant_cfg`` is not ``None``. Must be
117+
``None`` when ``quant_cfg=None``, which uses the built-in ``"NONE"`` recipe name.
116118
"""
117119

118120
def __init__(self, quant_cfg: QuantRecipeConfig = None, name: str | None = None):
119121
"""Initialize the QuantRecipe with the quantization configuration."""
120-
name = self.get_auto_name_for_config(quant_cfg) or name
121-
122122
if quant_cfg is None:
123+
if name is not None:
124+
raise ValueError("name must be None when quant_cfg is None")
125+
name = "NONE"
123126
self.config = mtq_config.QuantizeConfig(
124127
quant_cfg=[mtq_config.QuantizerCfgEntry(quantizer_name="*", enable=False)]
125128
)
126129
else:
130+
if name is None:
131+
raise ValueError("name must be provided when quant_cfg is not None")
127132
if isinstance(quant_cfg, str):
128133
assert hasattr(mtq_config, quant_cfg), f"Unknown quantization format {quant_cfg}"
129134
quant_cfg = getattr(mtq_config, quant_cfg)
130-
elif not isinstance(quant_cfg, QuantizeConfig) and name is None:
131-
raise ValueError("name must be provided for custom quantization formats")
132135

133136
self.config = (
134137
quant_cfg.model_copy(deep=True)
135138
if isinstance(quant_cfg, QuantizeConfig)
136139
else mtq_config.QuantizeConfig.model_validate(quant_cfg)
137140
)
138-
if name is None:
139-
raise ValueError("name must be provided for custom quantization formats")
140141

141142
# Disable KV Cache quantization
142143
# Currently KV Cache quantization is enabled for some quantization formats and disabled for others
@@ -211,6 +212,31 @@ def fold_pqs_to_weights(model):
211212
model_calib._apply_weight_pre_quant_scale(module, weight_pqs)
212213

213214

215+
def _validate_named_auto_quantize_formats(
216+
quantization_formats: Any,
217+
):
218+
"""Validate the internal AutoQuantize format protocol."""
219+
error_msg = (
220+
"`quantization_formats` must be a list of (quant_cfg, name) tuples. "
221+
"Normalize public inputs before calling the AutoQuantize searcher."
222+
)
223+
if not isinstance(quantization_formats, list):
224+
raise TypeError(error_msg)
225+
226+
for entry in quantization_formats:
227+
if not isinstance(entry, tuple) or len(entry) != 2:
228+
raise TypeError(error_msg)
229+
230+
quant_cfg, name = entry
231+
if quant_cfg is None or not isinstance(quant_cfg, str | QuantizeConfig | Mapping):
232+
raise TypeError(
233+
"Each named quantization format must contain a string, mapping, or "
234+
"QuantizeConfig as the first tuple item."
235+
)
236+
if not isinstance(name, str) or not name:
237+
raise TypeError("Each named quantization format must provide a non-empty name.")
238+
239+
214240
class QuantRecipeHparam(Hparam):
215241
"""An Hparam for quantization recipes.
216242
@@ -231,7 +257,7 @@ def __init__(
231257
quant_module_names: list[str] | None = None,
232258
) -> None:
233259
"""Initializes Hparam with original value and choices."""
234-
choices = sorted({*(choices if choices else []), QuantRecipe(quant_cfg=None)})
260+
choices = sorted({*(choices or []), QuantRecipe(quant_cfg=None)})
235261
super().__init__(choices, original=choices[0])
236262

237263
self.name = name
@@ -398,7 +424,10 @@ class _AutoQuantizeBaseSearcher(BaseSearcher, ABC):
398424
def default_search_config(self):
399425
"""Get the default config for the searcher."""
400426
return {
401-
"quantization_formats": ["NVFP4_DEFAULT_CFG", "FP8_DEFAULT_CFG"],
427+
"quantization_formats": [
428+
("NVFP4_DEFAULT_CFG", "NVFP4_DEFAULT_CFG"),
429+
("FP8_DEFAULT_CFG", "FP8_DEFAULT_CFG"),
430+
],
402431
"data_loader": None,
403432
"num_calib_steps": 512,
404433
"num_score_steps": 128,
@@ -428,6 +457,7 @@ def sanitize_search_config(self, config: SearchConfig | None) -> SearchConfig:
428457
assert config["forward_step"] is not None, (
429458
"`forward_step` must be provided for `auto_quantize`."
430459
)
460+
_validate_named_auto_quantize_formats(config["quantization_formats"])
431461
return config
432462

433463
def load_search_checkpoint(self) -> bool:
@@ -440,13 +470,11 @@ def _is_auto_quantize_module(module):
440470
) and isinstance(module, QuantModule)
441471

442472
@staticmethod
443-
def _get_search_recipes(quantization_formats):
473+
def _get_search_recipes(quantization_formats: Sequence[NamedQuantRecipeConfig]):
444474
return sorted(
445475
{
446-
QuantRecipe(quant_cfg=q[0], name=q[1])
447-
if isinstance(q, tuple)
448-
else QuantRecipe(quant_cfg=q)
449-
for q in quantization_formats
476+
QuantRecipe(quant_cfg=quant_cfg, name=name)
477+
for quant_cfg, name in quantization_formats
450478
}
451479
)
452480

modelopt/torch/quantization/model_quant.py

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import inspect
2020
import os
2121
import warnings
22-
from collections.abc import Callable, Iterable, Mapping
22+
from collections.abc import Callable, Iterable, Mapping, Sequence
2323
from typing import Any
2424

2525
import torch
@@ -36,7 +36,13 @@
3636
)
3737
from modelopt.torch.utils import atomic_print
3838

39-
from .algorithms import AutoQuantizeGradientSearcher, AutoQuantizeKLDivSearcher, QuantRecipe
39+
from .algorithms import (
40+
AutoQuantizeGradientSearcher,
41+
AutoQuantizeKLDivSearcher,
42+
NamedQuantRecipeConfig,
43+
QuantizationFormatConfig,
44+
QuantRecipe,
45+
)
4046
from .algorithms import get_auto_quantize_config as _get_auto_quantize_config
4147
from .config import QuantizeAlgoCfgType
4248
from .mode import QuantizeModeRegistry, get_modelike_from_algo_cfg
@@ -270,10 +276,38 @@ def forward_loop(model) -> None:
270276
}
271277

272278

279+
def _normalize_auto_quantize_formats(
280+
quantization_formats: Sequence[QuantizationFormatConfig],
281+
) -> list[NamedQuantRecipeConfig]:
282+
"""Normalize public auto_quantize format inputs into named search entries."""
283+
processed_quantization_formats: list[NamedQuantRecipeConfig] = []
284+
for i, quant_cfg in enumerate(quantization_formats):
285+
if quant_cfg is None:
286+
continue
287+
if isinstance(quant_cfg, tuple):
288+
raise TypeError(
289+
"Named quantization format tuples are internal to AutoQuantize search; "
290+
"pass raw configs to auto_quantize()."
291+
)
292+
293+
name = QuantRecipe.get_auto_name_for_config(quant_cfg)
294+
if name is None:
295+
name = f"CUSTOM_{i}"
296+
warnings.warn(
297+
"Received custom quantization formats for search, auto_quantize results "
298+
f"may not be optimal. This config will be displayed as {name}"
299+
)
300+
301+
processed_quantization_formats.append((quant_cfg, name))
302+
303+
assert len(processed_quantization_formats) > 0, "`quantization_formats` should not be empty"
304+
return processed_quantization_formats
305+
306+
273307
def auto_quantize(
274308
model: nn.Module,
275309
constraints: dict[str, float | str] = {"effective_bits": 4.8},
276-
quantization_formats: list[QuantizeConfig | Mapping[str, Any] | str | None] = [
310+
quantization_formats: list[QuantizationFormatConfig] = [
277311
mtq.NVFP4_AWQ_LITE_CFG,
278312
mtq.FP8_DEFAULT_CFG,
279313
],
@@ -319,6 +353,7 @@ def auto_quantize(
319353
Each config dictionary should be valid as a ``config`` argument in
320354
:meth:`quantize <modelopt.torch.quantization.model_quant.quantize>`.
321355
The supported quantization format names are as listed by :attr:`modelopt.torch.quantization.config.choices`.
356+
Custom configs without a built-in name are assigned ``CUSTOM_<index>`` display names internally.
322357
323358
Internally we always add "do not quantize" as a choice. Therefore, it is possible that a layer is
324359
not quantized by any of the quantization formats.
@@ -484,21 +519,7 @@ def forward_backward_step(model, batch) -> None:
484519
might not be readily deployable to TensorRT-LLM yet.
485520
486521
"""
487-
processed_quantization_formats = []
488-
for i, quant_cfg in enumerate(quantization_formats):
489-
if quant_cfg is None:
490-
continue
491-
492-
name = QuantRecipe.get_auto_name_for_config(quant_cfg)
493-
if name is None:
494-
name = f"CUSTOM_{i}"
495-
warnings.warn(
496-
f"Received custom quantization formats for search, auto_quantize results may not be optimal. "
497-
f"This config will be displayed as {name}"
498-
)
499-
processed_quantization_formats.append((quant_cfg, name))
500-
501-
assert len(processed_quantization_formats) > 0, "`quantization_formats` should not be empty"
522+
processed_quantization_formats = _normalize_auto_quantize_formats(quantization_formats)
502523

503524
for quant_cfg, name in processed_quantization_formats:
504525
algo = QuantRecipe(quant_cfg, name=name).config.algorithm

tests/unit/torch/quantization/test_autoquant.py

Lines changed: 73 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,13 @@
2424
import modelopt.torch.opt as mto
2525
import modelopt.torch.quantization as mtq
2626
from modelopt.torch.quantization.algorithms import (
27+
AutoQuantizeGradientSearcher,
2728
QuantRecipe,
2829
QuantRecipeHparam,
2930
estimate_quant_compression,
3031
)
3132
from modelopt.torch.quantization.config import _base_disable_all, _default_disabled_quantizer_cfg
33+
from modelopt.torch.quantization.model_quant import _normalize_auto_quantize_formats
3234
from modelopt.torch.utils import safe_load
3335
from modelopt.torch.utils.distributed import DistributedProcessGroup
3436

@@ -62,6 +64,11 @@ def get_input(self):
6264
return torch.randn(1, 4, 32)
6365

6466

67+
def _recipe(quant_cfg):
68+
name = None if quant_cfg is None else QuantRecipe.get_auto_name_for_config(quant_cfg)
69+
return QuantRecipe(quant_cfg, name=name)
70+
71+
6572
@pytest.mark.parametrize(
6673
("quant_cfg", "other_quant_cfg", "is_less_than"),
6774
[
@@ -71,30 +78,85 @@ def get_input(self):
7178
],
7279
)
7380
def test_quant_recipe(quant_cfg, other_quant_cfg, is_less_than):
74-
qr_this = QuantRecipe(quant_cfg)
75-
qr_other = QuantRecipe(other_quant_cfg)
81+
qr_this = _recipe(quant_cfg)
82+
qr_other = _recipe(other_quant_cfg)
7683
assert (qr_this < qr_other) == is_less_than
7784

78-
qr_this_duplicate = QuantRecipe(quant_cfg)
85+
qr_this_duplicate = _recipe(quant_cfg)
7986
assert qr_this_duplicate in {qr_this}
8087

8188

82-
def test_quant_recipe_custom_quantize_config_requires_name():
83-
custom_cfg = mtq.QuantizeConfig(
89+
def _custom_quantize_config(path):
90+
return mtq.QuantizeConfig(
8491
quant_cfg=[
8592
mtq.QuantizerCfgEntry(
86-
quantizer_name="*weight_quantizer",
93+
quantizer_name=path,
8794
cfg=mtq.QuantizerAttributeConfig(num_bits=8, axis=None),
8895
)
8996
]
9097
)
9198

99+
100+
def test_quant_recipe_custom_quantize_config_requires_name():
101+
custom_cfg = _custom_quantize_config("*custom_weight_quantizer")
102+
92103
with pytest.raises(ValueError, match="name must be provided"):
93104
QuantRecipe(custom_cfg)
94105

95106
assert str(QuantRecipe(custom_cfg, name="custom_cfg")).startswith("custom_cfg(")
96107

97108

109+
def test_quant_recipe_none_requires_no_name():
110+
assert str(QuantRecipe(quant_cfg=None)).startswith("NONE(")
111+
112+
with pytest.raises(ValueError, match="name must be None"):
113+
QuantRecipe(quant_cfg=None, name="NONE")
114+
115+
116+
def test_quant_recipe_honors_explicit_name():
117+
assert str(QuantRecipe(mtq.INT8_DEFAULT_CFG, name="int8_alias")).startswith("int8_alias(")
118+
119+
120+
def test_auto_quantize_search_config_requires_named_formats():
121+
custom_a = _custom_quantize_config("*custom_weight_quantizer_a")
122+
custom_b = _custom_quantize_config("*custom_weight_quantizer_b")
123+
searcher = AutoQuantizeGradientSearcher()
124+
125+
with pytest.warns(UserWarning) as records:
126+
quantization_formats = _normalize_auto_quantize_formats([custom_a, custom_b])
127+
128+
assert quantization_formats == [(custom_a, "CUSTOM_0"), (custom_b, "CUSTOM_1")]
129+
assert any("CUSTOM_0" in str(record.message) for record in records)
130+
assert any("CUSTOM_1" in str(record.message) for record in records)
131+
132+
config = searcher.sanitize_search_config(
133+
{
134+
"quantization_formats": quantization_formats,
135+
"data_loader": [torch.randn(1)],
136+
"forward_step": lambda model, data: data,
137+
"loss_func": lambda output, data: output.sum(),
138+
}
139+
)
140+
assert config["quantization_formats"] == quantization_formats
141+
142+
with pytest.raises(TypeError, match="Named quantization format tuples are internal"):
143+
_normalize_auto_quantize_formats([(custom_a, "custom_a")])
144+
145+
with pytest.raises(TypeError, match="must be a list of"):
146+
searcher.sanitize_search_config(
147+
{
148+
"quantization_formats": [custom_a],
149+
"data_loader": [torch.randn(1)],
150+
"forward_step": lambda model, data: data,
151+
"loss_func": lambda output, data: output.sum(),
152+
}
153+
)
154+
155+
recipes = AutoQuantizeGradientSearcher._get_search_recipes(config["quantization_formats"])
156+
assert {str(recipe).split("(", 1)[0] for recipe in recipes} == {"CUSTOM_0", "CUSTOM_1"}
157+
assert len(set(recipes)) == 2
158+
159+
98160
def test_quant_recipe_hparam():
99161
model_test = torch.nn.Linear(4, 16)
100162
model_ref = torch.nn.Linear(4, 16)
@@ -104,20 +166,20 @@ def test_quant_recipe_hparam():
104166
model_ref = mtq.quantize(model_ref, mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG)
105167

106168
search_recipes = [
107-
QuantRecipe(mtq.INT8_DEFAULT_CFG),
108-
QuantRecipe(mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG),
169+
_recipe(mtq.INT8_DEFAULT_CFG),
170+
_recipe(mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG),
109171
]
110172
hparam = QuantRecipeHparam(
111173
search_recipes,
112174
quant_modules=[model_test],
113175
)
114176
model_test._register_hparam("quant_recipe", hparam)
115-
assert model_test.quant_recipe == QuantRecipe(mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG)
177+
assert model_test.quant_recipe == _recipe(mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG)
116178
assert model_test.get_hparam("quant_recipe").choices == sorted(
117179
[*search_recipes, QuantRecipe(quant_cfg=None)]
118180
)
119181

120-
model_test.quant_recipe = QuantRecipe(mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG)
182+
model_test.quant_recipe = _recipe(mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG)
121183
inputs = torch.randn(1, 4, 4)
122184
output_test = model_test(inputs)
123185
output_ref = model_ref(inputs)
@@ -244,7 +306,7 @@ def test_auto_quantize_disabled_layers_no_poison():
244306

245307
assert not best_model.mlp.input_quantizer.is_enabled
246308
hparam = best_model.attn.q_proj.get_hparam("quant_recipe")
247-
assert QuantRecipe(mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG) in hparam.choices
309+
assert _recipe(mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG) in hparam.choices
248310

249311

250312
INT4INT8_AWQ_CFG = {

0 commit comments

Comments
 (0)