Skip to content

feat(recipe): DSV3 GB200 MXFP8 full-iter CG recipe#4226

Merged
dingqingy-nv merged 2 commits into
NVIDIA-NeMo:mainfrom
dingqingy-nv:dsv3-gb200-mxfp8-fullcg
Jun 10, 2026
Merged

feat(recipe): DSV3 GB200 MXFP8 full-iter CG recipe#4226
dingqingy-nv merged 2 commits into
NVIDIA-NeMo:mainfrom
dingqingy-nv:dsv3-gb200-mxfp8-fullcg

Conversation

@dingqingy-nv

@dingqingy-nv dingqingy-nv commented Jun 9, 2026

Copy link
Copy Markdown
Contributor

Summary

Mirrors the GB300 MXFP8 full-iter CG recipe (PR #3983) on GB200. Recipe shape matches GB300 except for recompute_modules=["mla_up_proj"] to fit GB200's smaller HBM budget; GB300 can run with recompute_modules=[].

Changes

scripts/performance/configs/deepseek/deepseek_workload_base_configs.py

  • DEEPSEEK_V3_PRETRAIN_CONFIG_GB200_FP8_MX_V1 no longer aliases the bf16 V1 config. It now enables:
    • cuda_graph_impl=\"full_iteration\", cuda_graph_scope=[]
    • moe_a2a_overlap=True
    • cutedsl_fused_grouped_mlp=True
    • recompute_modules=[\"mla_up_proj\"]
  • V2 (GBS=4096) and VR200 mxfp8 V1/V2 inherit via the existing replace(..., global_batch_size=4096) / alias chain.

scripts/performance/configs/deepseek/deepseek_llm_pretrain.py

Measured impact

64 GB200 nodes / 256 GPUs, DSv3-671B mxfp8, GBS=4096 (V2). Steady-state iters 14-19, averaged.

variant iter (s) TF/s/GPU
partial CG + `[core_attn]` offload + `[mlp]` recompute 13.89 1226
full CG + `[core_attn,attn_proj]` offload + no recompute 15.10 1128
full CG + no offload + `[mla_up_proj]` recompute (this PR) 12.94 1316

The third row is what this recipe ships: +7.3% throughput vs the prior partial-CG baseline, and a close match to the MLPerf reference config (~14 s/iter on similar parallelism shape).

Test plan

  • Empirical perf measured at 64 GB200 nodes, 256 GPUs (see iter-time / TF/s/GPU table above)
  • Loss curve healthy through iter 20 (iter 20 lm loss 8.149, mtp_1 0.081, grad norm 0.391)
  • `ruff check` and `ruff format --check` clean on both files
  • NeMo CI L0 (will run on push)

Mirror the GB300 MXFP8 full-iter CG recipe (PR NVIDIA-NeMo#3983) on GB200, with
mla_up_proj recompute substituted for the no-recompute / no-offload
strategy GB300 can afford with its larger HBM budget.

GB200 mxfp8 V1 (and by inheritance V2 / VR200 mxfp8 V1/V2):
- cuda_graph_impl=full_iteration, cuda_graph_scope=[]
- moe_a2a_overlap=True, cutedsl_fused_grouped_mlp=True
- recompute_modules=["mla_up_proj"]
- fp8_output_proj=True gated on mxfp8 recipe in
  deepseek_v3_pretrain_config_gb200 (mirrors GB300 gate)

Measured on 64 GB200 nodes / 256 GPUs, DSv3-671B mxfp8, GBS=4096:

| variant                                              | iter (s) | TF/s/GPU |
|------------------------------------------------------|---------:|---------:|
| partial CG + [core_attn] offload + [mlp] recompute   |   13.89  |    1226  |
| full CG  + [core_attn,attn_proj] offload + no recomp |   15.10  |    1128  |
| full CG  + no offload + [mla_up_proj] recompute      |   12.94  |    1316  |

The third row is what this recipe ships: +7.3% throughput vs the
partial-CG baseline (and the closest match to the MLPerf reference
config that runs ~14 s/iter).

Signed-off-by: Dingqing Yang <dingqingy@nvidia.com>
@yaoyu-33 yaoyu-33 added area:perf Performance optimizations and benchmarking feature New capabilities, enhancements, or enablement work needs-review PR is ready for code review and waiting on a reviewer labels Jun 9, 2026
malay-nagda
malay-nagda previously approved these changes Jun 9, 2026
V2 was aliased directly to GB200_V2 (transformer_engine CG), bypassing
the full-iteration CG / a2a overlap / cutedsl-fused-grouped-mlp /
mla_up_proj-recompute overrides on FP8_MX_V1. Restore the
replace(FP8_MX_V1, global_batch_size=4096) form so CONFIG_VARIANT=v2
exercises the same recipe as v1, with the V2 batch size.

Signed-off-by: Dingqing Yang <dingqingy@nvidia.com>
@dingqingy-nv

Copy link
Copy Markdown
Contributor Author

/claude review

Comment on lines +98 to +105
DEEPSEEK_V3_PRETRAIN_CONFIG_GB200_FP8_MX_V1 = replace(
DEEPSEEK_V3_PRETRAIN_CONFIG_GB200_V1,
cuda_graph_impl="full_iteration",
cuda_graph_scope=[],
moe_a2a_overlap=True,
cutedsl_fused_grouped_mlp=True,
recompute_modules=["mla_up_proj"],
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: GB300 MX V1 sets fp8_dot_product_attention=True (line 74) but the new GB200 MX V1 does not. The PR description says "Recipe shape matches GB300 except for recompute_modules=["mla_up_proj"]" — is the omission of fp8_dot_product_attention intentional (e.g. incompatible with GB200's PP=4/VP=4 layout), or should it be added here?

@claude

claude Bot commented Jun 9, 2026

Copy link
Copy Markdown
Contributor

Light Code Review

Clean, well-scoped change. The GB200 MX V1 config correctly mirrors the GB300 full-iter CG recipe shape with the expected recompute_modules adjustment for GB200 smaller HBM, and V2 properly inherits from V1 with only global_batch_size=4096 changed. The fp8_output_proj gate in the GB200 config function mirrors the existing GB300 gate.

Questions

  1. Missing fp8_dot_product_attention=True vs GB300 (inline comment posted) - DEEPSEEK_V3_PRETRAIN_CONFIG_GB300_FP8_MX_V1 sets fp8_dot_product_attention=True (line 74) but the new GB200_FP8_MX_V1 does not. The PR description states the recipe matches GB300 except for recompute_modules. Is this omission intentional?

  2. VR200 missing fp8_output_proj gate (pre-existing, but now more impactful) - DEEPSEEK_V3_PRETRAIN_CONFIG_VR200_FP8_MX_V2 (line 218) aliases GB200_FP8_MX_V2, which now carries full-iter CG + mxfp8 settings. However, deepseek_v3_pretrain_config_vr200() (line 144) lacks the fp8_output_proj=True gate that both GB300 (line 81) and GB200 (line 121) have. If fp8_output_proj is needed for correctness with mxfp8, VR200 MX V2 would be missing it. This is pre-existing but worth noting since VR200 MX V2 now actually uses the mxfp8-tuned config rather than a plain alias.

Suggested test cases

No perf tests impacted. The changed configs (DEEPSEEK_V3_PRETRAIN_CONFIG_GB200_FP8_MX_V1, V2, and transitive VR200_FP8_MX_V2) are 256-GPU / 128-GPU scale perf recipes with no corresponding L0/L1/L2 functional test launch scripts. The existing test_deepseek_v3_perf_config_instantiation unit test only covers H100 BF16. Consider extending it to cover GB200 mxfp8 instantiation to catch config wiring issues in CI.

@dingqingy-nv dingqingy-nv merged commit afb5dd4 into NVIDIA-NeMo:main Jun 10, 2026
86 checks passed
svcnvidia-nemo-ci pushed a commit that referenced this pull request Jun 10, 2026
Signed-off-by: Dingqing Yang <dingqingy@nvidia.com>
ko3n1g pushed a commit that referenced this pull request Jun 10, 2026
#4187)

Signed-off-by: Dingqing Yang <dingqingy@nvidia.com>
Signed-off-by: NeMo Bot <nemo-bot@nvidia.com>
Co-authored-by: Dingqing Yang <dingqingy@nvidia.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com>
Co-authored-by: malay-nagda <malayn@nvidia.com>
vasunvidia pushed a commit to vasunvidia/Megatron-Bridge that referenced this pull request Jun 10, 2026
Signed-off-by: Dingqing Yang <dingqingy@nvidia.com>
Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:perf Performance optimizations and benchmarking feature New capabilities, enhancements, or enablement work needs-review PR is ready for code review and waiting on a reviewer

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants