Skip to content

feat: Add cuBLASLt backend for mm_bf16 and enable multi-tactic autotuning for FP8/MXFP8 runners#2914

Open
vadiklyutiy wants to merge 24 commits intoflashinfer-ai:mainfrom
vadiklyutiy:mm-bf16-cublaslt
Open

feat: Add cuBLASLt backend for mm_bf16 and enable multi-tactic autotuning for FP8/MXFP8 runners#2914
vadiklyutiy wants to merge 24 commits intoflashinfer-ai:mainfrom
vadiklyutiy:mm-bf16-cublaslt

Conversation

@vadiklyutiy
Copy link
Copy Markdown
Contributor

@vadiklyutiy vadiklyutiy commented Mar 30, 2026

Summary

  • Add cuBLASLt backend for mm_bf16: new backend="cublaslt" option (gated to SM100/SM103). Autotuning across all available cuBLASLt algorithms via get_valid_tactics().
  • Enable multi-tactic autotuning for single-tactic GEMM runners: CublasFp8GemmRunner, CudnnFp8GemmRunner, and CudnnMxfp8GemmRunner previously hardcoded a single tactic (return [0] / return [-1]), preventing the autotuner from exploring better algorithms (the same are done for FP4 and FP16). Now all three enumerate available algorithms/plans via get_valid_tactics() and pass the selected tactic through to execution.
  • Improve test coverage: added cublaslt and auto backends + auto_tuning parameter to test_mm_bf16.py, auto backend to test_bmm_bf16.py, and a dedicated edge-case test for zero-algorithm handling.

Test Results

All GEMM/BMM test suites pass with 0 failures:

Test File Passed Failed
test_mm_bf16.py 1441 0
test_mm_fp8.py 30 0
test_mm_fp4.py 1440 0
test_mm_mxfp8.py 1843 0
test_bmm_bf16.py 144 0
test_bmm_fp8.py 1188 0
test_bmm_mxfp8.py 288 0
Total 6374 0

Test additions:

  • test_mm_bf16.py: added cublaslt and auto to backend parametrize, added auto_tuning parameter, added test_cublaslt_bf16_runner_zero_algos edge-case test.
  • test_bmm_bf16.py: added auto backend.

Summary by CodeRabbit

  • New Features

    • Added a cuBLASLt BF16 GEMM backend with selectable algorithm tactics; module generated for supported hardware.
    • Exposed APIs to enumerate and run FP8/BF16 GEMM algorithms for explicit tactic selection.
    • Expanded backend choices with "cublaslt" and broader "auto" routing.
  • Bug Fixes

    • Robustified autotuner hashing to avoid failures with unhashable attributes.
  • Tests

    • Extended tests for new backends, autotuning options, auto routing, and zero-algorithm failure handling.

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
…8 GEMM runners

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Mar 30, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds cuBLASLt-backed BF16 GEMM and FP8 BMM algorithm enumeration with serialized algo buffers and tactic-based execution, exposes new FFI bindings and JIT/AOT module for cublasLt BF16 GEMM, extends CLI/tests with cublaslt/auto, and hardens autotuner hashing.

Changes

Cohort / File(s) Summary
CLI & Tests
benchmarks/routines/gemm.py, tests/gemm/test_bmm_bf16.py, tests/gemm/test_mm_bf16.py
Added cublaslt and auto backend choices, added auto_tuning test dimension, adjusted backend skip logic, and added a zero-algorithms test for the cublaslt BF16 runner.
FP8 BMM runtime & bindings
csrc/bmm_fp8.cu, csrc/flashinfer_gemm_binding.cu, include/flashinfer/gemm/bmm_fp8.cuh
Changed workspace length units to bytes for runtime calls; added algorithm enumeration API (bmm_fp8_get_algos) and run-with-algo API (bmm_fp8_run_with_algo); added cuBLASLt FP8 descriptor helpers and serialized algo storage.
BF16 cuBLASLt implementation
csrc/mm_bf16_cublaslt.cu, include/flashinfer/gemm/mm_bf16_cublaslt.cuh
New cuBLASLt-backed BF16 GEMM: descriptor factories, algorithm enumeration (get_algorithms), serialized algo buffer format, run-with-algo execution, and FFI-exported getter/runner functions.
GEMM routing & runners
flashinfer/gemm/gemm_base.py
Added cublaslt backend and SM100/SM103 gating; implemented cublasLt BF16 runner with algo enumeration/cache and tactic-based execution; updated FP8 and cuDNN runners to enumerate and accept explicit tactics.
JIT / AOT module generation
flashinfer/jit/gemm/core.py, flashinfer/jit/gemm/__init__.py, flashinfer/aot.py
Added gen_mm_bf16_cublaslt_module, exported it, and wired it into AOT/JIT generation for SM100/SM103 with -lcublas/-lcublasLt linking.
Autotuner robustness
flashinfer/autotuner.py
Made TunableRunner.__hash__ resilient to unhashable attributes by skipping *_cache fields and falling back to id(v) for unhashable values.
FFI exports
csrc/flashinfer_gemm_binding.cu
Registered new FFI exports: bmm_fp8_get_algos, bmm_fp8_run_with_algo, mm_bf16_cublaslt_get_algos, and mm_bf16_cublaslt_run_with_algo.

Sequence Diagram

sequenceDiagram
    participant App as Application / Test
    participant Runner as GEMM/BMM Runner
    participant Cache as Algo Cache (CPU)
    participant Enumerator as Algo Enumerator (FFI)
    participant cuBLASLt as cuBLASLt (Heuristics & Executor)

    App->>Runner: forward(inputs, tactic? / auto_tune?)
    Runner->>Cache: lookup(shape, dtype)
    alt cache hit
        Cache-->>Runner: algo_buffer, count
    else cache miss
        Runner->>Enumerator: get_algorithms(A,B,workspace) 
        Enumerator->>cuBLASLt: query heuristics (workspace limit)
        cuBLASLt-->>Enumerator: [algo_t ...]
        Enumerator->>Enumerator: serialize algos -> algo_buffer (CPU)
        Enumerator-->>Runner: algo_count, algo_buffer
        Runner->>Cache: store(shape,dtype)->algo_buffer
    end
    alt tactic >= 0 or autotune loop
        loop try tactics 0..N-1
            Runner->>cuBLASLt: run_with_algo(algo_buf, idx, workspace, stream)
            cuBLASLt-->>Runner: status/result
        end
        Runner->>App: output
    else default
        Runner->>cuBLASLt: run_with_algo(algo_buf, idx=0)
        cuBLASLt-->>Runner: result
        Runner-->>App: output
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested labels

op: gemm, run-ci

Suggested reviewers

  • bkryu
  • nvmbreughe
  • jimmyzho
  • yzh119
  • jiahanc
  • yongwww
  • cyx-6

Poem

🐇 I bounded through code with a curious twitch,
I cached all the algos inside my little niche.
cuBLASLt hummed, I picked which to try,
I hop, I run — the fastest one flies high. ✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 16.95% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically describes the main changes: adding cuBLASLt backend for mm_bf16 and enabling multi-tactic autotuning for FP8/MXFP8 runners, matching the core objectives.
Description check ✅ Passed The PR description provides a clear summary section covering all major changes, includes comprehensive test results showing 6374 tests passed with 0 failures, and documents specific test additions; however, it lacks the required pre-commit checks, related issues link, and reviewer notes sections from the template.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new cublaslt backend for BF16 GEMM operations, enabling heuristic algorithm selection and caching to minimize runtime overhead. It also extends the FP8 BMM implementation with similar algorithm selection capabilities and updates the autotuner to handle non-hashable values. Review feedback correctly identified a cache key collision risk in the algorithm caching logic and a shape mismatch in the new test cases for non-square matrices.

Comment thread flashinfer/gemm/gemm_base.py Outdated
Comment thread tests/gemm/test_mm_bf16.py Outdated
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
@vadiklyutiy
Copy link
Copy Markdown
Contributor Author

/gemini review

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/aot.py (1)

490-498: ⚠️ Potential issue | 🟠 Major

Register the new AOT module for SM103 too.

Line 497 is currently nested under if has_sm100:, so an AOT build targeting only compute_103 never packages mm_bf16_cublaslt even though this PR adds that backend for SM103 as well. has_sm103 is already computed in this function, so the new append needs its own if has_sm100 or has_sm103: guard.

♻️ Proposed change
         if has_sm100:
             jit_specs.append(gen_fp4_quantization_sm100_module())
             jit_specs.append(gen_cutlass_fused_moe_sm100_module())
             jit_specs.append(gen_gemm_sm100_module())
             jit_specs.append(gen_gemm_sm100_module_cutlass_fp4())
             jit_specs.append(gen_gemm_sm100_module_cutlass_fp8())
             jit_specs.append(gen_gemm_sm100_module_cutlass_mxfp8())
-            jit_specs.append(gen_mm_bf16_cublaslt_module())
+        if has_sm100 or has_sm103:
+            jit_specs.append(gen_mm_bf16_cublaslt_module())
+        if has_sm100:
             # Add TGV GEMM modules for both bf16 and fp16
             jit_specs.append(
                 gen_tgv_gemm_sm10x_module(torch.bfloat16, use_sm_100f=False)
             )

As per coding guidelines, flashinfer/aot.py should "Register new operations in flashinfer/aot.py for AOT (Ahead-of-Time) compilation into pre-compiled packages".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/aot.py` around lines 490 - 498, The mm_bf16 cublaslt AOT module
registration is incorrectly placed inside the has_sm100-only block so builds
targeting compute_103 skip it; change the logic around
jit_specs.append(gen_mm_bf16_cublaslt_module()) so it is executed when either
has_sm100 or has_sm103 is true (i.e., wrap or move that append under an if
has_sm100 or has_sm103 guard), keeping the other SM100-only appends
(gen_fp4_quantization_sm100_module, gen_cutlass_fused_moe_sm100_module,
gen_gemm_sm100_module*, etc.) unchanged.
🧹 Nitpick comments (1)
flashinfer/jit/gemm/core.py (1)

53-60: Scope this JIT spec to SM10x builds.

The new backend is SM100/SM103-gated, but Lines 54-59 don't pass any arch-scoped NVCC flags. Mixed-arch builds will compile/package mm_bf16_cublaslt for every target in FLASHINFER_CUDA_ARCH_LIST, unlike the other GEMM generators in this file.

♻️ Proposed change
 def gen_mm_bf16_cublaslt_module() -> JitSpec:
+    nvcc_flags = current_compilation_context.get_nvcc_flags_list(
+        supported_major_versions=[10]
+    )
     return gen_jit_spec(
         "mm_bf16_cublaslt",
         [
             jit_env.FLASHINFER_CSRC_DIR / "mm_bf16_cublaslt.cu",
         ],
+        extra_cuda_cflags=nvcc_flags,
         extra_ldflags=["-lcublas", "-lcublasLt"],
     )

As per coding guidelines, flashinfer/jit/**/*.py should "Specify supported NVIDIA SM major versions in JIT modules using supported_major_versions parameter to limit compilation to specific GPU architectures".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/jit/gemm/core.py` around lines 53 - 60, The JIT spec
gen_mm_bf16_cublaslt_module currently calls gen_jit_spec without arch scoping,
causing mixed-arch builds; update the gen_jit_spec invocation in
gen_mm_bf16_cublaslt_module to pass supported_major_versions=[10] (or the list
containing SM10x major version) so the module is only compiled for
SM100/SM103-class GPUs (refer to gen_mm_bf16_cublaslt_module, gen_jit_spec, and
JitSpec to locate the change).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@csrc/mm_bf16_cublaslt.cu`:
- Around line 56-76: Reject invalid tensor residency and dtype before crossing
host/device: ensure algo_buffer is host (CPU) memory, contiguous, and uint8
(check algo_buffer.device().is_cpu() and algo_buffer.dtype()==torch::kUInt8 and
keep CHECK_CONTIGUOUS(algo_buffer)), and ensure workspace_buffer is CUDA device
memory (check workspace_buffer.device().is_cuda()) before passing
algo_buffer.data_ptr() to host-side memcpy helpers or
workspace_buffer.data_ptr() to cublasLt calls; use the same TVM_FFI_ICHECK (or
TVM_FFI_ICHECK_EQ) style used for other checks to return clear errors, and keep
get_algorithms / cublasLtMatmul calls unchanged otherwise so pointers are safe.

In `@flashinfer/gemm/gemm_base.py`:
- Around line 980-983: The BF16 cuBLASLt algorithm cache key in _get_algos is
using b.shape[0] (K) twice, so N is never included and different N values
collide; change the key construction in _get_algos to use b.shape[1] for N
(since mm_bf16 receives b in (K, N) layout) and make the same correction in the
other equivalent cache-key site (the second occurrence around the
mm_bf16-related logic) so the tuple becomes (M, N, K, compute_dt) (or equivalent
ordering used elsewhere) to uniquely key by N.
- Around line 165-179: The current path assumes at least one algorithm exists
and forces tactic=0 when tactic >= count, but if self._get_algos(inputs) returns
count==0 you must avoid calling module.bmm_fp8_run_with_algo with an
uninitialized algo buffer; update the block around self._get_algos(inputs) to
check if count == 0 and handle it (e.g., raise a clear RuntimeError or fall back
to the non-algo execution path) before computing/adjusting tactic and before
calling module.bmm_fp8_run_with_algo; refer to _get_algos and
module.bmm_fp8_run_with_algo (and the local variable tactic) to locate where to
add the guard.

---

Outside diff comments:
In `@flashinfer/aot.py`:
- Around line 490-498: The mm_bf16 cublaslt AOT module registration is
incorrectly placed inside the has_sm100-only block so builds targeting
compute_103 skip it; change the logic around
jit_specs.append(gen_mm_bf16_cublaslt_module()) so it is executed when either
has_sm100 or has_sm103 is true (i.e., wrap or move that append under an if
has_sm100 or has_sm103 guard), keeping the other SM100-only appends
(gen_fp4_quantization_sm100_module, gen_cutlass_fused_moe_sm100_module,
gen_gemm_sm100_module*, etc.) unchanged.

---

Nitpick comments:
In `@flashinfer/jit/gemm/core.py`:
- Around line 53-60: The JIT spec gen_mm_bf16_cublaslt_module currently calls
gen_jit_spec without arch scoping, causing mixed-arch builds; update the
gen_jit_spec invocation in gen_mm_bf16_cublaslt_module to pass
supported_major_versions=[10] (or the list containing SM10x major version) so
the module is only compiled for SM100/SM103-class GPUs (refer to
gen_mm_bf16_cublaslt_module, gen_jit_spec, and JitSpec to locate the change).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: cbf109a0-b993-446a-8aa7-c78df2ecc3d2

📥 Commits

Reviewing files that changed from the base of the PR and between 4941606 and 4616733.

📒 Files selected for processing (13)
  • benchmarks/routines/gemm.py
  • csrc/bmm_fp8.cu
  • csrc/flashinfer_gemm_binding.cu
  • csrc/mm_bf16_cublaslt.cu
  • flashinfer/aot.py
  • flashinfer/autotuner.py
  • flashinfer/gemm/gemm_base.py
  • flashinfer/jit/gemm/__init__.py
  • flashinfer/jit/gemm/core.py
  • include/flashinfer/gemm/bmm_fp8.cuh
  • include/flashinfer/gemm/mm_bf16_cublaslt.cuh
  • tests/gemm/test_bmm_bf16.py
  • tests/gemm/test_mm_bf16.py

Comment thread csrc/mm_bf16_cublaslt.cu
Comment thread flashinfer/autotuner.py
Comment thread flashinfer/gemm/gemm_base.py
Comment thread flashinfer/gemm/gemm_base.py Outdated
@vadiklyutiy
Copy link
Copy Markdown
Contributor Author

[FAILED] Pipeline #47392419: 6/20 passed

I looked in several fails. It looks that fails don't relate to this PR - fails in norm, quantisation, topk, etc

@dhiraj113 How do you usually process in such cases?

@dhiraj113
Copy link
Copy Markdown
Collaborator

@vadiklyutiy Could you resolve the conflicts and rebase/merge with main branch again. There is a lot of diff that is showing that I don't think is from your PR.

@vadiklyutiy
Copy link
Copy Markdown
Contributor Author

@vadiklyutiy Could you resolve the conflicts and rebase/merge with main branch again. There is a lot of diff that is showing that I don't think is from your PR.

done

Comment thread benchmarks/routines/gemm.py
Comment thread include/flashinfer/gemm/bmm_fp8.cuh
Comment thread include/flashinfer/gemm/mm_bf16_cublaslt.cuh
Comment thread flashinfer/autotuner.py
Comment thread flashinfer/gemm/gemm_base.py Outdated
Comment thread flashinfer/gemm/gemm_base.py Outdated
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Comment thread flashinfer/gemm/gemm_base.py Outdated
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Comment thread flashinfer/gemm/gemm_base.py Outdated
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
@dhiraj113
Copy link
Copy Markdown
Collaborator

@gemini-code-assist - Can you summarize the code review and changes done so far

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

This pull request expands the GEMM backend capabilities by integrating cuBLASLt for BF16 operations and enabling more flexible algorithm exploration for FP8/MXFP8 runners. By allowing the autotuner to enumerate and select from multiple tactics rather than relying on hardcoded defaults, the system can achieve better performance across diverse hardware configurations. The changes also include necessary FFI bindings, JIT/AOT module updates, and improvements to the autotuner's stability.

Highlights

  • cuBLASLt Backend for BF16 GEMM: Introduced a new cuBLASLt-backed BF16 GEMM implementation, specifically gated for SM100/SM103 hardware, featuring heuristic-based algorithm selection and multi-tactic autotuning.
  • Enhanced Autotuning for FP8/MXFP8: Enabled multi-tactic autotuning for single-tactic GEMM runners (CublasFp8GemmRunner, CudnnFp8GemmRunner, and CudnnMxfp8GemmRunner) by exposing APIs to enumerate and execute available algorithms.
  • Autotuner Robustness: Improved the TunableRunner hashing mechanism to ignore unhashable cache attributes, ensuring consistent autotuner behavior across different runner instances.
  • Expanded Test Coverage: Added comprehensive test cases for the new cuBLASLt backend, including auto-routing, autotuning parameters, and edge-case handling for zero-algorithm scenarios.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Activity
  • Pull request created by vadiklyutiy.
  • Automated review paused and resumed by CodeRabbit.
  • Initial review comments provided by gemini-code-assist and coderabbitai regarding cache keys, tensor residency, and device handling.
  • Author addressed review feedback, including fixing cache key logic, adding device guards for CUDA operations, and correcting precision handling for FP16 outputs.
  • CI pipeline failures were investigated and determined to be unrelated to the PR changes.
  • PR rebased and conflicts resolved against the main branch.
  • Final review comments addressed regarding precision and documentation.

Copy link
Copy Markdown
Collaborator

@dhiraj113 dhiraj113 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution.

@vadiklyutiy
Copy link
Copy Markdown
Contributor Author

@dhiraj113 could you pls advise how we can process with merging?

@dhiraj113 dhiraj113 enabled auto-merge (squash) April 16, 2026 21:40
@vadiklyutiy
Copy link
Copy Markdown
Contributor Author

@dhiraj113 could you help me to understand why auto-merge haven't worked?

auto-merge was automatically disabled April 22, 2026 00:43

Head branch was pushed to by a user without write access

@dhiraj113 dhiraj113 enabled auto-merge (squash) April 22, 2026 00:48
auto-merge was automatically disabled April 22, 2026 00:59

Head branch was pushed to by a user without write access

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants