5454 populate_shared_state ,
5555 promote_nvfp4_static_quantizers ,
5656 reduce_amax ,
57+ resolve_weight_global_amax_patterns ,
5758)
5859from .utils .calib_utils import _GPTQ_HELPER_REGISTRY , GPTQHelper
5960
@@ -112,16 +113,16 @@ def _collect_grouped_linears(model: nn.Module) -> list[list[nn.Module]]:
112113
113114@torch .no_grad ()
114115def _check_grouped_weight_global_amax_synced (model : nn .Module ) -> None :
115- """Verify SharedQuantState unified each name-based fusible group's weight global_amax.
116+ """Verify shared NVFP4 state unified each name-based fusible group's weight global_amax.
116117
117118 The legacy name-based grouping (Q/K/V, gate/up, w1/w3) is kept here as a *check*
118119 rather than performed: after attach/populate/promote, the promoted static-NVFP4 weight
119120 quantizers in each name group must already share one ``global_amax``. This catches the
120- SharedQuantState path failing to form or sync a group it should have (e.g. a
121+ SharedNVFP4GlobalAmaxState path failing to form or sync a group it should have (e.g. a
121122 :data:`DEFAULT_WEIGHT_SHARED_PATTERNS` regression, or an architecture the regexes miss)
122123 before the MSE per-block search — computed against ``global_amax`` — bakes in the
123124 inconsistency. Run only when the default patterns are in effect (custom
124- ``shared_patterns `` may intentionally group differently). Members whose ``global_amax``
125+ ``shared_states `` may intentionally group differently). Members whose ``global_amax``
125126 is not materialized (``None``/meta, e.g. an ``init_empty_weights`` model) are skipped.
126127 """
127128 for group in _collect_grouped_linears (model ):
@@ -132,8 +133,8 @@ def _check_grouped_weight_global_amax_synced(model: nn.Module) -> None:
132133 ref = amaxes [0 ]
133134 assert all (torch .equal (a , ref ) for a in amaxes ), (
134135 "A fusible sibling group (q/k/v or gate/up) was not unified to a shared weight "
135- "global_amax; SharedQuantState failed to sync it, so the per-block MSE scales "
136- "would be inconsistent across the group."
136+ "global_amax; SharedNVFP4GlobalAmaxState failed to sync it, so the per-block "
137+ "MSE scales would be inconsistent across the group."
137138 )
138139
139140
@@ -148,7 +149,7 @@ def _finalize_with_shared_state(model: nn.Module, weight_patterns: list[str]) ->
148149 populate_shared_state (model )
149150 promote_nvfp4_static_quantizers (model )
150151 # Under the default patterns, verify the fusible name groups were actually synced.
151- if weight_patterns is DEFAULT_WEIGHT_SHARED_PATTERNS :
152+ if weight_patterns == DEFAULT_WEIGHT_SHARED_PATTERNS :
152153 _check_grouped_weight_global_amax_synced (model )
153154
154155
@@ -264,7 +265,7 @@ def max_calibrate(
264265 forward_loop : ForwardLoop | None = None ,
265266 distributed_sync = True ,
266267 sync_expert_weight_amax = False ,
267- shared_patterns : Mapping [str , Sequence [str ]] | None = None ,
268+ shared_states : Mapping [str , Mapping [ str , Sequence [str ] ]] | None = None ,
268269):
269270 """Calibrate the model using max.
270271
@@ -275,29 +276,19 @@ def max_calibrate(
275276 distributed_sync: Whether to sync input_quantizer amax across distributed processes.
276277 sync_expert_weight_amax: SequentialMLP only — share one weight amax across all experts
277278 in a MoE layer (within-rank sync + EP all-reduce when EP>1).
278- shared_patterns: Optional dict keyed by quantizer kind (``"weight"``/``"input"``), each a
279- list of regexes over module FQNs. When the ``"weight"`` list is omitted,
280- :data:`DEFAULT_WEIGHT_SHARED_PATTERNS` (q/k/v, gate/up, w1/w3) is used. Modules whose
281- regex match yields the same capture-group tuple form one group — capture the immediate
282- parent for per-parent (per-expert) grouping, or leave the expert index uncaptured for
283- cross-expert. Only ``"weight"`` is used today; ``"input"`` is reserved for future
284- input-quantizer sharing.
279+ shared_states: Optional dict keyed by shared-state name. ``"weight_global_amax"`` is
280+ implemented today and accepts ``{"patterns": [...]}``; omitted patterns use
281+ :data:`DEFAULT_WEIGHT_SHARED_PATTERNS`, while an empty list disables the state.
285282
286283 See :class:`MaxCalibConfig <modelopt.torch.quantization.config.MaxCalibConfig>` for
287284 details on the remaining arguments.
288285 """
289286 # Discover fusible sibling groups by name regex and attach the (initially empty) shared
290- # state up front, so the SharedQuantState container exists for the whole calibration —
291- # forward-time fields can accumulate into it. Discovery is structural (a pattern over the
292- # module tree), so it needs no ``_amax``; per-member values are aggregated later by
293- # populate_shared_state, after the forward and any cross-rank ``_amax`` sync. Default to
294- # q/k/v + gate/up when no "weight" key is given; an explicit (possibly empty) list
295- # overrides it — key presence, not truthiness, so {"weight": []} disables grouping.
296- # Only "weight" is consumed today; "input" is reserved.
297- if shared_patterns is not None and "weight" in shared_patterns :
298- weight_patterns = list (shared_patterns ["weight" ])
299- else :
300- weight_patterns = DEFAULT_WEIGHT_SHARED_PATTERNS
287+ # state up front, so parent-level runtime hooks can be installed by future concrete
288+ # states. Discovery is structural (a pattern over the module tree), so it needs no
289+ # ``_amax``; per-member values are aggregated later by populate_shared_state, after the
290+ # forward and any cross-rank ``_amax`` sync.
291+ weight_patterns = resolve_weight_global_amax_patterns (shared_states = shared_states )
301292 attach_shared_quant_states (model , patterns = weight_patterns )
302293
303294 # Always run weight calibration on the weight tensor directly so every weight
0 commit comments