Skip to content

fix(moe): align EP expert weight dtype with activation dtype#1913

Open
jQizhang wants to merge 2 commits intoNVIDIA-NeMo:mainfrom
jQizhang:fix/moe-expert-dtype-1863
Open

fix(moe): align EP expert weight dtype with activation dtype#1913
jQizhang wants to merge 2 commits intoNVIDIA-NeMo:mainfrom
jQizhang:fix/moe-expert-dtype-1863

Conversation

@jQizhang
Copy link
Copy Markdown
Contributor

What does this PR do ?

Fix issue #1863 : EP-sharded MoE expert weights stay in fp32 while the surrounding block's activations are cast to bf16 by FSDP2's MixedPrecisionPolicy, causing grouped_mm to crash with a dtype mismatch. Aligns expert weight dtype to the input activation dtype at the call site.

Changelog

  • nemo_automodel/components/moe/experts.py:
    • GroupedExperts.forward (around L315): cast gate_and_up_projs and down_projs to x.dtype before the EP all-gather, so they match whatever dtype the surrounding block passed in.
    • GroupedExpertsDeepEP.forward (around L696): cast gate_and_up_projs and down_projs to permuted_local_hidden_states.dtype right after .to_local(), so torch._grouped_mm / ops.gmm receive same-dtype operands.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?

If you haven't finished some of the above items you can still open "Draft" PR.

Additional Information

…NeMo#1863)

FSDP2's MixedPrecisionPolicy only casts params on its wrap mesh, so
cross-mesh DTensors (EP-sharded MoE experts) stay in fp32 while the
surrounding block's activations are cast to bf16. grouped_mm then
raises `Expected b.scalar_type() == torch::kBFloat16 to be true, but
got false` (see NVIDIA-NeMo#1863).

Fix: in `GroupedExperts.forward` and `GroupedExpertsDeepEP.forward`,
cast the local expert weights to the input activation dtype right
before grouped_mm. This matches grouped_mm's own requirement that
both operands share a dtype. The .data is not mutated, so fp32
master weights remain available to the optimizer.

Validation: smoke tests on NeMo-RL GRPO 1n8g + Gemma4 26B-A4B MoE
against Automodel main @ bd942f2.
  * Without this patch: DTensorPolicyWorkerV2.get_logprobs() crashes
    at grouped_gemm.backend.gmm with the exact NVIDIA-NeMo#1863 error string.
  * With this patch: 3 GRPO steps complete, loss -0.007/0.008/-0.014,
    reward 0.70/0.60/0.84.

Signed-off-by: larkzhang-nv <larkz@nvidia.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Apr 20, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@ZhiyuLi-Nvidia
Copy link
Copy Markdown
Contributor

/ok to test 31f4157

Copy link
Copy Markdown
Contributor

@ZhiyuLi-Nvidia ZhiyuLi-Nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @jQizhang. LGTM!

@ZhiyuLi-Nvidia
Copy link
Copy Markdown
Contributor

@hemildesai could you also take a review?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants