-
Notifications
You must be signed in to change notification settings - Fork 123
Correctly trigger node limited routing in DeepSeek-V3 JAX path. #1891
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
bbc0645
eee1cb6
8a2bf1a
3418718
b8762ea
2515a0d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -527,7 +527,7 @@ def load_weights(self, weights): | |
| self.weight, | ||
| self.weight_scale_inv, | ||
| (0, 1), | ||
| block_size=self.quant_config.weight_block_size, | ||
| block_size=None, | ||
| ).T | ||
| A, N, qk_nope_head_dim, v_head_dim = self.mla_layer.kv_lora_rank, self.mla_layer.N, self.mla_layer.qk_nope_head_dim, self.mla_layer.v_head_dim | ||
| if dequantized_weight.shape != (A, N * | ||
|
|
@@ -841,6 +841,7 @@ def __init__(self, | |
| num_expert_parallelism, | ||
| moe_backend, | ||
| quant_config, | ||
| scoring_func, | ||
| rng, | ||
| prefix: str = ""): | ||
|
|
||
|
|
@@ -855,9 +856,10 @@ def __init__(self, | |
| routed_scaling_factor=routed_scaling_factor, | ||
| dtype=dtype, | ||
| moe_backend=moe_backend, | ||
| activation_ffw_td=P(ShardingAxisName.MLP_DATA, None), | ||
| ed_sharding=P(None, None), | ||
| e_sharding=P(None, ), | ||
| activation_ffw_td=(ShardingAxisName.MLP_DATA, None), | ||
| ed_sharding=(None, None), | ||
| e_sharding=(None, ), | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess this line is unnecessary, and you can revert line 1078? |
||
| scoring_func=scoring_func, | ||
| quant_config=quant_config) | ||
|
|
||
| # shared experts | ||
|
|
@@ -914,6 +916,7 @@ def __init__(self, | |
| prefix=f"{prefix}.experts", | ||
| router=self.gate, | ||
| shared_experts=self.shared_experts, | ||
| scoring_func=scoring_func, | ||
| routed_scaling_factor=routed_scaling_factor) | ||
|
|
||
| def __call__(self, x_TD: jax.Array): | ||
|
|
@@ -986,6 +989,7 @@ def __init__( | |
| random_init: bool = False, | ||
| quant_config: Optional[QuantizationConfig] = None, | ||
| router_bias_dtype: jnp.dtype = jnp.float32, | ||
| scoring_func: str = "sigmoid", | ||
| moe_backend: MoEBackend = MoEBackend.DENSE_MAT): | ||
| self.hidden_size = hidden_size | ||
| self.num_experts = num_experts | ||
|
|
@@ -1001,6 +1005,7 @@ def __init__( | |
| self.random_init = random_init | ||
| self.quant_config = quant_config | ||
| self.router_bias_dtype = router_bias_dtype | ||
| self.scoring_func = scoring_func | ||
| self.moe_backend = moe_backend | ||
| """Generates the router kernel (weights and bias) for routing.""" | ||
| D = self.hidden_size | ||
|
|
@@ -1040,14 +1045,18 @@ def get_topk_indices(self, scores_TE: Float) -> Float: | |
| scores_TE, (-1, self.n_groups, experts_per_group)) | ||
| group_scores_TG2 = jax.lax.top_k(group_scores_TGM, k=2)[0] | ||
| group_scores_TG = jnp.sum(group_scores_TG2, axis=-1) | ||
| indices = jax.lax.top_k(group_scores_TG, k=self.topk_groups)[1] | ||
| group_indices = jax.lax.top_k(group_scores_TG, | ||
| k=self.topk_groups)[1] | ||
|
|
||
| # Apply mask at the group level before flattening | ||
| mask_TG1 = jax.nn.one_hot( | ||
| group_indices, | ||
| self.n_groups).sum(axis=1)[..., None].astype(jnp.bool_) | ||
|
|
||
| # Apply mask to each group of experts | ||
| group_scores_TGM = jnp.where(mask_TG1, group_scores_TGM, -jnp.inf) | ||
|
|
||
| mask_TG = jnp.any(jnp.arange( | ||
| self.n_groups)[:, None] == indices[..., None, :], | ||
| axis=-1) | ||
| mask_TE = jnp.repeat(mask_TG, | ||
| scores_TE.shape[-1] // mask_TG.shape[-1], -1) | ||
| scores_TE = jnp.where(mask_TE, scores_TE, 0.0) | ||
| scores_TE = jnp.reshape(group_scores_TGM, (-1, self.num_experts)) | ||
|
|
||
| indices_TX = jax.lax.top_k(scores_TE, k=self.num_experts_per_tok)[1] | ||
|
|
||
|
|
@@ -1065,26 +1074,32 @@ def __call__(self, x_TD: Float) -> Tuple[Float, Float]: | |
| - indices: Indices of selected experts, shape (sequence, num_experts_per_tok). | ||
| """ | ||
| x_TD = jnp.asarray(x_TD, self.dtype) | ||
| x_TD = lax.with_sharding_constraint(x_TD, self.activation_ffw_td) | ||
| x_TD = jax.lax.with_sharding_constraint(x_TD, | ||
| P(*self.activation_ffw_td)) | ||
|
|
||
| logits_TE = super().__call__(x_TD).astype(jnp.float32) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add comment with reference of https://github.com/vllm-project/vllm/blob/e89a91d9275cd8ac086fe04476b41675a9ebbd5c/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py#L59 here? |
||
|
|
||
| scores_TE = super().__call__(x_TD) | ||
| scores_TE = nnx.sigmoid(scores_TE) | ||
| # Apply scoring function (Sigmoid/Softmax) to get probabilities | ||
| if self.scoring_func == "sigmoid": | ||
| probs_TE = jax.nn.sigmoid(logits_TE) | ||
| elif self.scoring_func == "softmax": | ||
| probs_TE = jax.nn.softmax(logits_TE, axis=-1) | ||
| else: | ||
| probs_TE = logits_TE | ||
|
|
||
| if self.moe_backend in MoEBackend.fused_moe_backends(): | ||
| return scores_TE | ||
| # Will add Aux-Loss-Free bias the activation outputs during topk selection. | ||
| topk_indices_TX = self.get_topk_indices(probs_TE) | ||
|
|
||
| original_scores_TE = scores_TE | ||
| topk_indices_TX = self.get_topk_indices(scores_TE) | ||
| weights_TX = jnp.take_along_axis(original_scores_TE, | ||
| topk_indices_TX, | ||
| axis=-1) | ||
| # The actual weights do not include the bias terms. | ||
| weights_TX = jnp.take_along_axis(probs_TE, topk_indices_TX, axis=-1) | ||
|
|
||
| if self.norm_topk_prob: | ||
| weights_TX /= jnp.sum(weights_TX, axis=-1)[..., None] + 1e-20 | ||
|
|
||
| # Scale expert weights before taking linear combination of experts. | ||
| weights_TX *= self.routed_scaling_factor | ||
|
|
||
| return weights_TX, topk_indices_TX | ||
| return weights_TX.astype(self.dtype), topk_indices_TX | ||
|
|
||
|
|
||
| @dataclass | ||
|
|
@@ -1122,6 +1137,7 @@ def __init__(self, | |
| self.is_last_rank = get_pp_group().is_last_rank | ||
| hf_config = vllm_config.model_config.hf_config | ||
| dtype = vllm_config.model_config.dtype | ||
| scoring_func = getattr(hf_config, "scoring_func", "sigmoid") | ||
|
|
||
| if self.is_first_rank: | ||
| self.embed_tokens = JaxEmbed( | ||
|
|
@@ -1270,6 +1286,7 @@ def get_decoder_layer(layer_index: int): | |
| num_expert_parallelism=self.num_expert_parallelism, | ||
| moe_backend=self.moe_backend, | ||
| quant_config=quant_config, | ||
| scoring_func=scoring_func, | ||
| rng=rng, | ||
| prefix=f"{prefix}.layers.{layer_index}.mlp") | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure if this change is necessary? wouldn't it be possible to make change
DeepSeekV3Routerto return a value thatfused_moe_funcexpects?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fused_moe_func applies global top-k routing whereas DeepSeek needs to use the custom grouped top-k routing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would you prefer if I passed get_topk_func as an argument? It could replace jax.lax.top_k if passed and I think retain the same logical flow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so vllm has a concept called monolithic vs. non-monolithic. and i was wondering if we can leverage this api: https://github.com/vllm-project/vllm/blob/bdd8981dab8d8c6ae88a3f605d04ec5243088e5a/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py#L505-L525