fix: fp32 master weights for custom MoE models under FSDP2#1896
Draft
fix: fp32 master weights for custom MoE models under FSDP2#1896
Conversation
The MoE code path silently dropped the FSDP2 MixedPrecisionPolicy configured
in FSDP2Config:
* _shard_ep_fsdp() never forwarded model_wrapper.mp_policy into
parallelize_fn, so apply_fsdp() fell back to a hardcoded
output_dtype=torch.bfloat16 default that disagreed with NeMo-RL's
intended output_dtype=torch.float32.
* fully_shard(moe_module.experts, ...) in the ep_shard branch was called
without an mp_policy, so ep-sharded expert weights were never cast to
the forward activation dtype during all-gather.
* For full-EP configs (ep_size == world_size, DP=1), the experts are
excluded from FSDP entirely (see ignored_params), so no FSDP wrapper
ever carries an mp_policy for them. With fp32-stored expert params
(e.g. under fp32 master weights) and bf16 activations propagated by
the block's mp_policy, grouped_gemm.ops.gmm / torch._grouped_mm then
crash with "Expected b.scalar_type() == torch::kBFloat16".
Changes:
* _transformers/infrastructure.py: forward model_wrapper.mp_policy into
parallelize_fn (overriding the moe_parallelizer mp_policy default when
the FSDP2Config policy is present).
* moe/parallelizer.py: pass mp_policy to the experts' own fully_shard
call so ep_shard-sharded experts honour the forward cast.
* moe/experts.py: in GroupedExpertsDeepEP.forward and
GroupedExperts.forward, cast expert weights (and biases) to the
activation dtype before the grouped GEMM. This is a no-op when
weights already match the activation dtype, and rescues the full-EP
case where no FSDP wrapper can carry mp_policy.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com>
initialize_model_weights() calls model.initialize_weights() with no args,
but several custom models (qwen3_next, qwen3_5_moe, ...) have a signature
of
initialize_weights(self, buffer_device=None, dtype=torch.bfloat16)
that ends with cast_model_to_dtype(self, dtype). That silently re-casts
every floating-point parameter back to bfloat16 on the first rank that
hits this code path, undoing fp32 initialization (e.g. for fp32 master
weights under FSDP2).
Infer the target dtype from the existing (floating-point) parameters and
pass it through when the model accepts a dtype kwarg, falling back to
the no-kwarg call for older signatures. This preserves the dtype chosen
at construction time (bf16 by default, fp32 when the user requested
fp32 master weights) without requiring every model's initialize_weights
to change.
The checkpoint-load path is unaffected: DCP copies tensors into the
model's existing parameters, so dtypes follow the model rather than the
checkpoint.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com>
initialize_linear_module / initialize_rms_norm_module default
dtype=torch.bfloat16, and MoEConfig.dtype defaults to torch.bfloat16.
Custom MoE models (qwen3_next, qwen3_5_moe) silently accepted those
defaults and ignored config.torch_dtype / torch.get_default_dtype(), so
attention projections, MoE experts, MLP, lm_head and RMSNorm were all
bf16 even when the user requested torch_dtype=float32 (e.g. for fp32
master weights under FSDP2).
Thread the model dtype explicitly from config.torch_dtype at every
helper call site:
* qwen3_next/layers.py Qwen3NextAttention: pass dtype to q/k/v/o_proj.
* qwen3_next/model.py Block: pass dtype to dense MLP; Qwen3NextModel:
pass dtype to MoEConfig, embed_tokens, RMSNorm; Qwen3NextForCausalLM:
pass dtype to lm_head and state_dict_adapter.
* qwen3_5_moe/model.py Qwen3_5MoeTextModelBackend: pass dtype to
MoEConfig and embed_tokens. Qwen3_5MoeForConditionalGeneration:
pass dtype to lm_head.
Also fix a nested-sub-config dtype leak in the VL class. _init_model()
overrides only the top-level hf_config.torch_dtype; for VL configs like
Qwen3_5MoeConfig the nested text_config / vision_config keep their
original (typically bf16) torch_dtype from the checkpoint's config.json.
Without the fix below, the text backend would then read
text_config.torch_dtype=bf16 while HF-native submodules (e.g.
CPAwareGatedDeltaNet, constructed via super().__init__(config) inside
local_torch_dtype(fp32)) used get_default_dtype()=fp32, producing a
mixed-dtype state that crashed with a generic CuBLAS Error at the second
grouped GEMM during validation.
Propagate the user-requested dtype to every nested sub-config exposing
a torch_dtype attribute, before calling super().__init__(config), so the
HF parent, the text backend, and any HF vision / multimodal code that
reads sub-config torch_dtype all agree.
Rationale for threading instead of flipping helper defaults: keeping
initialize_linear_module / MoEConfig bf16 defaults preserves the
existing API contract for third-party callers that construct models
directly (no local_torch_dtype() wrapper) — PyTorch's global default is
fp32, which would otherwise silently make projection modules fp32 while
embeddings stayed bf16.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com>
Extend the explicit torch_dtype threading pattern established for
qwen3_next / qwen3_5_moe to every other text-only custom MoE (and the
shared MLA attention) so fp32 master weights work consistently across
the full set of custom implementations, regardless of whether the caller
wrapped construction in local_torch_dtype().
Pattern applied per model:
* Attention layers.py: resolve dtype from config.torch_dtype once and
thread it into initialize_linear_module / initialize_rms_norm_module
for every q/k/v/o_proj and q_norm/k_norm call.
* model.py Block: pass dtype to the dense-path MLP(...) and the
per-block RMSNorms.
* model.py Model: resolve model_dtype once and thread it into
MoEConfig(dict dtype=...), nn.Embedding(dtype=...), and the final
initialize_rms_norm_module(dtype=...).
* model.py ForCausalLM: thread model_dtype into initialize_linear_module
for lm_head and into the state_dict_adapter dtype arg.
Models updated:
* qwen3_moe
* gpt_oss
* minimax_m2
* glm4_moe, glm4_moe_lite, glm_moe_dsa
* deepseek_v3, deepseek_v32 (both share MLA; v3.2 also has Indexer)
* step3p5
* nemotron_v3 (final model norm + attention projections)
All changes are purely call-site threading; helper defaults
(initialize_linear_module / initialize_rms_norm_module / MoEConfig.dtype)
remain bfloat16 so third-party direct construction keeps its previous
behaviour.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com>
Extend the fp32-master-weights fix pattern to the remaining VL MoE
models. These all share the same VL-specific gotcha: _init_model() only
overrides the top-level hf_config.torch_dtype, but the real params live
under nested sub-configs (e.g. config.text_config, config.vision_config,
or config.thinker_config.text_config). Those nested torch_dtype values
retain the checkpoint's original bf16, leaving a mixed-dtype state when
our text backend reads sub_config.torch_dtype while HF-native submodules
pick up get_default_dtype() from local_torch_dtype(). The symptom is
the same as qwen3_5_moe: a generic CuBLAS Error at the second grouped
GEMM during MoE forward.
Pattern applied per VL wrapper:
* Before super().__init__(config), walk vars(config) and propagate
config.torch_dtype to every nested sub-config that exposes a
torch_dtype attribute. For Omni's two-level nesting
(config.thinker_config.text_config) this walk is performed at both
levels.
* Text backend __init__: resolve model_dtype once from
config.torch_dtype and thread it into MoEConfig(dict), nn.Embedding,
and the final initialize_rms_norm_module(dtype=...).
* VL ForConditionalGeneration: thread model_dtype into lm_head's
initialize_linear_module and the state_dict_adapter dtype arg.
Models updated:
* qwen3_vl_moe
* qwen3_omni_moe (two-level sub-config walk)
* gemma4_moe (also removes a dead get_dtype(...) call that was
computing a value and throwing it away; the value now drives MoE
expert dtype)
* mistral4 (top-level Mistral4Model + Mistral4TextModelBackend VL
wrapper + Mistral3ForConditionalGeneration sub-config propagation)
HF-native submodules (attention, Gemma4MLP, Gemma4RMSNorm,
Gemma4TextScaledWordEmbedding, Mistral3MultiModalProjector, vision
tower) continue to inherit their dtype from local_torch_dtype() via
torch.get_default_dtype(), so no changes are needed there.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com>
Contributor
|
/ok to test 16f122a |
Contributor
|
/ok to test aa4aeb4 |
Contributor
|
/ok to test 24db51e |
After threading `get_dtype(getattr(config, "torch_dtype", None), torch.bfloat16)`
into every custom MoE model's __init__, a number of pre-existing unit tests
started failing with
TypeError: can only concatenate str (not "Mock") to str
The tests build their `config` via `unittest.mock.Mock()`, so
`getattr(config, "torch_dtype", None)` returns an auto-generated Mock rather
than None, which dtype_from_str() then tried to concatenate inside
"torch." + val.lower().
Real config objects only ever carry torch.dtype, str, or the attribute
simply isn't set. Treat anything else (Mock, or other unexpected types) as
"no explicit dtype" and fall back to the supplied default, instead of
crashing with a cryptic TypeError deep inside string coercion.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com>
Contributor
Author
|
/ok to test 7608e15 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do ?
Restore FP32 master weights on the custom-MoE FSDP2 path. Previously every custom MoE model silently collapsed to BF16 at the optimizer, causing slower convergence at low learning rates vs. MCore and Automodel + TE FusedAdam (
master_weights=True).Changelog
Five commits, each addressing one layer of silent bf16 casting. Full per-bug analysis in the commit messages.
fix(moe): plumb mp_policy end-to-end and cast expert weights in forward—_transformers/infrastructure.pyforwardsFSDP2Config.mp_policyinto the MoEparallelize_fn;moe/parallelizer.pypasses it tofully_shard(experts, ...);moe/experts.py(GroupedExpertsDeepEP/GroupedExperts) casts expert weights to the activation dtype before each grouped GEMM, covering the full-EP case where experts are excluded from FSDP entirely.fix(checkpoint): preserve model param dtype across initialize_weights()— infer the target dtype from existing params somodel.initialize_weights()no longer re-casts fp32 params to bf16 via its defaultdtype=torch.bfloat16.fix(models): thread config.torch_dtype explicitly in custom MoE models—qwen3_next,qwen3_5_moe(also propagatestorch_dtypeto the VLtext_config/vision_configsub-configs beforesuper().__init__(config)).fix(models): thread dtype in remaining text-only custom MoE models—qwen3_moe,gpt_oss,minimax_m2,glm4_moe,glm4_moe_lite,glm_moe_dsa,deepseek_v3,deepseek_v32,step3p5,nemotron_v3.fix(models): thread dtype + propagate to sub-configs for VL MoE models—qwen3_vl_moe,qwen3_omni_moe,gemma4_moe,mistral4.Helper defaults (
initialize_linear_module/initialize_rms_norm_module/MoEConfig.dtype) are intentionally left atbfloat16to keep the implicit API contract for third-party direct construction. Pure-bf16 runs are unaffected (every threadeddtyperesolves tobf16whenconfig.torch_dtypeisn't overridden, and the expert-forward cast is a no-op).Before your PR is "Ready for review"
Pre checks:
torch_dtype=float32would be a good follow-up.Verification: end-to-end on
qwen3_5_moe+qwen3_nextvia NeMo-RL SFT (2×8 GPUs, EP=16) — kernel error resolved,val_lossnow tracks MCore / TE-FusedAdam runs. Remaining models covered by the same mechanical pattern but not individually re-run.Additional Information