[NPU] Add NPU Fused MoE kernel#1183
Conversation
|
@Tcc0403 This PR is ready for review. |
Tcc0403
left a comment
There was a problem hiding this comment.
LGTM, just a tiny issue
| torch.cuda.synchronize() | ||
| if device == "cuda": | ||
| torch.cuda.synchronize() | ||
| elif device == "npu": | ||
| torch.npu.synchronize() | ||
|
|
There was a problem hiding this comment.
Great catch, we also have CPU support. Could you add it?
| if device == "cuda": | ||
| torch.cuda.synchronize() | ||
| elif device == "npu": | ||
| torch.npu.synchronize() | ||
| else: | ||
| torch.cpu.synchronize() |
There was a problem hiding this comment.
@Tcc0403 Thanks for the suggestion. torch provides a cpu equivalent, so I've added it here.
There was a problem hiding this comment.
Sorry typo, meant to be xpu not cpu 😅
There was a problem hiding this comment.
Got it. Added torch.xpu.synchronize, please take another look.
| def compute_routing_metadata(topk_indices: torch.Tensor, E: int, block_m_token: int = BLOCK_M_TOKEN): | ||
| """Compute token→expert routing permutation metadata via 3 Triton kernels. | ||
|
|
||
| Also computes GPU tile metadata (tile_row_start, tile_expert) inside | ||
| Kernel 3 — no CPU loop, one .item() sync for num_m_tiles allocation. | ||
|
|
||
| Args: | ||
| topk_indices: (T, K) int32 — pre-computed top-k expert indices per token | ||
| E: number of experts | ||
| block_m_token: BLOCK_M for token-dimension tiling (default BLOCK_M_TOKEN) | ||
|
|
||
| Returns: | ||
| expert_token_count: (E,) int32 | ||
| expert_start_idx: (E+1,) int32 | ||
| x_gather_idx: (TK,) int32 | ||
| s_scatter_idx: (TK,) int32 | ||
| s_reverse_scatter_idx: (TK,) int32 | ||
| tile_row_start: (num_m_tiles,) int32 — absolute row_start per M-tile | ||
| tile_expert: (num_m_tiles,) int32 — expert index per M-tile | ||
| """ |
| from liger_kernel.ops.fused_moe import LigerFusedMoEFunction | ||
| from liger_kernel.ops.fused_moe import compute_routing_metadata | ||
| from liger_kernel.ops import LigerFusedMoEFunction | ||
| from liger_kernel.ops import compute_routing_metadata |
There was a problem hiding this comment.
Hi @zheliuyu . Is this change required for testing on NPU? I tend not to export compute_routing_metadata externally as there is no use case to use this function outside LigerFusedMoEFuction.
Is there any workaround like have a test-level import redirect based on the infer_device?
There was a problem hiding this comment.
This is currently breaking the unit test on other devices (gpu) since we didn't export compute_routing_metadata in non-ascend backend.
…on device type (linkedin#1209) ## Summary As Title. Fix: linkedin#1183 (comment) - Hardware Type: npu - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence






Motivation
This pr ports
fused_moe.pyandfused_moe_kernels.pyto an NPU-affine implementation while preserving the original math. The computational definition is unchanged: forward remainsW1 (gate/up) -> SwiGLU -> W2 -> token-weighted gather, and backward still followsdA' = dO @ W2^Tto produced_pre_act / dS / dW2 / dX / dW1.The main changes are execution-strategy optimizations for NPU.
Note: Use the Skill
For this fused_moe kernel migration, we followed the skill document from #1197.
Testing Done
make testto ensure correctnessmake checkstyleto ensure code stylemake test-convergenceto ensure convergence🤖 Generated with: cursor.