|
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
15 | 15 |
|
16 | | -"""Tests for NVFP4QTensor per-block FP8 scale underflow clamping.""" |
| 16 | +"""Tests for NVFP4QTensor per-block FP8 scale clamping (underflow + overflow).""" |
| 17 | + |
| 18 | +from types import SimpleNamespace |
17 | 19 |
|
18 | 20 | import torch |
19 | 21 |
|
20 | 22 | from modelopt.torch.quantization.qtensor.nvfp4_tensor import NVFP4QTensor |
21 | 23 |
|
22 | 24 | _FP8_E4M3FN_MIN = 2**-9 # 0.001953125 — smallest positive FP8 E4M3FN subnormal |
| 25 | +_FP8_E4M3FN_MAX = 448.0 |
23 | 26 |
|
24 | 27 |
|
25 | 28 | class TestNVFP4ScaleClamping: |
26 | | - """Per-block weight scales below the FP8 E4M3FN minimum must be clamped, not rounded to zero.""" |
| 29 | + """Per-block weight scales outside the FP8 E4M3FN range must be clamped, not turned into 0/NaN.""" |
27 | 30 |
|
28 | 31 | def test_no_zero_scales_for_tiny_weights(self): |
29 | 32 | """Tiny per-block amax (<<FP8 min) must not underflow to zero after FP8 cast.""" |
@@ -67,3 +70,42 @@ def test_mixed_weight_no_zeros(self): |
67 | 70 | assert (per_block_scale.float() > 0).all(), ( |
68 | 71 | "Zero scales in mixed-magnitude tensor after FP8 cast." |
69 | 72 | ) |
| 73 | + |
| 74 | + def test_helper_clamps_overflow_to_max(self): |
| 75 | + """Values above 448 must saturate to 448, not cast to NaN (fp8_e4m3fn has no Inf).""" |
| 76 | + oversized = torch.tensor([100.0, 448.0, 1e3, 1e6]) |
| 77 | + out = NVFP4QTensor._cast_per_block_scale_to_fp8(oversized).float() |
| 78 | + assert torch.isfinite(out).all(), f"FP8 cast produced non-finite values: {out.tolist()}" |
| 79 | + assert (out <= _FP8_E4M3FN_MAX).all(), f"FP8 cast values exceed 448: {out.tolist()}" |
| 80 | + |
| 81 | + def test_helper_clamps_underflow_to_min(self): |
| 82 | + """Values below the FP8 subnormal must clamp up, not collapse to 0.""" |
| 83 | + tiny = torch.tensor([0.0, 1e-12, 1e-6, _FP8_E4M3FN_MIN / 2]) |
| 84 | + out = NVFP4QTensor._cast_per_block_scale_to_fp8(tiny).float() |
| 85 | + assert (out > 0).all(), f"FP8 cast produced zero scales: {out.tolist()}" |
| 86 | + |
| 87 | + def test_static_path_no_nan_when_block_amax_zero(self): |
| 88 | + """Static path: when a block's amax is 0 (all-zero weights), the `[==0]=1.0` safety net |
| 89 | + and a small global_amax push the pre-cast value above 448. Without the max clamp, |
| 90 | + fp8_e4m3fn would cast it to NaN — regression for the export-time NaN reported on this PR. |
| 91 | + """ |
| 92 | + block_size = 16 |
| 93 | + # global_amax small enough that 1.0 * 448 / (global_amax/6) >> 448. |
| 94 | + global_amax = torch.tensor(0.01) |
| 95 | + # One block with amax=0 (triggers safety net), three normal blocks. |
| 96 | + per_block_amax = torch.tensor([[0.0, 0.005, 0.008, 0.01]]) |
| 97 | + weight = torch.randn(1, 4 * block_size) |
| 98 | + q = SimpleNamespace( |
| 99 | + global_amax=global_amax, |
| 100 | + _amax=per_block_amax, |
| 101 | + block_sizes={-1: block_size}, |
| 102 | + ) |
| 103 | + |
| 104 | + per_block_scale, _ = NVFP4QTensor.get_weights_scaling_factor_from_quantizer(q, weight) |
| 105 | + per_block_scale_f32 = per_block_scale.float() |
| 106 | + assert torch.isfinite(per_block_scale_f32).all(), ( |
| 107 | + f"NaN/Inf in exported static per-block scale: {per_block_scale_f32.tolist()}" |
| 108 | + ) |
| 109 | + assert (per_block_scale_f32 <= _FP8_E4M3FN_MAX).all(), ( |
| 110 | + f"Static per-block scale exceeds FP8 max 448: {per_block_scale_f32.tolist()}" |
| 111 | + ) |
0 commit comments