Correctly trigger node limited routing in DeepSeek-V3 JAX path.#1891
Correctly trigger node limited routing in DeepSeek-V3 JAX path.#1891gpolovets1 wants to merge 6 commits intomainfrom
Conversation
DescriptionStart with a short description of what the PR does and how this is a change from The rest of the description includes relevant details and context, examples:
If the change fixes a Github issue, please include a link, e.g.,: TestsPlease describe how you tested this change, and include any instructions and/or ChecklistBefore submitting this PR, please make sure:
|
0babcc6 to
d160a78
Compare
…Also instead of applying sigmoids on the router outputs, passing the router assignments to the MoE layer and applying the activations on the router within the MoE layer. Signed-off-by: George Polovets <gpolovets@gmail.com>
…sn't being triggered. Signed-off-by: George Polovets <gpolovets@gmail.com>
…ring requant, correctly applying yarn scale, and using FP32 and correct order of operations in get_topk_indices (since it also applies activations and bias term). Signed-off-by: George Polovets <gpolovets@gmail.com>
…rge negative rather than 0. Signed-off-by: George Polovets <gpolovets@gmail.com>
…e factor to be applied on the expert weights instead of hidden_states. This improved perf by about 3% and now perf is within 1% of mainline but with much improved MMLU. Signed-off-by: George Polovets <gpolovets@gmail.com>
Signed-off-by: George Polovets <gpolovets@gmail.com>
d160a78 to
2515a0d
Compare
| e_sharding=P(None, ), | ||
| activation_ffw_td=(ShardingAxisName.MLP_DATA, None), | ||
| ed_sharding=(None, None), | ||
| e_sharding=(None, ), |
There was a problem hiding this comment.
I guess this line is unnecessary, and you can revert line 1078?
| x_TD = jax.lax.with_sharding_constraint(x_TD, | ||
| P(*self.activation_ffw_td)) | ||
|
|
||
| logits_TE = super().__call__(x_TD).astype(jnp.float32) |
There was a problem hiding this comment.
Could you add comment with reference of https://github.com/vllm-project/vllm/blob/e89a91d9275cd8ac086fe04476b41675a9ebbd5c/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py#L59 here?
| use_ep: bool, | ||
| activation: str, | ||
| scoring_fn: str, | ||
| topk_weights: jax.Array | None = None, |
There was a problem hiding this comment.
not sure if this change is necessary? wouldn't it be possible to make change DeepSeekV3Router to return a value that fused_moe_func expects?
There was a problem hiding this comment.
fused_moe_func applies global top-k routing whereas DeepSeek needs to use the custom grouped top-k routing.
There was a problem hiding this comment.
Would you prefer if I passed get_topk_func as an argument? It could replace jax.lax.top_k if passed and I think retain the same logical flow.
There was a problem hiding this comment.
so vllm has a concept called monolithic vs. non-monolithic. and i was wondering if we can leverage this api: https://github.com/vllm-project/vllm/blob/bdd8981dab8d8c6ae88a3f605d04ec5243088e5a/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py#L505-L525
Description
Previously, node limited routing was being skipped if fused MoE backends were being called.
Also fixed the following issues:
Tests
Locally tested that MMLU went up from 67 to 80 while maintaining perf to <1%.
Checklist
Before submitting this PR, please make sure: