Skip to content

Comments

[jit_kernel] Add moe_sum JIT kernel (port from sgl-kernel AOT)#19057

Draft
Johnsonms wants to merge 6 commits intosgl-project:mainfrom
Johnsonms:moe-sum-jit
Draft

[jit_kernel] Add moe_sum JIT kernel (port from sgl-kernel AOT)#19057
Johnsonms wants to merge 6 commits intosgl-project:mainfrom
Johnsonms:moe-sum-jit

Conversation

@Johnsonms
Copy link
Contributor

@Johnsonms Johnsonms commented Feb 20, 2026

Motivation

Part of tracking issue #17865 — migrate sgl-kernel AOT kernels to the lightweight python/sglang/jit_kernel/ system.

This PR ports sgl-kernel/csrc/moe/moe_sum.cu to a JIT kernel, matching the existing sgl_kernel.moe_sum call signature so it can serve as a drop-in replacement.

Changes

python/sglang/jit_kernel/csrc/moe/moe_sum.cuh

  • One-CTA-per-token reduction kernel
  • moe_sum_kernel<T, TOPK> with __ldg loads and compile-time loop unrolling via #pragma unroll
  • Static dispatch for TOPK = 1–9 (covering all common MoE topk values including DeepSeek-style topk=8/9)
  • moe_sum_kernel_general<T> fallback for TOPK > 9 (with #pragma unroll 1 to avoid register pressure)
  • moe_sum<T>(TensorView, TensorView) host launcher with tvm-ffi interface

python/sglang/jit_kernel/moe_sum.py

  • cache_once/load_jit wrapper; one JIT module per dtype (fp32/fp16/bf16)
  • moe_sum_out registered as a custom op (destination-passing style)
  • moe_sum(input_tensor, output_tensor) public API matching sgl_kernel.moe_sum

python/sglang/jit_kernel/tests/test_moe_sum.py

  • 283 correctness tests (CI-friendly reduced set, full combinatorial range otherwise)
  • test_moe_sum_vs_ref: JIT vs. x.sum(dim=1) PyTorch reference across all dtype/token/topk/hidden combinations
  • test_output_shape_and_dtype: shape and dtype preservation
  • test_general_fallback: exercises topk > 4 general kernel path
  • test_moe_sum_vs_aot: optional cross-validation against sgl_kernel.moe_sum when available

python/sglang/jit_kernel/benchmark/bench_moe_sum.py

  • JIT vs. AOT sgl_kernel.moe_sum throughput comparison using triton.testing.perf_report
  • CI-friendly reduced configurations (guarded by is_in_ci())
  • calculate_diff() quick correctness sanity check at startup

Test Results

283 passed in 16.6s

Performance Notes

Static dispatch for TOPK 1–9 ensures full loop unrolling for all common MoE configurations (topk=2 standard, topk=4 common, topk=8 DeepSeek-style). The general fallback (topk > 9) uses moe_sum_kernel_general<T> with #pragma unroll 1 to avoid register pressure on rarely-used paths.

Port sgl-kernel/csrc/moe/moe_sum.cu to the lightweight JIT kernel
system, matching the sgl_kernel.moe_sum call signature.

- csrc/moe/moe_sum.cuh: one-CTA-per-token kernel with static dispatch
  for TOPK 1-9 (unrolled) plus a general fallback for larger topk
- moe_sum.py: cache_once/load_jit wrapper, register_custom_op
- tests/test_moe_sum.py: 283 correctness tests vs PyTorch reference
  and optional cross-validation vs AOT sgl_kernel
- benchmark/bench_moe_sum.py: JIT vs AOT throughput comparison
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @Johnsonms, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request migrates the moe_sum kernel from an AOT (Ahead-Of-Time) compiled sgl-kernel implementation to a new lightweight JIT (Just-In-Time) compilation system. The goal is to provide a drop-in replacement that matches the existing call signature, improving flexibility and maintainability while ensuring correctness and performance.

Highlights

  • JIT Kernel Implementation: Implemented a one-CTA-per-token reduction kernel for moe_sum in CUDA, supporting static dispatch for common TOPK values (1-9) and a general fallback for TOPK > 9.
  • Python Integration: Created Python wrappers to load and cache the JIT-compiled moe_sum kernel, registering it as a custom operation with a tvm-ffi interface and a public API matching the existing sgl_kernel.moe_sum.
  • Comprehensive Testing: Added extensive correctness tests against a PyTorch reference, covering various configurations and data types, including specific tests for output shape/dtype preservation and the general fallback mechanism.
  • Performance Benchmarking: Introduced a benchmark script to compare the throughput of the new JIT moe_sum kernel against the legacy AOT sgl_kernel implementation, including a quick correctness diff check.
  • Optimized Performance: Utilized static dispatch for TOPK values 1-9 to ensure full loop unrolling and employed #pragma unroll 1 in the general fallback to avoid register pressure on less common paths.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/jit_kernel/benchmark/bench_moe_sum.py
    • Added a new benchmark script to compare JIT and AOT moe_sum performance.
    • Included a quick correctness diff check between JIT and AOT implementations.
  • python/sglang/jit_kernel/csrc/moe/moe_sum.cuh
    • Added the CUDA kernel for moe_sum, featuring one-CTA-per-token reduction.
    • Implemented static dispatch for TOPK values from 1 to 9 and a general fallback kernel.
    • Included a host launcher with a tvm-ffi interface for moe_sum.
  • python/sglang/jit_kernel/moe_sum.py
    • Added Python logic to cache and load the JIT moe_sum module.
    • Registered moe_sum_out as a custom op.
    • Provided a public moe_sum API matching the sgl_kernel signature.
  • python/sglang/jit_kernel/tests/test_moe_sum.py
    • Added 283 correctness tests for the moe_sum JIT kernel.
    • Included tests for JIT vs. PyTorch reference, output shape/dtype preservation, general fallback, and optional cross-validation against AOT.
Activity
  • Initial implementation of the moe_sum JIT kernel and its Python bindings.
  • Added comprehensive unit tests, with 283 tests passing in 16.6 seconds.
  • Introduced a performance benchmark to compare JIT and AOT versions.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request introduces a JIT kernel for moe_sum, porting functionality from an AOT kernel. The changes include the CUDA kernel implementation, Python bindings, and comprehensive tests and benchmarks. The code appears well-structured and follows good practices for CUDA kernel development and Python integration. The benchmarking and testing setup is thorough, ensuring correctness and performance comparison with the AOT version. I've identified a few areas for improvement regarding error handling and consistency in the C++ code.

Apply clang-format and black fixes flagged by pre-commit hooks:
- moe_sum.cuh: align comment spacing, collapse kernel_general signature
- test_moe_sum.py: reformat assert torch.allclose() calls (line length)
@yuan-luo
Copy link
Collaborator

/tag-and-rerun-ci

Johnsonms and others added 4 commits February 20, 2026 23:26
Root cause: the AOT kernel only dispatches topk=2,3,4 to the custom CUDA
kernel (bf16 accumulation); for topk>4 it falls through to at::sum_out
which uses float32 accumulation internally. The JIT dispatched all topk
values to the custom kernel with bf16 accumulation, causing a large
mismatch (max_err=1.25e-01) for topk=8 with bfloat16.

Fix: accumulate in float32 in both the static-dispatch and general-fallback
kernels, matching at::sum_out precision for all topk values.

Test updates:
- test_moe_sum_vs_ref: tighten atol/rtol from 0.05/1e-2 to 1e-3/1e-3
  (fp32 accumulation is now much closer to the float32 reference)
- test_moe_sum_vs_aot: split into two tests:
  - test_moe_sum_vs_aot_fp32: float32 dtype only for topk=2,3,4,8
    (AOT bf16 accum == JIT fp32 accum for float32 inputs)
  - test_moe_sum_vs_aot_large_topk: fp16/bf16 for topk=8
    (AOT uses at::sum_out fp32, matches JIT fp32 exactly)

Benchmark: calculate_diff() now compares JIT against float32 reference
instead of AOT to avoid false mismatches from AOT's inconsistent precision.

Note: the tokens=1, hidden=7168, topk=8 performance regression vs AOT is
expected — AOT uses at::sum_out which is better optimized for the degenerate
single-token case. For all realistic batch sizes (64+) JIT is faster.
…den dim

Previously the grid was (num_tokens,) with one block per token, leaving
tokens=1 with only 1 block (1 SM) for hidden=7168 and causing a ~54%
regression vs AOT's at::sum_out which launches ~7 blocks for the same case.

Fix: decompose the grid as (num_tokens * hidden_blocks,) where
hidden_blocks = ceil(hidden_size / 1024), so each block handles a 1024-wide
slice of the hidden dimension. Each thread now processes exactly one output
element, eliminating the stride loop and maximizing parallelism.

Result for tokens=1, hidden=7168, topk=8:
  Before: JIT=3.29µs  AOT=2.14µs  (54% regression)
  After:  JIT=1.44µs  AOT=2.15µs  (33% faster than AOT)
LaunchKernel uses cudaLaunchKernelEx with RuntimeDeviceCheck for automatic
error checking after each kernel launch, consistent with other JIT kernels.
…blocks and stride loop

Two changes to moe_sum.cuh:

1. Adaptive hidden_blocks: cap total grid to TARGET_MAX_BLOCKS (256) so large
   token counts get hidden_blocks=1 (each thread strides over the full hidden
   slice) while small token counts keep full 2-D decomposition for parallelism.

2. Stride loop in both kernels: each thread loops over its hidden slice with
   stride = hidden_blocks * blockDim.x, so work is correct for any hidden_blocks.
   When hidden_blocks=1 this is identical to the original 1-D stride loop.

Previously the 2D grid (always hidden_blocks=ceil(hidden/1024)) caused 20-68%
slowdown at large token counts (tokens≥128, hidden=4096, topk=2) because each
thread had only 2 loads, insufficient to hide memory latency.  After this fix:
- topk=2: 5-13% residual gap (down from 20-68%)
- topk=4: at parity with AOT
- topk=8: 7-32% faster than AOT across all token counts

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@Johnsonms Johnsonms marked this pull request as draft February 20, 2026 23:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants