Skip to content

Skip redundant moe_sum_reduce for single-expert routing on XPU#22660

Open
rahulvijayaraghavan wants to merge 1 commit intosgl-project:mainfrom
rahulvijayaraghavan:skip-redundant-moe-sum-reduce-xpu
Open

Skip redundant moe_sum_reduce for single-expert routing on XPU#22660
rahulvijayaraghavan wants to merge 1 commit intosgl-project:mainfrom
rahulvijayaraghavan:skip-redundant-moe-sum-reduce-xpu

Conversation

@rahulvijayaraghavan
Copy link
Copy Markdown
Contributor

When topk_ids.shape[1] == 1 and routed_scaling_factor == 1.0, the second invoke_fused_moe_kernel call already writes its output directly into out_hidden_states, so the subsequent moe_sum_reduce is a no-op reduction over a single element. This adds an early-exit check on the XPU path to skip the unnecessary kernel launch, matching the existing optimization already present in the CUDA path.

This is particularly relevant for Llama-4-Scout models (e.g. Llama-4-Scout-17B-16E-Instruct), which set num_experts_per_tok = 1, meaning this fast path is hit on every MoE layer forward pass.

When topk_ids.shape[1] == 1 and routed_scaling_factor == 1.0, the second
invoke_fused_moe_kernel call already writes its output directly into
out_hidden_states, so the subsequent moe_sum_reduce is a no-op reduction
over a single element. This adds an early-exit check on the XPU path to
skip the unnecessary kernel launch, matching the existing optimization
already present in the CUDA path.

This is particularly relevant for Llama-4-Scout models (e.g.
Llama-4-Scout-17B-16E-Instruct), which set num_experts_per_tok = 1,
meaning this fast path is hit on every MoE layer forward pass.
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

Comment on lines +681 to +688
if topk_ids.shape[1] == 1 and routed_scaling_factor == 1.0:
pass # we write directly into out_hidden_states
else:
moe_sum_reduce(
intermediate_cache3.view(*intermediate_cache3.shape),
out_hidden_states[begin_chunk_idx:end_chunk_idx],
routed_scaling_factor,
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sure we have test cases to cover:

  • topk == 1, routed_scaling_factor == 1.0
  • topk == 1, routed_scaling_factor != 1.0

addtionally, is it possible to move topk_ids.shape[1] == 1 and routed_scaling_factor == 1.0 up and skip intermediate_cache3 allocation at the first place.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants