[Feat] Add CountAndGather and MoeReduce operations for FusedMoe#180
[Feat] Add CountAndGather and MoeReduce operations for FusedMoe#180zhen8838 wants to merge 2 commits intotile-ai:mainfrom
Conversation
Summary of ChangesHello @zhen8838, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces core functionalities for Fused Mixture of Experts (MoE) within TileOPs, specifically focusing on the routing and aggregation phases. It integrates highly optimized Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
The pull request introduces CountAndGather and MoeReduce operations for FusedMoe, along with their respective TileLang kernels and tests. The new functionality is well-tested with reference implementations and basic performance checks. The overall structure follows the Op and Kernel abstraction, which is good for modularity.
However, there are critical efficiency concerns regarding the TileLang kernel compilation strategy. The current design causes recompilation of the JIT kernels whenever dynamic input shapes change, which will lead to significant performance overhead. Additionally, the high-level functions in top/functions/fuse_moe.py instantiate kernels on every call, exacerbating this recompilation issue. Addressing these recompilation issues is crucial for the performance of these operations.
| kernel = CountAndGatherKernel( | ||
| num_seq=num_seq, | ||
| hidden_size=hidden_size, | ||
| num_topk=num_topk, | ||
| num_expert=num_expert, | ||
| config={"tile_m": tile_m}, | ||
| ) | ||
| return kernel.count_and_gather(x, topk_ids, rank_ep) |
There was a problem hiding this comment.
The count_and_gather function instantiates a new CountAndGatherKernel on every call. This is highly inefficient as it leads to repeated kernel initialization and potential JIT recompilation overhead, especially if this function is called frequently. Kernels, especially JIT-compiled ones, should ideally be instantiated once and reused to avoid this overhead. Consider refactoring to accept a pre-initialized kernel or Op instance, or to manage the kernel's lifecycle more effectively.
| kernel = MoeReduceKernel( | ||
| num_seq=num_seq, | ||
| hidden_size=hidden_size, | ||
| num_topk=num_topk, | ||
| ) | ||
| return kernel.forward(x, topk_pos, topk_scale, shared_output) |
There was a problem hiding this comment.
Similar to count_and_gather, the reduce function instantiates a new MoeReduceKernel on every call. This will cause repeated kernel initialization and potential JIT recompilation, which is inefficient. For optimal performance, kernels should be instantiated once and reused across multiple calls. Consider passing a pre-initialized kernel or Op instance to this function.
| def _count_seq_and_cuseq_kernel(total_num_topk: int, num_expert: int, start_expert: int, | ||
| end_expert: int, tile_m: int): |
There was a problem hiding this comment.
The _count_seq_and_cuseq_kernel function takes dynamic dimensions (total_num_topk, num_expert, start_expert, end_expert, tile_m) as arguments to its outer Python function. When these values change (e.g., with different input tensor shapes), tilelang.jit will recompile the kernel. This leads to significant performance overhead for dynamic workloads. To avoid recompilation, dynamic dimensions should be passed as T.int32 scalar arguments to the T.prim_func itself, and T.Tensor declarations should use symbolic dimensions or maximum possible dimensions.
| def _gather_kernel(num_seq: int, hidden_size: int, num_topk: int, total_num_topk: int, | ||
| num_expert: int, start_expert: int, end_expert: int): |
There was a problem hiding this comment.
The _gather_kernel function also takes dynamic dimensions (num_seq, hidden_size, num_topk, total_num_topk, num_expert, start_expert, end_expert) as arguments to its outer Python function. This will cause tilelang.jit to recompile the kernel whenever these values change, leading to significant performance overhead. Refactor this JIT kernel to accept dynamic dimensions as T.int32 scalar arguments to the T.prim_func itself, using symbolic or maximum dimensions for T.Tensor declarations.
| from top.kernels.kernel import Kernel | ||
|
|
||
|
|
||
| def _moe_reduce_kernel(total_num_seq: int, num_seq: int, hidden_size: int, num_topk: int): |
There was a problem hiding this comment.
The _moe_reduce_kernel function takes dynamic dimensions (total_num_seq, num_seq, hidden_size, num_topk) as arguments to its outer Python function. This design causes tilelang.jit to recompile the kernel every time these values change, which is a critical performance bottleneck for dynamic shapes. The kernel should be refactored to accept these dynamic dimensions as T.int32 scalar arguments to the T.prim_func itself, and T.Tensor declarations should use symbolic or maximum possible dimensions.
| print(f"Average time per iteration: {avg_time * 1000:.2f} ms") | ||
| print(f"Throughput: {throughput:.2f} sequences/sec") | ||
|
|
||
| assert avg_time < 0.1 |
There was a problem hiding this comment.
The performance assertion assert avg_time < 0.1 uses a magic number 0.1. Hardcoded thresholds can make tests brittle and prone to failure across different environments or as code evolves. It's generally better to use more robust performance checks, such as relative thresholds or named constants to improve maintainability.
| assert avg_time < 0.1 | |
| PERFORMANCE_THRESHOLD_MS = 100 # Example: 100ms | |
| assert avg_time * 1000 < PERFORMANCE_THRESHOLD_MS |
There was a problem hiding this comment.
Pull request overview
This PR implements the Fused MoE routing and aggregation operations in TileLang, adding CountAndGatherOp and MoeReduceOp to support MoE layer computations. The implementation follows the kernel -> op -> function layering pattern, with TileLang kernels handling the low-level compute, ops providing the high-level interface, and functions offering standalone convenience wrappers.
Changes:
- Migrated
count_and_gatherandmoe_reducekernels to TileLang with staged logic - Added Op wrappers (CountAndGatherOp, MoeReduceOp) with input validation and kernel dispatch
- Created standalone function wrappers in
top/functions/fuse_moe.py - Added comprehensive unit tests with reference-based validation
- Included placeholder FuseMoePertensorFp8Op for future FP8 quantization support
Reviewed changes
Copilot reviewed 10 out of 10 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| top/kernels/fuse_moe/moe_reduce.py | TileLang kernel implementation for MoE reduce operation with FP32 accumulation |
| top/kernels/fuse_moe/count_and_gather.py | Two-stage TileLang kernel: count tokens per expert and gather inputs by expert |
| top/kernels/fuse_moe/init.py | Kernel module exports |
| top/ops/moe_reduce.py | Op wrapper for MoeReduceKernel with input validation and dimension management |
| top/ops/count_and_gather.py | Op wrapper for CountAndGatherKernel with config management |
| top/ops/fuse_moe_pertensor_fp8.py | Placeholder Op for future FP8 quantization support (not implemented) |
| top/ops/init.py | Added new Op exports |
| top/functions/fuse_moe.py | Standalone function wrappers for count_and_gather, reduce, and fuse_moe_pertensor_fp8 |
| tests/ops/test_moe_reduce.py | Unit tests for MoeReduceOp with reference implementation and edge cases |
| tests/ops/test_count_and_gather.py | Unit tests for CountAndGatherOp with reference implementation and performance checks |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Description
This PR implements the Fused MoE routing and aggregation path in TileOPs:
fixes #179 #178
count_and_gathermigrated to TileLang kernel path (with staged logic)moe_reducemigrated to TileLang kernel path and renamed at Op level to avoid generic naming conflictType of Change
Checklist
pre-commit run --all-filesand fixed all linting issues.Benchmarkclass inbenchmarks/.