[jit_kernel] Add moe_sum JIT kernel (port from sgl-kernel AOT)#19057
[jit_kernel] Add moe_sum JIT kernel (port from sgl-kernel AOT)#19057Johnsonms wants to merge 6 commits intosgl-project:mainfrom
Conversation
Port sgl-kernel/csrc/moe/moe_sum.cu to the lightweight JIT kernel system, matching the sgl_kernel.moe_sum call signature. - csrc/moe/moe_sum.cuh: one-CTA-per-token kernel with static dispatch for TOPK 1-9 (unrolled) plus a general fallback for larger topk - moe_sum.py: cache_once/load_jit wrapper, register_custom_op - tests/test_moe_sum.py: 283 correctness tests vs PyTorch reference and optional cross-validation vs AOT sgl_kernel - benchmark/bench_moe_sum.py: JIT vs AOT throughput comparison
Summary of ChangesHello @Johnsonms, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request migrates the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
The pull request introduces a JIT kernel for moe_sum, porting functionality from an AOT kernel. The changes include the CUDA kernel implementation, Python bindings, and comprehensive tests and benchmarks. The code appears well-structured and follows good practices for CUDA kernel development and Python integration. The benchmarking and testing setup is thorough, ensuring correctness and performance comparison with the AOT version. I've identified a few areas for improvement regarding error handling and consistency in the C++ code.
Apply clang-format and black fixes flagged by pre-commit hooks: - moe_sum.cuh: align comment spacing, collapse kernel_general signature - test_moe_sum.py: reformat assert torch.allclose() calls (line length)
|
/tag-and-rerun-ci |
Root cause: the AOT kernel only dispatches topk=2,3,4 to the custom CUDA
kernel (bf16 accumulation); for topk>4 it falls through to at::sum_out
which uses float32 accumulation internally. The JIT dispatched all topk
values to the custom kernel with bf16 accumulation, causing a large
mismatch (max_err=1.25e-01) for topk=8 with bfloat16.
Fix: accumulate in float32 in both the static-dispatch and general-fallback
kernels, matching at::sum_out precision for all topk values.
Test updates:
- test_moe_sum_vs_ref: tighten atol/rtol from 0.05/1e-2 to 1e-3/1e-3
(fp32 accumulation is now much closer to the float32 reference)
- test_moe_sum_vs_aot: split into two tests:
- test_moe_sum_vs_aot_fp32: float32 dtype only for topk=2,3,4,8
(AOT bf16 accum == JIT fp32 accum for float32 inputs)
- test_moe_sum_vs_aot_large_topk: fp16/bf16 for topk=8
(AOT uses at::sum_out fp32, matches JIT fp32 exactly)
Benchmark: calculate_diff() now compares JIT against float32 reference
instead of AOT to avoid false mismatches from AOT's inconsistent precision.
Note: the tokens=1, hidden=7168, topk=8 performance regression vs AOT is
expected — AOT uses at::sum_out which is better optimized for the degenerate
single-token case. For all realistic batch sizes (64+) JIT is faster.
…den dim Previously the grid was (num_tokens,) with one block per token, leaving tokens=1 with only 1 block (1 SM) for hidden=7168 and causing a ~54% regression vs AOT's at::sum_out which launches ~7 blocks for the same case. Fix: decompose the grid as (num_tokens * hidden_blocks,) where hidden_blocks = ceil(hidden_size / 1024), so each block handles a 1024-wide slice of the hidden dimension. Each thread now processes exactly one output element, eliminating the stride loop and maximizing parallelism. Result for tokens=1, hidden=7168, topk=8: Before: JIT=3.29µs AOT=2.14µs (54% regression) After: JIT=1.44µs AOT=2.15µs (33% faster than AOT)
LaunchKernel uses cudaLaunchKernelEx with RuntimeDeviceCheck for automatic error checking after each kernel launch, consistent with other JIT kernels.
…blocks and stride loop Two changes to moe_sum.cuh: 1. Adaptive hidden_blocks: cap total grid to TARGET_MAX_BLOCKS (256) so large token counts get hidden_blocks=1 (each thread strides over the full hidden slice) while small token counts keep full 2-D decomposition for parallelism. 2. Stride loop in both kernels: each thread loops over its hidden slice with stride = hidden_blocks * blockDim.x, so work is correct for any hidden_blocks. When hidden_blocks=1 this is identical to the original 1-D stride loop. Previously the 2D grid (always hidden_blocks=ceil(hidden/1024)) caused 20-68% slowdown at large token counts (tokens≥128, hidden=4096, topk=2) because each thread had only 2 loads, insufficient to hide memory latency. After this fix: - topk=2: 5-13% residual gap (down from 20-68%) - topk=4: at parity with AOT - topk=8: 7-32% faster than AOT across all token counts Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Motivation
Part of tracking issue #17865 — migrate
sgl-kernelAOT kernels to the lightweightpython/sglang/jit_kernel/system.This PR ports
sgl-kernel/csrc/moe/moe_sum.cuto a JIT kernel, matching the existingsgl_kernel.moe_sumcall signature so it can serve as a drop-in replacement.Changes
python/sglang/jit_kernel/csrc/moe/moe_sum.cuhmoe_sum_kernel<T, TOPK>with__ldgloads and compile-time loop unrolling via#pragma unrollmoe_sum_kernel_general<T>fallback for TOPK > 9 (with#pragma unroll 1to avoid register pressure)moe_sum<T>(TensorView, TensorView)host launcher with tvm-ffi interfacepython/sglang/jit_kernel/moe_sum.pycache_once/load_jitwrapper; one JIT module per dtype (fp32/fp16/bf16)moe_sum_outregistered as a custom op (destination-passing style)moe_sum(input_tensor, output_tensor)public API matchingsgl_kernel.moe_sumpython/sglang/jit_kernel/tests/test_moe_sum.pytest_moe_sum_vs_ref: JIT vs.x.sum(dim=1)PyTorch reference across all dtype/token/topk/hidden combinationstest_output_shape_and_dtype: shape and dtype preservationtest_general_fallback: exercises topk > 4 general kernel pathtest_moe_sum_vs_aot: optional cross-validation againstsgl_kernel.moe_sumwhen availablepython/sglang/jit_kernel/benchmark/bench_moe_sum.pysgl_kernel.moe_sumthroughput comparison usingtriton.testing.perf_reportis_in_ci())calculate_diff()quick correctness sanity check at startupTest Results
Performance Notes
Static dispatch for TOPK 1–9 ensures full loop unrolling for all common MoE configurations (topk=2 standard, topk=4 common, topk=8 DeepSeek-style). The general fallback (topk > 9) uses
moe_sum_kernel_general<T>with#pragma unroll 1to avoid register pressure on rarely-used paths.