Skip to content

Commit bde162a

Browse files
cjluo-nvclaude
andauthored
feat(deepseek): add --cast_mxfp4_to_nvfp4 to deepseek_v4 quantize step (#1653)
### What does this PR do? Type of change: new feature Brings the GPT-OSS lossless MXFP4 → NVFP4 cast (#1372) to DeepSeek V4's routed-expert export by adding a `--cast_mxfp4_to_nvfp4` flag to `examples/deepseek/deepseek_v4/quantize_to_nvfp4.py`. To avoid duplicating the closed-form math, the shared numerics — `mxfp4_to_nvfp4_global_amax`, `mxfp4_to_nvfp4_per_block_amax`, and the E2M1/E4M3/E8M0 constants — are **hoisted out of the GPT-OSS example cast into the library** at `modelopt/torch/quantization/utils/numeric_utils.py`. Both the GPT-OSS cast (`examples/llm_ptq/cast_mxfp4_to_nvfp4.py`) and the new DeepSeek path now import them from there. DeepSeek V4's routed experts ship as MXFP4 (E2M1 nibbles + a power-of-two E8M0 scale per 32-element block). By default the export dequantizes them to BF16 and re-quantizes to NVFP4 using the calibrated per-tensor weight amax, which re-derives per-block scales from the data and is therefore lossy. With the flag, the cast pins `scale_2 = 2^(k_max-8)` and each per-block E4M3 scale to `2^(k_j-m)` straight from the source E8M0 scales, so `per_block_scale * scale_2 = 2^k_j` and the NVFP4 nibbles equal the source MXFP4 nibbles bit-for-bit (for every block whose `k_j` lands in E4M3's representable window; rare out-of-range blocks clamp). The one V4-specific addition is that w1/w3 share a single `scale_2` for the fused GEMM1, so `k_max` is taken over both projections. The flag only affects routed-expert **weights** — activation `input_scale` still comes from `--amax_path` calibration. ### Usage ```bash python deepseek_v4/quantize_to_nvfp4.py \ --amax_path ${AMAX} \ --source_ckpt ${DS_V4} \ --output_ckpt ${HF_NVFP4_PATH} \ --cast_mxfp4_to_nvfp4 ``` ### Testing - The hoisted numerics get unit tests in `tests/unit/torch/quantization/test_numeric_utils.py` (10 cases: per-tensor global_amax, per-block amax incl. out-of-range, magnitude-table cache) — 10/10 pass. The example test `tests/examples/llm_ptq/test_cast_mxfp4_to_nvfp4.py` keeps the cast-specific cases (quantizer naming, `build_amax_map`, `apply_to_model`). - Validated on real DeepSeek-V4-Flash expert tensors (incl. the on-disk `float8_e8m0fnu` scale dtype): 23.5M blocks, 100% lossless, 0 error. - Generated a full NVFP4 checkpoint for DeepSeek-V4-Flash (43 layers, 256 routed experts) end-to-end: `[cast] lossless MXFP4->NVFP4 blocks: 8,657,043,456/8,657,043,456 (100.0000%)`. Output weights match an independently-produced reference cast byte-for-byte (`weight_scale`, `weight_scale_2`, packed nibbles modulo the harmless sign-of-zero). ### Before your PR is "*Ready for review*" - Is this change backward compatible?: ✅ (new opt-in flag; default export behavior unchanged; hoist re-exports through the existing example module) - If you copied code from any other sources or added a new PIP dependency, did you follow guidance in `CONTRIBUTING.md`: ✅ N/A (no new deps; shared numerics moved into the library rather than duplicated) - Did you write any new necessary tests?: ✅ (library numerics covered by `tests/unit/torch/quantization/test_numeric_utils.py`; end-to-end validated on a real DeepSeek-V4 checkpoint) - Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?: ✅ - Did you get Claude approval on this PR?: ❌ (will run `/claude review`) ### Additional Information Mirrors and reuses #1372 (GPT-OSS MXFP4 → NVFP4 cast); the closed-form numerics are now shared via `modelopt.torch.quantization.utils.numeric_utils`. 🤖 Generated with [Claude Code](https://claude.com/claude-code) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added `--cast_mxfp4_to_nvfp4` flag to perform a closed-form, mostly lossless MXFP4→NVFP4 conversion for routed-expert weights with aggregated lossless/block statistics. * **Documentation** * Updated DeepSeek V4 export instructions and README to document the new flag and clarify calibration behavior for activation scales. * **Chores** * Exposed shared numeric quantization utilities for MXFP4→NVFP4 casting. * **Tests** * Added and updated tests to validate the new numeric helpers and conversion behavior. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Chenjie Luo <chenjiel@nvidia.com> Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
1 parent 111b7eb commit bde162a

7 files changed

Lines changed: 532 additions & 294 deletions

File tree

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ Changelog
4141
- Add quantization examples for the Megatron-Bridge framework: post-training quantization (`quantize.py <https://github.com/NVIDIA/Model-Optimizer/blob/main/examples/megatron_bridge/quantize.py>`_), export to a deployable HuggingFace checkpoint (`export.py <https://github.com/NVIDIA/Model-Optimizer/blob/main/examples/megatron_bridge/export.py>`_), and Quantization Aware Distillation (extend existing `distill.py <https://github.com/NVIDIA/Model-Optimizer/blob/main/examples/megatron_bridge/distill.py>`_).
4242
- Add end-to-end optimization tutorial for Minitron pruning + two-phase distillation (80B @ 8K + 20B @ 32K long-context = 100B tokens) + FP8 PTQ + vLLM deployment for Nemotron-3-Nano-30B-A3B-BF16 (MoE + Mamba-Transformer hybrid) → Pruned 22B/A3.0B active params, along with data blend preparation steps (with tool-calling data) and detailed pruning / data-blend / long-context ablations. See `examples/megatron_bridge/tutorials/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/megatron_bridge/tutorials/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/>`_ for details.
4343
- Add ``--cast_mxfp4_to_nvfp4`` flag to ``examples/llm_ptq/hf_ptq.py`` for closed-form, bit-exact MXFP4 → NVFP4 weight conversion. Supports the GPT-OSS family (``openai/gpt-oss-20b``, ``openai/gpt-oss-120b``). See `examples/llm_ptq/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/llm_ptq#mxfp4--nvfp4-cast-for-gpt-oss>`__ for usage.
44+
- Add ``--cast_mxfp4_to_nvfp4`` flag to ``examples/deepseek/deepseek_v4/quantize_to_nvfp4.py`` for closed-form, bit-exact MXFP4 → NVFP4 conversion of DeepSeek V4 routed-expert weights (mirrors the GPT-OSS cast; w1/w3 share one per-tensor ``scale_2`` for the fused GEMM1). Activation ``input_scale`` still comes from ``--amax_path`` calibration.
4445
- DeepSeek PTQ (``examples/deepseek/ptq.py``) now defaults to native top-k calibration with post-hoc per-layer peer-max sync of expert ``input_quantizer.amax``; the all-experts path is preserved behind ``--calib_all_experts``.
4546
- Add NVFP4 W4A16 weight-only quantization (``w4a16_nvfp4``): FP4 weights with group_size=16, BF16 activations, no calibration forward pass required. Use ``mtq.W4A16_NVFP4_CFG`` or ``--qformat w4a16_nvfp4`` in ``hf_ptq.py``. vLLM deployment support is in progress.
4647
- Add FP8 KV-cache cast variants for the partial-NVFP4 and weight-only general PTQ recipes: ``general/ptq/nvfp4_mlp_only-kv_fp8_cast``, ``general/ptq/nvfp4_experts_only-kv_fp8_cast``, ``general/ptq/nvfp4_omlp_only-kv_fp8_cast``, and ``general/ptq/nvfp4_weight_only-kv_fp8_cast``. These compose the same model-quant configs as their ``-kv_fp8`` siblings with the ``kv_fp8_cast`` unit (constant-amax FP8 KV cache, no KV calibration forward pass).

examples/deepseek/README.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,3 +174,25 @@ python deepseek_v4/quantize_to_nvfp4.py \
174174
The output includes an updated `model.safetensors.index.json`, a `config.json`
175175
with `quantization_config.moe_quant_algo = "NVFP4"`, and `hf_quant_config.json`
176176
describing the mixed NVFP4 expert layers.
177+
178+
When the source routed experts are MXFP4 (as in the V4 release), add
179+
`--cast_mxfp4_to_nvfp4` for a lossless weight conversion — recommended over the
180+
default lossy dequant/re-quant path. See below.
181+
182+
#### Lossless MXFP4 → NVFP4 weight cast (`--cast_mxfp4_to_nvfp4`)
183+
184+
The routed experts in the source checkpoint are already MXFP4 (E2M1 nibbles +
185+
a power-of-two E8M0 scale per 32-element block). Without the flag, the export
186+
dequantizes them to BF16 and re-quantizes to NVFP4 using the calibrated
187+
per-tensor weight amax, which re-derives the per-block scales from the data and
188+
is therefore lossy. With `--cast_mxfp4_to_nvfp4`, the per-tensor `scale_2` is
189+
pinned to `2^(k_max - 8)` and each per-block E4M3 scale to `2^(k_j - m)` straight
190+
from the source E8M0 scales, so `per_block_scale * scale_2 = 2^k_j` and the NVFP4
191+
nibbles equal the source MXFP4 nibbles bit-for-bit (for every block whose `k_j`
192+
lands in E4M3's representable window; the rare out-of-range block falls back to a
193+
data-derived scale). The flag only affects routed-expert **weights** — activation
194+
`input_scale` still comes from `${AMAX}` calibration — and the run prints a
195+
`[cast] lossless MXFP4->NVFP4 blocks: …` summary. This mirrors the GPTOSS cast in
196+
[`examples/llm_ptq/cast_mxfp4_to_nvfp4.py`](../llm_ptq/cast_mxfp4_to_nvfp4.py); the
197+
V4 twist is that w1/w3 share one `scale_2` (fused GEMM1), so `k_max` is taken over
198+
both projections.

examples/deepseek/deepseek_v4/quantize_to_nvfp4.py

Lines changed: 162 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,20 @@
6363
for the same projection. If no calibrated expert exists for that
6464
projection, export fails.
6565
66+
Lossless weight cast (``--cast_mxfp4_to_nvfp4``): the source routed experts are
67+
already MXFP4 (E2M1 nibbles + a power-of-two E8M0 scale per 32-element block).
68+
By default this script dequantizes them to BF16 and re-quantizes to NVFP4 with
69+
the calibrated per-tensor weight amax, which re-derives per-block scales from
70+
the data and is therefore lossy. With ``--cast_mxfp4_to_nvfp4`` we instead pin
71+
``scale_2 = 2^(k_max - 8)`` and the per-block E4M3 scale to ``2^(k_j - m)``
72+
straight from the source E8M0 scales, so ``per_block_scale * scale_2 = 2^k_j``
73+
and the NVFP4 nibbles equal the source MXFP4 nibbles bit-for-bit (for every
74+
block whose ``k_j`` lands in E4M3's representable window). The flag only affects
75+
routed-expert *weights*; activation ``input_scale`` still comes from
76+
``--amax_path`` calibration. This mirrors the GPTOSS cast in
77+
``examples/llm_ptq/cast_mxfp4_to_nvfp4.py`` (PR #1372); the V4 twist is that
78+
w1/w3 share one ``scale_2`` (fused GEMM1), so ``k_max`` is taken over both.
79+
6680
Usage (single compute node, CPU-default; dequant+requant math is cheap
6781
relative to shard I/O):
6882
@@ -91,6 +105,17 @@
91105

92106
from modelopt.torch.quantization.qtensor import MXFP4QTensor, NVFP4QTensor
93107

108+
# Closed-form MXFP4 -> NVFP4 numerics shared with the GPT-OSS cast (PR #1372).
109+
from modelopt.torch.quantization.utils.numeric_utils import (
110+
E2M1_MAX,
111+
E4M3_KMAX,
112+
E4M3_KMIN,
113+
E4M3_MAX,
114+
E8M0_BIAS,
115+
mxfp4_to_nvfp4_global_amax,
116+
mxfp4_to_nvfp4_per_block_amax,
117+
)
118+
94119
# Routed-expert weights in regular MoE layers. MTP experts remain in source format.
95120
_EXPERT_WEIGHT_RE = re.compile(r"^layers\.\d+\.ffn\.experts\.\d+\.w[123]\.weight$")
96121
_EXPERT_PROJ_RE = re.compile(r"^(?P<experts>layers\.\d+\.ffn\.experts)\.\d+\.w[123]$")
@@ -233,6 +258,98 @@ def _quantize_weight_nvfp4(
233258
return q_tensor._quantized_data, weight_scale, weight_scale_2, synthesized
234259

235260

261+
# ---------------------------------------------------------------------------
262+
# Lossless MXFP4 -> NVFP4 weight cast (``--cast_mxfp4_to_nvfp4``).
263+
#
264+
# NVFP4 uses the same E2M1 nibble grid as MXFP4 with 16-element blocks and a
265+
# two-level scale ``per_block_scale (E4M3) * scale_2 (fp32)``. Pinning
266+
# ``scale_2 = 2^m`` (``m = k_max - 8``) and ``per_block_scale = 2^(k_j - m)``
267+
# makes ``per_block_scale * scale_2 = 2^k_j`` exactly, so each NVFP4 nibble
268+
# equals the source MXFP4 nibble verbatim — bit-exact for every block whose
269+
# ``k_j`` lands in E4M3's window (``k_max - k_j <= 17``). The closed-form
270+
# per-block amax and the format constants are reused from the GPT-OSS cast
271+
# (``cast_mxfp4_to_nvfp4``, PR #1372); the V4 twist is that w1/w3 share one
272+
# ``scale_2`` (fused GEMM1), so ``k_max`` is taken over both projections.
273+
# ---------------------------------------------------------------------------
274+
_NVFP4_BLOCK = 16 # NVFP4 block size (elements)
275+
_MXFP4_BYTES_PER_BLOCK = 16 # 32 E2M1 nibbles packed 2-per-byte
276+
277+
278+
def _kmax_from_mxfp4_scale(mxfp4_scale: torch.Tensor, device: str = "cpu") -> int:
279+
"""Largest non-zero E8M0 exponent ``k_j = e8m0 - 127`` (0 if all-zero).
280+
281+
Delegates to the GPT-OSS cast's ``k_max`` logic, which excludes the
282+
all-zero sentinel (``e8m0 == 0`` => ``k == -127``).
283+
"""
284+
e8m0 = mxfp4_scale.to(device).contiguous().view(torch.uint8)
285+
return mxfp4_to_nvfp4_global_amax(e8m0)[1]["k_max"]
286+
287+
288+
def _build_w13_kmax_overrides(f, expert_weight_keys: list[str], device: str) -> dict[str, int]:
289+
"""Shared ``k_max`` per w1/w3 pair so the fused GEMM1 gets one ``scale_2``."""
290+
groups: dict[str, dict[str, str]] = defaultdict(dict)
291+
for key in expert_weight_keys:
292+
expert_path = key[: -len(".weight")]
293+
base, proj = expert_path.rsplit(".", 1)
294+
if proj in {"w1", "w3"}:
295+
groups[base][proj] = expert_path
296+
297+
overrides: dict[str, int] = {}
298+
for paths in groups.values():
299+
if "w1" not in paths or "w3" not in paths:
300+
continue
301+
k1 = _kmax_from_mxfp4_scale(f.get_tensor(paths["w1"] + ".scale"), device)
302+
k3 = _kmax_from_mxfp4_scale(f.get_tensor(paths["w3"] + ".scale"), device)
303+
shared = max(k1, k3)
304+
overrides[paths["w1"]] = shared
305+
overrides[paths["w3"]] = shared
306+
return overrides
307+
308+
309+
def _quantize_weight_nvfp4_lossless(
310+
mxfp4_weight: torch.Tensor,
311+
mxfp4_scale: torch.Tensor,
312+
k_max: int,
313+
device: str,
314+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int]:
315+
"""Closed-form bit-exact MXFP4 -> NVFP4 weight conversion.
316+
317+
Pins ``scale_2 = 2^(k_max - 8)`` and the per-block E4M3 scale to
318+
``2^(k_j - m)`` so the NVFP4 nibbles equal the source MXFP4 nibbles for
319+
every in-range block. ``k_max`` is shared across w1/w3 (fused GEMM1), so it
320+
is passed in rather than derived per tensor. The closed-form per-block amax
321+
(``6 * 2^k_j`` in range, data-derived out of range) is independent of
322+
``k_max``, so we reuse the GPT-OSS helper directly. Returns
323+
``(packed, weight_scale, weight_scale_2, n_blocks, n_lossless)``.
324+
"""
325+
bf16 = _dequantize_mxfp4_to_bf16(mxfp4_weight, mxfp4_scale, device)
326+
e8m0 = mxfp4_scale.to(bf16.device).contiguous().view(torch.uint8) # (out, nblk32)
327+
packed = mxfp4_weight.to(bf16.device).contiguous().view(torch.uint8) # (out, nblk32*16)
328+
blocks = packed.view(*packed.shape[:-1], e8m0.shape[-1], _MXFP4_BYTES_PER_BLOCK)
329+
per_block_amax = mxfp4_to_nvfp4_per_block_amax(blocks, e8m0) # (out, nblk16) fp32
330+
331+
m = k_max - E4M3_KMAX
332+
weight_scale_2 = torch.tensor(2.0**m, dtype=torch.float32, device=bf16.device).reshape(())
333+
per_block_scale = (
334+
(per_block_amax / (E2M1_MAX * weight_scale_2))
335+
.clamp(min=2**-9, max=E4M3_MAX)
336+
.to(torch.float8_e4m3fn)
337+
)
338+
339+
# Lossless accounting against the (possibly shared) k_max. A block is lossy
340+
# only if k_max - k_j > 17; all-zero blocks (e8m0 == 0) reconstruct to 0
341+
# regardless of scale and so are always lossless.
342+
k = e8m0.to(torch.int32) - E8M0_BIAS
343+
lossless = (k >= (k_max - (E4M3_KMAX - E4M3_KMIN))) | (e8m0 == 0)
344+
n_blocks = k.numel()
345+
n_lossless = int(lossless.sum().item())
346+
347+
q_tensor, weight_scale, _ = NVFP4QTensor.quantize(
348+
bf16, _NVFP4_BLOCK, per_block_scale, weight_scale_2, try_tensorrt=False
349+
)
350+
return q_tensor._quantized_data, weight_scale, weight_scale_2, n_blocks, n_lossless
351+
352+
236353
def _build_w13_weight_amax_overrides(
237354
f,
238355
expert_weight_keys: list[str],
@@ -279,6 +396,7 @@ def convert_shard(
279396
input_fallback: dict[str, torch.Tensor],
280397
device: str,
281398
stats: dict[str, int],
399+
cast: bool = False,
282400
) -> tuple[list[str], list[str]]:
283401
"""Rewrite one HF-style shard and return index deltas."""
284402
out: dict[str, torch.Tensor] = {}
@@ -289,9 +407,16 @@ def convert_shard(
289407
all_keys = list(f.keys())
290408
expert_weight_keys = [k for k in all_keys if _EXPERT_WEIGHT_RE.match(k)]
291409
expert_weight_key_set = set(expert_weight_keys)
292-
w13_weight_amax, w13_synth_paths = _build_w13_weight_amax_overrides(
293-
f, expert_weight_keys, amax, device
294-
)
410+
if cast:
411+
# Closed-form weight cast derives scales from the source E8M0
412+
# exponents, not from calibrated weight amax. w1/w3 share k_max.
413+
w13_kmax = _build_w13_kmax_overrides(f, expert_weight_keys, device)
414+
w13_weight_amax, w13_synth_paths = {}, set()
415+
else:
416+
w13_kmax = {}
417+
w13_weight_amax, w13_synth_paths = _build_w13_weight_amax_overrides(
418+
f, expert_weight_keys, amax, device
419+
)
295420
scale_siblings = {
296421
k[: -len(".weight")] + ".scale"
297422
for k in expert_weight_keys
@@ -335,9 +460,22 @@ def convert_shard(
335460

336461
w = f.get_tensor(key)
337462
s = f.get_tensor(scale_key)
338-
packed, weight_scale, weight_scale_2, weight_synth = _quantize_weight_nvfp4(
339-
w, s, weight_amax, device=device
340-
)
463+
if cast:
464+
k_max = w13_kmax.get(expert_path)
465+
if k_max is None:
466+
k_max = _kmax_from_mxfp4_scale(s, device)
467+
packed, weight_scale, weight_scale_2, n_blk, n_lossless = (
468+
_quantize_weight_nvfp4_lossless(w, s, k_max, device)
469+
)
470+
weight_synth = False
471+
stats["cast_blocks_total"] += n_blk
472+
stats["cast_blocks_lossless"] += n_lossless
473+
if n_lossless < n_blk:
474+
stats[f"cast_oor_tensors_{block_kind}"] += 1
475+
else:
476+
packed, weight_scale, weight_scale_2, weight_synth = _quantize_weight_nvfp4(
477+
w, s, weight_amax, device=device
478+
)
341479
input_scale = _amax_to_nvfp4_scale_2(input_amax).to(weight_scale_2.device)
342480

343481
out[key] = packed.cpu()
@@ -607,6 +745,17 @@ def main():
607745
action="store_true",
608746
help="replace an existing non-empty output checkpoint directory",
609747
)
748+
p.add_argument(
749+
"--cast_mxfp4_to_nvfp4",
750+
action="store_true",
751+
help=(
752+
"losslessly cast the source MXFP4 routed-expert weights to NVFP4 "
753+
"(pin scale_2 = 2^(k_max-8) and per-block scale = 2^(k_j-m) from the "
754+
"source E8M0 scales) instead of dequant + re-quant with calibrated "
755+
"weight amax. Only affects weights; input_scale still comes from "
756+
"--amax_path calibration."
757+
),
758+
)
610759
args = p.parse_args()
611760

612761
_validate_paths(args.source_ckpt, args.output_ckpt)
@@ -639,6 +788,7 @@ def main():
639788
input_fallback,
640789
args.device,
641790
stats,
791+
args.cast_mxfp4_to_nvfp4,
642792
)
643793
shard_updates[src.name] = (added, removed)
644794

@@ -647,6 +797,12 @@ def main():
647797
for k in sorted(stats.keys()):
648798
_log(f" {k:40s} {stats[k]}")
649799

800+
if args.cast_mxfp4_to_nvfp4:
801+
tot = stats.get("cast_blocks_total", 0)
802+
loss = stats.get("cast_blocks_lossless", 0)
803+
pct = 100.0 * loss / tot if tot else 100.0
804+
_log(f"[cast] lossless MXFP4->NVFP4 blocks: {loss}/{tot} ({pct:.4f}%)")
805+
650806
quantized: set[str] = set()
651807
for _added, _removed in shard_updates.values():
652808
for a in _added:

0 commit comments

Comments
 (0)