Skip to content

Migrate bmm_fp8 from AOT cuBLASLt to flashinfer JIT#18999

Open
Johnsonms wants to merge 6 commits intosgl-project:mainfrom
Johnsonms:bmm-fp8
Open

Migrate bmm_fp8 from AOT cuBLASLt to flashinfer JIT#18999
Johnsonms wants to merge 6 commits intosgl-project:mainfrom
Johnsonms:bmm-fp8

Conversation

@Johnsonms
Copy link
Contributor

@Johnsonms Johnsonms commented Feb 19, 2026

Motivation

PR: #17865 (comment)
sgl-kernel shipped its own AOT (ahead-of-time) cuBLASLt wrapper for batched
FP8 matrix multiplication (bmm_fp8). FlashInfer already provides
flashinfer.bmm_fp8 via its JIT system with an identical public signature
and additional backend options (cublas / cutlass / cudnn / auto). Keeping a
separate C++ implementation adds maintenance burden with no measurable
benefit.

This PR removes the redundant AOT kernel and routes sgl_kernel.bmm_fp8
directly to flashinfer.bmm_fp8, reducing ~200 lines of C++/CUDA code.

Modifications

  • Deleted sgl-kernel/csrc/gemm/bmm_fp8.cu — cuBLASLt CUDA wrapper
  • Removed bmm_fp8 entry from sgl-kernel/CMakeLists.txt
  • Removed bmm_fp8 schema registration and m.impl from
    sgl-kernel/csrc/common_extension.cc
  • Removed void bmm_fp8(...) declaration from
    sgl-kernel/include/sgl_kernel_ops.h
  • Simplified sgl_kernel.gemm.bmm_fp8 to delegate directly to
    flashinfer.bmm_fp8 — public API signature is unchanged
  • Added python/sglang/jit_kernel/tests/test_bmm_fp8.py — parametrized
    correctness tests (FP8 dtypes × output dtypes × shapes)
  • Added python/sglang/jit_kernel/benchmark/bench_bmm_fp8.py — latency
    benchmark using triton.testing.perf_report

Accuracy Tests

image
  Correctness validated against torch.bmm float32 reference across all
  parametrized combinations (48 passed, 16 skipped for invalid e5m2×e5m2):

  pytest python/sglang/jit_kernel/tests/test_bmm_fp8.py -v -s
  48 passed, cosine similarity > 0.99 for all valid dtype combinations

Benchmarking and Profiling

image
  Benchmark run comparing flashinfer JIT vs the original sgl-kernel AOT
  cuBLASLt across batch sizes 1–16, shapes M∈{64,128,512}, K∈{128,512,1024},
  N∈{256,512,1024}:

  - No regression observed — latency difference <1% across all 27 shape ×
  batch combinations
  - Both implementations are thin wrappers over the same underlying cuBLASLt
  path

  Backup branch with the two-provider benchmark harness is available at
  bmm-fp8-backup for reference.

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @Johnsonms, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request streamlines the sgl-kernel codebase by eliminating a redundant Ahead-Of-Time (AOT) cuBLASLt wrapper for batched FP8 matrix multiplication (bmm_fp8). Instead, it now delegates calls to the existing and more robust flashinfer.bmm_fp8 implementation. This change reduces maintenance overhead and leverages a well-tested external dependency, all while ensuring no regression in accuracy or performance, as validated by new comprehensive tests and benchmarks.

Highlights

  • Code Simplification: Removed the custom Ahead-Of-Time (AOT) cuBLASLt wrapper for batched FP8 matrix multiplication (bmm_fp8) from sgl-kernel.
  • Dependency Delegation: Rerouted sgl_kernel.bmm_fp8 to directly utilize the existing and more versatile flashinfer.bmm_fp8 implementation, maintaining the public API signature.
  • Testing: Introduced comprehensive correctness tests for bmm_fp8 covering various FP8 dtypes, output dtypes, and shapes.
  • Benchmarking: Added latency benchmarks for bmm_fp8 to compare performance between the original and new implementations, confirming no regression.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/jit_kernel/benchmark/bench_bmm_fp8.py
    • Added a new benchmark script for bmm_fp8 performance using triton.testing.perf_report.
  • python/sglang/jit_kernel/tests/test_bmm_fp8.py
    • Added a new test suite for bmm_fp8 to validate correctness against torch.bmm reference across various configurations.
  • sgl-kernel/CMakeLists.txt
    • Removed the csrc/gemm/bmm_fp8.cu entry from the list of source files, removing it from the build configuration.
  • sgl-kernel/csrc/common_extension.cc
    • Removed the bmm_fp8 function definition and its m.impl registration from the common extension.
  • sgl-kernel/csrc/gemm/bmm_fp8.cu
    • Deleted the custom cuBLASLt CUDA wrapper file for bmm_fp8.
  • sgl-kernel/include/sgl_kernel_ops.h
    • Removed the declaration for the bmm_fp8 function.
  • sgl-kernel/python/sgl_kernel/gemm.py
    • Removed the internal _bmm_fp8_internal function and its related _get_cache_buf import.
    • Modified the bmm_fp8 function to directly call flashinfer.bmm_fp8.
  • sgl-kernel/tests/test_bmm_fp8.py
    • Removed the old test file for bmm_fp8.
Activity
  • Accuracy tests were performed, validating correctness against torch.bmm float32 reference across various FP8 dtypes, output dtypes, and shapes. 48 combinations passed with a cosine similarity > 0.99.
  • Benchmarking was conducted comparing flashinfer JIT against the original sgl-kernel AOT cuBLASLt across different batch sizes and matrix shapes, showing no regression in latency (difference <1%).
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request is a great simplification, removing a custom C++/CUDA implementation of bmm_fp8 and replacing it with the equivalent function from the flashinfer library. This change effectively reduces code duplication and maintenance overhead. The removal of the old C++ source and header files, along with the corresponding build system and registration code, is done correctly. The addition of new, more comprehensive tests and benchmarks is a valuable improvement that ensures the correctness and performance of the new implementation. I have one minor suggestion regarding Python import conventions.

@Johnsonms Johnsonms changed the title Bmm fp8 Migrate bmm_fp8 from AOT cuBLASLt to flashinfer JIT Feb 19, 2026
…enchmark and tests

- sgl_kernel/gemm.py: add backend= param to bmm_fp8(); default keeps AOT
  cuBLASLt path, backend="flashinfer" delegates to flashinfer.bmm_fp8()
- python/sglang/jit_kernel/tests/test_bmm_fp8.py: correctness tests for
  flashinfer JIT vs float reference and vs sgl_kernel AOT (cos_sim checks)
- python/sglang/jit_kernel/benchmark/bench_bmm_fp8.py: perf comparison
  sgl_kernel (AOT cuBLASLt) vs flashinfer (JIT); no regression confirmed
Remove the sgl-kernel C++ bmm_fp8 implementation (cuBLASLt wrapper)
and delegate sgl_kernel.bmm_fp8 directly to flashinfer.bmm_fp8.
Benchmarks showed <1% latency difference across all tested shapes.

Changes:
- Delete sgl-kernel/csrc/gemm/bmm_fp8.cu
- Remove bmm_fp8 from CMakeLists.txt, common_extension.cc, sgl_kernel_ops.h
- Simplify sgl_kernel.gemm.bmm_fp8 to call flashinfer.bmm_fp8
- Move test + benchmark to python/sglang/jit_kernel/{tests,benchmark}/
- Update benchmark to flashinfer-only (sgl_kernel provider removed)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant

Comments