Skip to content

Correctly trigger node limited routing in DeepSeek-V3 JAX path.#1891

Open
gpolovets1 wants to merge 6 commits intomainfrom
gpolovets/fix_ds_router_logic
Open

Correctly trigger node limited routing in DeepSeek-V3 JAX path.#1891
gpolovets1 wants to merge 6 commits intomainfrom
gpolovets/fix_ds_router_logic

Conversation

@gpolovets1
Copy link
Collaborator

@gpolovets1 gpolovets1 commented Mar 10, 2026

Description

Previously, node limited routing was being skipped if fused MoE backends were being called.
Also fixed the following issues:

  • Applying router activations just once (previously applied an extra time in fused_moe_func).
  • Masking expert scores to negative value instead of 0 before top-K.
  • Using float32 precision during topk selection (vLLM reference implementation is also doing this)

Tests

Locally tested that MMLU went up from 67 to 80 while maintaining perf to <1%.

Checklist

Before submitting this PR, please make sure:

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have made or will make corresponding changes to any relevant documentation.

@github-actions
Copy link

Description

Start with a short description of what the PR does and how this is a change from
the past.

The rest of the description includes relevant details and context, examples:

  • why is this change being made,
  • the problem being solved and any relevant context,
  • why this is a good solution,
  • some information about the specific implementation,
  • shortcomings of the solution and possible future improvements.

If the change fixes a Github issue, please include a link, e.g.,:
FIXES: #123456

Tests

Please describe how you tested this change, and include any instructions and/or
commands to reproduce.

Checklist

Before submitting this PR, please make sure:

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have made or will make corresponding changes to any relevant documentation.

@gpolovets1 gpolovets1 requested review from gxd3, kyuyeunk and lk-chen and removed request for bzgoogle and vipannalla March 10, 2026 00:03
@gpolovets1 gpolovets1 force-pushed the gpolovets/fix_ds_router_logic branch from 0babcc6 to d160a78 Compare March 10, 2026 00:21
…Also instead of applying sigmoids on the router outputs, passing the router assignments to the MoE layer and applying the activations on the router within the MoE layer.

Signed-off-by: George Polovets <gpolovets@gmail.com>
…sn't being triggered.

Signed-off-by: George Polovets <gpolovets@gmail.com>
…ring requant, correctly applying yarn scale, and using FP32 and correct order of operations in get_topk_indices (since it also applies activations and bias term).

Signed-off-by: George Polovets <gpolovets@gmail.com>
…rge negative rather than 0.

Signed-off-by: George Polovets <gpolovets@gmail.com>
…e factor to be applied on the expert weights instead of hidden_states. This improved perf by about 3% and now perf is within 1% of mainline but with much improved MMLU.

Signed-off-by: George Polovets <gpolovets@gmail.com>
Signed-off-by: George Polovets <gpolovets@gmail.com>
@gpolovets1 gpolovets1 force-pushed the gpolovets/fix_ds_router_logic branch from d160a78 to 2515a0d Compare March 10, 2026 00:24
@gpolovets1 gpolovets1 added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 10, 2026
e_sharding=P(None, ),
activation_ffw_td=(ShardingAxisName.MLP_DATA, None),
ed_sharding=(None, None),
e_sharding=(None, ),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this line is unnecessary, and you can revert line 1078?

x_TD = jax.lax.with_sharding_constraint(x_TD,
P(*self.activation_ffw_td))

logits_TE = super().__call__(x_TD).astype(jnp.float32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use_ep: bool,
activation: str,
scoring_fn: str,
topk_weights: jax.Array | None = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if this change is necessary? wouldn't it be possible to make change DeepSeekV3Router to return a value that fused_moe_func expects?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fused_moe_func applies global top-k routing whereas DeepSeek needs to use the custom grouped top-k routing.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you prefer if I passed get_topk_func as an argument? It could replace jax.lax.top_k if passed and I think retain the same logical flow.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so vllm has a concept called monolithic vs. non-monolithic. and i was wondering if we can leverage this api: https://github.com/vllm-project/vllm/blob/bdd8981dab8d8c6ae88a3f605d04ec5243088e5a/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py#L505-L525

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants