|
52 | 52 | promote_nvfp4_static_quantizers, |
53 | 53 | quantizer_attr_names, |
54 | 54 | reduce_amax, |
55 | | - weight_attr_names, |
56 | 55 | ) |
57 | 56 | from .utils.calib_utils import _GPTQ_HELPER_REGISTRY, GPTQHelper |
58 | 57 |
|
|
64 | 63 | "max_calibrate", |
65 | 64 | "smoothquant", |
66 | 65 | "svdquant", |
67 | | - "sync_grouped_weight_global_amax", |
68 | 66 | ] |
69 | 67 |
|
70 | 68 |
|
71 | | -# Sibling weight-quantizer name groups whose ``global_amax`` should share an |
72 | | -# FP8 scale-of-scales. All members of a group sit under the same parent module |
73 | | -# (e.g. one self-attention or one MLP block) and either consume the same input |
74 | | -# tensor or get fused at deployment, so a divergent global_amax across siblings |
75 | | -# would split their FP8 grids and skew the round. |
76 | | -_GROUPED_WEIGHT_QUANTIZER_PATTERNS: tuple[tuple[str, ...], ...] = ( |
77 | | - # Standard self-attention (skipped for fused qkv_proj — single weight). |
78 | | - ("q_proj", "k_proj", "v_proj"), |
79 | | - # Gated MLP, modern naming (Llama / Qwen / Mistral / etc.). |
80 | | - ("gate_proj", "up_proj"), |
81 | | - # Gated MLP, older Mixtral-style naming. |
82 | | - ("w1", "w3"), |
83 | | -) |
84 | | - |
85 | | - |
86 | | -def _is_calibrated_nvfp4_static_weight_quantizer(q) -> bool: |
87 | | - """True for an NVFP4-static weight quantizer that ``max_calibrate`` already |
88 | | - populated with a per-block ``_amax`` and that is currently enabled. |
89 | | - """ |
| 69 | +def _is_calibrated_nvfp4_static(q) -> bool: |
| 70 | + """True iff ``q`` is an enabled NVFP4-static weight quantizer with ``_amax`` set.""" |
90 | 71 | return ( |
91 | 72 | isinstance(q, TensorQuantizer) |
92 | 73 | and not q._disabled |
93 | 74 | and q.is_nvfp4_static |
94 | | - and hasattr(q, "_amax") |
95 | | - and q._amax is not None |
| 75 | + and getattr(q, "_amax", None) is not None |
96 | 76 | ) |
97 | 77 |
|
98 | 78 |
|
99 | 79 | def _collect_grouped_linears(model: nn.Module) -> list[list[nn.Module]]: |
100 | | - """Find groups of Linear-like submodules whose NVFP4-static weight quantizers |
101 | | - should share ``global_amax`` (Q/K/V under one attention parent; gate/up under |
102 | | - one MLP parent). |
103 | | - """ |
| 80 | + """Collect sibling groups (Q/K/V, gate/up) with calibrated NVFP4-static weight quantizers.""" |
| 81 | + # Inline: layer_utils → quant_utils → model_calib cycle. |
| 82 | + from modelopt.torch.export.layer_utils import _GATE_UP_PAIRS |
| 83 | + |
| 84 | + # Reuses the existing gate/up pairs and adds Q/K/V (no equivalent constant |
| 85 | + # in export). Single source for the gate/up half avoids parallel lists. |
| 86 | + patterns: tuple[tuple[str, ...], ...] = (("q_proj", "k_proj", "v_proj"), *_GATE_UP_PAIRS) |
104 | 87 | groups: list[list[nn.Module]] = [] |
105 | 88 | wq_attr = quantizer_attr_names("weight").weight_quantizer |
106 | 89 | for parent in model.modules(): |
107 | | - for sibling_names in _GROUPED_WEIGHT_QUANTIZER_PATTERNS: |
108 | | - members: list[nn.Module] = [] |
109 | | - for n in sibling_names: |
110 | | - child = getattr(parent, n, None) |
111 | | - if child is None: |
112 | | - continue |
113 | | - wq = getattr(child, wq_attr, None) |
114 | | - if _is_calibrated_nvfp4_static_weight_quantizer(wq): |
115 | | - members.append(child) |
| 90 | + for sibling_names in patterns: |
| 91 | + members = [ |
| 92 | + child |
| 93 | + for child in (getattr(parent, n, None) for n in sibling_names) |
| 94 | + if child is not None and _is_calibrated_nvfp4_static(getattr(child, wq_attr, None)) |
| 95 | + ] |
116 | 96 | if len(members) >= 2: |
117 | 97 | groups.append(members) |
118 | 98 | return groups |
119 | 99 |
|
120 | 100 |
|
121 | 101 | @torch.no_grad() |
122 | | -def sync_grouped_weight_global_amax(model: nn.Module) -> int: |
123 | | - """Sync ``global_amax`` across sibling NVFP4-static weight quantizers. |
124 | | -
|
125 | | - For each group of siblings (Q/K/V projections under one attention parent; |
126 | | - gate/up — a.k.a. ``w1``/``w3`` — under one MLP parent) unifies the |
127 | | - NVFP4 ``global_amax`` so the per-block FP8 round picks scales against a |
128 | | - consistent FP8 grid across the group during MSE / local-Hessian search. |
129 | | -
|
130 | | - Reuses :func:`modelopt.torch.export.quant_utils.preprocess_linear_fusion` |
131 | | - (whose ``NVFP4StaticQuantizer`` branch performs the same |
132 | | - ``max(stack(global_amax))`` unification at export time). To call it before |
133 | | - MSE, this helper first promotes each grouped weight quantizer to |
134 | | - :class:`NVFP4StaticQuantizer` with its local ``global_amax`` (= |
135 | | - ``reduce_amax(_amax)``); ``preprocess_linear_fusion`` then unifies in |
136 | | - place. |
137 | | -
|
138 | | - Must be called after ``max_calibrate`` has populated each weight |
139 | | - quantizer's ``_amax``. Idempotent. Returns the number of groups synced. |
| 102 | +def _bootstrap_uncalibrated_weight_quantizers(model: nn.Module) -> int: |
| 103 | + """Re-run weight calibration on the weight tensor for quantizers missing ``_amax``. |
| 104 | +
|
| 105 | + Covers MoE experts that ``max_calibrate`` skipped (no routed tokens) so MSE |
| 106 | + doesn't drop them and break the gate==up ``weight_scale_2`` export invariant. |
| 107 | + Activation quantizers on those modules remain uncalibrated; emits a warning. |
| 108 | + """ |
| 109 | + name_to_module = dict(model.named_modules()) |
| 110 | + n = 0 |
| 111 | + for module in name_to_module.values(): |
| 112 | + if not isinstance(module, QuantModule): |
| 113 | + continue |
| 114 | + with enable_weight_access_and_writeback(module, model, name_to_module): |
| 115 | + for weight, q in module.iter_weights_for_calibration(): |
| 116 | + if not isinstance(q, TensorQuantizer) or q._disabled or q._dynamic: |
| 117 | + continue |
| 118 | + if q._calibrator is None: |
| 119 | + continue |
| 120 | + if getattr(q, "_amax", None) is not None and not torch.all(q._amax == 0): |
| 121 | + continue |
| 122 | + q.disable_quant() |
| 123 | + q.enable_calib() |
| 124 | + q(weight) |
| 125 | + if q._calibrator.compute_amax() is not None: |
| 126 | + q.load_calib_amax() |
| 127 | + q.enable_quant() |
| 128 | + q.disable_calib() |
| 129 | + if hasattr(q._calibrator, "reset"): |
| 130 | + q._calibrator.reset() |
| 131 | + n += 1 |
| 132 | + if n > 0: |
| 133 | + warnings.warn( |
| 134 | + f"Bootstrapped {n} weight quantizer(s) with no routed calibration tokens; " |
| 135 | + f"their activation quantizers (if any) remain uncalibrated. " |
| 136 | + f"Increase calib size/seq len to activate all experts.", |
| 137 | + stacklevel=2, |
| 138 | + ) |
| 139 | + return n |
| 140 | + |
| 141 | + |
| 142 | +@torch.no_grad() |
| 143 | +def _sync_grouped_weight_global_amax(model: nn.Module) -> int: |
| 144 | + """Unify NVFP4 ``global_amax`` across Q/K/V and gate/up sibling weight quantizers. |
| 145 | +
|
| 146 | + Run after ``max_calibrate``. Sibling discovery is name-based via |
| 147 | + ``_collect_grouped_linears``; non-matching architectures (wqkv, fused |
| 148 | + qkv_proj, DeepSeek variants, single-Linear fused gate_up_proj) silently |
| 149 | + fall back to per-module global_amax. Fused-experts containers already |
| 150 | + share a single quantizer across gate/up halves and need no sync. |
140 | 151 | """ |
| 152 | + # quant_utils imports back from this module; top-level would cycle. |
141 | 153 | from modelopt.torch.export.quant_utils import preprocess_linear_fusion |
142 | 154 |
|
| 155 | + wq_attr = quantizer_attr_names("weight").weight_quantizer |
143 | 156 | n_groups = 0 |
144 | 157 | for group in _collect_grouped_linears(model): |
145 | | - # Promote each member's weight quantizer so `preprocess_linear_fusion` |
146 | | - # sees post-conversion NVFP4StaticQuantizers (its NVFP4 branch reads |
147 | | - # `global_amax`, which only exists post-promotion). |
148 | | - wq_attr = quantizer_attr_names("weight").weight_quantizer |
149 | 158 | for child in group: |
150 | 159 | wq = getattr(child, wq_attr) |
151 | 160 | if not isinstance(wq, NVFP4StaticQuantizer): |
152 | | - local_global = reduce_amax(wq._amax, axis=None) |
153 | | - NVFP4StaticQuantizer.from_tensor_quantizer(wq, global_amax=local_global) |
| 161 | + NVFP4StaticQuantizer.from_tensor_quantizer( |
| 162 | + wq, global_amax=reduce_amax(wq._amax, axis=None) |
| 163 | + ) |
154 | 164 | preprocess_linear_fusion(group) |
155 | 165 | n_groups += 1 |
156 | 166 | return n_groups |
@@ -436,37 +446,24 @@ def mse_calibrate( |
436 | 446 | See :class:`MseCalibConfig <modelopt.torch.quantization.config.MseCalibConfig>` for |
437 | 447 | details on the remaining arguments. |
438 | 448 | """ |
439 | | - # Step 1: First get initial amax using max calibration |
| 449 | + # Step 1: max calibrate, bootstrap dead-expert weight quantizers, |
| 450 | + # unify grouped NVFP4 global_amax so MSE sees a consistent FP8 grid. |
440 | 451 | max_calibrate(model, forward_loop, distributed_sync) |
| 452 | + _bootstrap_uncalibrated_weight_quantizers(model) |
| 453 | + _sync_grouped_weight_global_amax(model) |
441 | 454 |
|
442 | | - # Step 1b: Sync global_amax across sibling NVFP4-static weight quantizers |
443 | | - # (q/k/v_proj under one attention block; gate/up — a.k.a. w1/w3 — under one |
444 | | - # MLP block) so their FP8 scale-of-scales matches and the per-block FP8 |
445 | | - # round uses a consistent grid. No-op when there are no sibling groups |
446 | | - # (e.g. fused QKV / fused gate_up_proj). |
447 | | - sync_grouped_weight_global_amax(model) |
448 | | - |
449 | | - # Step 2: Replace calibrators with MseCalibrator for enabled quantizers |
450 | | - # and identify weight quantizers |
451 | | - weight_quantizers = [] |
452 | | - seen_modules = set() |
453 | | - |
| 455 | + # Step 2: replace calibrators with MseCalibrator for enabled quantizers. |
454 | 456 | for name, module in list(model.named_modules()): |
455 | 457 | if isinstance(module, TensorQuantizer) and not module._disabled: |
456 | 458 | if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"): |
457 | | - # Get the initial amax from max calibration |
458 | 459 | initial_amax = module._amax.clone().detach() |
459 | | - |
460 | 460 | is_nvfp4_static = module.is_nvfp4_static |
461 | 461 |
|
462 | | - if is_nvfp4_static: |
463 | | - # If sync_grouped_weight_global_amax already promoted this |
464 | | - # quantizer (it's a sibling in a Q/K/V or gate/up group), |
465 | | - # its global_amax has been unified across the group; just |
466 | | - # leave it. Otherwise convert + set local global_amax. |
467 | | - if not isinstance(module, NVFP4StaticQuantizer): |
468 | | - global_amax = reduce_amax(initial_amax, axis=None) |
469 | | - NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax) |
| 462 | + # Promote standalone NVFP4-static quantizers; grouped siblings |
| 463 | + # already promoted by _sync_grouped_weight_global_amax above. |
| 464 | + if is_nvfp4_static and not isinstance(module, NVFP4StaticQuantizer): |
| 465 | + global_amax = reduce_amax(initial_amax, axis=None) |
| 466 | + NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax) |
470 | 467 |
|
471 | 468 | if fp8_scale_sweep: |
472 | 469 | # Check if backend has a registered custom calibrator factory. |
@@ -506,52 +503,48 @@ def mse_calibrate( |
506 | 503 | quant_func=partial(_mse_quant_func, quantizer=module), |
507 | 504 | ) |
508 | 505 |
|
509 | | - # Identify weight quantizers by checking if they have corresponding weight parameters |
| 506 | + # Step 3: calibrate weight quantizers via iter_weights_for_calibration. |
510 | 507 | name_to_module = dict(model.named_modules()) |
| 508 | + seen_modules: set[int] = set() |
| 509 | + pbar = tqdm(desc="MSE weight calibration") |
| 510 | + n_calibrated = 0 |
511 | 511 | for parent_module in name_to_module.values(): |
512 | | - if parent_module in seen_modules: |
| 512 | + if id(parent_module) in seen_modules or not isinstance(parent_module, QuantModule): |
513 | 513 | continue |
514 | | - for weight_name in weight_attr_names(parent_module): |
515 | | - weight_quantizer_name = quantizer_attr_names(weight_name).weight_quantizer |
516 | | - weight_quantizer = getattr(parent_module, weight_quantizer_name, None) |
517 | | - if isinstance(weight_quantizer, TensorQuantizer) and weight_quantizer.is_enabled: |
518 | | - if getattr(weight_quantizer, "_calibrator", None) is not None: |
519 | | - weight_quantizers.append((parent_module, weight_name, weight_quantizer)) |
520 | | - seen_modules.add(parent_module) |
521 | | - |
522 | | - # Step 3: Calibrate weight quantizers ONE AT A TIME with immediate amax computation |
523 | | - # This prevents massive memory accumulation seen in large models |
524 | | - for idx, (parent_module, weight_name, weight_quantizer) in enumerate( |
525 | | - tqdm(weight_quantizers, desc="MSE weight calibration") |
526 | | - ): |
527 | | - # Enable calibration mode for the weight quantizer |
528 | | - weight_quantizer.disable_quant() |
529 | | - weight_quantizer.enable_calib() |
| 514 | + seen_modules.add(id(parent_module)) |
530 | 515 | with enable_weight_access_and_writeback(parent_module, model, name_to_module): |
531 | | - weight = getattr(parent_module, weight_name) |
532 | | - weight_quantizer(weight) |
| 516 | + for weight, weight_quantizer in parent_module.iter_weights_for_calibration(): |
| 517 | + if not ( |
| 518 | + isinstance(weight_quantizer, TensorQuantizer) |
| 519 | + and weight_quantizer.is_enabled |
| 520 | + and getattr(weight_quantizer, "_calibrator", None) is not None |
| 521 | + ): |
| 522 | + continue |
| 523 | + weight_quantizer.disable_quant() |
| 524 | + weight_quantizer.enable_calib() |
| 525 | + weight_quantizer(weight) |
533 | 526 |
|
534 | | - # IMMEDIATELY compute amax and reset calibrator to free memory |
535 | | - cal = getattr(weight_quantizer, "_calibrator", None) |
536 | | - if cal is not None and cal.compute_amax() is not None: |
537 | | - weight_quantizer.load_calib_amax() |
| 527 | + cal = weight_quantizer._calibrator |
| 528 | + if cal.compute_amax() is not None: |
| 529 | + weight_quantizer.load_calib_amax() |
538 | 530 |
|
539 | | - weight_quantizer.enable_quant() |
540 | | - weight_quantizer.disable_calib() |
| 531 | + weight_quantizer.enable_quant() |
| 532 | + weight_quantizer.disable_calib() |
541 | 533 |
|
542 | | - # Synchronize ALL CUDA devices before resetting to ensure all async operations complete |
543 | | - # This is critical for multi-GPU setups where tensors may be on different devices |
544 | | - if torch.cuda.is_available(): |
545 | | - for dev_id in range(torch.cuda.device_count()): |
546 | | - torch.cuda.synchronize(torch.device(f"cuda:{dev_id}")) |
| 534 | + if torch.cuda.is_available(): |
| 535 | + for dev_id in range(torch.cuda.device_count()): |
| 536 | + torch.cuda.synchronize(torch.device(f"cuda:{dev_id}")) |
547 | 537 |
|
548 | | - if cal is not None and hasattr(cal, "reset"): |
549 | | - cal.reset() |
| 538 | + if hasattr(cal, "reset"): |
| 539 | + cal.reset() |
550 | 540 |
|
551 | | - if (idx + 1) % 10 == 0 and torch.cuda.is_available(): |
552 | | - for dev_id in range(torch.cuda.device_count()): |
553 | | - torch.cuda.synchronize(torch.device(f"cuda:{dev_id}")) |
554 | | - torch.cuda.empty_cache() |
| 541 | + pbar.update(1) |
| 542 | + n_calibrated += 1 |
| 543 | + if n_calibrated % 10 == 0 and torch.cuda.is_available(): |
| 544 | + for dev_id in range(torch.cuda.device_count()): |
| 545 | + torch.cuda.synchronize(torch.device(f"cuda:{dev_id}")) |
| 546 | + torch.cuda.empty_cache() |
| 547 | + pbar.close() |
555 | 548 |
|
556 | 549 | if torch.cuda.is_available(): |
557 | 550 | for dev_id in range(torch.cuda.device_count()): |
@@ -706,10 +699,7 @@ def forward(self, input, *args, **kwargs): |
706 | 699 | print_rank_0("local_hessian: Running max calibration for all quantizers...") |
707 | 700 | max_calibrate(model, forward_loop, distributed_sync) |
708 | 701 |
|
709 | | - # Sync global_amax across sibling NVFP4-static weight quantizers |
710 | | - # (q/k/v_proj, gate/up_proj a.k.a. w1/w3) so the FP8 scale-of-scales |
711 | | - # is consistent across the group. Idempotent; no-op when fused. |
712 | | - sync_grouped_weight_global_amax(model) |
| 702 | + _sync_grouped_weight_global_amax(model) |
713 | 703 |
|
714 | 704 | # Setup helpers for all quantized linear modules |
715 | 705 | name_to_module = dict(model.named_modules()) |
|
0 commit comments