Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 24 additions & 8 deletions tpu_inference/layers/common/fused_moe_gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,8 @@ def fused_moe_func(
use_ep: bool,
activation: str,
scoring_fn: str,
topk_weights: jax.Array | None = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if this change is necessary? wouldn't it be possible to make change DeepSeekV3Router to return a value that fused_moe_func expects?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fused_moe_func applies global top-k routing whereas DeepSeek needs to use the custom grouped top-k routing.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you prefer if I passed get_topk_func as an argument? It could replace jax.lax.top_k if passed and I think retain the same logical flow.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so vllm has a concept called monolithic vs. non-monolithic. and i was wondering if we can leverage this api: https://github.com/vllm-project/vllm/blob/bdd8981dab8d8c6ae88a3f605d04ec5243088e5a/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py#L505-L525

topk_indices: jax.Array | None = None,
) -> jax.Array:
"""Route tokens in hidden_states into each experts based on routing.

Expand All @@ -321,6 +323,8 @@ def fused_moe_func(
use_ep: use expert parallelism.
activation: activation function to perform on the output of w1.
scoring_fn: scoring function to apply on gating_output.
topk_weights: pre-calculated expert weights.
topk_indices: pre-calculated expert indices.

Returns:
Output of moe operation [num_tokens, hidden_size]
Expand All @@ -333,15 +337,27 @@ def fused_moe_func(
"The kernel requires num_tokens * topk to be a multiple of "
f"16 but got {num_tokens}*{topk}={num_tokens*topk}")

assert gating_output.shape == (num_tokens, global_num_experts)
if topk_weights is None or topk_indices is None:
assert gating_output is not None
assert gating_output.shape == (num_tokens, global_num_experts)

topk_weights = apply_scoring_fn(scoring_fn, gating_output)
# All-gather topk weights for attention dp
topk_weights = jax.lax.with_sharding_constraint(
topk_weights,
NamedSharding(mesh, P(ShardingAxisName.MLP_DATA, None)))
topk_weights, topk_indices = jax.lax.top_k(topk_weights, k=topk)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(axis=-1,
keepdims=True)
else:
topk_weights = jax.lax.with_sharding_constraint(
topk_weights,
NamedSharding(mesh, P(ShardingAxisName.MLP_DATA, None)))
topk_indices = jax.lax.with_sharding_constraint(
topk_indices,
NamedSharding(mesh, P(ShardingAxisName.MLP_DATA, None)))

topk_weights = apply_scoring_fn(scoring_fn, gating_output)
# All-gather topk weights for attention dp
topk_weights = jax.lax.with_sharding_constraint(
topk_weights, NamedSharding(mesh, P(ShardingAxisName.MLP_DATA, None)))
topk_weights, topk_indices = jax.lax.top_k(topk_weights, k=topk)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(axis=-1, keepdims=True)
topk_weights = topk_weights.astype(dtype)

def _process_tokens_locally(hidden_states_local, topk_indices_local):
Expand Down
8 changes: 8 additions & 0 deletions tpu_inference/layers/common/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,12 @@ def moe_apply(
**extra_backend_kwargs,
)[:, :actual_hidden_size]
case MoEBackend.GMM_EP | MoEBackend.GMM_TP:
topk_weights = None
topk_indices = None
if isinstance(gating_output, tuple):
topk_weights, topk_indices = gating_output
gating_output = None # Not used if pre-calculated

output = fused_moe_func(
hidden_states=x,
w1=weights.w13_weight,
Expand All @@ -136,6 +142,8 @@ def moe_apply(
use_ep=layer.use_ep,
activation=activation,
scoring_fn=layer.scoring_func,
topk_weights=topk_weights,
topk_indices=topk_indices,
)
case MoEBackend.DENSE_MAT:
# NOTE: circular import avoidance
Expand Down
2 changes: 1 addition & 1 deletion tpu_inference/layers/jax/moe/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ class JaxMoE(JaxModule):
num_expert_parallelism: int
random_init: bool = False
moe_backend: MoEBackend = MoEBackend.DENSE_MAT
scoring_func = "softmax"
scoring_func: str = "softmax"

# --- Sparse MoE Specific Attributes ---
num_experts_per_tok: int = 1 # Required for Sparse, optional/derived for Dense
Expand Down
8 changes: 2 additions & 6 deletions tpu_inference/layers/jax/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,10 +517,6 @@ def process_weights_after_loading(self, layer: JaxMoE) -> bool:
w13_weight = jnp.concatenate([w_gate, w_up], axis=1)
w13_weight_scale = jnp.concatenate([s_gate, s_up], axis=1)

weight_block_size = None
if self.weight_block_size is not None:
weight_block_size = tuple(self.weight_block_size)

# TODO (jacobplatin): we should support bias
input_weights = FusedMoEWeights(
w13_weight=w13_weight,
Expand All @@ -535,8 +531,8 @@ def process_weights_after_loading(self, layer: JaxMoE) -> bool:
moe_backend=layer.moe_backend,
mesh=layer.mesh,
activation=layer.activation,
# Convert to tuple so jax jit can hash it
weight_block_size=weight_block_size,
# Source block size should be inferred from scale shape
weight_block_size=None,
)

del layer.kernel_gating_EDF
Expand Down
61 changes: 39 additions & 22 deletions tpu_inference/models/jax/deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ def load_weights(self, weights):
self.weight,
self.weight_scale_inv,
(0, 1),
block_size=self.quant_config.weight_block_size,
block_size=None,
).T
A, N, qk_nope_head_dim, v_head_dim = self.mla_layer.kv_lora_rank, self.mla_layer.N, self.mla_layer.qk_nope_head_dim, self.mla_layer.v_head_dim
if dequantized_weight.shape != (A, N *
Expand Down Expand Up @@ -841,6 +841,7 @@ def __init__(self,
num_expert_parallelism,
moe_backend,
quant_config,
scoring_func,
rng,
prefix: str = ""):

Expand All @@ -855,9 +856,10 @@ def __init__(self,
routed_scaling_factor=routed_scaling_factor,
dtype=dtype,
moe_backend=moe_backend,
activation_ffw_td=P(ShardingAxisName.MLP_DATA, None),
ed_sharding=P(None, None),
e_sharding=P(None, ),
activation_ffw_td=(ShardingAxisName.MLP_DATA, None),
ed_sharding=(None, None),
e_sharding=(None, ),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this line is unnecessary, and you can revert line 1078?

scoring_func=scoring_func,
quant_config=quant_config)

# shared experts
Expand Down Expand Up @@ -914,6 +916,7 @@ def __init__(self,
prefix=f"{prefix}.experts",
router=self.gate,
shared_experts=self.shared_experts,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor)

def __call__(self, x_TD: jax.Array):
Expand Down Expand Up @@ -986,6 +989,7 @@ def __init__(
random_init: bool = False,
quant_config: Optional[QuantizationConfig] = None,
router_bias_dtype: jnp.dtype = jnp.float32,
scoring_func: str = "sigmoid",
moe_backend: MoEBackend = MoEBackend.DENSE_MAT):
self.hidden_size = hidden_size
self.num_experts = num_experts
Expand All @@ -1001,6 +1005,7 @@ def __init__(
self.random_init = random_init
self.quant_config = quant_config
self.router_bias_dtype = router_bias_dtype
self.scoring_func = scoring_func
self.moe_backend = moe_backend
"""Generates the router kernel (weights and bias) for routing."""
D = self.hidden_size
Expand Down Expand Up @@ -1040,14 +1045,18 @@ def get_topk_indices(self, scores_TE: Float) -> Float:
scores_TE, (-1, self.n_groups, experts_per_group))
group_scores_TG2 = jax.lax.top_k(group_scores_TGM, k=2)[0]
group_scores_TG = jnp.sum(group_scores_TG2, axis=-1)
indices = jax.lax.top_k(group_scores_TG, k=self.topk_groups)[1]
group_indices = jax.lax.top_k(group_scores_TG,
k=self.topk_groups)[1]

# Apply mask at the group level before flattening
mask_TG1 = jax.nn.one_hot(
group_indices,
self.n_groups).sum(axis=1)[..., None].astype(jnp.bool_)

# Apply mask to each group of experts
group_scores_TGM = jnp.where(mask_TG1, group_scores_TGM, -jnp.inf)

mask_TG = jnp.any(jnp.arange(
self.n_groups)[:, None] == indices[..., None, :],
axis=-1)
mask_TE = jnp.repeat(mask_TG,
scores_TE.shape[-1] // mask_TG.shape[-1], -1)
scores_TE = jnp.where(mask_TE, scores_TE, 0.0)
scores_TE = jnp.reshape(group_scores_TGM, (-1, self.num_experts))

indices_TX = jax.lax.top_k(scores_TE, k=self.num_experts_per_tok)[1]

Expand All @@ -1065,26 +1074,32 @@ def __call__(self, x_TD: Float) -> Tuple[Float, Float]:
- indices: Indices of selected experts, shape (sequence, num_experts_per_tok).
"""
x_TD = jnp.asarray(x_TD, self.dtype)
x_TD = lax.with_sharding_constraint(x_TD, self.activation_ffw_td)
x_TD = jax.lax.with_sharding_constraint(x_TD,
P(*self.activation_ffw_td))

logits_TE = super().__call__(x_TD).astype(jnp.float32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


scores_TE = super().__call__(x_TD)
scores_TE = nnx.sigmoid(scores_TE)
# Apply scoring function (Sigmoid/Softmax) to get probabilities
if self.scoring_func == "sigmoid":
probs_TE = jax.nn.sigmoid(logits_TE)
elif self.scoring_func == "softmax":
probs_TE = jax.nn.softmax(logits_TE, axis=-1)
else:
probs_TE = logits_TE

if self.moe_backend in MoEBackend.fused_moe_backends():
return scores_TE
# Will add Aux-Loss-Free bias the activation outputs during topk selection.
topk_indices_TX = self.get_topk_indices(probs_TE)

original_scores_TE = scores_TE
topk_indices_TX = self.get_topk_indices(scores_TE)
weights_TX = jnp.take_along_axis(original_scores_TE,
topk_indices_TX,
axis=-1)
# The actual weights do not include the bias terms.
weights_TX = jnp.take_along_axis(probs_TE, topk_indices_TX, axis=-1)

if self.norm_topk_prob:
weights_TX /= jnp.sum(weights_TX, axis=-1)[..., None] + 1e-20

# Scale expert weights before taking linear combination of experts.
weights_TX *= self.routed_scaling_factor

return weights_TX, topk_indices_TX
return weights_TX.astype(self.dtype), topk_indices_TX


@dataclass
Expand Down Expand Up @@ -1122,6 +1137,7 @@ def __init__(self,
self.is_last_rank = get_pp_group().is_last_rank
hf_config = vllm_config.model_config.hf_config
dtype = vllm_config.model_config.dtype
scoring_func = getattr(hf_config, "scoring_func", "sigmoid")

if self.is_first_rank:
self.embed_tokens = JaxEmbed(
Expand Down Expand Up @@ -1270,6 +1286,7 @@ def get_decoder_layer(layer_index: int):
num_expert_parallelism=self.num_expert_parallelism,
moe_backend=self.moe_backend,
quant_config=quant_config,
scoring_func=scoring_func,
rng=rng,
prefix=f"{prefix}.layers.{layer_index}.mlp")

Expand Down
Loading