Skip to content

Commit f3315e4

Browse files
committed
new design
Signed-off-by: Shiyang Chen <shiychen@nvidia.com>
1 parent acd9822 commit f3315e4

8 files changed

Lines changed: 496 additions & 229 deletions

File tree

modelopt/torch/quantization/config.py

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -716,47 +716,48 @@ class MaxCalibConfig(QuantizeAlgorithmConfig):
716716
),
717717
)
718718

719-
shared_patterns: dict[str, list[str]] | None = ModeloptField(
719+
shared_states: dict[str, dict[str, list[str]]] | None = ModeloptField(
720720
default=None,
721-
title="Regex patterns for groups that share quantization state",
721+
title="Concrete shared quantization states and their grouping patterns",
722722
description=(
723-
"Optional dict keyed by quantizer kind (``'weight'`` and/or ``'input'``), each a list "
724-
"of regexes matched (full-match) against module fully-qualified names. They must list "
725-
"every group you want for that kind. Modules whose match yields the same capture-group "
726-
"tuple form one group; the capture boundary chooses granularity: capture the immediate "
727-
"parent for per-parent / per-expert groups (e.g. ``r'(.*)\\.(?:q_proj|k_proj|v_proj)'``, "
728-
"``r'(.*)\\.(?:w1|w3)'``); leave the expert index uncaptured for one cross-expert group "
729-
"(``r'(.*)\\.experts\\.\\d+\\.(?:w1|w3)'``). Only ``'weight'`` is used today; ``'input'`` is "
730-
"reserved for future input-quantizer sharing. When the ``'weight'`` list is omitted, "
731-
"the default fusible patterns (q/k/v, gate/up, w1/w3) are used — these match exactly "
732-
"the sibling groups export fuses, avoiding the over-grouping a shared-input heuristic "
733-
"would cause (e.g. a ``shared_expert_gate`` that reads the same input but is not fused)."
723+
"Optional dict keyed by shared-state name. ``'weight_global_amax'`` is implemented "
724+
"today and accepts ``{'patterns': [...]}``, where patterns are full-match regexes "
725+
"against module fully-qualified names. Omitted patterns use the state's defaults; "
726+
"an empty pattern list disables that state."
734727
),
735728
)
736729

737-
@field_validator("shared_patterns")
730+
@field_validator("shared_states")
738731
@classmethod
739-
def validate_shared_patterns(cls, v):
740-
"""Reject unknown quantizer kinds and invalid regexes at the config boundary."""
732+
def validate_shared_states(cls, v):
733+
"""Reject unknown shared-state names, fields, and invalid regexes."""
741734
if v is None:
742735
return v
743-
supported = {"weight", "input"}
736+
supported = {"weight_global_amax"}
744737
unknown = set(v) - supported
745738
if unknown:
746739
raise ValueError(
747-
f"shared_patterns has unsupported quantizer kind(s) {sorted(unknown)}; "
740+
f"shared_states has unsupported state(s) {sorted(unknown)}; "
748741
f"expected keys from {sorted(supported)}."
749742
)
750-
offending = ("", "") # (kind, pattern) of the last regex tried; set before each compile
743+
744+
offending = ("", "")
751745
try:
752-
for kind, patterns in v.items():
753-
for pattern in patterns:
754-
offending = (kind, pattern)
746+
for state_name, state_cfg in v.items():
747+
unknown_fields = set(state_cfg) - {"patterns"}
748+
if unknown_fields:
749+
raise ValueError(
750+
f"shared_states[{state_name!r}] has unsupported field(s) "
751+
f"{sorted(unknown_fields)}; expected ['patterns']."
752+
)
753+
for pattern in state_cfg.get("patterns", []):
754+
offending = (state_name, pattern)
755755
re.compile(pattern)
756756
except re.error as e:
757-
bad_kind, bad_pattern = offending
757+
bad_state, bad_pattern = offending
758758
raise ValueError(
759-
f"shared_patterns[{bad_kind!r}] has an invalid regex {bad_pattern!r}: {e}"
759+
f"shared_states[{bad_state!r}]['patterns'] has an invalid regex "
760+
f"{bad_pattern!r}: {e}"
760761
) from e
761762
return v
762763

modelopt/torch/quantization/conversion.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@
4545
TensorQuantizer,
4646
)
4747
from .utils import is_quantized, is_quantized_linear
48+
from .utils.shared_input import (
49+
DEFAULT_WEIGHT_SHARED_PATTERNS,
50+
rebuild_shared_quant_states,
51+
resolve_weight_global_amax_patterns,
52+
shared_quant_states_metadata,
53+
)
4854

4955
__all__ = [
5056
"register",
@@ -105,6 +111,24 @@ def maybe_promote_nvfp4_static_quantizer(module: nn.Module, quantizer_state: dic
105111
NVFP4StaticQuantizer.from_tensor_quantizer(module)
106112

107113

114+
def _restore_shared_quant_state_aliases(
115+
model: nn.Module, config: QuantizeConfig, metadata: MetadataDict
116+
) -> None:
117+
"""Rebuild shared-state ties before checkpoint tensor values are loaded."""
118+
if not metadata.get("shared_quant_states"):
119+
return
120+
method = getattr(config, "method", None)
121+
if method == "max":
122+
patterns = resolve_weight_global_amax_patterns(
123+
shared_states=getattr(config, "shared_states", None)
124+
)
125+
elif method in {"mse", "local_hessian"}:
126+
patterns = DEFAULT_WEIGHT_SHARED_PATTERNS
127+
else:
128+
return
129+
rebuild_shared_quant_states(model, patterns=patterns)
130+
131+
108132
def restore_quantizer_state(model: nn.Module, config: QuantizeConfig, metadata: MetadataDict):
109133
"""Restore the quantizer states from the given state dict.
110134
@@ -146,6 +170,8 @@ def restore_quantizer_state(model: nn.Module, config: QuantizeConfig, metadata:
146170
name = get_unwrapped_name(name, model)
147171
module.modelopt_post_restore(name)
148172

173+
_restore_shared_quant_state_aliases(model, config, metadata)
174+
149175
return model
150176

151177

@@ -176,6 +202,10 @@ def update_quantize_metadata(
176202
) -> None:
177203
"""Update the quantizer state in the metadata dict."""
178204
metadata["quantizer_state"] = quantizer_state(model)
205+
if shared_state_metadata := shared_quant_states_metadata(model):
206+
metadata["shared_quant_states"] = shared_state_metadata
207+
else:
208+
metadata.pop("shared_quant_states", None)
179209

180210

181211
def quantizer_state(model: nn.Module) -> dict[str, Any]:

modelopt/torch/quantization/model_calib.py

Lines changed: 16 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
populate_shared_state,
5555
promote_nvfp4_static_quantizers,
5656
reduce_amax,
57+
resolve_weight_global_amax_patterns,
5758
)
5859
from .utils.calib_utils import _GPTQ_HELPER_REGISTRY, GPTQHelper
5960

@@ -112,16 +113,16 @@ def _collect_grouped_linears(model: nn.Module) -> list[list[nn.Module]]:
112113

113114
@torch.no_grad()
114115
def _check_grouped_weight_global_amax_synced(model: nn.Module) -> None:
115-
"""Verify SharedQuantState unified each name-based fusible group's weight global_amax.
116+
"""Verify shared NVFP4 state unified each name-based fusible group's weight global_amax.
116117
117118
The legacy name-based grouping (Q/K/V, gate/up, w1/w3) is kept here as a *check*
118119
rather than performed: after attach/populate/promote, the promoted static-NVFP4 weight
119120
quantizers in each name group must already share one ``global_amax``. This catches the
120-
SharedQuantState path failing to form or sync a group it should have (e.g. a
121+
SharedNVFP4GlobalAmaxState path failing to form or sync a group it should have (e.g. a
121122
:data:`DEFAULT_WEIGHT_SHARED_PATTERNS` regression, or an architecture the regexes miss)
122123
before the MSE per-block search — computed against ``global_amax`` — bakes in the
123124
inconsistency. Run only when the default patterns are in effect (custom
124-
``shared_patterns`` may intentionally group differently). Members whose ``global_amax``
125+
``shared_states`` may intentionally group differently). Members whose ``global_amax``
125126
is not materialized (``None``/meta, e.g. an ``init_empty_weights`` model) are skipped.
126127
"""
127128
for group in _collect_grouped_linears(model):
@@ -132,8 +133,8 @@ def _check_grouped_weight_global_amax_synced(model: nn.Module) -> None:
132133
ref = amaxes[0]
133134
assert all(torch.equal(a, ref) for a in amaxes), (
134135
"A fusible sibling group (q/k/v or gate/up) was not unified to a shared weight "
135-
"global_amax; SharedQuantState failed to sync it, so the per-block MSE scales "
136-
"would be inconsistent across the group."
136+
"global_amax; SharedNVFP4GlobalAmaxState failed to sync it, so the per-block "
137+
"MSE scales would be inconsistent across the group."
137138
)
138139

139140

@@ -148,7 +149,7 @@ def _finalize_with_shared_state(model: nn.Module, weight_patterns: list[str]) ->
148149
populate_shared_state(model)
149150
promote_nvfp4_static_quantizers(model)
150151
# Under the default patterns, verify the fusible name groups were actually synced.
151-
if weight_patterns is DEFAULT_WEIGHT_SHARED_PATTERNS:
152+
if weight_patterns == DEFAULT_WEIGHT_SHARED_PATTERNS:
152153
_check_grouped_weight_global_amax_synced(model)
153154

154155

@@ -264,7 +265,7 @@ def max_calibrate(
264265
forward_loop: ForwardLoop | None = None,
265266
distributed_sync=True,
266267
sync_expert_weight_amax=False,
267-
shared_patterns: Mapping[str, Sequence[str]] | None = None,
268+
shared_states: Mapping[str, Mapping[str, Sequence[str]]] | None = None,
268269
):
269270
"""Calibrate the model using max.
270271
@@ -275,29 +276,19 @@ def max_calibrate(
275276
distributed_sync: Whether to sync input_quantizer amax across distributed processes.
276277
sync_expert_weight_amax: SequentialMLP only — share one weight amax across all experts
277278
in a MoE layer (within-rank sync + EP all-reduce when EP>1).
278-
shared_patterns: Optional dict keyed by quantizer kind (``"weight"``/``"input"``), each a
279-
list of regexes over module FQNs. When the ``"weight"`` list is omitted,
280-
:data:`DEFAULT_WEIGHT_SHARED_PATTERNS` (q/k/v, gate/up, w1/w3) is used. Modules whose
281-
regex match yields the same capture-group tuple form one group — capture the immediate
282-
parent for per-parent (per-expert) grouping, or leave the expert index uncaptured for
283-
cross-expert. Only ``"weight"`` is used today; ``"input"`` is reserved for future
284-
input-quantizer sharing.
279+
shared_states: Optional dict keyed by shared-state name. ``"weight_global_amax"`` is
280+
implemented today and accepts ``{"patterns": [...]}``; omitted patterns use
281+
:data:`DEFAULT_WEIGHT_SHARED_PATTERNS`, while an empty list disables the state.
285282
286283
See :class:`MaxCalibConfig <modelopt.torch.quantization.config.MaxCalibConfig>` for
287284
details on the remaining arguments.
288285
"""
289286
# Discover fusible sibling groups by name regex and attach the (initially empty) shared
290-
# state up front, so the SharedQuantState container exists for the whole calibration —
291-
# forward-time fields can accumulate into it. Discovery is structural (a pattern over the
292-
# module tree), so it needs no ``_amax``; per-member values are aggregated later by
293-
# populate_shared_state, after the forward and any cross-rank ``_amax`` sync. Default to
294-
# q/k/v + gate/up when no "weight" key is given; an explicit (possibly empty) list
295-
# overrides it — key presence, not truthiness, so {"weight": []} disables grouping.
296-
# Only "weight" is consumed today; "input" is reserved.
297-
if shared_patterns is not None and "weight" in shared_patterns:
298-
weight_patterns = list(shared_patterns["weight"])
299-
else:
300-
weight_patterns = DEFAULT_WEIGHT_SHARED_PATTERNS
287+
# state up front, so parent-level runtime hooks can be installed by future concrete
288+
# states. Discovery is structural (a pattern over the module tree), so it needs no
289+
# ``_amax``; per-member values are aggregated later by populate_shared_state, after the
290+
# forward and any cross-rank ``_amax`` sync.
291+
weight_patterns = resolve_weight_global_amax_patterns(shared_states=shared_states)
301292
attach_shared_quant_states(model, patterns=weight_patterns)
302293

303294
# Always run weight calibration on the weight tensor directly so every weight

modelopt/torch/quantization/nn/modules/tensor_quantizer.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -172,12 +172,25 @@ class TensorQuantizer(nn.Module):
172172
"pre_bwd_fn",
173173
# quantizer cache for custom backends, like luts
174174
"_quantizer_cache",
175-
# Runtime-only back-reference to a sibling group's SharedQuantState; it is
176-
# re-established during calibration and must not be serialized (it points to a
177-
# live module whose dynamic QuantLinear members are not picklable).
178-
"_shared_quant_state_ref",
175+
# Runtime-only references to concrete shared-state owners; they are re-established
176+
# during calibration and must not be serialized.
177+
"_shared_quant_state_refs",
178+
# Runtime-only set of storage attributes tied to shared state. The tied
179+
# aliases are rebuilt from calibration config and tensor state during restore.
180+
"_shared_quant_tied_attrs",
179181
}
180182

183+
def __setattr__(self, name, value):
184+
tied = self.__dict__.get("_shared_quant_tied_attrs", set())
185+
if name in tied:
186+
current = self._buffers.get(name, None) if "_buffers" in self.__dict__ else None
187+
if value is not current:
188+
raise RuntimeError(
189+
f"{name} is tied shared quant state; update it in-place or replace it "
190+
"through the owning shared-state object."
191+
)
192+
return super().__setattr__(name, value)
193+
181194
def __init__(
182195
self,
183196
quant_attribute_cfg=None,
@@ -1368,8 +1381,11 @@ def _preserve_amax_in_fp32(self):
13681381
if amax is not None:
13691382
self._amax = amax.to(dtype=torch.float32)
13701383
global_amax = getattr(self, "_global_amax", None)
1371-
if global_amax is not None:
1372-
self._global_amax = global_amax.to(dtype=torch.float32)
1384+
if global_amax is not None and global_amax.dtype != torch.float32:
1385+
if "_global_amax" in self.__dict__.get("_shared_quant_tied_attrs", set()):
1386+
global_amax.data = global_amax.to(dtype=torch.float32)
1387+
else:
1388+
self._global_amax = global_amax.to(dtype=torch.float32)
13731389

13741390
def _amax_setter_helper(self, value):
13751391
super()._amax_setter_helper(value)

modelopt/torch/quantization/utils/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
__all__ = [
2424
"DEFAULT_WEIGHT_SHARED_PATTERNS",
2525
"EXPORT_MODE",
26-
"SharedQuantState",
26+
"SharedNVFP4GlobalAmaxState",
2727
"attach_shared_quant_states",
2828
"convert_quantization_axis_to_reduce_axis",
2929
"export_torch_mode",
@@ -32,11 +32,15 @@
3232
"is_quantized_column_parallel_linear",
3333
"is_quantized_linear",
3434
"is_quantized_row_parallel_linear",
35+
"iter_shared_quant_states",
3536
"populate_shared_state",
37+
"rebuild_shared_quant_states",
3638
"reduce_amax",
3739
"reduce_sum",
3840
"replace_function",
3941
"representative_weight_quantizer",
42+
"resolve_weight_global_amax_patterns",
43+
"shared_quant_states_metadata",
4044
"update_quant_cfg_with_kv_cache_quant",
4145
"weight_attr_names",
4246
]

modelopt/torch/quantization/utils/core_utils.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -954,24 +954,26 @@ def promote_nvfp4_static_quantizers(model: nn.Module) -> int:
954954
need to be promoted so they use the two-level scaling path (global amax +
955955
per-block amax) instead of the generic E4M3 path.
956956
957-
If the quantizer has a ``_shared_quant_state_ref`` with a populated
958-
``weight_global_amax`` (sibling group) whose owning state lives within ``model``,
959-
that shared value is used instead of this quantizer's own ``_amax`` reduction,
960-
keeping siblings on a common FP8 grid.
957+
If the quantizer has a shared-state reference with a populated
958+
``global_amax`` (sibling group) whose owning state lives within ``model``, the
959+
promoted quantizer's ``_global_amax`` buffer is tied to that canonical state
960+
buffer instead of receiving an independent copy.
961961
962962
Returns the number of quantizers converted.
963963
"""
964964
from modelopt.torch.quantization.nn import NVFP4StaticQuantizer, TensorQuantizer
965+
from modelopt.torch.quantization.utils.shared_input import (
966+
SharedNVFP4GlobalAmaxState,
967+
iter_shared_quant_states,
968+
)
965969

966970
# Shared states owned within THIS promotion root. This function also runs on
967971
# submodules / individual linears; a quantizer may still carry a back-reference from
968-
# an earlier full-model calibration whose owning ``_shared_quant_state`` is outside
969-
# ``model``. Only trust refs reachable here — otherwise the global_amax would come
972+
# an earlier full-model calibration whose owning state is outside ``model``. Only
973+
# trust refs reachable here — otherwise the global_amax would come
970974
# from an unrelated prior run; fall back to the quantizer's own amax instead.
971975
valid_shared_states = {
972-
id(state)
973-
for owner in model.modules()
974-
if (state := getattr(owner, "_shared_quant_state", None)) is not None
976+
id(state) for state in iter_shared_quant_states(model, SharedNVFP4GlobalAmaxState)
975977
}
976978

977979
converted = 0
@@ -984,19 +986,23 @@ def promote_nvfp4_static_quantizers(model: nn.Module) -> int:
984986
if amax is None:
985987
continue
986988

987-
# Grouped siblings share one ``weight_global_amax`` (common FP8 grid);
988-
# otherwise fall back to this quantizer's own per-block amax.
989+
# Grouped siblings share one canonical global_amax (common FP8 grid); otherwise
990+
# fall back to this quantizer's own per-block amax.
989991
already_promoted = isinstance(module, NVFP4StaticQuantizer)
990-
shared = getattr(module, "_shared_quant_state_ref", None)
992+
shared_refs = module.__dict__.get("_shared_quant_state_refs", {})
993+
shared = shared_refs.get(SharedNVFP4GlobalAmaxState.state_name)
991994
if (
992-
shared is not None
995+
isinstance(shared, SharedNVFP4GlobalAmaxState)
993996
and id(shared) in valid_shared_states
994-
and shared.weight_global_amax is not None
997+
and shared.global_amax is not None
995998
):
996-
global_amax = shared.weight_global_amax
999+
NVFP4StaticQuantizer.from_tensor_quantizer(module)
1000+
shared.tie_member_quantizer(module)
9971001
else:
1002+
if isinstance(shared, SharedNVFP4GlobalAmaxState):
1003+
shared.untie_member_quantizer(module)
9981004
global_amax = reduce_amax(amax.clone().detach(), axis=None)
999-
NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax)
1005+
NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax)
10001006
if not already_promoted:
10011007
converted += 1
10021008
return converted

0 commit comments

Comments
 (0)