feat: Add cuBLASLt backend for mm_bf16 and enable multi-tactic autotuning for FP8/MXFP8 runners#2914
feat: Add cuBLASLt backend for mm_bf16 and enable multi-tactic autotuning for FP8/MXFP8 runners#2914vadiklyutiy wants to merge 24 commits intoflashinfer-ai:mainfrom
mm_bf16 and enable multi-tactic autotuning for FP8/MXFP8 runners#2914Conversation
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
…8 GEMM runners Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds cuBLASLt-backed BF16 GEMM and FP8 BMM algorithm enumeration with serialized algo buffers and tactic-based execution, exposes new FFI bindings and JIT/AOT module for cublasLt BF16 GEMM, extends CLI/tests with Changes
Sequence DiagramsequenceDiagram
participant App as Application / Test
participant Runner as GEMM/BMM Runner
participant Cache as Algo Cache (CPU)
participant Enumerator as Algo Enumerator (FFI)
participant cuBLASLt as cuBLASLt (Heuristics & Executor)
App->>Runner: forward(inputs, tactic? / auto_tune?)
Runner->>Cache: lookup(shape, dtype)
alt cache hit
Cache-->>Runner: algo_buffer, count
else cache miss
Runner->>Enumerator: get_algorithms(A,B,workspace)
Enumerator->>cuBLASLt: query heuristics (workspace limit)
cuBLASLt-->>Enumerator: [algo_t ...]
Enumerator->>Enumerator: serialize algos -> algo_buffer (CPU)
Enumerator-->>Runner: algo_count, algo_buffer
Runner->>Cache: store(shape,dtype)->algo_buffer
end
alt tactic >= 0 or autotune loop
loop try tactics 0..N-1
Runner->>cuBLASLt: run_with_algo(algo_buf, idx, workspace, stream)
cuBLASLt-->>Runner: status/result
end
Runner->>App: output
else default
Runner->>cuBLASLt: run_with_algo(algo_buf, idx=0)
cuBLASLt-->>Runner: result
Runner-->>App: output
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces a new cublaslt backend for BF16 GEMM operations, enabling heuristic algorithm selection and caching to minimize runtime overhead. It also extends the FP8 BMM implementation with similar algorithm selection capabilities and updates the autotuner to handle non-hashable values. Review feedback correctly identified a cache key collision risk in the algorithm caching logic and a shape mismatch in the new test cases for non-square matrices.
|
/gemini review |
There was a problem hiding this comment.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/aot.py (1)
490-498:⚠️ Potential issue | 🟠 MajorRegister the new AOT module for SM103 too.
Line 497 is currently nested under
if has_sm100:, so an AOT build targeting onlycompute_103never packagesmm_bf16_cublaslteven though this PR adds that backend for SM103 as well.has_sm103is already computed in this function, so the new append needs its ownif has_sm100 or has_sm103:guard.♻️ Proposed change
if has_sm100: jit_specs.append(gen_fp4_quantization_sm100_module()) jit_specs.append(gen_cutlass_fused_moe_sm100_module()) jit_specs.append(gen_gemm_sm100_module()) jit_specs.append(gen_gemm_sm100_module_cutlass_fp4()) jit_specs.append(gen_gemm_sm100_module_cutlass_fp8()) jit_specs.append(gen_gemm_sm100_module_cutlass_mxfp8()) - jit_specs.append(gen_mm_bf16_cublaslt_module()) + if has_sm100 or has_sm103: + jit_specs.append(gen_mm_bf16_cublaslt_module()) + if has_sm100: # Add TGV GEMM modules for both bf16 and fp16 jit_specs.append( gen_tgv_gemm_sm10x_module(torch.bfloat16, use_sm_100f=False) )As per coding guidelines,
flashinfer/aot.pyshould "Register new operations inflashinfer/aot.pyfor AOT (Ahead-of-Time) compilation into pre-compiled packages".🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/aot.py` around lines 490 - 498, The mm_bf16 cublaslt AOT module registration is incorrectly placed inside the has_sm100-only block so builds targeting compute_103 skip it; change the logic around jit_specs.append(gen_mm_bf16_cublaslt_module()) so it is executed when either has_sm100 or has_sm103 is true (i.e., wrap or move that append under an if has_sm100 or has_sm103 guard), keeping the other SM100-only appends (gen_fp4_quantization_sm100_module, gen_cutlass_fused_moe_sm100_module, gen_gemm_sm100_module*, etc.) unchanged.
🧹 Nitpick comments (1)
flashinfer/jit/gemm/core.py (1)
53-60: Scope this JIT spec to SM10x builds.The new backend is SM100/SM103-gated, but Lines 54-59 don't pass any arch-scoped NVCC flags. Mixed-arch builds will compile/package
mm_bf16_cublasltfor every target inFLASHINFER_CUDA_ARCH_LIST, unlike the other GEMM generators in this file.♻️ Proposed change
def gen_mm_bf16_cublaslt_module() -> JitSpec: + nvcc_flags = current_compilation_context.get_nvcc_flags_list( + supported_major_versions=[10] + ) return gen_jit_spec( "mm_bf16_cublaslt", [ jit_env.FLASHINFER_CSRC_DIR / "mm_bf16_cublaslt.cu", ], + extra_cuda_cflags=nvcc_flags, extra_ldflags=["-lcublas", "-lcublasLt"], )As per coding guidelines,
flashinfer/jit/**/*.pyshould "Specify supported NVIDIA SM major versions in JIT modules usingsupported_major_versionsparameter to limit compilation to specific GPU architectures".🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/jit/gemm/core.py` around lines 53 - 60, The JIT spec gen_mm_bf16_cublaslt_module currently calls gen_jit_spec without arch scoping, causing mixed-arch builds; update the gen_jit_spec invocation in gen_mm_bf16_cublaslt_module to pass supported_major_versions=[10] (or the list containing SM10x major version) so the module is only compiled for SM100/SM103-class GPUs (refer to gen_mm_bf16_cublaslt_module, gen_jit_spec, and JitSpec to locate the change).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@csrc/mm_bf16_cublaslt.cu`:
- Around line 56-76: Reject invalid tensor residency and dtype before crossing
host/device: ensure algo_buffer is host (CPU) memory, contiguous, and uint8
(check algo_buffer.device().is_cpu() and algo_buffer.dtype()==torch::kUInt8 and
keep CHECK_CONTIGUOUS(algo_buffer)), and ensure workspace_buffer is CUDA device
memory (check workspace_buffer.device().is_cuda()) before passing
algo_buffer.data_ptr() to host-side memcpy helpers or
workspace_buffer.data_ptr() to cublasLt calls; use the same TVM_FFI_ICHECK (or
TVM_FFI_ICHECK_EQ) style used for other checks to return clear errors, and keep
get_algorithms / cublasLtMatmul calls unchanged otherwise so pointers are safe.
In `@flashinfer/gemm/gemm_base.py`:
- Around line 980-983: The BF16 cuBLASLt algorithm cache key in _get_algos is
using b.shape[0] (K) twice, so N is never included and different N values
collide; change the key construction in _get_algos to use b.shape[1] for N
(since mm_bf16 receives b in (K, N) layout) and make the same correction in the
other equivalent cache-key site (the second occurrence around the
mm_bf16-related logic) so the tuple becomes (M, N, K, compute_dt) (or equivalent
ordering used elsewhere) to uniquely key by N.
- Around line 165-179: The current path assumes at least one algorithm exists
and forces tactic=0 when tactic >= count, but if self._get_algos(inputs) returns
count==0 you must avoid calling module.bmm_fp8_run_with_algo with an
uninitialized algo buffer; update the block around self._get_algos(inputs) to
check if count == 0 and handle it (e.g., raise a clear RuntimeError or fall back
to the non-algo execution path) before computing/adjusting tactic and before
calling module.bmm_fp8_run_with_algo; refer to _get_algos and
module.bmm_fp8_run_with_algo (and the local variable tactic) to locate where to
add the guard.
---
Outside diff comments:
In `@flashinfer/aot.py`:
- Around line 490-498: The mm_bf16 cublaslt AOT module registration is
incorrectly placed inside the has_sm100-only block so builds targeting
compute_103 skip it; change the logic around
jit_specs.append(gen_mm_bf16_cublaslt_module()) so it is executed when either
has_sm100 or has_sm103 is true (i.e., wrap or move that append under an if
has_sm100 or has_sm103 guard), keeping the other SM100-only appends
(gen_fp4_quantization_sm100_module, gen_cutlass_fused_moe_sm100_module,
gen_gemm_sm100_module*, etc.) unchanged.
---
Nitpick comments:
In `@flashinfer/jit/gemm/core.py`:
- Around line 53-60: The JIT spec gen_mm_bf16_cublaslt_module currently calls
gen_jit_spec without arch scoping, causing mixed-arch builds; update the
gen_jit_spec invocation in gen_mm_bf16_cublaslt_module to pass
supported_major_versions=[10] (or the list containing SM10x major version) so
the module is only compiled for SM100/SM103-class GPUs (refer to
gen_mm_bf16_cublaslt_module, gen_jit_spec, and JitSpec to locate the change).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: cbf109a0-b993-446a-8aa7-c78df2ecc3d2
📒 Files selected for processing (13)
benchmarks/routines/gemm.pycsrc/bmm_fp8.cucsrc/flashinfer_gemm_binding.cucsrc/mm_bf16_cublaslt.cuflashinfer/aot.pyflashinfer/autotuner.pyflashinfer/gemm/gemm_base.pyflashinfer/jit/gemm/__init__.pyflashinfer/jit/gemm/core.pyinclude/flashinfer/gemm/bmm_fp8.cuhinclude/flashinfer/gemm/mm_bf16_cublaslt.cuhtests/gemm/test_bmm_bf16.pytests/gemm/test_mm_bf16.py
I looked in several fails. It looks that fails don't relate to this PR - fails in norm, quantisation, topk, etc @dhiraj113 How do you usually process in such cases? |
|
@vadiklyutiy Could you resolve the conflicts and rebase/merge with |
done |
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
|
@gemini-code-assist - Can you summarize the code review and changes done so far |
Summary of ChangesThis pull request expands the GEMM backend capabilities by integrating cuBLASLt for BF16 operations and enabling more flexible algorithm exploration for FP8/MXFP8 runners. By allowing the autotuner to enumerate and select from multiple tactics rather than relying on hardcoded defaults, the system can achieve better performance across diverse hardware configurations. The changes also include necessary FFI bindings, JIT/AOT module updates, and improvements to the autotuner's stability. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Activity
|
dhiraj113
left a comment
There was a problem hiding this comment.
Thanks for the contribution.
|
@dhiraj113 could you pls advise how we can process with merging? |
|
@dhiraj113 could you help me to understand why auto-merge haven't worked? |
Head branch was pushed to by a user without write access
Head branch was pushed to by a user without write access
Summary
mm_bf16: newbackend="cublaslt"option (gated to SM100/SM103). Autotuning across all available cuBLASLt algorithms viaget_valid_tactics().CublasFp8GemmRunner,CudnnFp8GemmRunner, andCudnnMxfp8GemmRunnerpreviously hardcoded a single tactic (return [0]/return [-1]), preventing the autotuner from exploring better algorithms (the same are done for FP4 and FP16). Now all three enumerate available algorithms/plans viaget_valid_tactics()and pass the selected tactic through to execution.cublasltandautobackends +auto_tuningparameter totest_mm_bf16.py,autobackend totest_bmm_bf16.py, and a dedicated edge-case test for zero-algorithm handling.Test Results
All GEMM/BMM test suites pass with 0 failures:
test_mm_bf16.pytest_mm_fp8.pytest_mm_fp4.pytest_mm_mxfp8.pytest_bmm_bf16.pytest_bmm_fp8.pytest_bmm_mxfp8.pyTest additions:
test_mm_bf16.py: addedcublasltandautoto backend parametrize, addedauto_tuningparameter, addedtest_cublaslt_bf16_runner_zero_algosedge-case test.test_bmm_bf16.py: addedautobackend.Summary by CodeRabbit
New Features
Bug Fixes
Tests