|
52 | 52 | promote_nvfp4_static_quantizers, |
53 | 53 | quantizer_attr_names, |
54 | 54 | reduce_amax, |
55 | | - weight_attr_names, |
56 | 55 | ) |
57 | 56 | from .utils.calib_utils import _GPTQ_HELPER_REGISTRY, GPTQHelper |
58 | 57 |
|
|
68 | 67 | ] |
69 | 68 |
|
70 | 69 |
|
71 | | -# Sibling weight-quantizer name groups whose ``global_amax`` should share an |
72 | | -# FP8 scale-of-scales. All members of a group sit under the same parent module |
73 | | -# (e.g. one self-attention or one MLP block) and either consume the same input |
74 | | -# tensor or get fused at deployment, so a divergent global_amax across siblings |
75 | | -# would split their FP8 grids and skew the round. |
| 70 | +# Sibling groups that share an FP8 scale-of-scales: members feed the same input |
| 71 | +# (Q/K/V) or get fused at deployment (gate/up), so divergent global_amax would |
| 72 | +# split their FP8 grids. |
76 | 73 | _GROUPED_WEIGHT_QUANTIZER_PATTERNS: tuple[tuple[str, ...], ...] = ( |
77 | | - # Standard self-attention (skipped for fused qkv_proj — single weight). |
78 | 74 | ("q_proj", "k_proj", "v_proj"), |
79 | | - # Gated MLP, modern naming (Llama / Qwen / Mistral / etc.). |
80 | | - ("gate_proj", "up_proj"), |
81 | | - # Gated MLP, older Mixtral-style naming. |
82 | | - ("w1", "w3"), |
| 75 | + ("gate_proj", "up_proj"), # Llama/Qwen/Mistral |
| 76 | + ("w1", "w3"), # Mixtral |
83 | 77 | ) |
84 | 78 |
|
85 | 79 |
|
86 | | -def _is_calibrated_nvfp4_static_weight_quantizer(q) -> bool: |
87 | | - """True for an NVFP4-static weight quantizer that ``max_calibrate`` already |
88 | | - populated with a per-block ``_amax`` and that is currently enabled. |
89 | | - """ |
90 | | - return ( |
91 | | - isinstance(q, TensorQuantizer) |
92 | | - and not q._disabled |
93 | | - and q.is_nvfp4_static |
94 | | - and hasattr(q, "_amax") |
95 | | - and q._amax is not None |
96 | | - ) |
97 | | - |
98 | | - |
99 | 80 | def _collect_grouped_linears(model: nn.Module) -> list[list[nn.Module]]: |
100 | | - """Find groups of Linear-like submodules whose NVFP4-static weight quantizers |
101 | | - should share ``global_amax`` (Q/K/V under one attention parent; gate/up under |
102 | | - one MLP parent). |
103 | | - """ |
| 81 | + """Collect sibling groups (Q/K/V, gate/up) with calibrated NVFP4-static weight quantizers.""" |
104 | 82 | groups: list[list[nn.Module]] = [] |
105 | 83 | wq_attr = quantizer_attr_names("weight").weight_quantizer |
106 | 84 | for parent in model.modules(): |
107 | 85 | for sibling_names in _GROUPED_WEIGHT_QUANTIZER_PATTERNS: |
108 | | - members: list[nn.Module] = [] |
| 86 | + members = [] |
109 | 87 | for n in sibling_names: |
110 | 88 | child = getattr(parent, n, None) |
111 | | - if child is None: |
112 | | - continue |
113 | | - wq = getattr(child, wq_attr, None) |
114 | | - if _is_calibrated_nvfp4_static_weight_quantizer(wq): |
| 89 | + wq = getattr(child, wq_attr, None) if child is not None else None |
| 90 | + if ( |
| 91 | + isinstance(wq, TensorQuantizer) |
| 92 | + and not wq._disabled |
| 93 | + and wq.is_nvfp4_static |
| 94 | + and getattr(wq, "_amax", None) is not None |
| 95 | + ): |
115 | 96 | members.append(child) |
116 | 97 | if len(members) >= 2: |
117 | 98 | groups.append(members) |
118 | 99 | return groups |
119 | 100 |
|
120 | 101 |
|
| 102 | +@torch.no_grad() |
| 103 | +def _bootstrap_uncalibrated_weight_quantizers(model: nn.Module) -> int: |
| 104 | + """Populate ``_amax`` from weights for quantizers the forward pass didn't reach. |
| 105 | +
|
| 106 | + Dead MoE experts that received no tokens are otherwise skipped by |
| 107 | + ``mse_calibrate``, leaving export to derive separate per-half amax for |
| 108 | + gate/up and break the gate==up ``weight_scale_2`` invariant. |
| 109 | + """ |
| 110 | + n = 0 |
| 111 | + for module in model.modules(): |
| 112 | + if not isinstance(module, QuantModule): |
| 113 | + continue |
| 114 | + for weight, q in module.iter_weights_for_calibration(): |
| 115 | + if not isinstance(q, TensorQuantizer) or q._disabled or q._dynamic: |
| 116 | + continue |
| 117 | + if q._calibrator is None: |
| 118 | + continue |
| 119 | + if getattr(q, "_amax", None) is not None and not torch.all(q._amax == 0): |
| 120 | + continue |
| 121 | + q.disable_quant() |
| 122 | + q.enable_calib() |
| 123 | + q(weight) |
| 124 | + if q._calibrator.compute_amax() is not None: |
| 125 | + q.load_calib_amax() |
| 126 | + q.enable_quant() |
| 127 | + q.disable_calib() |
| 128 | + if hasattr(q._calibrator, "reset"): |
| 129 | + q._calibrator.reset() |
| 130 | + n += 1 |
| 131 | + return n |
| 132 | + |
| 133 | + |
121 | 134 | @torch.no_grad() |
122 | 135 | def sync_grouped_weight_global_amax(model: nn.Module) -> int: |
123 | | - """Sync ``global_amax`` across sibling NVFP4-static weight quantizers. |
124 | | -
|
125 | | - For each group of siblings (Q/K/V projections under one attention parent; |
126 | | - gate/up — a.k.a. ``w1``/``w3`` — under one MLP parent) unifies the |
127 | | - NVFP4 ``global_amax`` so the per-block FP8 round picks scales against a |
128 | | - consistent FP8 grid across the group during MSE / local-Hessian search. |
129 | | -
|
130 | | - Reuses :func:`modelopt.torch.export.quant_utils.preprocess_linear_fusion` |
131 | | - (whose ``NVFP4StaticQuantizer`` branch performs the same |
132 | | - ``max(stack(global_amax))`` unification at export time). To call it before |
133 | | - MSE, this helper first promotes each grouped weight quantizer to |
134 | | - :class:`NVFP4StaticQuantizer` with its local ``global_amax`` (= |
135 | | - ``reduce_amax(_amax)``); ``preprocess_linear_fusion`` then unifies in |
136 | | - place. |
137 | | -
|
138 | | - Must be called after ``max_calibrate`` has populated each weight |
139 | | - quantizer's ``_amax``. Idempotent. Returns the number of groups synced. |
| 136 | + """Unify NVFP4 ``global_amax`` across Q/K/V and gate/up sibling weight quantizers. |
| 137 | +
|
| 138 | + Reuses ``preprocess_linear_fusion`` (which performs the same unification at |
| 139 | + export time) to keep the FP8 scale-of-scales consistent across siblings |
| 140 | + during MSE / local-Hessian search. Must run after ``max_calibrate``. |
140 | 141 | """ |
| 142 | + # Inline: quant_utils imports enable_stats_collection/finish_stats_collection/svd |
| 143 | + # from this module, so top-level would deadlock the cycle. |
141 | 144 | from modelopt.torch.export.quant_utils import preprocess_linear_fusion |
142 | 145 |
|
| 146 | + wq_attr = quantizer_attr_names("weight").weight_quantizer |
143 | 147 | n_groups = 0 |
144 | 148 | for group in _collect_grouped_linears(model): |
145 | | - # Promote each member's weight quantizer so `preprocess_linear_fusion` |
146 | | - # sees post-conversion NVFP4StaticQuantizers (its NVFP4 branch reads |
147 | | - # `global_amax`, which only exists post-promotion). |
148 | | - wq_attr = quantizer_attr_names("weight").weight_quantizer |
149 | 149 | for child in group: |
150 | 150 | wq = getattr(child, wq_attr) |
151 | 151 | if not isinstance(wq, NVFP4StaticQuantizer): |
152 | | - local_global = reduce_amax(wq._amax, axis=None) |
153 | | - NVFP4StaticQuantizer.from_tensor_quantizer(wq, global_amax=local_global) |
| 152 | + NVFP4StaticQuantizer.from_tensor_quantizer( |
| 153 | + wq, global_amax=reduce_amax(wq._amax, axis=None) |
| 154 | + ) |
154 | 155 | preprocess_linear_fusion(group) |
155 | 156 | n_groups += 1 |
156 | 157 | return n_groups |
@@ -436,37 +437,26 @@ def mse_calibrate( |
436 | 437 | See :class:`MseCalibConfig <modelopt.torch.quantization.config.MseCalibConfig>` for |
437 | 438 | details on the remaining arguments. |
438 | 439 | """ |
439 | | - # Step 1: First get initial amax using max calibration |
| 440 | + # Step 1: max calibration; then populate _amax for dead experts so step 3 |
| 441 | + # doesn't skip them, and unify NVFP4 global_amax across Q/K/V and gate/up |
| 442 | + # siblings so MSE searches against a consistent FP8 grid. |
440 | 443 | max_calibrate(model, forward_loop, distributed_sync) |
441 | | - |
442 | | - # Step 1b: Sync global_amax across sibling NVFP4-static weight quantizers |
443 | | - # (q/k/v_proj under one attention block; gate/up — a.k.a. w1/w3 — under one |
444 | | - # MLP block) so their FP8 scale-of-scales matches and the per-block FP8 |
445 | | - # round uses a consistent grid. No-op when there are no sibling groups |
446 | | - # (e.g. fused QKV / fused gate_up_proj). |
| 444 | + _bootstrap_uncalibrated_weight_quantizers(model) |
447 | 445 | sync_grouped_weight_global_amax(model) |
448 | 446 |
|
449 | | - # Step 2: Replace calibrators with MseCalibrator for enabled quantizers |
450 | | - # and identify weight quantizers |
451 | | - weight_quantizers = [] |
452 | | - seen_modules = set() |
453 | | - |
| 447 | + # Step 2: replace calibrators with MseCalibrator for enabled quantizers. |
454 | 448 | for name, module in list(model.named_modules()): |
455 | 449 | if isinstance(module, TensorQuantizer) and not module._disabled: |
456 | 450 | if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"): |
457 | | - # Get the initial amax from max calibration |
458 | 451 | initial_amax = module._amax.clone().detach() |
459 | | - |
460 | 452 | is_nvfp4_static = module.is_nvfp4_static |
461 | 453 |
|
462 | | - if is_nvfp4_static: |
463 | | - # If sync_grouped_weight_global_amax already promoted this |
464 | | - # quantizer (it's a sibling in a Q/K/V or gate/up group), |
465 | | - # its global_amax has been unified across the group; just |
466 | | - # leave it. Otherwise convert + set local global_amax. |
467 | | - if not isinstance(module, NVFP4StaticQuantizer): |
468 | | - global_amax = reduce_amax(initial_amax, axis=None) |
469 | | - NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax) |
| 454 | + # sync_grouped_weight_global_amax may have already promoted + |
| 455 | + # unified global_amax across the sibling group; only promote |
| 456 | + # standalone (non-grouped) NVFP4-static quantizers here. |
| 457 | + if is_nvfp4_static and not isinstance(module, NVFP4StaticQuantizer): |
| 458 | + global_amax = reduce_amax(initial_amax, axis=None) |
| 459 | + NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax) |
470 | 460 |
|
471 | 461 | if fp8_scale_sweep: |
472 | 462 | # Check if backend has a registered custom calibrator factory. |
@@ -506,52 +496,50 @@ def mse_calibrate( |
506 | 496 | quant_func=partial(_mse_quant_func, quantizer=module), |
507 | 497 | ) |
508 | 498 |
|
509 | | - # Identify weight quantizers by checking if they have corresponding weight parameters |
| 499 | + # Step 3: calibrate weight quantizers via iter_weights_for_calibration. |
| 500 | + # The fused-experts override yields one pair per expert per projection, so |
| 501 | + # every per-expert quantizer is MSE-calibrated (not just routed ones). |
510 | 502 | name_to_module = dict(model.named_modules()) |
| 503 | + seen_modules: set[int] = set() |
| 504 | + pbar = tqdm(desc="MSE weight calibration") |
| 505 | + n_calibrated = 0 |
511 | 506 | for parent_module in name_to_module.values(): |
512 | | - if parent_module in seen_modules: |
| 507 | + if id(parent_module) in seen_modules or not isinstance(parent_module, QuantModule): |
513 | 508 | continue |
514 | | - for weight_name in weight_attr_names(parent_module): |
515 | | - weight_quantizer_name = quantizer_attr_names(weight_name).weight_quantizer |
516 | | - weight_quantizer = getattr(parent_module, weight_quantizer_name, None) |
517 | | - if isinstance(weight_quantizer, TensorQuantizer) and weight_quantizer.is_enabled: |
518 | | - if getattr(weight_quantizer, "_calibrator", None) is not None: |
519 | | - weight_quantizers.append((parent_module, weight_name, weight_quantizer)) |
520 | | - seen_modules.add(parent_module) |
521 | | - |
522 | | - # Step 3: Calibrate weight quantizers ONE AT A TIME with immediate amax computation |
523 | | - # This prevents massive memory accumulation seen in large models |
524 | | - for idx, (parent_module, weight_name, weight_quantizer) in enumerate( |
525 | | - tqdm(weight_quantizers, desc="MSE weight calibration") |
526 | | - ): |
527 | | - # Enable calibration mode for the weight quantizer |
528 | | - weight_quantizer.disable_quant() |
529 | | - weight_quantizer.enable_calib() |
| 509 | + seen_modules.add(id(parent_module)) |
530 | 510 | with enable_weight_access_and_writeback(parent_module, model, name_to_module): |
531 | | - weight = getattr(parent_module, weight_name) |
532 | | - weight_quantizer(weight) |
| 511 | + for weight, weight_quantizer in parent_module.iter_weights_for_calibration(): |
| 512 | + if not ( |
| 513 | + isinstance(weight_quantizer, TensorQuantizer) |
| 514 | + and weight_quantizer.is_enabled |
| 515 | + and getattr(weight_quantizer, "_calibrator", None) is not None |
| 516 | + ): |
| 517 | + continue |
| 518 | + weight_quantizer.disable_quant() |
| 519 | + weight_quantizer.enable_calib() |
| 520 | + weight_quantizer(weight) |
533 | 521 |
|
534 | | - # IMMEDIATELY compute amax and reset calibrator to free memory |
535 | | - cal = getattr(weight_quantizer, "_calibrator", None) |
536 | | - if cal is not None and cal.compute_amax() is not None: |
537 | | - weight_quantizer.load_calib_amax() |
| 522 | + cal = weight_quantizer._calibrator |
| 523 | + if cal.compute_amax() is not None: |
| 524 | + weight_quantizer.load_calib_amax() |
538 | 525 |
|
539 | | - weight_quantizer.enable_quant() |
540 | | - weight_quantizer.disable_calib() |
| 526 | + weight_quantizer.enable_quant() |
| 527 | + weight_quantizer.disable_calib() |
541 | 528 |
|
542 | | - # Synchronize ALL CUDA devices before resetting to ensure all async operations complete |
543 | | - # This is critical for multi-GPU setups where tensors may be on different devices |
544 | | - if torch.cuda.is_available(): |
545 | | - for dev_id in range(torch.cuda.device_count()): |
546 | | - torch.cuda.synchronize(torch.device(f"cuda:{dev_id}")) |
| 529 | + if torch.cuda.is_available(): |
| 530 | + for dev_id in range(torch.cuda.device_count()): |
| 531 | + torch.cuda.synchronize(torch.device(f"cuda:{dev_id}")) |
547 | 532 |
|
548 | | - if cal is not None and hasattr(cal, "reset"): |
549 | | - cal.reset() |
| 533 | + if hasattr(cal, "reset"): |
| 534 | + cal.reset() |
550 | 535 |
|
551 | | - if (idx + 1) % 10 == 0 and torch.cuda.is_available(): |
552 | | - for dev_id in range(torch.cuda.device_count()): |
553 | | - torch.cuda.synchronize(torch.device(f"cuda:{dev_id}")) |
554 | | - torch.cuda.empty_cache() |
| 536 | + pbar.update(1) |
| 537 | + n_calibrated += 1 |
| 538 | + if n_calibrated % 10 == 0 and torch.cuda.is_available(): |
| 539 | + for dev_id in range(torch.cuda.device_count()): |
| 540 | + torch.cuda.synchronize(torch.device(f"cuda:{dev_id}")) |
| 541 | + torch.cuda.empty_cache() |
| 542 | + pbar.close() |
555 | 543 |
|
556 | 544 | if torch.cuda.is_available(): |
557 | 545 | for dev_id in range(torch.cuda.device_count()): |
@@ -706,9 +694,6 @@ def forward(self, input, *args, **kwargs): |
706 | 694 | print_rank_0("local_hessian: Running max calibration for all quantizers...") |
707 | 695 | max_calibrate(model, forward_loop, distributed_sync) |
708 | 696 |
|
709 | | - # Sync global_amax across sibling NVFP4-static weight quantizers |
710 | | - # (q/k/v_proj, gate/up_proj a.k.a. w1/w3) so the FP8 scale-of-scales |
711 | | - # is consistent across the group. Idempotent; no-op when fused. |
712 | 697 | sync_grouped_weight_global_amax(model) |
713 | 698 |
|
714 | 699 | # Setup helpers for all quantized linear modules |
|
0 commit comments