Skip to content

Commit b161f3b

Browse files
committed
more reviewers feedback
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
1 parent ab8a162 commit b161f3b

3 files changed

Lines changed: 70 additions & 119 deletions

File tree

modelopt/torch/quantization/qtensor/nvfp4_tensor.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,16 @@ def get_weights_scaling_factor_2_from_quantizer(cls, weight_quantizer):
7878
)
7979
return weight_quantizer._amax.float() / (6.0 * 448.0)
8080

81+
@classmethod
82+
def _cast_per_block_scale_to_fp8(cls, per_block_scale: torch.Tensor) -> torch.Tensor:
83+
"""Clamp to FP8 E4M3FN representable range, then cast.
84+
85+
FP8 E4M3FN has no Inf and a smallest positive subnormal of ``2**-9`` (~0.00195).
86+
Values below the min silently underflow to 0 (zero outputs at inference); values
87+
above 448 cast to NaN.
88+
"""
89+
return per_block_scale.clamp(min=2**-9, max=448.0).to(torch.float8_e4m3fn)
90+
8191
@classmethod
8292
def get_weights_scaling_factor_from_quantizer(
8393
cls,
@@ -122,17 +132,9 @@ def get_weights_scaling_factor_from_quantizer(
122132
expected_shape = (*weight.shape[:-1], num_blocks_per_row)
123133
per_block_scale = per_block_scale.view(expected_shape)
124134

125-
# Quantize scales to FP8. Saturate to the fp8_e4m3fn max (448) before the
126-
# cast: when the [==0]=1.0 safety net above fires (per_block_amax was zero
127-
# for an all-zero weight block) and global_amax is small, the pre-cast value
128-
# explodes to ``1.0 * 448 / (global_amax/6)``. fp8_e4m3fn has no Inf, so any
129-
# value >= 480 casts to NaN — clamp first to keep the stored byte finite.
130135
if not keep_high_precision:
131-
fp8_e4m3fn_min = 2**-9 # 0.001953125 — smallest positive subnormal
132-
per_block_scale = (
133-
(per_block_scale * 448.0 / per_block_scale_max)
134-
.clamp(min=fp8_e4m3fn_min, max=448.0)
135-
.to(torch.float8_e4m3fn)
136+
per_block_scale = cls._cast_per_block_scale_to_fp8(
137+
per_block_scale * 448.0 / per_block_scale_max
136138
)
137139
return per_block_scale, weights_scaling_factor_2
138140
else:
@@ -172,15 +174,8 @@ def get_weights_scaling_factor(
172174
)
173175
# Set all zero values in scale to 1.0
174176
per_block_scale[per_block_scale == 0] = 1.0
175-
# Convert to torch.float8_e4m3fn
176177
if not keep_high_precision:
177-
# Clamp to the minimum positive FP8 E4M3FN subnormal (~0.00195 = 2^-9) before
178-
# casting. Without this, blocks whose scale falls below the FP8 representable
179-
# range silently underflow to 0, causing those blocks to produce zero output at
180-
# inference even when the weights are non-trivial.
181-
fp8_e4m3fn_min = 2**-9 # 0.001953125 — smallest positive subnormal
182-
per_block_scale = per_block_scale.clamp(min=fp8_e4m3fn_min)
183-
per_block_scale = per_block_scale.to(torch.float8_e4m3fn)
178+
per_block_scale = cls._cast_per_block_scale_to_fp8(per_block_scale)
184179
return per_block_scale, weights_scaling_factor_2
185180

186181
@classmethod

modelopt_recipes/general/ptq/nvfp4_experts_only_mse.yaml

Lines changed: 13 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -13,118 +13,32 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
imports:
17+
base_disable_all: configs/ptq/units/base_disable_all
18+
default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers
19+
nvfp4: configs/numerics/nvfp4
20+
nvfp4_static: configs/numerics/nvfp4_static
21+
1622
metadata:
1723
recipe_type: ptq
18-
description: >
19-
NVFP4 W4A4 for MoE routed experts only. Static weight scales via MSE + FP8 scale sweep;
20-
dynamic activation scales. Supports sequential experts (nn.Linear-based) and fused experts
21-
(_QuantFusedExperts, HF transformers 5.0+ 3D nn.Parameter style).
24+
description: NVFP4 static weight (MSE FP8-scale sweep) and dynamic activation for expert layers only (W4A4), no KV-cache quantization.
2225
quantize:
2326
algorithm:
2427
method: mse
2528
fp8_scale_sweep: true
2629
layerwise: false
2730
quant_cfg:
28-
# ── Disable everything first ─────────────────────────────────────────────
29-
- quantizer_name: '*'
30-
enable: false
31-
32-
# ── Sequential experts (nn.Linear per expert) ────────────────────────────
31+
- $import: base_disable_all
3332
- quantizer_name: '*mlp.experts*weight_quantizer'
34-
enable: true
3533
cfg:
36-
block_sizes:
37-
-1: 16
38-
type: static
39-
scale_bits: e4m3
40-
num_bits: e2m1
34+
$import: nvfp4_static
4135
- quantizer_name: '*mlp.experts*input_quantizer'
42-
enable: true
4336
cfg:
44-
block_sizes:
45-
-1: 16
46-
type: dynamic
47-
scale_bits: e4m3
48-
num_bits: e2m1
49-
50-
# ── Sequential experts: Mixtral / block_sparse_moe style ────────────────
37+
$import: nvfp4
5138
- quantizer_name: '*block_sparse_moe*weight_quantizer'
52-
enable: true
5339
cfg:
54-
block_sizes:
55-
-1: 16
56-
type: static
57-
scale_bits: e4m3
58-
num_bits: e2m1
40+
$import: nvfp4_static
5941
- quantizer_name: '*block_sparse_moe*input_quantizer'
60-
enable: true
61-
cfg:
62-
block_sizes:
63-
-1: 16
64-
type: dynamic
65-
scale_bits: e4m3
66-
num_bits: e2m1
67-
68-
# ── Fused experts (_QuantFusedExperts, HF transformers 5.0+ 3D nn.Parameter style) ──
69-
- quantizer_name: '*gate_up_proj_weight_quantizers*'
70-
enable: true
7142
cfg:
72-
block_sizes:
73-
-1: 16
74-
type: static
75-
scale_bits: e4m3
76-
num_bits: e2m1
77-
- quantizer_name: '*gate_up_proj_input_quantizer*'
78-
enable: true
79-
cfg:
80-
block_sizes:
81-
-1: 16
82-
type: dynamic
83-
scale_bits: e4m3
84-
num_bits: e2m1
85-
- quantizer_name: '*down_proj_weight_quantizers*'
86-
enable: true
87-
cfg:
88-
block_sizes:
89-
-1: 16
90-
type: static
91-
scale_bits: e4m3
92-
num_bits: e2m1
93-
- quantizer_name: '*down_proj_input_quantizer*'
94-
enable: true
95-
cfg:
96-
block_sizes:
97-
-1: 16
98-
type: dynamic
99-
scale_bits: e4m3
100-
num_bits: e2m1
101-
102-
# ── Exclusions: shared experts, attention, routers, lm_head ─────────────
103-
- quantizer_name: '*block_sparse_moe.gate*'
104-
enable: false
105-
- quantizer_name: '*linear_attn.conv1d*'
106-
enable: false
107-
- quantizer_name: '*lm_head*'
108-
enable: false
109-
- quantizer_name: '*mlp.gate.*'
110-
enable: false
111-
- quantizer_name: '*mlp.shared_expert*'
112-
enable: false
113-
- quantizer_name: '*mlp.shared_expert_gate.*'
114-
enable: false
115-
- quantizer_name: '*router*'
116-
enable: false
117-
- quantizer_name: 'output.*'
118-
enable: false
119-
- parent_class: 'nn.BatchNorm1d'
120-
quantizer_name: '*'
121-
enable: false
122-
- parent_class: 'nn.BatchNorm2d'
123-
quantizer_name: '*'
124-
enable: false
125-
- parent_class: 'nn.BatchNorm3d'
126-
quantizer_name: '*'
127-
enable: false
128-
- parent_class: 'nn.LeakyReLU'
129-
quantizer_name: '*'
130-
enable: false
43+
$import: nvfp4
44+
- $import: default_disabled_quantizers

tests/unit/torch/quantization/test_nvfp4_tensor.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,20 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
"""Tests for NVFP4QTensor per-block FP8 scale underflow clamping."""
16+
"""Tests for NVFP4QTensor per-block FP8 scale clamping (underflow + overflow)."""
17+
18+
from types import SimpleNamespace
1719

1820
import torch
1921

2022
from modelopt.torch.quantization.qtensor.nvfp4_tensor import NVFP4QTensor
2123

2224
_FP8_E4M3FN_MIN = 2**-9 # 0.001953125 — smallest positive FP8 E4M3FN subnormal
25+
_FP8_E4M3FN_MAX = 448.0
2326

2427

2528
class TestNVFP4ScaleClamping:
26-
"""Per-block weight scales below the FP8 E4M3FN minimum must be clamped, not rounded to zero."""
29+
"""Per-block weight scales outside the FP8 E4M3FN range must be clamped, not turned into 0/NaN."""
2730

2831
def test_no_zero_scales_for_tiny_weights(self):
2932
"""Tiny per-block amax (<<FP8 min) must not underflow to zero after FP8 cast."""
@@ -67,3 +70,42 @@ def test_mixed_weight_no_zeros(self):
6770
assert (per_block_scale.float() > 0).all(), (
6871
"Zero scales in mixed-magnitude tensor after FP8 cast."
6972
)
73+
74+
def test_helper_clamps_overflow_to_max(self):
75+
"""Values above 448 must saturate to 448, not cast to NaN (fp8_e4m3fn has no Inf)."""
76+
oversized = torch.tensor([100.0, 448.0, 1e3, 1e6])
77+
out = NVFP4QTensor._cast_per_block_scale_to_fp8(oversized).float()
78+
assert torch.isfinite(out).all(), f"FP8 cast produced non-finite values: {out.tolist()}"
79+
assert (out <= _FP8_E4M3FN_MAX).all(), f"FP8 cast values exceed 448: {out.tolist()}"
80+
81+
def test_helper_clamps_underflow_to_min(self):
82+
"""Values below the FP8 subnormal must clamp up, not collapse to 0."""
83+
tiny = torch.tensor([0.0, 1e-12, 1e-6, _FP8_E4M3FN_MIN / 2])
84+
out = NVFP4QTensor._cast_per_block_scale_to_fp8(tiny).float()
85+
assert (out > 0).all(), f"FP8 cast produced zero scales: {out.tolist()}"
86+
87+
def test_static_path_no_nan_when_block_amax_zero(self):
88+
"""Static path: when a block's amax is 0 (all-zero weights), the `[==0]=1.0` safety net
89+
and a small global_amax push the pre-cast value above 448. Without the max clamp,
90+
fp8_e4m3fn would cast it to NaN — regression for the export-time NaN reported on this PR.
91+
"""
92+
block_size = 16
93+
# global_amax small enough that 1.0 * 448 / (global_amax/6) >> 448.
94+
global_amax = torch.tensor(0.01)
95+
# One block with amax=0 (triggers safety net), three normal blocks.
96+
per_block_amax = torch.tensor([[0.0, 0.005, 0.008, 0.01]])
97+
weight = torch.randn(1, 4 * block_size)
98+
q = SimpleNamespace(
99+
global_amax=global_amax,
100+
_amax=per_block_amax,
101+
block_sizes={-1: block_size},
102+
)
103+
104+
per_block_scale, _ = NVFP4QTensor.get_weights_scaling_factor_from_quantizer(q, weight)
105+
per_block_scale_f32 = per_block_scale.float()
106+
assert torch.isfinite(per_block_scale_f32).all(), (
107+
f"NaN/Inf in exported static per-block scale: {per_block_scale_f32.tolist()}"
108+
)
109+
assert (per_block_scale_f32 <= _FP8_E4M3FN_MAX).all(), (
110+
f"Static per-block scale exceeds FP8 max 448: {per_block_scale_f32.tolist()}"
111+
)

0 commit comments

Comments
 (0)