Skip to content

[feat] Init true on policy with qwen_moe#3

Draft
maocheng23 wants to merge 6 commits intofeat/true_on_policy_qwen_densefrom
feat/true_on_policy_qwen_moe
Draft

[feat] Init true on policy with qwen_moe#3
maocheng23 wants to merge 6 commits intofeat/true_on_policy_qwen_densefrom
feat/true_on_policy_qwen_moe

Conversation

@maocheng23
Copy link
Copy Markdown
Owner

@maocheng23 maocheng23 commented May 1, 2026

Summary

Adds the Qwen3-MoE true-on-policy contract on top of the dense stack. Stacked on feat/true_on_policy_qwen_dense — diff in this PR is only the MoE delta. This is fork-side review; upstream PR will retarget to sgl-project:main once dense lands.

This is one of three tightly-coupled MoE PRs that must land together — they share a single contract identifier qwen3_moe_true_on_policy_v1 defined by a vendored schema in each repo.

Companion PRs (MoE, must land in lockstep):

Stacked on dense (must land first):

Target

Bit-identical (exact-zero) logprob parity between the SGLang rollout engine and the Megatron trainer for every scored response token at TP=1/EP=4/CP=2 for Qwen3-30B-A3B (MoE).

Validated on H200 x8: train_rollout_logprob_abs_diff = 0.0 for 3 steps with both the full deterministic decode path and the prefill-only fast-decode path.

Design

The MoE contract reuses the dense three-layer architecture (Miles -> SGLang -> Megatron), adding a new qwen3_moe_true_on_policy_v1 contract object. SGLang remains the numerical source of truth — Megatron's training forward reproduces SGLang's MoE numerics in a differentiable surface that delegates the no-grad inference path to SGLang's fused-experts kernel.

New runtime-policy fields used to keep the rollout/train forward identical under EP:

  • deterministic_moe_routing — fp32 router pre-softmax, deterministic top-k tie-break
  • moe_topk_tiebreak — explicit policy hook for tie-break ordering
  • deterministic_moe_dispatch — deterministic permute (no permute fusion under contract)
  • deterministic_moe_combine — deterministic combine path
  • ep_invariant_moe — engages when ep_size > 1

In this PR (SGLang)

  • python/sglang/srt/true_on_policy/:
    • schema.py — adds QWEN3_MOE_TRUE_ON_POLICY_V1_SCHEMA and qwen3_moe_sglang_math kernel contract (byte-identical with Megatron and Miles copies)
    • contracts.py — adds QWEN3_MOE_TRUE_ON_POLICY_V1 registry entry; policy_for(server_args) derives ep_invariant_moe from ep_size > 1 and emits MoE-only fields when model_family == 'qwen3_moe'
    • config.py — adds prefill-only-deterministic helpers and per-forward-pass runtime-policy scope so decode CUDA-graph capture can run non-deterministic while logprob-recompute prefill stays under the contract
  • python/sglang/srt/tp_invariant_ops/ — exposes the deterministic K-block matmul under a stable name for MoE expert grouped GEMMs
  • python/sglang/srt/models/qwen3_moe.py:
    • QKV, Q/K, router, and expert inputs cast to projection/kernel weight dtype before graph-captured decode paths use them (fixes FP32->BF16 mismatch under fast decode)
    • MoE routing under contract uses fp32 pre-softmax with deterministic tie-break
  • python/sglang/srt/model_executor/model_runner.py, forward_batch_info.pypatch_prefill_only_deterministic_attention_backend(...) temporarily forces FA3 prefill num_splits=1 during deterministic prefill and restores afterwards; decode stays graph-captured and non-deterministic
  • python/sglang/srt/server_args.py--enable-prefill-only-deterministic-inference no longer auto-upgrades to full enable_deterministic_inference
  • python/sglang/srt/managers/tp_worker.py, layers/communicator.py, layers/dp_attention.py, layers/quantization/unquant.py, distributed/communication_op.py — minimal call-site changes to honour the per-forward runtime-policy scope

Validation

Remote H200 x8, container miles-maocheng, Qwen3-30B-A3B, TP=1/EP=4/CP=2:

Mode Step rollout_logp train_logp abs_diff grad_norm rollout tok/GPU/s
Full deterministic 0 -0.2502 -0.2502 0.0 0.0342 224.8
Full deterministic 1 -0.2335 -0.2335 0.0 0.0465 (post-update)
Fast decode (no fusion) 0 -0.2452 -0.2452 0.0 0.0391 1000.4
Fast decode (no fusion) 1 -0.2350 -0.2350 0.0 0.0318 990.7
Fast decode (no fusion) 2 -0.2467 -0.2467 0.0 0.0459 1010.6

Source: recovery/qwen3_moe_clean/journal/2026-04-30-moe-onpolicy-normal-validation.md.

CPU/unit tests:

  • test/registered/core/test_on_policy_wiring.py — extended with MoE coverage (graph-capture policy scope, qwen3_moe_attention_uses_dense_qk_dtype_contract, qwen3_moe_experts_use_weight_dtype_under_deterministic_routing, prefill-only flag wiring)
  • test/registered/core/test_tp_invariant_ops.py — extended for MoE grouped-GEMM under the deterministic K-block contract

Out of scope

  • Qwen3-Next contract — additive after this stack lands
  • Permute-fusion compatibility — current MoE contract requires MODEL_ARGS_DISABLE_MOE_PERMUTE_FUSION=1 (Miles enforces); a permute-fusion-equivalent deterministic path can land separately
  • Alignment harness extensions for MoE layer dumps — separate follow-up

Test plan

  • CPU unit tests pass in CI
  • GPU exact-zero E2E gate at TP=1/EP=4/CP=2 (full deterministic) — done locally, needs CI replay
  • GPU exact-zero E2E gate at TP=1/EP=4/CP=2 (fast decode, no fusion) — done locally, needs CI replay
  • 100-step on/off-policy comparison run

🤖 Generated with Claude Code

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deterministic documentation Improvements or additions to documentation npu quant

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant