Skip to content

fix: fp32 master weights for custom MoE models under FSDP2#1896

Draft
zpqiu wants to merge 10 commits intomainfrom
zpqiu/fp32-master-weights-custom-moe
Draft

fix: fp32 master weights for custom MoE models under FSDP2#1896
zpqiu wants to merge 10 commits intomainfrom
zpqiu/fp32-master-weights-custom-moe

Conversation

@zpqiu
Copy link
Copy Markdown
Contributor

@zpqiu zpqiu commented Apr 17, 2026

What does this PR do ?

Restore FP32 master weights on the custom-MoE FSDP2 path. Previously every custom MoE model silently collapsed to BF16 at the optimizer, causing slower convergence at low learning rates vs. MCore and Automodel + TE FusedAdam (master_weights=True).

Changelog

Five commits, each addressing one layer of silent bf16 casting. Full per-bug analysis in the commit messages.

  • fix(moe): plumb mp_policy end-to-end and cast expert weights in forward_transformers/infrastructure.py forwards FSDP2Config.mp_policy into the MoE parallelize_fn; moe/parallelizer.py passes it to fully_shard(experts, ...); moe/experts.py (GroupedExpertsDeepEP / GroupedExperts) casts expert weights to the activation dtype before each grouped GEMM, covering the full-EP case where experts are excluded from FSDP entirely.
  • fix(checkpoint): preserve model param dtype across initialize_weights() — infer the target dtype from existing params so model.initialize_weights() no longer re-casts fp32 params to bf16 via its default dtype=torch.bfloat16.
  • fix(models): thread config.torch_dtype explicitly in custom MoE modelsqwen3_next, qwen3_5_moe (also propagates torch_dtype to the VL text_config / vision_config sub-configs before super().__init__(config)).
  • fix(models): thread dtype in remaining text-only custom MoE modelsqwen3_moe, gpt_oss, minimax_m2, glm4_moe, glm4_moe_lite, glm_moe_dsa, deepseek_v3, deepseek_v32, step3p5, nemotron_v3.
  • fix(models): thread dtype + propagate to sub-configs for VL MoE modelsqwen3_vl_moe, qwen3_omni_moe, gemma4_moe, mistral4.

Helper defaults (initialize_linear_module / initialize_rms_norm_module / MoEConfig.dtype) are intentionally left at bfloat16 to keep the implicit API contract for third-party direct construction. Pure-bf16 runs are unaffected (every threaded dtype resolves to bf16 when config.torch_dtype isn't overridden, and the expert-forward cast is a no-op).

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests? — no new tests; a targeted regression asserting optimizer params are fp32 when torch_dtype=float32 would be a good follow-up.
  • Did you add or update any necessary documentation? — internal fix, no public API change.

Verification: end-to-end on qwen3_5_moe + qwen3_next via NeMo-RL SFT (2×8 GPUs, EP=16) — kernel error resolved, val_loss now tracks MCore / TE-FusedAdam runs. Remaining models covered by the same mechanical pattern but not individually re-run.

Additional Information

  • No related issue; surfaced during an internal NeMo-RL SFT convergence comparison across Automodel / MCore / TE-FusedAdam paths.

zpqiu and others added 5 commits April 17, 2026 03:26
The MoE code path silently dropped the FSDP2 MixedPrecisionPolicy configured
in FSDP2Config:

  * _shard_ep_fsdp() never forwarded model_wrapper.mp_policy into
    parallelize_fn, so apply_fsdp() fell back to a hardcoded
    output_dtype=torch.bfloat16 default that disagreed with NeMo-RL's
    intended output_dtype=torch.float32.
  * fully_shard(moe_module.experts, ...) in the ep_shard branch was called
    without an mp_policy, so ep-sharded expert weights were never cast to
    the forward activation dtype during all-gather.
  * For full-EP configs (ep_size == world_size, DP=1), the experts are
    excluded from FSDP entirely (see ignored_params), so no FSDP wrapper
    ever carries an mp_policy for them. With fp32-stored expert params
    (e.g. under fp32 master weights) and bf16 activations propagated by
    the block's mp_policy, grouped_gemm.ops.gmm / torch._grouped_mm then
    crash with "Expected b.scalar_type() == torch::kBFloat16".

Changes:

  * _transformers/infrastructure.py: forward model_wrapper.mp_policy into
    parallelize_fn (overriding the moe_parallelizer mp_policy default when
    the FSDP2Config policy is present).
  * moe/parallelizer.py: pass mp_policy to the experts' own fully_shard
    call so ep_shard-sharded experts honour the forward cast.
  * moe/experts.py: in GroupedExpertsDeepEP.forward and
    GroupedExperts.forward, cast expert weights (and biases) to the
    activation dtype before the grouped GEMM. This is a no-op when
    weights already match the activation dtype, and rescues the full-EP
    case where no FSDP wrapper can carry mp_policy.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com>
initialize_model_weights() calls model.initialize_weights() with no args,
but several custom models (qwen3_next, qwen3_5_moe, ...) have a signature
of

    initialize_weights(self, buffer_device=None, dtype=torch.bfloat16)

that ends with cast_model_to_dtype(self, dtype). That silently re-casts
every floating-point parameter back to bfloat16 on the first rank that
hits this code path, undoing fp32 initialization (e.g. for fp32 master
weights under FSDP2).

Infer the target dtype from the existing (floating-point) parameters and
pass it through when the model accepts a dtype kwarg, falling back to
the no-kwarg call for older signatures. This preserves the dtype chosen
at construction time (bf16 by default, fp32 when the user requested
fp32 master weights) without requiring every model's initialize_weights
to change.

The checkpoint-load path is unaffected: DCP copies tensors into the
model's existing parameters, so dtypes follow the model rather than the
checkpoint.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com>
initialize_linear_module / initialize_rms_norm_module default
dtype=torch.bfloat16, and MoEConfig.dtype defaults to torch.bfloat16.
Custom MoE models (qwen3_next, qwen3_5_moe) silently accepted those
defaults and ignored config.torch_dtype / torch.get_default_dtype(), so
attention projections, MoE experts, MLP, lm_head and RMSNorm were all
bf16 even when the user requested torch_dtype=float32 (e.g. for fp32
master weights under FSDP2).

Thread the model dtype explicitly from config.torch_dtype at every
helper call site:

  * qwen3_next/layers.py Qwen3NextAttention: pass dtype to q/k/v/o_proj.
  * qwen3_next/model.py Block: pass dtype to dense MLP; Qwen3NextModel:
    pass dtype to MoEConfig, embed_tokens, RMSNorm; Qwen3NextForCausalLM:
    pass dtype to lm_head and state_dict_adapter.
  * qwen3_5_moe/model.py Qwen3_5MoeTextModelBackend: pass dtype to
    MoEConfig and embed_tokens. Qwen3_5MoeForConditionalGeneration:
    pass dtype to lm_head.

Also fix a nested-sub-config dtype leak in the VL class. _init_model()
overrides only the top-level hf_config.torch_dtype; for VL configs like
Qwen3_5MoeConfig the nested text_config / vision_config keep their
original (typically bf16) torch_dtype from the checkpoint's config.json.
Without the fix below, the text backend would then read
text_config.torch_dtype=bf16 while HF-native submodules (e.g.
CPAwareGatedDeltaNet, constructed via super().__init__(config) inside
local_torch_dtype(fp32)) used get_default_dtype()=fp32, producing a
mixed-dtype state that crashed with a generic CuBLAS Error at the second
grouped GEMM during validation.

Propagate the user-requested dtype to every nested sub-config exposing
a torch_dtype attribute, before calling super().__init__(config), so the
HF parent, the text backend, and any HF vision / multimodal code that
reads sub-config torch_dtype all agree.

Rationale for threading instead of flipping helper defaults: keeping
initialize_linear_module / MoEConfig bf16 defaults preserves the
existing API contract for third-party callers that construct models
directly (no local_torch_dtype() wrapper) — PyTorch's global default is
fp32, which would otherwise silently make projection modules fp32 while
embeddings stayed bf16.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com>
Extend the explicit torch_dtype threading pattern established for
qwen3_next / qwen3_5_moe to every other text-only custom MoE (and the
shared MLA attention) so fp32 master weights work consistently across
the full set of custom implementations, regardless of whether the caller
wrapped construction in local_torch_dtype().

Pattern applied per model:

  * Attention layers.py: resolve dtype from config.torch_dtype once and
    thread it into initialize_linear_module / initialize_rms_norm_module
    for every q/k/v/o_proj and q_norm/k_norm call.
  * model.py Block: pass dtype to the dense-path MLP(...) and the
    per-block RMSNorms.
  * model.py Model: resolve model_dtype once and thread it into
    MoEConfig(dict dtype=...), nn.Embedding(dtype=...), and the final
    initialize_rms_norm_module(dtype=...).
  * model.py ForCausalLM: thread model_dtype into initialize_linear_module
    for lm_head and into the state_dict_adapter dtype arg.

Models updated:

  * qwen3_moe
  * gpt_oss
  * minimax_m2
  * glm4_moe, glm4_moe_lite, glm_moe_dsa
  * deepseek_v3, deepseek_v32 (both share MLA; v3.2 also has Indexer)
  * step3p5
  * nemotron_v3 (final model norm + attention projections)

All changes are purely call-site threading; helper defaults
(initialize_linear_module / initialize_rms_norm_module / MoEConfig.dtype)
remain bfloat16 so third-party direct construction keeps its previous
behaviour.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com>
Extend the fp32-master-weights fix pattern to the remaining VL MoE
models. These all share the same VL-specific gotcha: _init_model() only
overrides the top-level hf_config.torch_dtype, but the real params live
under nested sub-configs (e.g. config.text_config, config.vision_config,
or config.thinker_config.text_config). Those nested torch_dtype values
retain the checkpoint's original bf16, leaving a mixed-dtype state when
our text backend reads sub_config.torch_dtype while HF-native submodules
pick up get_default_dtype() from local_torch_dtype(). The symptom is
the same as qwen3_5_moe: a generic CuBLAS Error at the second grouped
GEMM during MoE forward.

Pattern applied per VL wrapper:

  * Before super().__init__(config), walk vars(config) and propagate
    config.torch_dtype to every nested sub-config that exposes a
    torch_dtype attribute. For Omni's two-level nesting
    (config.thinker_config.text_config) this walk is performed at both
    levels.
  * Text backend __init__: resolve model_dtype once from
    config.torch_dtype and thread it into MoEConfig(dict), nn.Embedding,
    and the final initialize_rms_norm_module(dtype=...).
  * VL ForConditionalGeneration: thread model_dtype into lm_head's
    initialize_linear_module and the state_dict_adapter dtype arg.

Models updated:

  * qwen3_vl_moe
  * qwen3_omni_moe (two-level sub-config walk)
  * gemma4_moe (also removes a dead get_dtype(...) call that was
    computing a value and throwing it away; the value now drives MoE
    expert dtype)
  * mistral4 (top-level Mistral4Model + Mistral4TextModelBackend VL
    wrapper + Mistral3ForConditionalGeneration sub-config propagation)

HF-native submodules (attention, Gemma4MLP, Gemma4RMSNorm,
Gemma4TextScaledWordEmbedding, Mistral3MultiModalProjector, vision
tower) continue to inherit their dtype from local_torch_dtype() via
torch.get_default_dtype(), so no changes are needed there.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Apr 17, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@zpqiu zpqiu changed the title fix(moe): fp32 master weights for custom MoE models under FSDP2 fix: fp32 master weights for custom MoE models under FSDP2 Apr 17, 2026
@akoumpa
Copy link
Copy Markdown
Contributor

akoumpa commented Apr 18, 2026

/ok to test 16f122a

@akoumpa
Copy link
Copy Markdown
Contributor

akoumpa commented Apr 19, 2026

/ok to test aa4aeb4

@akoumpa
Copy link
Copy Markdown
Contributor

akoumpa commented Apr 19, 2026

/ok to test 24db51e

After threading `get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16)`
into every custom MoE model's __init__, a number of pre-existing unit tests
started failing with

    TypeError: can only concatenate str (not "Mock") to str

The tests build their `config` via `unittest.mock.Mock()`, so
`getattr(config, "torch_dtype", None)` returns an auto-generated Mock rather
than None, which dtype_from_str() then tried to concatenate inside
"torch." + val.lower().

Real config objects only ever carry torch.dtype, str, or the attribute
simply isn't set. Treat anything else (Mock, or other unexpected types) as
"no explicit dtype" and fall back to the supplied default, instead of
crashing with a cryptic TypeError deep inside string coercion.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com>
@zpqiu
Copy link
Copy Markdown
Contributor Author

zpqiu commented Apr 20, 2026

/ok to test 7608e15

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants